Common: Migrate request utils to networking

Helps organize the project better. Utils is meant to be for simple
functions like unwrap.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-21 23:21:57 -04:00
parent 2961c5f3f9
commit 6dfcbbd813
8 changed files with 111 additions and 115 deletions

View file

@ -10,13 +10,16 @@ from typing import Optional
from backends.exllamav2.model import ExllamaV2Container
from common.logger import get_loading_progress_bar
from common.utils import load_progress
# Global model container
container: Optional[ExllamaV2Container] = None
def load_progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules
async def unload_model():
"""Unloads a model"""
global container

96
common/networking.py Normal file
View file

@ -0,0 +1,96 @@
"""Common utility functions"""
import asyncio
import socket
import traceback
from fastapi import Request
from loguru import logger
from pydantic import BaseModel
from typing import Optional
from common.concurrency import release_semaphore
class TabbyRequestErrorMessage(BaseModel):
"""Common request error type."""
message: str
trace: Optional[str] = None
class TabbyRequestError(BaseModel):
"""Common request error type."""
error: TabbyRequestErrorMessage
def get_generator_error(message: str, exc_info: bool = True):
"""Get a generator error."""
generator_error = handle_request_error(message, exc_info)
return generator_error.model_dump_json()
def handle_request_error(message: str, exc_info: bool = True):
"""Log a request error to the console."""
error_message = TabbyRequestErrorMessage(
message=message, trace=traceback.format_exc()
)
request_error = TabbyRequestError(error=error_message)
# Log the error and provided message to the console
if error_message.trace and exc_info:
logger.error(error_message.trace)
logger.error(f"Sent to request: {message}")
return request_error
def handle_request_disconnect(message: str):
"""Wrapper for handling for request disconnection."""
release_semaphore()
logger.error(message)
async def request_disconnect_loop(request: Request):
"""Polls for a starlette request disconnect."""
while not await request.is_disconnected():
await asyncio.sleep(0.5)
async def run_with_request_disconnect(
request: Request, call_task: asyncio.Task, disconnect_message: str
):
"""Utility function to cancel if a request is disconnected."""
_, unfinished = await asyncio.wait(
[
call_task,
asyncio.create_task(request_disconnect_loop(request)),
],
return_when=asyncio.FIRST_COMPLETED,
)
for task in unfinished:
task.cancel()
try:
return call_task.result()
except (asyncio.CancelledError, asyncio.InvalidStateError):
handle_request_disconnect(disconnect_message)
def is_port_in_use(port: int) -> bool:
"""
Checks if a port is in use
From https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0

View file

@ -1,94 +1,5 @@
"""Common utility functions"""
import asyncio
import socket
import traceback
from fastapi import Request
from loguru import logger
from pydantic import BaseModel
from typing import Optional
from common.concurrency import release_semaphore
def load_progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules
class TabbyRequestErrorMessage(BaseModel):
"""Common request error type."""
message: str
trace: Optional[str] = None
class TabbyRequestError(BaseModel):
"""Common request error type."""
error: TabbyRequestErrorMessage
def get_generator_error(message: str, exc_info: bool = True):
"""Get a generator error."""
generator_error = handle_request_error(message)
return generator_error.model_dump_json()
def handle_request_error(message: str, exc_info: bool = True):
"""Log a request error to the console."""
error_message = TabbyRequestErrorMessage(
message=message, trace=traceback.format_exc()
)
request_error = TabbyRequestError(error=error_message)
# Log the error and provided message to the console
if error_message.trace and exc_info:
logger.error(error_message.trace)
logger.error(f"Sent to request: {message}")
return request_error
def handle_request_disconnect(message: str):
"""Wrapper for handling for request disconnection."""
release_semaphore()
logger.error(message)
async def request_disconnect_loop(request: Request):
"""Polls for a starlette request disconnect."""
while not await request.is_disconnected():
await asyncio.sleep(0.5)
async def run_with_request_disconnect(
request: Request, call_task: asyncio.Task, disconnect_message: str
):
"""Utility function to cancel if a request is disconnected."""
_, unfinished = await asyncio.wait(
[
call_task,
asyncio.create_task(request_disconnect_loop(request)),
],
return_when=asyncio.FIRST_COMPLETED,
)
for task in unfinished:
task.cancel()
try:
return call_task.result()
except (asyncio.CancelledError, asyncio.InvalidStateError):
handle_request_disconnect(disconnect_message)
def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
@ -107,14 +18,3 @@ def prune_dict(input_dict):
"""Trim out instances of None from a dictionary"""
return {k: v for k, v in input_dict.items() if v is not None}
def is_port_in_use(port: int) -> bool:
"""
Checks if a port is in use
From https://stackoverflow.com/questions/2470971/fast-way-to-test-if-a-port-is-in-use-using-python
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0

View file

@ -18,17 +18,13 @@ from common.concurrency import (
generate_with_semaphore,
)
from common.logger import UVICORN_LOG_CONFIG
from common.networking import handle_request_error, run_with_request_disconnect
from common.signals import uvicorn_signal_handler
from common.templating import (
get_all_templates,
get_template_from_file,
)
from common.utils import (
coalesce,
handle_request_error,
run_with_request_disconnect,
unwrap,
)
from common.utils import coalesce, unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest
from endpoints.OAI.types.chat_completion import ChatCompletionRequest

View file

@ -10,13 +10,13 @@ from fastapi import HTTPException
from jinja2 import TemplateError
from common import model
from common.templating import get_prompt_from_template
from common.utils import (
from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
unwrap,
)
from common.templating import get_prompt_from_template
from common.utils import unwrap
from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs,
ChatCompletionLogprob,

View file

@ -7,12 +7,12 @@ from fastapi import HTTPException
from typing import Optional
from common import model
from common.utils import (
from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
unwrap,
)
from common.utils import unwrap
from endpoints.OAI.types.completion import (
CompletionRequest,
CompletionResponse,

View file

@ -3,7 +3,7 @@ from asyncio import CancelledError
from typing import Optional
from common import model
from common.utils import get_generator_error, handle_request_disconnect
from common.networking import get_generator_error, handle_request_disconnect
from endpoints.OAI.types.model import (
ModelCard,

View file

@ -12,8 +12,9 @@ from common import config, gen_logging, sampling, model
from common.args import convert_args_to_dict, init_argparser
from common.auth import load_auth_keys
from common.logger import setup_logger
from common.networking import is_port_in_use
from common.signals import signal_handler
from common.utils import is_port_in_use, unwrap
from common.utils import unwrap
from endpoints.OAI.app import start_api