Tree: Refactor code organization
Move common functions into their own folder and refactor the backends to use their own folder as well. Also cleanup imports and alphabetize import statments themselves. Finally, move colab and docker into their own folders as well. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ee99349a78
commit
78f920eeda
22 changed files with 41 additions and 42 deletions
|
|
@ -1,7 +1,8 @@
|
|||
from uuid import uuid4
|
||||
from time import time
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import Union, List, Optional, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
""" Completion API protocols """
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import List, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from OAI.types.common import CommonCompletionRequest, LogProbs, UsageStats
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
""" Lora types """
|
||||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import Optional, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LoraCard(BaseModel):
|
||||
"""Represents a single Lora card."""
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
""" Contains model card types. """
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
from gen_logging import LogPreferences
|
||||
from common.gen_logging import LogPreferences
|
||||
|
||||
|
||||
class ModelCardParameters(BaseModel):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
""" Tokenization types """
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
class CommonTokenRequest(BaseModel):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import pathlib
|
||||
from typing import Optional
|
||||
|
||||
from common.utils import unwrap
|
||||
from OAI.types.chat_completion import (
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionRespChoice,
|
||||
|
|
@ -14,8 +15,6 @@ from OAI.types.common import UsageStats
|
|||
from OAI.types.lora import LoraList, LoraCard
|
||||
from OAI.types.model import ModelList, ModelCard
|
||||
|
||||
from utils import unwrap
|
||||
|
||||
|
||||
def create_completion_response(
|
||||
text: str,
|
||||
|
|
|
|||
|
|
@ -13,17 +13,17 @@ from exllamav2 import (
|
|||
ExLlamaV2Lora,
|
||||
)
|
||||
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
||||
|
||||
from gen_logging import log_generation_params, log_prompt, log_response
|
||||
from typing import List, Optional, Union
|
||||
from templating import (
|
||||
|
||||
from common.gen_logging import log_generation_params, log_prompt, log_response
|
||||
from common.templating import (
|
||||
PromptTemplate,
|
||||
find_template_from_model,
|
||||
get_template_from_model_json,
|
||||
get_template_from_file,
|
||||
)
|
||||
from utils import coalesce, unwrap
|
||||
from logger import init_logger
|
||||
from common.utils import coalesce, unwrap
|
||||
from common.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ logger = init_logger(__name__)
|
|||
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
|
||||
|
||||
|
||||
class ModelContainer:
|
||||
class ExllamaV2Container:
|
||||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
|
|
@ -3,13 +3,12 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
|
|||
application, it should be fine.
|
||||
"""
|
||||
import secrets
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from fastapi import Header, HTTPException
|
||||
from pydantic import BaseModel
|
||||
import yaml
|
||||
from typing import Optional
|
||||
|
||||
from logger import init_logger
|
||||
from common.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
import yaml
|
||||
import pathlib
|
||||
|
||||
from logger import init_logger
|
||||
from utils import unwrap
|
||||
from common.logger import init_logger
|
||||
from common.utils import unwrap
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
|
@ -4,7 +4,7 @@ Functions for logging generation events.
|
|||
from pydantic import BaseModel
|
||||
from typing import Dict, Optional
|
||||
|
||||
from logger import init_logger
|
||||
from common.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
|
@ -4,7 +4,7 @@ from typing import Optional
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from logger import init_logger
|
||||
from common.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
|
@ -2,7 +2,8 @@ version: '3.8'
|
|||
services:
|
||||
tabbyapi:
|
||||
build:
|
||||
context: .
|
||||
context: ..
|
||||
dockerfile: ./docker/Dockerfile
|
||||
ports:
|
||||
- "5000:5000"
|
||||
environment:
|
||||
27
main.py
27
main.py
|
|
@ -11,10 +11,11 @@ from fastapi.responses import StreamingResponse
|
|||
from functools import partial
|
||||
from progress.bar import IncrementalBar
|
||||
|
||||
import gen_logging
|
||||
from args import convert_args_to_dict, init_argparser
|
||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from config import (
|
||||
import common.gen_logging as gen_logging
|
||||
from backends.exllamav2.model import ExllamaV2Container
|
||||
from common.args import convert_args_to_dict, init_argparser
|
||||
from common.auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from common.config import (
|
||||
override_config_from_args,
|
||||
read_config_from_file,
|
||||
get_gen_logging_config,
|
||||
|
|
@ -23,8 +24,10 @@ from config import (
|
|||
get_lora_config,
|
||||
get_network_config,
|
||||
)
|
||||
from generators import call_with_semaphore, generate_with_semaphore
|
||||
from model import ModelContainer
|
||||
from common.generators import call_with_semaphore, generate_with_semaphore
|
||||
from common.templating import get_all_templates, get_prompt_from_template
|
||||
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||
from common.logger import init_logger
|
||||
from OAI.types.completion import CompletionRequest
|
||||
from OAI.types.chat_completion import ChatCompletionRequest
|
||||
from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse
|
||||
|
|
@ -48,9 +51,6 @@ from OAI.utils_oai import (
|
|||
create_chat_completion_response,
|
||||
create_chat_completion_stream_chunk,
|
||||
)
|
||||
from templating import get_all_templates, get_prompt_from_template
|
||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||
from logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
|
@ -64,7 +64,7 @@ app = FastAPI(
|
|||
)
|
||||
|
||||
# Globally scoped variables. Undefined until initalized in main
|
||||
MODEL_CONTAINER: Optional[ModelContainer] = None
|
||||
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
|
||||
|
||||
|
||||
def _check_model_container():
|
||||
|
|
@ -182,7 +182,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not model_path.exists():
|
||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||
|
||||
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **load_data)
|
||||
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
||||
|
||||
async def generator():
|
||||
"""Generator for the loading process."""
|
||||
|
|
@ -530,7 +530,9 @@ def entrypoint(args: Optional[dict] = None):
|
|||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = model_path / model_name
|
||||
|
||||
MODEL_CONTAINER = ModelContainer(model_path.resolve(), False, **model_config)
|
||||
MODEL_CONTAINER = ExllamaV2Container(
|
||||
model_path.resolve(), False, **model_config
|
||||
)
|
||||
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||
for module, modules in load_status:
|
||||
if module == 0:
|
||||
|
|
@ -550,6 +552,7 @@ def entrypoint(args: Optional[dict] = None):
|
|||
host = unwrap(network_config.get("host"), "127.0.0.1")
|
||||
port = unwrap(network_config.get("port"), 5000)
|
||||
|
||||
# TODO: Move OAI API to a separate folder
|
||||
logger.info(f"Developer documentation: http://{host}:{port}/docs")
|
||||
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||
|
|
|
|||
2
start.py
2
start.py
|
|
@ -3,7 +3,7 @@ import argparse
|
|||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
from args import convert_args_to_dict, init_argparser
|
||||
from common.args import convert_args_to_dict, init_argparser
|
||||
|
||||
|
||||
def get_requirements_file():
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
""" Test the model container. """
|
||||
from model import ModelContainer
|
||||
from backends.exllamav2.model import ModelContainer
|
||||
|
||||
|
||||
def progress(module, modules):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue