diff --git a/auth.py b/auth.py index 451ba0e..4185ddb 100644 --- a/auth.py +++ b/auth.py @@ -9,6 +9,10 @@ from fastapi import Header, HTTPException from pydantic import BaseModel import yaml +from logger import init_logger + +logger = init_logger(__name__) + class AuthKeys(BaseModel): """ @@ -44,11 +48,10 @@ def load_auth_keys(disable_from_config: bool): DISABLE_AUTH = disable_from_config if disable_from_config: - print( - "!! Warning: Disabling authentication", - "makes your instance vulnerable.", - "Set the 'disable_auth' flag to False in config.yml", - "if you want to share this instance with others.", + logger.warning( + "Disabling authentication makes your instance vulnerable. " + "Set the `disable_auth` flag to False in config.yml if you " + "want to share this instance with others." ) return @@ -66,7 +69,7 @@ def load_auth_keys(disable_from_config: bool): with open("api_tokens.yml", "w", encoding="utf8") as auth_file: yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False) - print( + logger.info( f"Your API key is: {AUTH_KEYS.api_key}\n" f"Your admin key is: {AUTH_KEYS.admin_key}\n\n" "If these keys get compromised, make sure to delete api_tokens.yml " diff --git a/gen_logging.py b/gen_logging.py index c8d3620..7731357 100644 --- a/gen_logging.py +++ b/gen_logging.py @@ -4,6 +4,10 @@ Functions for logging generation events. from typing import Dict from pydantic import BaseModel +from logger import init_logger + +logger = init_logger(__name__) + class LogConfig(BaseModel): """Logging preference config.""" @@ -38,24 +42,24 @@ def broadcast_status(): enabled.append("generation params") if len(enabled) > 0: - print("Generation logging is enabled for: " + ", ".join(enabled)) + logger.info("Generation logging is enabled for: " + ", ".join(enabled)) else: - print("Generation logging is disabled") + logger.info("Generation logging is disabled") def log_generation_params(**kwargs): """Logs generation parameters to console.""" if CONFIG.generation_params: - print(f"Generation options: {kwargs}\n") + logger.info(f"Generation options: {kwargs}\n") def log_prompt(prompt: str): """Logs the prompt to console.""" if CONFIG.prompt: - print(f"Prompt: {prompt if prompt else 'Empty'}\n") + logger.info(f"Prompt: {prompt if prompt else 'Empty'}\n") def log_response(response: str): """Logs the response to console.""" if CONFIG.prompt: - print(f"Response: {response if response else 'Empty'}\n") + logger.info(f"Response: {response if response else 'Empty'}\n") diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..2c6248f --- /dev/null +++ b/logger.py @@ -0,0 +1,71 @@ +""" +Logging utility. +https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/common/logger.py +""" + +import logging +import sys +import colorlog + +_FORMAT = "%(log_color)s%(levelname)s: %(message)s" +_DATE_FORMAT = "%m-%d %H:%M:%S" + + +class ColoredFormatter(colorlog.ColoredFormatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None, log_colors=None, reset=True, style="%"): + super().__init__( + fmt, datefmt=datefmt, log_colors=log_colors, reset=reset, style=style + ) + + def format(self, record): + msg = super().format(record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg + + +_root_logger = logging.getLogger("aphrodite") +_default_handler = None + + +def _setup_logger(): + _root_logger.setLevel(logging.DEBUG) + global _default_handler + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.flush = sys.stdout.flush # type: ignore + _default_handler.setLevel(logging.INFO) + _root_logger.addHandler(_default_handler) + fmt = ColoredFormatter( + _FORMAT, + datefmt=_DATE_FORMAT, + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "red,bg_white", + }, + reset=True, + ) + _default_handler.setFormatter(fmt) + # Setting this will avoid the message + # being propagated to the parent logger. + _root_logger.propagate = False + + +# The logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +_setup_logger() + + +def init_logger(name: str): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.addHandler(_default_handler) + logger.propagate = False + return logger diff --git a/main.py b/main.py index fb1fe26..6d5d304 100644 --- a/main.py +++ b/main.py @@ -41,6 +41,9 @@ from OAI.utils_oai import ( ) from templating import get_prompt_from_template from utils import get_generator_error, get_sse_packet, load_progress, unwrap +from logger import init_logger + +logger = init_logger(__name__) app = FastAPI() @@ -210,8 +213,8 @@ async def load_model(request: Request, data: ModelLoadRequest): yield get_sse_packet(response.model_dump_json()) except CancelledError: - print( - "\nError: Model load cancelled by user. " + logger.error( + "Model load cancelled by user. " "Please make sure to run unload to free up resources." ) except Exception as exc: @@ -369,7 +372,7 @@ async def generate_completion(request: Request, data: CompletionRequest): yield get_sse_packet(response.model_dump_json()) except CancelledError: - print("Error: Completion request cancelled by user.") + logger.error("Completion request cancelled by user.") except Exception as exc: yield get_generator_error(str(exc)) @@ -456,7 +459,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest yield get_sse_packet(finish_response.model_dump_json()) except CancelledError: - print("Error: Chat completion cancelled by user.") + logger.error("Chat completion cancelled by user.") except Exception as exc: yield get_generator_error(str(exc)) @@ -481,10 +484,10 @@ if __name__ == "__main__": with open("config.yml", "r", encoding="utf8") as config_file: config = unwrap(yaml.safe_load(config_file), {}) except Exception as exc: - print( - "The YAML config couldn't load because of the following error:", - f"\n\n{exc}", - "\n\nTabbyAPI will start anyway and not parse this config file.", + logger.error( + "The YAML config couldn't load because of the following error: " + f"\n\n{exc}" + "\n\nTabbyAPI will start anyway and not parse this config file." ) config = {} diff --git a/model.py b/model.py index e4dde49..08bc23f 100644 --- a/model.py +++ b/model.py @@ -23,6 +23,9 @@ from templating import ( get_template_from_file, ) from utils import coalesce, unwrap +from logger import init_logger + +logger = init_logger(__name__) # Bytes to reserve on first device when loading with auto split AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2 @@ -143,10 +146,7 @@ class ModelContainer: # Set prompt template override if provided prompt_template_name = kwargs.get("prompt_template") if prompt_template_name: - print( - "Attempting to load prompt template with name", - {prompt_template_name}, - ) + logger.info("Loading prompt template with name " f"{prompt_template_name}") # Read the template self.prompt_template = get_template_from_file(prompt_template_name) else: @@ -175,13 +175,13 @@ class ModelContainer: # Catch all for template lookup errors if self.prompt_template: - print( - f"Using template {self.prompt_template.name} for chat " "completions." + logger.info( + f"Using template {self.prompt_template.name} " "for chat completions." ) else: - print( - "Chat completions are disabled because a prompt template", - "wasn't provided or auto-detected.", + logger.warning( + "Chat completions are disabled because a prompt " + "template wasn't provided or auto-detected." ) # Set num of experts per token if provided @@ -190,9 +190,9 @@ class ModelContainer: if hasattr(self.config, "num_experts_per_token"): self.config.num_experts_per_token = num_experts_override else: - print( - " !! Warning: Currently installed ExLlamaV2 does not " - "support overriding MoE experts" + logger.warning( + "MoE experts per token override is not " + "supported by the current ExLlamaV2 version." ) chunk_size = min( @@ -207,9 +207,9 @@ class ModelContainer: # Always disable draft if params are incorrectly configured if draft_args and draft_model_name is None: - print( - "A draft config was found but a model name was not given. " - "Please check your config.yml! Skipping draft load." + logger.warning( + "Draft model is disabled because a model name " + "wasn't provided. Please check your config.yml!" ) enable_draft = False @@ -283,20 +283,20 @@ class ModelContainer: lora_scaling = unwrap(lora.get("scaling"), 1.0) if lora_name is None: - print( + logger.warning( "One of your loras does not have a name. Please check your " "config.yml! Skipping lora load." ) failure.append(lora_name) continue - print(f"Loading lora: {lora_name} at scaling {lora_scaling}") + logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}") lora_path = lora_directory / lora_name # FIXME(alpin): Does self.model need to be passed here? self.active_loras.append( ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) ) - print("Lora successfully loaded.") + logger.info(f"Lora successfully loaded: {lora_name}") success.append(lora_name) # Return success and failure names @@ -319,7 +319,7 @@ class ModelContainer: if self.draft_config: self.draft_model = ExLlamaV2(self.draft_config) if not self.quiet: - print("Loading draft model: " + self.draft_config.model_dir) + logger.info("Loading draft model: " + self.draft_config.model_dir) self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True) reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16 @@ -337,7 +337,7 @@ class ModelContainer: # Load model self.model = ExLlamaV2(self.config) if not self.quiet: - print("Loading model: " + self.config.model_dir) + logger.info("Loading model: " + self.config.model_dir) if not self.gpu_split_auto: for value in self.model.load_gen( @@ -373,7 +373,7 @@ class ModelContainer: self.draft_cache, ) - print("Model successfully loaded.") + logger.info("Model successfully loaded.") def unload(self, loras_only: bool = False): """ @@ -494,33 +494,33 @@ class ModelContainer: if (unwrap(kwargs.get("mirostat"), False)) and not hasattr( gen_settings, "mirostat" ): - print( - " !! Warning: Currently installed ExLlamaV2 does not support " - "Mirostat sampling" + logger.warning( + "Mirostat sampling is not supported by the currently " + "installed ExLlamaV2 version." ) if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr( gen_settings, "min_p" ): - print( - " !! Warning: Currently installed ExLlamaV2 does not " - "support min-P sampling" + logger.warning( + "Min-P sampling is not supported by the currently " + "installed ExLlamaV2 version." ) if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr( gen_settings, "tfs" ): - print( - " !! Warning: Currently installed ExLlamaV2 does not support " - "tail-free sampling (TFS)" + logger.warning( + "Tail-free sampling (TFS) is not supported by the currently " + "installed ExLlamaV2 version." ) if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr( gen_settings, "temperature_last" ): - print( - " !! Warning: Currently installed ExLlamaV2 does not support " - "temperature_last" + logger.warning( + "Temperature last is not supported by the currently " + "installed ExLlamaV2 version." ) # Apply settings @@ -614,10 +614,10 @@ class ModelContainer: context_len = len(ids[0]) if context_len > self.config.max_seq_len: - print( - f"WARNING: The context length {context_len} is greater than " - f"the max_seq_len {self.config.max_seq_len}.", - "Generation is truncated and metrics may not be accurate.", + logger.warning( + f"Context length {context_len} is greater than max_seq_len " + f"{self.config.max_seq_len}. Generation is truncated and " + "metrics may not be accurate." ) prompt_tokens = ids.shape[-1] @@ -705,7 +705,7 @@ class ModelContainer: extra_parts.append("<-- Not accurate (truncated)") # Print output - print( + logger.info( initial_response + " (" + ", ".join(itemization) diff --git a/requirements-amd.txt b/requirements-amd.txt index 24e0b0b..d3239e9 100644 --- a/requirements-amd.txt +++ b/requirements-amd.txt @@ -13,3 +13,4 @@ PyYAML progress uvicorn jinja2 >= 3.0.0 +colorlog \ No newline at end of file diff --git a/requirements-cu118.txt b/requirements-cu118.txt index d5eac3f..d804226 100644 --- a/requirements-cu118.txt +++ b/requirements-cu118.txt @@ -19,6 +19,7 @@ PyYAML progress uvicorn jinja2 >= 3.0.0 +colorlog # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10" diff --git a/requirements-nowheel.txt b/requirements-nowheel.txt index 1f19530..c36272e 100644 --- a/requirements-nowheel.txt +++ b/requirements-nowheel.txt @@ -5,3 +5,4 @@ PyYAML progress uvicorn jinja2 >= 3.0.0 +colorlog diff --git a/requirements.txt b/requirements.txt index f87bc74..a0bd32f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,6 +19,7 @@ PyYAML progress uvicorn jinja2 >= 3.0.0 +colorlog # Flash attention v2 diff --git a/tests/wheel_test.py b/tests/wheel_test.py index c343ed5..f9ee15e 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -2,45 +2,52 @@ from importlib.metadata import version from importlib.util import find_spec +from logger import init_logger + +logger = init_logger(__name__) + successful_packages = [] errored_packages = [] if find_spec("flash_attn") is not None: - print(f"Flash attention on version {version('flash_attn')} successfully imported") + logger.info( + f"Flash attention on version {version('flash_attn')} " "successfully imported" + ) successful_packages.append("flash_attn") else: - print("Flash attention 2 is not found in your environment.") + logger.error("Flash attention 2 is not found in your environment.") errored_packages.append("flash_attn") if find_spec("exllamav2") is not None: - print(f"Exllamav2 on version {version('exllamav2')} successfully imported") + logger.info(f"Exllamav2 on version {version('exllamav2')} " "successfully imported") successful_packages.append("exllamav2") else: - print("Exllamav2 is not found in your environment.") + logger.error("Exllamav2 is not found in your environment.") errored_packages.append("exllamav2") if find_spec("torch") is not None: - print(f"Torch on version {version('torch')} successfully imported") + logger.info(f"Torch on version {version('torch')} successfully imported") successful_packages.append("torch") else: - print("Torch is not found in your environment.") + logger.error("Torch is not found in your environment.") errored_packages.append("torch") if find_spec("jinja2") is not None: - print(f"Jinja2 on version {version('jinja2')} successfully imported") + logger.info(f"Jinja2 on version {version('jinja2')} successfully imported") successful_packages.append("jinja2") else: - print("Jinja2 is not found in your environment.") + logger.error("Jinja2 is not found in your environment.") errored_packages.append("jinja2") -print( - f"\nSuccessful imports: {', '.join(successful_packages)}", - f"\nErrored imports: {''.join(errored_packages)}", -) +logger.info(f"\nSuccessful imports: {', '.join(successful_packages)}") +logger.error(f"Errored imports: {''.join(errored_packages)}") if len(errored_packages) > 0: - print( - "\nIf packages are installed, but not found on this test, please " - "check the wheel versions for the correct python version and CUDA " - "version (if applicable)." + logger.warning( + "If all packages are installed, but not found " + "on this test, please check the wheel versions for the " + "correct python version and CUDA version (if " + "applicable)." ) +else: + logger.info("All wheels are installed correctly.") diff --git a/utils.py b/utils.py index 6f00d3e..529afe0 100644 --- a/utils.py +++ b/utils.py @@ -4,6 +4,10 @@ from typing import Optional from pydantic import BaseModel +from logger import init_logger + +logger = init_logger(__name__) + def load_progress(module, modules): """Wrapper callback for load progress.""" @@ -32,7 +36,7 @@ def get_generator_error(message: str): generator_error = TabbyGeneratorError(error=error_message) # Log and send the exception - print(f"\n{generator_error.error.trace}") + logger.error(generator_error.error.message) return get_sse_packet(generator_error.model_dump_json())