Tree: Refactor code organization

Move common functions into their own folder and refactor the backends
to use their own folder as well.

Also cleanup imports and alphabetize import statments themselves.

Finally, move colab and docker into their own folders as well.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-01-18 00:42:52 -05:00 committed by Brian Dashore
parent ee99349a78
commit 78f920eeda
22 changed files with 41 additions and 42 deletions

View file

@ -1,7 +1,8 @@
from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from time import time
from typing import Union, List, Optional, Dict
from uuid import uuid4
from OAI.types.common import UsageStats, CommonCompletionRequest

View file

@ -1,10 +1,9 @@
""" Completion API protocols """
from pydantic import BaseModel, Field
from time import time
from typing import List, Optional, Union
from uuid import uuid4
from pydantic import BaseModel, Field
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats

View file

@ -1,9 +1,8 @@
""" Lora types """
from pydantic import BaseModel, Field
from time import time
from typing import Optional, List
from pydantic import BaseModel, Field
class LoraCard(BaseModel):
"""Represents a single Lora card."""

View file

@ -1,10 +1,9 @@
""" Contains model card types. """
from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Optional
from pydantic import BaseModel, Field, ConfigDict
from gen_logging import LogPreferences
from common.gen_logging import LogPreferences
class ModelCardParameters(BaseModel):

View file

@ -1,7 +1,6 @@
""" Tokenization types """
from typing import List
from pydantic import BaseModel
from typing import List
class CommonTokenRequest(BaseModel):

View file

@ -2,6 +2,7 @@
import pathlib
from typing import Optional
from common.utils import unwrap
from OAI.types.chat_completion import (
ChatCompletionMessage,
ChatCompletionRespChoice,
@ -14,8 +15,6 @@ from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from utils import unwrap
def create_completion_response(
text: str,

View file

@ -13,17 +13,17 @@ from exllamav2 import (
ExLlamaV2Lora,
)
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union
from templating import (
from common.gen_logging import log_generation_params, log_prompt, log_response
from common.templating import (
PromptTemplate,
find_template_from_model,
get_template_from_model_json,
get_template_from_file,
)
from utils import coalesce, unwrap
from logger import init_logger
from common.utils import coalesce, unwrap
from common.logger import init_logger
logger = init_logger(__name__)
@ -31,7 +31,7 @@ logger = init_logger(__name__)
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
class ModelContainer:
class ExllamaV2Container:
"""The model container class for ExLlamaV2 models."""
config: Optional[ExLlamaV2Config] = None

View file

@ -3,13 +3,12 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
application, it should be fine.
"""
import secrets
from typing import Optional
import yaml
from fastapi import Header, HTTPException
from pydantic import BaseModel
import yaml
from typing import Optional
from logger import init_logger
from common.logger import init_logger
logger = init_logger(__name__)

View file

@ -1,8 +1,8 @@
import yaml
import pathlib
from logger import init_logger
from utils import unwrap
from common.logger import init_logger
from common.utils import unwrap
logger = init_logger(__name__)

View file

@ -4,7 +4,7 @@ Functions for logging generation events.
from pydantic import BaseModel
from typing import Dict, Optional
from logger import init_logger
from common.logger import init_logger
logger = init_logger(__name__)

View file

@ -4,7 +4,7 @@ from typing import Optional
from pydantic import BaseModel
from logger import init_logger
from common.logger import init_logger
logger = init_logger(__name__)

View file

@ -2,7 +2,8 @@ version: '3.8'
services:
tabbyapi:
build:
context: .
context: ..
dockerfile: ./docker/Dockerfile
ports:
- "5000:5000"
environment:

27
main.py
View file

@ -11,10 +11,11 @@ from fastapi.responses import StreamingResponse
from functools import partial
from progress.bar import IncrementalBar
import gen_logging
from args import convert_args_to_dict, init_argparser
from auth import check_admin_key, check_api_key, load_auth_keys
from config import (
import common.gen_logging as gen_logging
from backends.exllamav2.model import ExllamaV2Container
from common.args import convert_args_to_dict, init_argparser
from common.auth import check_admin_key, check_api_key, load_auth_keys
from common.config import (
override_config_from_args,
read_config_from_file,
get_gen_logging_config,
@ -23,8 +24,10 @@ from config import (
get_lora_config,
get_network_config,
)
from generators import call_with_semaphore, generate_with_semaphore
from model import ModelContainer
from common.generators import call_with_semaphore, generate_with_semaphore
from common.templating import get_all_templates, get_prompt_from_template
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
from common.logger import init_logger
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
@ -48,9 +51,6 @@ from OAI.utils_oai import (
create_chat_completion_response,
create_chat_completion_stream_chunk,
)
from templating import get_all_templates, get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from logger import init_logger
logger = init_logger(__name__)
@ -64,7 +64,7 @@ app = FastAPI(
)
# Globally scoped variables. Undefined until initalized in main
MODEL_CONTAINER: Optional[ModelContainer] = None
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
def _check_model_container():
@ -182,7 +182,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data)
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
async def generator():
"""Generator for the loading process."""
@ -530,7 +530,9 @@ def entrypoint(args: Optional[dict] = None):
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_name
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config)
MODEL_CONTAINER = ExllamaV2Container(
model_path.resolve(), False, **model_config
)
load_status = MODEL_CONTAINER.load_gen(load_progress)
for module, modules in load_status:
if module == 0:
@ -550,6 +552,7 @@ def entrypoint(args: Optional[dict] = None):
host = unwrap(network_config.get("host"), "127.0.0.1")
port = unwrap(network_config.get("port"), 5000)
# TODO: Move OAI API to a separate folder
logger.info(f"Developer documentation: http://{host}:{port}/docs")
logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")

View file

@ -3,7 +3,7 @@ import argparse
import os
import pathlib
import subprocess
from args import convert_args_to_dict, init_argparser
from common.args import convert_args_to_dict, init_argparser
def get_requirements_file():

View file

@ -1,5 +1,5 @@
""" Test the model container. """
from model import ModelContainer
from backends.exllamav2.model import ModelContainer
def progress(module, modules):