From b625bface9c2c752b6722e24e23562c331eb5d19 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 14 Nov 2023 01:17:19 -0500 Subject: [PATCH] OAI: Add API-based model loading/unloading and auth routes Models can be loaded and unloaded via the API. Also add authentication to use the API and for administrator tasks. Both types of authorization use different keys. Also fix the unload function to properly free all used vram. Signed-off-by: kingbri --- OAI/models/models.py | 13 ---- OAI/{models => types}/common.py | 0 OAI/{models => types}/completions.py | 2 +- OAI/types/models.py | 27 +++++++ OAI/utils.py | 6 +- api_tokens.yml | 3 + auth.py | 54 ++++++++++++++ config_sample.yml | 22 ++++-- main.py | 107 ++++++++++++++++++++------- model.py | 13 ++-- utils.py | 3 + 11 files changed, 195 insertions(+), 55 deletions(-) delete mode 100644 OAI/models/models.py rename OAI/{models => types}/common.py (100%) rename OAI/{models => types}/completions.py (98%) create mode 100644 OAI/types/models.py create mode 100644 api_tokens.yml create mode 100644 auth.py create mode 100644 utils.py diff --git a/OAI/models/models.py b/OAI/models/models.py deleted file mode 100644 index 44a2f26..0000000 --- a/OAI/models/models.py +++ /dev/null @@ -1,13 +0,0 @@ -from pydantic import BaseModel, Field -from time import time -from typing import List - -class ModelCard(BaseModel): - id: str - object: str = "model" - created: int = Field(default_factory=lambda: int(time())) - owned_by: str = "tabbyAPI" - -class ModelList(BaseModel): - object: str = "list" - data: List[ModelCard] = Field(default_factory=list) diff --git a/OAI/models/common.py b/OAI/types/common.py similarity index 100% rename from OAI/models/common.py rename to OAI/types/common.py diff --git a/OAI/models/completions.py b/OAI/types/completions.py similarity index 98% rename from OAI/models/completions.py rename to OAI/types/completions.py index ba107f6..cd95ff5 100644 --- a/OAI/models/completions.py +++ b/OAI/types/completions.py @@ -2,7 +2,7 @@ from uuid import uuid4 from time import time from pydantic import BaseModel, Field from typing import List, Optional, Dict, Union -from OAI.models.common import LogProbs, UsageStats +from OAI.types.common import LogProbs, UsageStats class CompletionRespChoice(BaseModel): finish_reason: str diff --git a/OAI/types/models.py b/OAI/types/models.py new file mode 100644 index 0000000..1a2db7f --- /dev/null +++ b/OAI/types/models.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel, Field +from time import time +from typing import List, Optional + +class ModelCard(BaseModel): + id: str = "test" + object: str = "model" + created: int = Field(default_factory=lambda: int(time())) + owned_by: str = "tabbyAPI" + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + +class ModelLoadRequest(BaseModel): + name: str + max_seq_len: Optional[int] = 4096 + gpu_split: Optional[str] = "auto" + rope_scale: Optional[float] = 1.0 + rope_alpha: Optional[float] = 1.0 + no_flash_attention: Optional[bool] = False + low_mem: Optional[bool] = False + +class ModelLoadResponse(BaseModel): + module: int + modules: int + status: str diff --git a/OAI/utils.py b/OAI/utils.py index 8b76d39..c609057 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -1,7 +1,7 @@ import pathlib -from OAI.models.completions import CompletionResponse, CompletionRespChoice -from OAI.models.common import UsageStats -from OAI.models.models import ModelList, ModelCard +from OAI.types.completions import CompletionResponse, CompletionRespChoice +from OAI.types.common import UsageStats +from OAI.types.models import ModelList, ModelCard from typing import Optional def create_completion_response(text: str, index: int, model_name: Optional[str]): diff --git a/api_tokens.yml b/api_tokens.yml new file mode 100644 index 0000000..ed8a91f --- /dev/null +++ b/api_tokens.yml @@ -0,0 +1,3 @@ +!!python/object:auth.AuthKeys +admin_key: 5b9e30a4197557dcd6cf48445ee174dc +api_key: 2261702e8a220c6c4671a264cd1236ce diff --git a/auth.py b/auth.py new file mode 100644 index 0000000..38e3a48 --- /dev/null +++ b/auth.py @@ -0,0 +1,54 @@ +import secrets +import yaml +from fastapi import Header, HTTPException +from typing import Optional + +""" +This method of authorization is pretty insecure, but since TabbyAPI is a local +application, it should be fine. +""" + +class AuthKeys: + api_key: str + admin_key: str + + def __init__(self, api_key: str, admin_key: str): + self.api_key = api_key + self.admin_key = admin_key + +auth_keys: Optional[AuthKeys] = None + +def load_auth_keys(): + global auth_keys + try: + with open("api_tokens.yml", "r") as auth_file: + auth_keys = yaml.safe_load(auth_file) + except: + new_auth_keys = AuthKeys( + api_key = secrets.token_hex(16), + admin_key = secrets.token_hex(16) + ) + auth_keys = new_auth_keys + + with open("api_tokens.yml", "w") as auth_file: + yaml.dump(auth_keys, auth_file) + +def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)): + if x_api_key and x_api_key == auth_keys.api_key: + return x_api_key + elif authorization: + split_key = authorization.split(" ") + if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.api_key: + return authorization + else: + raise HTTPException(401, "Invalid API key") + +def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)): + if x_admin_key and x_admin_key == auth_keys.admin_key: + return x_admin_key + elif authorization: + split_key = authorization.split(" ") + if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.admin_key: + return authorization + else: + raise HTTPException(401, "Invalid admin key") diff --git a/config_sample.yml b/config_sample.yml index c18197d..f397cd9 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -1,8 +1,14 @@ -model_dir: "D:/models" -model_name: "this_is_a_exl2_model" -max_seq_len: 4096 -gpu_split: "auto" -rope_scale: 1.0 -rope_alpha: 1.0 -no_flash_attention: False -low_mem: False +# Network options +network: + host: "0.0.0.0" + port: 8012 +# Only used if you want to initially load a model +model: + model_dir: "D:/models" + model_name: "airoboros-mistral2.2-7b-exl2" + max_seq_len: 4096 + gpu_split: "auto" + rope_scale: 1.0 + rope_alpha: 1.0 + no_flash_attention: False + low_mem: False diff --git a/main.py b/main.py index 38c7979..c0bb54c 100644 --- a/main.py +++ b/main.py @@ -1,30 +1,86 @@ import uvicorn import yaml -from fastapi import FastAPI, Request +import pathlib +from auth import check_admin_key, check_api_key, load_auth_keys +from fastapi import FastAPI, Request, HTTPException, Depends from model import ModelContainer from progress.bar import IncrementalBar from sse_starlette import EventSourceResponse -from OAI.models.completions import CompletionRequest, CompletionResponse -from OAI.models.models import ModelCard, ModelList +from OAI.types.completions import CompletionRequest, CompletionResponse +from OAI.types.models import ModelCard, ModelList, ModelLoadRequest, ModelLoadResponse from OAI.utils import create_completion_response, get_model_list +from typing import Optional +from utils import load_progress app = FastAPI() -# Initialize a model container. This can be undefined at any period of time -model_container: ModelContainer = None +# Globally scoped variables. Undefined until initalized in main +model_container: Optional[ModelContainer] = None +config: Optional[dict] = None -@app.get("/v1/models") -@app.get("/v1/model/list") +@app.get("/v1/models", dependencies=[Depends(check_api_key)]) +@app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): - models = get_model_list(model_container.get_model_path()) + model_config = config["model"] + models = get_model_list(pathlib.Path(model_config["model_dir"] or "models")) return models.model_dump_json() -@app.get("/v1/model") +@app.get("/v1/model", dependencies=[Depends(check_api_key)]) async def get_current_model(): - return ModelCard(id = model_container.get_model_path().name) + if model_container is None or model_container.model is None: + return HTTPException(400, "No models are loaded.") -@app.post("/v1/completions", response_class=CompletionResponse) + model_card = ModelCard(id=model_container.get_model_path().name) + return model_card.model_dump_json() + +@app.post("/v1/model/load", response_class=ModelLoadResponse, dependencies=[Depends(check_admin_key)]) +async def load_model(data: ModelLoadRequest): + if model_container and model_container.model: + raise HTTPException(400, "A model is already loaded! Please unload it first.") + + def generator(): + global model_container + model_config = config["model"] + model_path = pathlib.Path(model_config["model_dir"] or "models") + model_path = model_path / data.name + + model_container = ModelContainer(model_path, False, **data.model_dump()) + load_status = model_container.load_gen(load_progress) + for (module, modules) in load_status: + if module == 0: + loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + elif module == modules: + loading_bar.next() + loading_bar.finish() + else: + loading_bar.next() + + yield ModelLoadResponse( + module=module, + modules=modules, + status="processing" + ).model_dump_json() + + yield ModelLoadResponse( + module=module, + modules=modules, + status="finished" + ).model_dump_json() + + return EventSourceResponse(generator()) + +@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)]) +async def unload_model(): + global model_container + + if model_container is None: + raise HTTPException(400, "No models are loaded.") + + model_container.unload() + model_container = None + +@app.post("/v1/completions", response_class=CompletionResponse, dependencies=[Depends(check_api_key)]) async def generate_completion(request: Request, data: CompletionRequest): if data.stream: async def generator(): @@ -44,31 +100,32 @@ async def generate_completion(request: Request, data: CompletionRequest): return response.model_dump_json() - -# Wrapper callback for load progress -def load_progress(module, modules): - yield module, modules - if __name__ == "__main__": + # Initialize auth keys + load_auth_keys() + # Load from YAML config. Possibly add a config -> kwargs conversion function with open('config.yml', 'r') as config_file: config = yaml.safe_load(config_file) # If an initial model name is specified, create a container and load the model - if config["model_name"]: - model_path = f"{config['model_dir']}/{config['model_name']}" if config['model_dir'] else f"models/{config['model_name']}" - - model_container = ModelContainer(model_path, False, **config) + model_config = config["model"] + if model_config["model_name"]: + model_path = pathlib.Path(model_config["model_dir"] or "models") + model_path = model_path / model_config["model_name"] + + model_container = ModelContainer(model_path, False, **model_config) load_status = model_container.load_gen(load_progress) for (module, modules) in load_status: if module == 0: loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + elif module == modules: + loading_bar.next() + loading_bar.finish() else: loading_bar.next() - - if module == modules: - loading_bar.finish() - + print("Model successfully loaded.") - uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug") + network_config = config["network"] + uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug") diff --git a/model.py b/model.py index 0434c60..41b51c9 100644 --- a/model.py +++ b/model.py @@ -32,7 +32,7 @@ class ModelContainer: gpu_split_auto: bool = True gpu_split: list or None = None - def __init__(self, model_directory: str, quiet = False, **kwargs): + def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs): """ Create model container @@ -62,11 +62,11 @@ class ModelContainer: self.quiet = quiet self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8" - self.gpu_split_auto = kwargs.get("gpu_split_auto", True) self.gpu_split = kwargs.get("gpu_split", None) + self.gpu_split_auto = self.gpu_split == "auto" self.config = ExLlamaV2Config() - self.config.model_dir = model_directory + self.config.model_dir = str(model_directory.resolve()) self.config.prepare() if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"] @@ -85,7 +85,7 @@ class ModelContainer: if self.draft_enabled: self.draft_config = ExLlamaV2Config() - self.draft_config.model_dir = kwargs["draft_model_directory"] + self.draft_config.model_dir = kwargs["draft_model_dir"] self.draft_config.prepare() self.draft_config.max_seq_len = self.config.max_seq_len @@ -103,7 +103,7 @@ class ModelContainer: def get_model_path(self): - model_path = pathlib.Path(self.draft_config.model_dir if self.draft_enabled else self.config.model_dir) + model_path = pathlib.Path(self.config.model_dir) return model_path @@ -185,9 +185,12 @@ class ModelContainer: if self.model: self.model.unload() self.model = None + if self.draft_model: self.draft_model.unload() + self.draft_model = None self.config = None self.cache = None self.tokenizer = None + self.generator = None gc.collect() torch.cuda.empty_cache() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..1fa6283 --- /dev/null +++ b/utils.py @@ -0,0 +1,3 @@ +# Wrapper callback for load progress +def load_progress(module, modules): + yield module, modules \ No newline at end of file