diff --git a/OAI/types/chat_completion.py b/OAI/types/chat_completion.py index 5e0e80b..ba0a968 100644 --- a/OAI/types/chat_completion.py +++ b/OAI/types/chat_completion.py @@ -1,7 +1,8 @@ -from uuid import uuid4 -from time import time from pydantic import BaseModel, Field +from time import time from typing import Union, List, Optional, Dict +from uuid import uuid4 + from OAI.types.common import UsageStats, CommonCompletionRequest diff --git a/OAI/types/completion.py b/OAI/types/completion.py index 15e84a7..4fa380c 100644 --- a/OAI/types/completion.py +++ b/OAI/types/completion.py @@ -1,10 +1,9 @@ """ Completion API protocols """ +from pydantic import BaseModel, Field from time import time from typing import List, Optional, Union from uuid import uuid4 -from pydantic import BaseModel, Field - from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats diff --git a/OAI/types/lora.py b/OAI/types/lora.py index 841c3a8..018bf06 100644 --- a/OAI/types/lora.py +++ b/OAI/types/lora.py @@ -1,9 +1,8 @@ """ Lora types """ +from pydantic import BaseModel, Field from time import time from typing import Optional, List -from pydantic import BaseModel, Field - class LoraCard(BaseModel): """Represents a single Lora card.""" diff --git a/OAI/types/model.py b/OAI/types/model.py index 483c41f..9096d41 100644 --- a/OAI/types/model.py +++ b/OAI/types/model.py @@ -1,10 +1,9 @@ """ Contains model card types. """ +from pydantic import BaseModel, Field, ConfigDict from time import time from typing import List, Optional -from pydantic import BaseModel, Field, ConfigDict - -from gen_logging import LogPreferences +from common.gen_logging import LogPreferences class ModelCardParameters(BaseModel): diff --git a/OAI/types/token.py b/OAI/types/token.py index 98cbc98..8467aa6 100644 --- a/OAI/types/token.py +++ b/OAI/types/token.py @@ -1,7 +1,6 @@ """ Tokenization types """ -from typing import List - from pydantic import BaseModel +from typing import List class CommonTokenRequest(BaseModel): diff --git a/OAI/utils_oai.py b/OAI/utils_oai.py index b3c59d6..5ad2873 100644 --- a/OAI/utils_oai.py +++ b/OAI/utils_oai.py @@ -2,6 +2,7 @@ import pathlib from typing import Optional +from common.utils import unwrap from OAI.types.chat_completion import ( ChatCompletionMessage, ChatCompletionRespChoice, @@ -14,8 +15,6 @@ from OAI.types.common import UsageStats from OAI.types.lora import LoraList, LoraCard from OAI.types.model import ModelList, ModelCard -from utils import unwrap - def create_completion_response( text: str, diff --git a/model.py b/backends/exllamav2/model.py similarity index 99% rename from model.py rename to backends/exllamav2/model.py index cef53c3..6ee186b 100644 --- a/model.py +++ b/backends/exllamav2/model.py @@ -13,17 +13,17 @@ from exllamav2 import ( ExLlamaV2Lora, ) from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler - -from gen_logging import log_generation_params, log_prompt, log_response from typing import List, Optional, Union -from templating import ( + +from common.gen_logging import log_generation_params, log_prompt, log_response +from common.templating import ( PromptTemplate, find_template_from_model, get_template_from_model_json, get_template_from_file, ) -from utils import coalesce, unwrap -from logger import init_logger +from common.utils import coalesce, unwrap +from common.logger import init_logger logger = init_logger(__name__) @@ -31,7 +31,7 @@ logger = init_logger(__name__) AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2 -class ModelContainer: +class ExllamaV2Container: """The model container class for ExLlamaV2 models.""" config: Optional[ExLlamaV2Config] = None diff --git a/TabbyAPI_Colab_Example.ipynb b/colab/TabbyAPI_Colab_Example.ipynb similarity index 100% rename from TabbyAPI_Colab_Example.ipynb rename to colab/TabbyAPI_Colab_Example.ipynb diff --git a/args.py b/common/args.py similarity index 100% rename from args.py rename to common/args.py diff --git a/auth.py b/common/auth.py similarity index 99% rename from auth.py rename to common/auth.py index 4185ddb..ea42168 100644 --- a/auth.py +++ b/common/auth.py @@ -3,13 +3,12 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local application, it should be fine. """ import secrets -from typing import Optional - +import yaml from fastapi import Header, HTTPException from pydantic import BaseModel -import yaml +from typing import Optional -from logger import init_logger +from common.logger import init_logger logger = init_logger(__name__) diff --git a/config.py b/common/config.py similarity index 97% rename from config.py rename to common/config.py index 178977b..e46be62 100644 --- a/config.py +++ b/common/config.py @@ -1,8 +1,8 @@ import yaml import pathlib -from logger import init_logger -from utils import unwrap +from common.logger import init_logger +from common.utils import unwrap logger = init_logger(__name__) diff --git a/gen_logging.py b/common/gen_logging.py similarity index 98% rename from gen_logging.py rename to common/gen_logging.py index a82cea3..a20e45c 100644 --- a/gen_logging.py +++ b/common/gen_logging.py @@ -4,7 +4,7 @@ Functions for logging generation events. from pydantic import BaseModel from typing import Dict, Optional -from logger import init_logger +from common.logger import init_logger logger = init_logger(__name__) diff --git a/generators.py b/common/generators.py similarity index 100% rename from generators.py rename to common/generators.py diff --git a/logger.py b/common/logger.py similarity index 100% rename from logger.py rename to common/logger.py diff --git a/templating.py b/common/templating.py similarity index 100% rename from templating.py rename to common/templating.py diff --git a/utils.py b/common/utils.py similarity index 97% rename from utils.py rename to common/utils.py index 529afe0..2db97e9 100644 --- a/utils.py +++ b/common/utils.py @@ -4,7 +4,7 @@ from typing import Optional from pydantic import BaseModel -from logger import init_logger +from common.logger import init_logger logger = init_logger(__name__) diff --git a/.dockerignore b/docker/.dockerignore similarity index 100% rename from .dockerignore rename to docker/.dockerignore diff --git a/Dockerfile b/docker/Dockerfile similarity index 100% rename from Dockerfile rename to docker/Dockerfile diff --git a/docker-compose.yml b/docker/docker-compose.yml similarity index 86% rename from docker-compose.yml rename to docker/docker-compose.yml index c553612..d50682e 100644 --- a/docker-compose.yml +++ b/docker/docker-compose.yml @@ -2,7 +2,8 @@ version: '3.8' services: tabbyapi: build: - context: . + context: .. + dockerfile: ./docker/Dockerfile ports: - "5000:5000" environment: diff --git a/main.py b/main.py index bbc7ad8..0ba9650 100644 --- a/main.py +++ b/main.py @@ -11,10 +11,11 @@ from fastapi.responses import StreamingResponse from functools import partial from progress.bar import IncrementalBar -import gen_logging -from args import convert_args_to_dict, init_argparser -from auth import check_admin_key, check_api_key, load_auth_keys -from config import ( +import common.gen_logging as gen_logging +from backends.exllamav2.model import ExllamaV2Container +from common.args import convert_args_to_dict, init_argparser +from common.auth import check_admin_key, check_api_key, load_auth_keys +from common.config import ( override_config_from_args, read_config_from_file, get_gen_logging_config, @@ -23,8 +24,10 @@ from config import ( get_lora_config, get_network_config, ) -from generators import call_with_semaphore, generate_with_semaphore -from model import ModelContainer +from common.generators import call_with_semaphore, generate_with_semaphore +from common.templating import get_all_templates, get_prompt_from_template +from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap +from common.logger import init_logger from OAI.types.completion import CompletionRequest from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse @@ -48,9 +51,6 @@ from OAI.utils_oai import ( create_chat_completion_response, create_chat_completion_stream_chunk, ) -from templating import get_all_templates, get_prompt_from_template -from utils import get_generator_error, get_sse_packet, load_progress, unwrap -from logger import init_logger logger = init_logger(__name__) @@ -64,7 +64,7 @@ app = FastAPI( ) # Globally scoped variables. Undefined until initalized in main -MODEL_CONTAINER: Optional[ModelContainer] = None +MODEL_CONTAINER: Optional[ExllamaV2Container] = None def _check_model_container(): @@ -182,7 +182,7 @@ async def load_model(request: Request, data: ModelLoadRequest): if not model_path.exists(): raise HTTPException(400, "model_path does not exist. Check model_name?") - MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data) + MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data) async def generator(): """Generator for the loading process.""" @@ -530,7 +530,9 @@ def entrypoint(args: Optional[dict] = None): model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) model_path = model_path / model_name - MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config) + MODEL_CONTAINER = ExllamaV2Container( + model_path.resolve(), False, **model_config + ) load_status = MODEL_CONTAINER.load_gen(load_progress) for module, modules in load_status: if module == 0: @@ -550,6 +552,7 @@ def entrypoint(args: Optional[dict] = None): host = unwrap(network_config.get("host"), "127.0.0.1") port = unwrap(network_config.get("port"), 5000) + # TODO: Move OAI API to a separate folder logger.info(f"Developer documentation: http://{host}:{port}/docs") logger.info(f"Completions: http://{host}:{port}/v1/completions") logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") diff --git a/start.py b/start.py index c4b9334..9bfae3e 100644 --- a/start.py +++ b/start.py @@ -3,7 +3,7 @@ import argparse import os import pathlib import subprocess -from args import convert_args_to_dict, init_argparser +from common.args import convert_args_to_dict, init_argparser def get_requirements_file(): diff --git a/tests/model_test.py b/tests/model_test.py index b4ac158..b47449e 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -1,5 +1,5 @@ """ Test the model container. """ -from model import ModelContainer +from backends.exllamav2.model import ModelContainer def progress(module, modules):