feat: logging (#39)

* add logging

* simplify the logger

* formatting

* final touches

* fix format

* Model: Add log to metrics

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>
This commit is contained in:
AlpinDale 2023-12-23 04:33:31 +00:00 committed by GitHub
parent f5314fcdad
commit 6a5bbd217c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 170 additions and 74 deletions

15
auth.py
View file

@ -9,6 +9,10 @@ from fastapi import Header, HTTPException
from pydantic import BaseModel
import yaml
from logger import init_logger
logger = init_logger(__name__)
class AuthKeys(BaseModel):
"""
@ -44,11 +48,10 @@ def load_auth_keys(disable_from_config: bool):
DISABLE_AUTH = disable_from_config
if disable_from_config:
print(
"!! Warning: Disabling authentication",
"makes your instance vulnerable.",
"Set the 'disable_auth' flag to False in config.yml",
"if you want to share this instance with others.",
logger.warning(
"Disabling authentication makes your instance vulnerable. "
"Set the `disable_auth` flag to False in config.yml if you "
"want to share this instance with others."
)
return
@ -66,7 +69,7 @@ def load_auth_keys(disable_from_config: bool):
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
print(
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
"If these keys get compromised, make sure to delete api_tokens.yml "

View file

@ -4,6 +4,10 @@ Functions for logging generation events.
from typing import Dict
from pydantic import BaseModel
from logger import init_logger
logger = init_logger(__name__)
class LogConfig(BaseModel):
"""Logging preference config."""
@ -38,24 +42,24 @@ def broadcast_status():
enabled.append("generation params")
if len(enabled) > 0:
print("Generation logging is enabled for: " + ", ".join(enabled))
logger.info("Generation logging is enabled for: " + ", ".join(enabled))
else:
print("Generation logging is disabled")
logger.info("Generation logging is disabled")
def log_generation_params(**kwargs):
"""Logs generation parameters to console."""
if CONFIG.generation_params:
print(f"Generation options: {kwargs}\n")
logger.info(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str):
"""Logs the prompt to console."""
if CONFIG.prompt:
print(f"Prompt: {prompt if prompt else 'Empty'}\n")
logger.info(f"Prompt: {prompt if prompt else 'Empty'}\n")
def log_response(response: str):
"""Logs the response to console."""
if CONFIG.prompt:
print(f"Response: {response if response else 'Empty'}\n")
logger.info(f"Response: {response if response else 'Empty'}\n")

71
logger.py Normal file
View file

@ -0,0 +1,71 @@
"""
Logging utility.
https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/common/logger.py
"""
import logging
import sys
import colorlog
_FORMAT = "%(log_color)s%(levelname)s: %(message)s"
_DATE_FORMAT = "%m-%d %H:%M:%S"
class ColoredFormatter(colorlog.ColoredFormatter):
"""Adds logging prefix to newlines to align multi-line messages."""
def __init__(self, fmt, datefmt=None, log_colors=None, reset=True, style="%"):
super().__init__(
fmt, datefmt=datefmt, log_colors=log_colors, reset=reset, style=style
)
def format(self, record):
msg = super().format(record)
if record.message != "":
parts = msg.split(record.message)
msg = msg.replace("\n", "\r\n" + parts[0])
return msg
_root_logger = logging.getLogger("aphrodite")
_default_handler = None
def _setup_logger():
_root_logger.setLevel(logging.DEBUG)
global _default_handler
if _default_handler is None:
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.flush = sys.stdout.flush # type: ignore
_default_handler.setLevel(logging.INFO)
_root_logger.addHandler(_default_handler)
fmt = ColoredFormatter(
_FORMAT,
datefmt=_DATE_FORMAT,
log_colors={
"DEBUG": "cyan",
"INFO": "green",
"WARNING": "yellow",
"ERROR": "red",
"CRITICAL": "red,bg_white",
},
reset=True,
)
_default_handler.setFormatter(fmt)
# Setting this will avoid the message
# being propagated to the parent logger.
_root_logger.propagate = False
# The logger is initialized when the module is imported.
# This is thread-safe as the module is only imported once,
# guaranteed by the Python GIL.
_setup_logger()
def init_logger(name: str):
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.addHandler(_default_handler)
logger.propagate = False
return logger

19
main.py
View file

@ -41,6 +41,9 @@ from OAI.utils_oai import (
)
from templating import get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from logger import init_logger
logger = init_logger(__name__)
app = FastAPI()
@ -210,8 +213,8 @@ async def load_model(request: Request, data: ModelLoadRequest):
yield get_sse_packet(response.model_dump_json())
except CancelledError:
print(
"\nError: Model load cancelled by user. "
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
except Exception as exc:
@ -369,7 +372,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
yield get_sse_packet(response.model_dump_json())
except CancelledError:
print("Error: Completion request cancelled by user.")
logger.error("Completion request cancelled by user.")
except Exception as exc:
yield get_generator_error(str(exc))
@ -456,7 +459,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
yield get_sse_packet(finish_response.model_dump_json())
except CancelledError:
print("Error: Chat completion cancelled by user.")
logger.error("Chat completion cancelled by user.")
except Exception as exc:
yield get_generator_error(str(exc))
@ -481,10 +484,10 @@ if __name__ == "__main__":
with open("config.yml", "r", encoding="utf8") as config_file:
config = unwrap(yaml.safe_load(config_file), {})
except Exception as exc:
print(
"The YAML config couldn't load because of the following error:",
f"\n\n{exc}",
"\n\nTabbyAPI will start anyway and not parse this config file.",
logger.error(
"The YAML config couldn't load because of the following error: "
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
)
config = {}

View file

@ -23,6 +23,9 @@ from templating import (
get_template_from_file,
)
from utils import coalesce, unwrap
from logger import init_logger
logger = init_logger(__name__)
# Bytes to reserve on first device when loading with auto split
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
@ -143,10 +146,7 @@ class ModelContainer:
# Set prompt template override if provided
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
print(
"Attempting to load prompt template with name",
{prompt_template_name},
)
logger.info("Loading prompt template with name " f"{prompt_template_name}")
# Read the template
self.prompt_template = get_template_from_file(prompt_template_name)
else:
@ -175,13 +175,13 @@ class ModelContainer:
# Catch all for template lookup errors
if self.prompt_template:
print(
f"Using template {self.prompt_template.name} for chat " "completions."
logger.info(
f"Using template {self.prompt_template.name} " "for chat completions."
)
else:
print(
"Chat completions are disabled because a prompt template",
"wasn't provided or auto-detected.",
logger.warning(
"Chat completions are disabled because a prompt "
"template wasn't provided or auto-detected."
)
# Set num of experts per token if provided
@ -190,9 +190,9 @@ class ModelContainer:
if hasattr(self.config, "num_experts_per_token"):
self.config.num_experts_per_token = num_experts_override
else:
print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support overriding MoE experts"
logger.warning(
"MoE experts per token override is not "
"supported by the current ExLlamaV2 version."
)
chunk_size = min(
@ -207,9 +207,9 @@ class ModelContainer:
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
print(
"A draft config was found but a model name was not given. "
"Please check your config.yml! Skipping draft load."
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False
@ -283,20 +283,20 @@ class ModelContainer:
lora_scaling = unwrap(lora.get("scaling"), 1.0)
if lora_name is None:
print(
logger.warning(
"One of your loras does not have a name. Please check your "
"config.yml! Skipping lora load."
)
failure.append(lora_name)
continue
print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
# FIXME(alpin): Does self.model need to be passed here?
self.active_loras.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
print("Lora successfully loaded.")
logger.info(f"Lora successfully loaded: {lora_name}")
success.append(lora_name)
# Return success and failure names
@ -319,7 +319,7 @@ class ModelContainer:
if self.draft_config:
self.draft_model = ExLlamaV2(self.draft_config)
if not self.quiet:
print("Loading draft model: " + self.draft_config.model_dir)
logger.info("Loading draft model: " + self.draft_config.model_dir)
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
@ -337,7 +337,7 @@ class ModelContainer:
# Load model
self.model = ExLlamaV2(self.config)
if not self.quiet:
print("Loading model: " + self.config.model_dir)
logger.info("Loading model: " + self.config.model_dir)
if not self.gpu_split_auto:
for value in self.model.load_gen(
@ -373,7 +373,7 @@ class ModelContainer:
self.draft_cache,
)
print("Model successfully loaded.")
logger.info("Model successfully loaded.")
def unload(self, loras_only: bool = False):
"""
@ -494,33 +494,33 @@ class ModelContainer:
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
gen_settings, "mirostat"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"Mirostat sampling"
logger.warning(
"Mirostat sampling is not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
gen_settings, "min_p"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not "
"support min-P sampling"
logger.warning(
"Min-P sampling is not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
gen_settings, "tfs"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"tail-free sampling (TFS)"
logger.warning(
"Tail-free sampling (TFS) is not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
gen_settings, "temperature_last"
):
print(
" !! Warning: Currently installed ExLlamaV2 does not support "
"temperature_last"
logger.warning(
"Temperature last is not supported by the currently "
"installed ExLlamaV2 version."
)
# Apply settings
@ -614,10 +614,10 @@ class ModelContainer:
context_len = len(ids[0])
if context_len > self.config.max_seq_len:
print(
f"WARNING: The context length {context_len} is greater than "
f"the max_seq_len {self.config.max_seq_len}.",
"Generation is truncated and metrics may not be accurate.",
logger.warning(
f"Context length {context_len} is greater than max_seq_len "
f"{self.config.max_seq_len}. Generation is truncated and "
"metrics may not be accurate."
)
prompt_tokens = ids.shape[-1]
@ -705,7 +705,7 @@ class ModelContainer:
extra_parts.append("<-- Not accurate (truncated)")
# Print output
print(
logger.info(
initial_response
+ " ("
+ ", ".join(itemization)

View file

@ -13,3 +13,4 @@ PyYAML
progress
uvicorn
jinja2 >= 3.0.0
colorlog

View file

@ -19,6 +19,7 @@ PyYAML
progress
uvicorn
jinja2 >= 3.0.0
colorlog
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"

View file

@ -5,3 +5,4 @@ PyYAML
progress
uvicorn
jinja2 >= 3.0.0
colorlog

View file

@ -19,6 +19,7 @@ PyYAML
progress
uvicorn
jinja2 >= 3.0.0
colorlog
# Flash attention v2

View file

@ -2,45 +2,52 @@
from importlib.metadata import version
from importlib.util import find_spec
from logger import init_logger
logger = init_logger(__name__)
successful_packages = []
errored_packages = []
if find_spec("flash_attn") is not None:
print(f"Flash attention on version {version('flash_attn')} successfully imported")
logger.info(
f"Flash attention on version {version('flash_attn')} " "successfully imported"
)
successful_packages.append("flash_attn")
else:
print("Flash attention 2 is not found in your environment.")
logger.error("Flash attention 2 is not found in your environment.")
errored_packages.append("flash_attn")
if find_spec("exllamav2") is not None:
print(f"Exllamav2 on version {version('exllamav2')} successfully imported")
logger.info(f"Exllamav2 on version {version('exllamav2')} " "successfully imported")
successful_packages.append("exllamav2")
else:
print("Exllamav2 is not found in your environment.")
logger.error("Exllamav2 is not found in your environment.")
errored_packages.append("exllamav2")
if find_spec("torch") is not None:
print(f"Torch on version {version('torch')} successfully imported")
logger.info(f"Torch on version {version('torch')} successfully imported")
successful_packages.append("torch")
else:
print("Torch is not found in your environment.")
logger.error("Torch is not found in your environment.")
errored_packages.append("torch")
if find_spec("jinja2") is not None:
print(f"Jinja2 on version {version('jinja2')} successfully imported")
logger.info(f"Jinja2 on version {version('jinja2')} successfully imported")
successful_packages.append("jinja2")
else:
print("Jinja2 is not found in your environment.")
logger.error("Jinja2 is not found in your environment.")
errored_packages.append("jinja2")
print(
f"\nSuccessful imports: {', '.join(successful_packages)}",
f"\nErrored imports: {''.join(errored_packages)}",
)
logger.info(f"\nSuccessful imports: {', '.join(successful_packages)}")
logger.error(f"Errored imports: {''.join(errored_packages)}")
if len(errored_packages) > 0:
print(
"\nIf packages are installed, but not found on this test, please "
"check the wheel versions for the correct python version and CUDA "
"version (if applicable)."
logger.warning(
"If all packages are installed, but not found "
"on this test, please check the wheel versions for the "
"correct python version and CUDA version (if "
"applicable)."
)
else:
logger.info("All wheels are installed correctly.")

View file

@ -4,6 +4,10 @@ from typing import Optional
from pydantic import BaseModel
from logger import init_logger
logger = init_logger(__name__)
def load_progress(module, modules):
"""Wrapper callback for load progress."""
@ -32,7 +36,7 @@ def get_generator_error(message: str):
generator_error = TabbyGeneratorError(error=error_message)
# Log and send the exception
print(f"\n{generator_error.error.trace}")
logger.error(generator_error.error.message)
return get_sse_packet(generator_error.model_dump_json())