diff --git a/common/model.py b/common/model.py index 6b2d212..f916f5a 100644 --- a/common/model.py +++ b/common/model.py @@ -10,13 +10,16 @@ from typing import Optional from backends.exllamav2.model import ExllamaV2Container from common.logger import get_loading_progress_bar -from common.utils import load_progress - # Global model container container: Optional[ExllamaV2Container] = None +def load_progress(module, modules): + """Wrapper callback for load progress.""" + yield module, modules + + async def unload_model(): """Unloads a model""" global container diff --git a/common/networking.py b/common/networking.py new file mode 100644 index 0000000..fee7c28 --- /dev/null +++ b/common/networking.py @@ -0,0 +1,96 @@ +"""Common utility functions""" + +import asyncio +import socket +import traceback +from fastapi import Request +from loguru import logger +from pydantic import BaseModel +from typing import Optional + +from common.concurrency import release_semaphore + + +class TabbyRequestErrorMessage(BaseModel): + """Common request error type.""" + + message: str + trace: Optional[str] = None + + +class TabbyRequestError(BaseModel): + """Common request error type.""" + + error: TabbyRequestErrorMessage + + +def get_generator_error(message: str, exc_info: bool = True): + """Get a generator error.""" + + generator_error = handle_request_error(message, exc_info) + + return generator_error.model_dump_json() + + +def handle_request_error(message: str, exc_info: bool = True): + """Log a request error to the console.""" + + error_message = TabbyRequestErrorMessage( + message=message, trace=traceback.format_exc() + ) + + request_error = TabbyRequestError(error=error_message) + + # Log the error and provided message to the console + if error_message.trace and exc_info: + logger.error(error_message.trace) + + logger.error(f"Sent to request: {message}") + + return request_error + + +def handle_request_disconnect(message: str): + """Wrapper for handling for request disconnection.""" + + release_semaphore() + logger.error(message) + + +async def request_disconnect_loop(request: Request): + """Polls for a starlette request disconnect.""" + + while not await request.is_disconnected(): + await asyncio.sleep(0.5) + + +async def run_with_request_disconnect( + request: Request, call_task: asyncio.Task, disconnect_message: str +): + """Utility function to cancel if a request is disconnected.""" + + _, unfinished = await asyncio.wait( + [ + call_task, + asyncio.create_task(request_disconnect_loop(request)), + ], + return_when=asyncio.FIRST_COMPLETED, + ) + for task in unfinished: + task.cancel() + + try: + return call_task.result() + except (asyncio.CancelledError, asyncio.InvalidStateError): + handle_request_disconnect(disconnect_message) + + +def is_port_in_use(port: int) -> bool: + """ + Checks if a port is in use + + From https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python + """ + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 diff --git a/common/utils.py b/common/utils.py index bae567b..079a380 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,94 +1,5 @@ """Common utility functions""" -import asyncio -import socket -import traceback -from fastapi import Request -from loguru import logger -from pydantic import BaseModel -from typing import Optional - -from common.concurrency import release_semaphore - - -def load_progress(module, modules): - """Wrapper callback for load progress.""" - yield module, modules - - -class TabbyRequestErrorMessage(BaseModel): - """Common request error type.""" - - message: str - trace: Optional[str] = None - - -class TabbyRequestError(BaseModel): - """Common request error type.""" - - error: TabbyRequestErrorMessage - - -def get_generator_error(message: str, exc_info: bool = True): - """Get a generator error.""" - - generator_error = handle_request_error(message) - - return generator_error.model_dump_json() - - -def handle_request_error(message: str, exc_info: bool = True): - """Log a request error to the console.""" - - error_message = TabbyRequestErrorMessage( - message=message, trace=traceback.format_exc() - ) - - request_error = TabbyRequestError(error=error_message) - - # Log the error and provided message to the console - if error_message.trace and exc_info: - logger.error(error_message.trace) - - logger.error(f"Sent to request: {message}") - - return request_error - - -def handle_request_disconnect(message: str): - """Wrapper for handling for request disconnection.""" - - release_semaphore() - logger.error(message) - - -async def request_disconnect_loop(request: Request): - """Polls for a starlette request disconnect.""" - - while not await request.is_disconnected(): - await asyncio.sleep(0.5) - - -async def run_with_request_disconnect( - request: Request, call_task: asyncio.Task, disconnect_message: str -): - """Utility function to cancel if a request is disconnected.""" - - _, unfinished = await asyncio.wait( - [ - call_task, - asyncio.create_task(request_disconnect_loop(request)), - ], - return_when=asyncio.FIRST_COMPLETED, - ) - for task in unfinished: - task.cancel() - - try: - return call_task.result() - except (asyncio.CancelledError, asyncio.InvalidStateError): - handle_request_disconnect(disconnect_message) - def unwrap(wrapped, default=None): """Unwrap function for Optionals.""" @@ -107,14 +18,3 @@ def prune_dict(input_dict): """Trim out instances of None from a dictionary""" return {k: v for k, v in input_dict.items() if v is not None} - - -def is_port_in_use(port: int) -> bool: - """ - Checks if a port is in use - - From https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python - """ - - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(("localhost", port)) == 0 diff --git a/endpoints/OAI/app.py b/endpoints/OAI/app.py index 0f2e0e9..c448eb7 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/app.py @@ -18,17 +18,13 @@ from common.concurrency import ( generate_with_semaphore, ) from common.logger import UVICORN_LOG_CONFIG +from common.networking import handle_request_error, run_with_request_disconnect from common.signals import uvicorn_signal_handler from common.templating import ( get_all_templates, get_template_from_file, ) -from common.utils import ( - coalesce, - handle_request_error, - run_with_request_disconnect, - unwrap, -) +from common.utils import coalesce, unwrap from endpoints.OAI.types.auth import AuthPermissionResponse from endpoints.OAI.types.completion import CompletionRequest from endpoints.OAI.types.chat_completion import ChatCompletionRequest diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 562e736..d74db83 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -10,13 +10,13 @@ from fastapi import HTTPException from jinja2 import TemplateError from common import model -from common.templating import get_prompt_from_template -from common.utils import ( +from common.networking import ( get_generator_error, handle_request_disconnect, handle_request_error, - unwrap, ) +from common.templating import get_prompt_from_template +from common.utils import unwrap from endpoints.OAI.types.chat_completion import ( ChatCompletionLogprobs, ChatCompletionLogprob, diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 02b7852..f7e50af 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -7,12 +7,12 @@ from fastapi import HTTPException from typing import Optional from common import model -from common.utils import ( +from common.networking import ( get_generator_error, handle_request_disconnect, handle_request_error, - unwrap, ) +from common.utils import unwrap from endpoints.OAI.types.completion import ( CompletionRequest, CompletionResponse, diff --git a/endpoints/OAI/utils/model.py b/endpoints/OAI/utils/model.py index 61d210f..66c7625 100644 --- a/endpoints/OAI/utils/model.py +++ b/endpoints/OAI/utils/model.py @@ -3,7 +3,7 @@ from asyncio import CancelledError from typing import Optional from common import model -from common.utils import get_generator_error, handle_request_disconnect +from common.networking import get_generator_error, handle_request_disconnect from endpoints.OAI.types.model import ( ModelCard, diff --git a/main.py b/main.py index f6fc52a..5b95d87 100644 --- a/main.py +++ b/main.py @@ -12,8 +12,9 @@ from common import config, gen_logging, sampling, model from common.args import convert_args_to_dict, init_argparser from common.auth import load_auth_keys from common.logger import setup_logger +from common.networking import is_port_in_use from common.signals import signal_handler -from common.utils import is_port_in_use, unwrap +from common.utils import unwrap from endpoints.OAI.app import start_api