diff --git a/OAI/models/models.py b/OAI/models/models.py deleted file mode 100644 index 44a2f26..0000000 --- a/OAI/models/models.py +++ /dev/null @@ -1,13 +0,0 @@ -from pydantic import BaseModel, Field -from time import time -from typing import List - -class ModelCard(BaseModel): - id: str - object: str = "model" - created: int = Field(default_factory=lambda: int(time())) - owned_by: str = "tabbyAPI" - -class ModelList(BaseModel): - object: str = "list" - data: List[ModelCard] = Field(default_factory=list) diff --git a/OAI/models/common.py b/OAI/types/common.py similarity index 100% rename from OAI/models/common.py rename to OAI/types/common.py diff --git a/OAI/models/completions.py b/OAI/types/completions.py similarity index 98% rename from OAI/models/completions.py rename to OAI/types/completions.py index ba107f6..cd95ff5 100644 --- a/OAI/models/completions.py +++ b/OAI/types/completions.py @@ -2,7 +2,7 @@ from uuid import uuid4 from time import time from pydantic import BaseModel, Field from typing import List, Optional, Dict, Union -from OAI.models.common import LogProbs, UsageStats +from OAI.types.common import LogProbs, UsageStats class CompletionRespChoice(BaseModel): finish_reason: str diff --git a/OAI/types/models.py b/OAI/types/models.py new file mode 100644 index 0000000..1a2db7f --- /dev/null +++ b/OAI/types/models.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel, Field +from time import time +from typing import List, Optional + +class ModelCard(BaseModel): + id: str = "test" + object: str = "model" + created: int = Field(default_factory=lambda: int(time())) + owned_by: str = "tabbyAPI" + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + +class ModelLoadRequest(BaseModel): + name: str + max_seq_len: Optional[int] = 4096 + gpu_split: Optional[str] = "auto" + rope_scale: Optional[float] = 1.0 + rope_alpha: Optional[float] = 1.0 + no_flash_attention: Optional[bool] = False + low_mem: Optional[bool] = False + +class ModelLoadResponse(BaseModel): + module: int + modules: int + status: str diff --git a/OAI/utils.py b/OAI/utils.py index 8b76d39..c609057 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -1,7 +1,7 @@ import pathlib -from OAI.models.completions import CompletionResponse, CompletionRespChoice -from OAI.models.common import UsageStats -from OAI.models.models import ModelList, ModelCard +from OAI.types.completions import CompletionResponse, CompletionRespChoice +from OAI.types.common import UsageStats +from OAI.types.models import ModelList, ModelCard from typing import Optional def create_completion_response(text: str, index: int, model_name: Optional[str]): diff --git a/api_tokens.yml b/api_tokens.yml new file mode 100644 index 0000000..ed8a91f --- /dev/null +++ b/api_tokens.yml @@ -0,0 +1,3 @@ +!!python/object:auth.AuthKeys +admin_key: 5b9e30a4197557dcd6cf48445ee174dc +api_key: 2261702e8a220c6c4671a264cd1236ce diff --git a/auth.py b/auth.py new file mode 100644 index 0000000..38e3a48 --- /dev/null +++ b/auth.py @@ -0,0 +1,54 @@ +import secrets +import yaml +from fastapi import Header, HTTPException +from typing import Optional + +""" +This method of authorization is pretty insecure, but since TabbyAPI is a local +application, it should be fine. +""" + +class AuthKeys: + api_key: str + admin_key: str + + def __init__(self, api_key: str, admin_key: str): + self.api_key = api_key + self.admin_key = admin_key + +auth_keys: Optional[AuthKeys] = None + +def load_auth_keys(): + global auth_keys + try: + with open("api_tokens.yml", "r") as auth_file: + auth_keys = yaml.safe_load(auth_file) + except: + new_auth_keys = AuthKeys( + api_key = secrets.token_hex(16), + admin_key = secrets.token_hex(16) + ) + auth_keys = new_auth_keys + + with open("api_tokens.yml", "w") as auth_file: + yaml.dump(auth_keys, auth_file) + +def check_api_key(x_api_key: str = Header(None), authorization: str = Header(None)): + if x_api_key and x_api_key == auth_keys.api_key: + return x_api_key + elif authorization: + split_key = authorization.split(" ") + if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.api_key: + return authorization + else: + raise HTTPException(401, "Invalid API key") + +def check_admin_key(x_admin_key: str = Header(None), authorization: str = Header(None)): + if x_admin_key and x_admin_key == auth_keys.admin_key: + return x_admin_key + elif authorization: + split_key = authorization.split(" ") + if split_key[0].lower() == "bearer" and split_key[1] == auth_keys.admin_key: + return authorization + else: + raise HTTPException(401, "Invalid admin key") diff --git a/config_sample.yml b/config_sample.yml index c18197d..f397cd9 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -1,8 +1,14 @@ -model_dir: "D:/models" -model_name: "this_is_a_exl2_model" -max_seq_len: 4096 -gpu_split: "auto" -rope_scale: 1.0 -rope_alpha: 1.0 -no_flash_attention: False -low_mem: False +# Network options +network: + host: "0.0.0.0" + port: 8012 +# Only used if you want to initially load a model +model: + model_dir: "D:/models" + model_name: "airoboros-mistral2.2-7b-exl2" + max_seq_len: 4096 + gpu_split: "auto" + rope_scale: 1.0 + rope_alpha: 1.0 + no_flash_attention: False + low_mem: False diff --git a/main.py b/main.py index 38c7979..c0bb54c 100644 --- a/main.py +++ b/main.py @@ -1,30 +1,86 @@ import uvicorn import yaml -from fastapi import FastAPI, Request +import pathlib +from auth import check_admin_key, check_api_key, load_auth_keys +from fastapi import FastAPI, Request, HTTPException, Depends from model import ModelContainer from progress.bar import IncrementalBar from sse_starlette import EventSourceResponse -from OAI.models.completions import CompletionRequest, CompletionResponse -from OAI.models.models import ModelCard, ModelList +from OAI.types.completions import CompletionRequest, CompletionResponse +from OAI.types.models import ModelCard, ModelList, ModelLoadRequest, ModelLoadResponse from OAI.utils import create_completion_response, get_model_list +from typing import Optional +from utils import load_progress app = FastAPI() -# Initialize a model container. This can be undefined at any period of time -model_container: ModelContainer = None +# Globally scoped variables. Undefined until initalized in main +model_container: Optional[ModelContainer] = None +config: Optional[dict] = None -@app.get("/v1/models") -@app.get("/v1/model/list") +@app.get("/v1/models", dependencies=[Depends(check_api_key)]) +@app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): - models = get_model_list(model_container.get_model_path()) + model_config = config["model"] + models = get_model_list(pathlib.Path(model_config["model_dir"] or "models")) return models.model_dump_json() -@app.get("/v1/model") +@app.get("/v1/model", dependencies=[Depends(check_api_key)]) async def get_current_model(): - return ModelCard(id = model_container.get_model_path().name) + if model_container is None or model_container.model is None: + return HTTPException(400, "No models are loaded.") -@app.post("/v1/completions", response_class=CompletionResponse) + model_card = ModelCard(id=model_container.get_model_path().name) + return model_card.model_dump_json() + +@app.post("/v1/model/load", response_class=ModelLoadResponse, dependencies=[Depends(check_admin_key)]) +async def load_model(data: ModelLoadRequest): + if model_container and model_container.model: + raise HTTPException(400, "A model is already loaded! Please unload it first.") + + def generator(): + global model_container + model_config = config["model"] + model_path = pathlib.Path(model_config["model_dir"] or "models") + model_path = model_path / data.name + + model_container = ModelContainer(model_path, False, **data.model_dump()) + load_status = model_container.load_gen(load_progress) + for (module, modules) in load_status: + if module == 0: + loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + elif module == modules: + loading_bar.next() + loading_bar.finish() + else: + loading_bar.next() + + yield ModelLoadResponse( + module=module, + modules=modules, + status="processing" + ).model_dump_json() + + yield ModelLoadResponse( + module=module, + modules=modules, + status="finished" + ).model_dump_json() + + return EventSourceResponse(generator()) + +@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)]) +async def unload_model(): + global model_container + + if model_container is None: + raise HTTPException(400, "No models are loaded.") + + model_container.unload() + model_container = None + +@app.post("/v1/completions", response_class=CompletionResponse, dependencies=[Depends(check_api_key)]) async def generate_completion(request: Request, data: CompletionRequest): if data.stream: async def generator(): @@ -44,31 +100,32 @@ async def generate_completion(request: Request, data: CompletionRequest): return response.model_dump_json() - -# Wrapper callback for load progress -def load_progress(module, modules): - yield module, modules - if __name__ == "__main__": + # Initialize auth keys + load_auth_keys() + # Load from YAML config. Possibly add a config -> kwargs conversion function with open('config.yml', 'r') as config_file: config = yaml.safe_load(config_file) # If an initial model name is specified, create a container and load the model - if config["model_name"]: - model_path = f"{config['model_dir']}/{config['model_name']}" if config['model_dir'] else f"models/{config['model_name']}" - - model_container = ModelContainer(model_path, False, **config) + model_config = config["model"] + if model_config["model_name"]: + model_path = pathlib.Path(model_config["model_dir"] or "models") + model_path = model_path / model_config["model_name"] + + model_container = ModelContainer(model_path, False, **model_config) load_status = model_container.load_gen(load_progress) for (module, modules) in load_status: if module == 0: loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + elif module == modules: + loading_bar.next() + loading_bar.finish() else: loading_bar.next() - - if module == modules: - loading_bar.finish() - + print("Model successfully loaded.") - uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug") + network_config = config["network"] + uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug") diff --git a/model.py b/model.py index 0434c60..41b51c9 100644 --- a/model.py +++ b/model.py @@ -32,7 +32,7 @@ class ModelContainer: gpu_split_auto: bool = True gpu_split: list or None = None - def __init__(self, model_directory: str, quiet = False, **kwargs): + def __init__(self, model_directory: pathlib.Path, quiet = False, **kwargs): """ Create model container @@ -62,11 +62,11 @@ class ModelContainer: self.quiet = quiet self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8" - self.gpu_split_auto = kwargs.get("gpu_split_auto", True) self.gpu_split = kwargs.get("gpu_split", None) + self.gpu_split_auto = self.gpu_split == "auto" self.config = ExLlamaV2Config() - self.config.model_dir = model_directory + self.config.model_dir = str(model_directory.resolve()) self.config.prepare() if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"] @@ -85,7 +85,7 @@ class ModelContainer: if self.draft_enabled: self.draft_config = ExLlamaV2Config() - self.draft_config.model_dir = kwargs["draft_model_directory"] + self.draft_config.model_dir = kwargs["draft_model_dir"] self.draft_config.prepare() self.draft_config.max_seq_len = self.config.max_seq_len @@ -103,7 +103,7 @@ class ModelContainer: def get_model_path(self): - model_path = pathlib.Path(self.draft_config.model_dir if self.draft_enabled else self.config.model_dir) + model_path = pathlib.Path(self.config.model_dir) return model_path @@ -185,9 +185,12 @@ class ModelContainer: if self.model: self.model.unload() self.model = None + if self.draft_model: self.draft_model.unload() + self.draft_model = None self.config = None self.cache = None self.tokenizer = None + self.generator = None gc.collect() torch.cuda.empty_cache() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..1fa6283 --- /dev/null +++ b/utils.py @@ -0,0 +1,3 @@ +# Wrapper callback for load progress +def load_progress(module, modules): + yield module, modules \ No newline at end of file