diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 3834c5f..e5c1d16 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -2,10 +2,7 @@ import traceback from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2.generator import ExLlamaV2Sampler from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter - -from common.logger import init_logger - -logger = init_logger(__name__) +from loguru import logger class OutlinesTokenizerWrapper: diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 18d4f63..9125aa3 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -15,6 +15,7 @@ from exllamav2 import ( ExLlamaV2Lora, ) from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler +from loguru import logger from typing import List, Optional, Union from backends.exllamav2.grammar import ExLlamaV2Grammar @@ -26,9 +27,6 @@ from common.templating import ( get_template_from_file, ) from common.utils import coalesce, unwrap -from common.logger import init_logger - -logger = init_logger(__name__) class ExllamaV2Container: diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index d957b1d..6d01ef7 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -1,9 +1,6 @@ from packaging import version from importlib.metadata import version as package_version - -from common.logger import init_logger - -logger = init_logger(__name__) +from loguru import logger def check_exllama_version(): diff --git a/common/auth.py b/common/auth.py index ded42b9..8c00d9a 100644 --- a/common/auth.py +++ b/common/auth.py @@ -6,12 +6,9 @@ import secrets import yaml from fastapi import Header, HTTPException from pydantic import BaseModel +from loguru import logger from typing import Optional -from common.logger import init_logger - -logger = init_logger(__name__) - class AuthKeys(BaseModel): """ diff --git a/common/config.py b/common/config.py index f02e48d..fc18c71 100644 --- a/common/config.py +++ b/common/config.py @@ -1,11 +1,9 @@ import yaml import pathlib +from loguru import logger -from common.logger import init_logger from common.utils import unwrap -logger = init_logger(__name__) - GLOBAL_CONFIG: dict = {} diff --git a/common/gen_logging.py b/common/gen_logging.py index a20e45c..1fe84e8 100644 --- a/common/gen_logging.py +++ b/common/gen_logging.py @@ -2,12 +2,9 @@ Functions for logging generation events. """ from pydantic import BaseModel +from loguru import logger from typing import Dict, Optional -from common.logger import init_logger - -logger = init_logger(__name__) - class LogPreferences(BaseModel): """Logging preference config.""" diff --git a/common/logger.py b/common/logger.py index 2c6248f..0975db4 100644 --- a/common/logger.py +++ b/common/logger.py @@ -1,71 +1,104 @@ """ -Logging utility. -https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/common/logger.py +Internal logging utility. """ import logging -import sys -import colorlog +from loguru import logger +from rich.console import Console +from rich.progress import ( + Progress, + TextColumn, + BarColumn, + TimeRemainingColumn, + TaskProgressColumn, + MofNCompleteColumn, +) -_FORMAT = "%(log_color)s%(levelname)s: %(message)s" -_DATE_FORMAT = "%m-%d %H:%M:%S" +from common.utils import unwrap + +RICH_CONSOLE = Console() -class ColoredFormatter(colorlog.ColoredFormatter): - """Adds logging prefix to newlines to align multi-line messages.""" +def get_loading_progress_bar(): + """Gets a pre-made progress bar for loading tasks.""" - def __init__(self, fmt, datefmt=None, log_colors=None, reset=True, style="%"): - super().__init__( - fmt, datefmt=datefmt, log_colors=log_colors, reset=reset, style=style + return Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + MofNCompleteColumn(), + TimeRemainingColumn(), + console=RICH_CONSOLE, + ) + + +def _log_formatter(record: dict) -> str: + """Log message formatter.""" + + color_map = { + "TRACE": "dim blue", + "DEBUG": "cyan", + "INFO": "green", + "SUCCESS": "bold green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "bold white on red", + } + level = record.get("level") + level_color = color_map.get(level.name, "cyan") + + message = unwrap(record.get("message"), "") + lines = message.splitlines() + + # Replace once loguru allows for turning off str.format + message = message.replace("{", "{{").replace("}", "}}") + + fmt = "" + if len(lines) > 1: + fmt = "\n".join( + [ + f"[{level_color}]{level.name + ':' :<10}[/{level_color}]{line}" + for line in lines + ] + ) + else: + fmt = f"[{level_color}]{level.name + ':' :<10}[/{level_color}]{message}" + + return fmt + + +# Uvicorn log handler +# Uvicorn log portions inspired from https://github.com/encode/uvicorn/discussions/2027#discussioncomment-6432362 +class UvicornLoggingHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + logger.opt(exception=record.exc_info).log( + record.levelname, self.format(record).rstrip() ) - def format(self, record): - msg = super().format(record) - if record.message != "": - parts = msg.split(record.message) - msg = msg.replace("\n", "\r\n" + parts[0]) - return msg - -_root_logger = logging.getLogger("aphrodite") -_default_handler = None - - -def _setup_logger(): - _root_logger.setLevel(logging.DEBUG) - global _default_handler - if _default_handler is None: - _default_handler = logging.StreamHandler(sys.stdout) - _default_handler.flush = sys.stdout.flush # type: ignore - _default_handler.setLevel(logging.INFO) - _root_logger.addHandler(_default_handler) - fmt = ColoredFormatter( - _FORMAT, - datefmt=_DATE_FORMAT, - log_colors={ - "DEBUG": "cyan", - "INFO": "green", - "WARNING": "yellow", - "ERROR": "red", - "CRITICAL": "red,bg_white", +# Uvicorn config for logging. Passed into run when creating all loggers in server +UVICORN_LOG_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "handlers": { + "uvicorn": { + "class": ( + f"{UvicornLoggingHandler.__module__}.{UvicornLoggingHandler.__qualname__}", + ) }, - reset=True, + }, + "root": {"handlers": ["uvicorn"], "propagate": False, "level": "TRACE"}, +} + + +def setup_logger(): + """Bootstrap the logger.""" + + logger.remove() + + logger.add( + RICH_CONSOLE.print, + level="DEBUG", + format=_log_formatter, + colorize=True, ) - _default_handler.setFormatter(fmt) - # Setting this will avoid the message - # being propagated to the parent logger. - _root_logger.propagate = False - - -# The logger is initialized when the module is imported. -# This is thread-safe as the module is only imported once, -# guaranteed by the Python GIL. -_setup_logger() - - -def init_logger(name: str): - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - logger.addHandler(_default_handler) - logger.propagate = False - return logger diff --git a/common/sampling.py b/common/sampling.py index d62f87a..4acbbbb 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -1,17 +1,14 @@ """Common functions for sampling parameters""" import pathlib -from typing import Dict, List, Optional, Union -from pydantic import AliasChoices, BaseModel, Field import yaml +from loguru import logger +from pydantic import AliasChoices, BaseModel, Field +from typing import Dict, List, Optional, Union -from common.logger import init_logger from common.utils import unwrap, prune_dict -logger = init_logger(__name__) - - # Common class for sampler params class BaseSamplerRequest(BaseModel): """Common class for sampler params that are used in APIs""" diff --git a/common/utils.py b/common/utils.py index 5f70757..5e0ef78 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,21 +1,10 @@ """Common utility functions""" import traceback +from loguru import logger from pydantic import BaseModel -from rich.progress import ( - Progress, - TextColumn, - BarColumn, - TimeRemainingColumn, - TaskProgressColumn, - MofNCompleteColumn, -) from typing import Optional -from common.logger import init_logger - -logger = init_logger(__name__) - def load_progress(module, modules): """Wrapper callback for load progress.""" @@ -66,18 +55,6 @@ def get_sse_packet(json_data: str): return f"data: {json_data}\n\n" -def get_loading_progress_bar(): - """Gets a pre-made progress bar for loading tasks.""" - - return Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - MofNCompleteColumn(), - TimeRemainingColumn(), - ) - - def unwrap(wrapped, default=None): """Unwrap function for Optionals.""" if wrapped is None: diff --git a/main.py b/main.py index 0fec2cd..f2048bc 100644 --- a/main.py +++ b/main.py @@ -4,8 +4,8 @@ import pathlib import signal import sys import time -import uvicorn import threading +import uvicorn from asyncio import CancelledError from typing import Optional from uuid import uuid4 @@ -15,7 +15,9 @@ from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from functools import partial +from loguru import logger +from common.logger import setup_logger, get_loading_progress_bar import common.gen_logging as gen_logging from backends.exllamav2.model import ExllamaV2Container from backends.exllamav2.utils import check_exllama_version @@ -45,13 +47,11 @@ from common.templating import ( ) from common.utils import ( get_generator_error, - get_loading_progress_bar, get_sse_packet, handle_request_error, load_progress, unwrap, ) -from common.logger import init_logger from OAI.types.completion import CompletionRequest from OAI.types.chat_completion import ChatCompletionRequest from OAI.types.lora import LoraCard, LoraList, LoraLoadRequest, LoraLoadResponse @@ -77,8 +77,6 @@ from OAI.utils.completion import ( from OAI.utils.model import get_model_list from OAI.utils.lora import get_lora_list -logger = init_logger(__name__) - app = FastAPI( title="TabbyAPI", summary="An OAI compatible exllamav2 API that's both lightweight and fast", @@ -692,6 +690,8 @@ def entrypoint(args: Optional[dict] = None): global MODEL_CONTAINER + setup_logger() + # Set up signal aborting signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) diff --git a/requirements-amd.txt b/requirements-amd.txt index ac8518a..792429e 100644 --- a/requirements-amd.txt +++ b/requirements-amd.txt @@ -13,4 +13,4 @@ PyYAML rich uvicorn jinja2 >= 3.0.0 -colorlog +loguru diff --git a/requirements-cu118.txt b/requirements-cu118.txt index 2c46d16..b2476bb 100644 --- a/requirements-cu118.txt +++ b/requirements-cu118.txt @@ -19,7 +19,7 @@ PyYAML rich uvicorn jinja2 >= 3.0.0 -colorlog +loguru # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.2/flash_attn-2.5.2+cu118torch2.2cxx11abiFALSE-cp311-cp311-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.11" diff --git a/requirements-nowheel.txt b/requirements-nowheel.txt index d3b6721..a1e360b 100644 --- a/requirements-nowheel.txt +++ b/requirements-nowheel.txt @@ -5,4 +5,4 @@ PyYAML rich uvicorn jinja2 >= 3.0.0 -colorlog +loguru diff --git a/requirements.txt b/requirements.txt index b613116..acde85f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,7 +19,7 @@ PyYAML rich uvicorn jinja2 >= 3.0.0 -colorlog +loguru # Flash attention v2