feat: logging (#39)
* add logging * simplify the logger * formatting * final touches * fix format * Model: Add log to metrics Signed-off-by: kingbri <bdashore3@proton.me> --------- Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>
This commit is contained in:
parent
f5314fcdad
commit
6a5bbd217c
11 changed files with 170 additions and 74 deletions
15
auth.py
15
auth.py
|
|
@ -9,6 +9,10 @@ from fastapi import Header, HTTPException
|
|||
from pydantic import BaseModel
|
||||
import yaml
|
||||
|
||||
from logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AuthKeys(BaseModel):
|
||||
"""
|
||||
|
|
@ -44,11 +48,10 @@ def load_auth_keys(disable_from_config: bool):
|
|||
|
||||
DISABLE_AUTH = disable_from_config
|
||||
if disable_from_config:
|
||||
print(
|
||||
"!! Warning: Disabling authentication",
|
||||
"makes your instance vulnerable.",
|
||||
"Set the 'disable_auth' flag to False in config.yml",
|
||||
"if you want to share this instance with others.",
|
||||
logger.warning(
|
||||
"Disabling authentication makes your instance vulnerable. "
|
||||
"Set the `disable_auth` flag to False in config.yml if you "
|
||||
"want to share this instance with others."
|
||||
)
|
||||
|
||||
return
|
||||
|
|
@ -66,7 +69,7 @@ def load_auth_keys(disable_from_config: bool):
|
|||
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
|
||||
|
||||
print(
|
||||
logger.info(
|
||||
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||
f"Your admin key is: {AUTH_KEYS.admin_key}\n\n"
|
||||
"If these keys get compromised, make sure to delete api_tokens.yml "
|
||||
|
|
|
|||
|
|
@ -4,6 +4,10 @@ Functions for logging generation events.
|
|||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LogConfig(BaseModel):
|
||||
"""Logging preference config."""
|
||||
|
|
@ -38,24 +42,24 @@ def broadcast_status():
|
|||
enabled.append("generation params")
|
||||
|
||||
if len(enabled) > 0:
|
||||
print("Generation logging is enabled for: " + ", ".join(enabled))
|
||||
logger.info("Generation logging is enabled for: " + ", ".join(enabled))
|
||||
else:
|
||||
print("Generation logging is disabled")
|
||||
logger.info("Generation logging is disabled")
|
||||
|
||||
|
||||
def log_generation_params(**kwargs):
|
||||
"""Logs generation parameters to console."""
|
||||
if CONFIG.generation_params:
|
||||
print(f"Generation options: {kwargs}\n")
|
||||
logger.info(f"Generation options: {kwargs}\n")
|
||||
|
||||
|
||||
def log_prompt(prompt: str):
|
||||
"""Logs the prompt to console."""
|
||||
if CONFIG.prompt:
|
||||
print(f"Prompt: {prompt if prompt else 'Empty'}\n")
|
||||
logger.info(f"Prompt: {prompt if prompt else 'Empty'}\n")
|
||||
|
||||
|
||||
def log_response(response: str):
|
||||
"""Logs the response to console."""
|
||||
if CONFIG.prompt:
|
||||
print(f"Response: {response if response else 'Empty'}\n")
|
||||
logger.info(f"Response: {response if response else 'Empty'}\n")
|
||||
|
|
|
|||
71
logger.py
Normal file
71
logger.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
"""
|
||||
Logging utility.
|
||||
https://github.com/PygmalionAI/aphrodite-engine/blob/main/aphrodite/common/logger.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import colorlog
|
||||
|
||||
_FORMAT = "%(log_color)s%(levelname)s: %(message)s"
|
||||
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
class ColoredFormatter(colorlog.ColoredFormatter):
|
||||
"""Adds logging prefix to newlines to align multi-line messages."""
|
||||
|
||||
def __init__(self, fmt, datefmt=None, log_colors=None, reset=True, style="%"):
|
||||
super().__init__(
|
||||
fmt, datefmt=datefmt, log_colors=log_colors, reset=reset, style=style
|
||||
)
|
||||
|
||||
def format(self, record):
|
||||
msg = super().format(record)
|
||||
if record.message != "":
|
||||
parts = msg.split(record.message)
|
||||
msg = msg.replace("\n", "\r\n" + parts[0])
|
||||
return msg
|
||||
|
||||
|
||||
_root_logger = logging.getLogger("aphrodite")
|
||||
_default_handler = None
|
||||
|
||||
|
||||
def _setup_logger():
|
||||
_root_logger.setLevel(logging.DEBUG)
|
||||
global _default_handler
|
||||
if _default_handler is None:
|
||||
_default_handler = logging.StreamHandler(sys.stdout)
|
||||
_default_handler.flush = sys.stdout.flush # type: ignore
|
||||
_default_handler.setLevel(logging.INFO)
|
||||
_root_logger.addHandler(_default_handler)
|
||||
fmt = ColoredFormatter(
|
||||
_FORMAT,
|
||||
datefmt=_DATE_FORMAT,
|
||||
log_colors={
|
||||
"DEBUG": "cyan",
|
||||
"INFO": "green",
|
||||
"WARNING": "yellow",
|
||||
"ERROR": "red",
|
||||
"CRITICAL": "red,bg_white",
|
||||
},
|
||||
reset=True,
|
||||
)
|
||||
_default_handler.setFormatter(fmt)
|
||||
# Setting this will avoid the message
|
||||
# being propagated to the parent logger.
|
||||
_root_logger.propagate = False
|
||||
|
||||
|
||||
# The logger is initialized when the module is imported.
|
||||
# This is thread-safe as the module is only imported once,
|
||||
# guaranteed by the Python GIL.
|
||||
_setup_logger()
|
||||
|
||||
|
||||
def init_logger(name: str):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(_default_handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
19
main.py
19
main.py
|
|
@ -41,6 +41,9 @@ from OAI.utils_oai import (
|
|||
)
|
||||
from templating import get_prompt_from_template
|
||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||
from logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
|
@ -210,8 +213,8 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
except CancelledError:
|
||||
print(
|
||||
"\nError: Model load cancelled by user. "
|
||||
logger.error(
|
||||
"Model load cancelled by user. "
|
||||
"Please make sure to run unload to free up resources."
|
||||
)
|
||||
except Exception as exc:
|
||||
|
|
@ -369,7 +372,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
except CancelledError:
|
||||
print("Error: Completion request cancelled by user.")
|
||||
logger.error("Completion request cancelled by user.")
|
||||
except Exception as exc:
|
||||
yield get_generator_error(str(exc))
|
||||
|
||||
|
|
@ -456,7 +459,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
|
||||
yield get_sse_packet(finish_response.model_dump_json())
|
||||
except CancelledError:
|
||||
print("Error: Chat completion cancelled by user.")
|
||||
logger.error("Chat completion cancelled by user.")
|
||||
except Exception as exc:
|
||||
yield get_generator_error(str(exc))
|
||||
|
||||
|
|
@ -481,10 +484,10 @@ if __name__ == "__main__":
|
|||
with open("config.yml", "r", encoding="utf8") as config_file:
|
||||
config = unwrap(yaml.safe_load(config_file), {})
|
||||
except Exception as exc:
|
||||
print(
|
||||
"The YAML config couldn't load because of the following error:",
|
||||
f"\n\n{exc}",
|
||||
"\n\nTabbyAPI will start anyway and not parse this config file.",
|
||||
logger.error(
|
||||
"The YAML config couldn't load because of the following error: "
|
||||
f"\n\n{exc}"
|
||||
"\n\nTabbyAPI will start anyway and not parse this config file."
|
||||
)
|
||||
config = {}
|
||||
|
||||
|
|
|
|||
76
model.py
76
model.py
|
|
@ -23,6 +23,9 @@ from templating import (
|
|||
get_template_from_file,
|
||||
)
|
||||
from utils import coalesce, unwrap
|
||||
from logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
AUTO_SPLIT_RESERVE_BYTES = 96 * 1024**2
|
||||
|
|
@ -143,10 +146,7 @@ class ModelContainer:
|
|||
# Set prompt template override if provided
|
||||
prompt_template_name = kwargs.get("prompt_template")
|
||||
if prompt_template_name:
|
||||
print(
|
||||
"Attempting to load prompt template with name",
|
||||
{prompt_template_name},
|
||||
)
|
||||
logger.info("Loading prompt template with name " f"{prompt_template_name}")
|
||||
# Read the template
|
||||
self.prompt_template = get_template_from_file(prompt_template_name)
|
||||
else:
|
||||
|
|
@ -175,13 +175,13 @@ class ModelContainer:
|
|||
|
||||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
print(
|
||||
f"Using template {self.prompt_template.name} for chat " "completions."
|
||||
logger.info(
|
||||
f"Using template {self.prompt_template.name} " "for chat completions."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Chat completions are disabled because a prompt template",
|
||||
"wasn't provided or auto-detected.",
|
||||
logger.warning(
|
||||
"Chat completions are disabled because a prompt "
|
||||
"template wasn't provided or auto-detected."
|
||||
)
|
||||
|
||||
# Set num of experts per token if provided
|
||||
|
|
@ -190,9 +190,9 @@ class ModelContainer:
|
|||
if hasattr(self.config, "num_experts_per_token"):
|
||||
self.config.num_experts_per_token = num_experts_override
|
||||
else:
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not "
|
||||
"support overriding MoE experts"
|
||||
logger.warning(
|
||||
"MoE experts per token override is not "
|
||||
"supported by the current ExLlamaV2 version."
|
||||
)
|
||||
|
||||
chunk_size = min(
|
||||
|
|
@ -207,9 +207,9 @@ class ModelContainer:
|
|||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
print(
|
||||
"A draft config was found but a model name was not given. "
|
||||
"Please check your config.yml! Skipping draft load."
|
||||
logger.warning(
|
||||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
enable_draft = False
|
||||
|
||||
|
|
@ -283,20 +283,20 @@ class ModelContainer:
|
|||
lora_scaling = unwrap(lora.get("scaling"), 1.0)
|
||||
|
||||
if lora_name is None:
|
||||
print(
|
||||
logger.warning(
|
||||
"One of your loras does not have a name. Please check your "
|
||||
"config.yml! Skipping lora load."
|
||||
)
|
||||
failure.append(lora_name)
|
||||
continue
|
||||
|
||||
print(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
||||
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
||||
lora_path = lora_directory / lora_name
|
||||
# FIXME(alpin): Does self.model need to be passed here?
|
||||
self.active_loras.append(
|
||||
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
||||
)
|
||||
print("Lora successfully loaded.")
|
||||
logger.info(f"Lora successfully loaded: {lora_name}")
|
||||
success.append(lora_name)
|
||||
|
||||
# Return success and failure names
|
||||
|
|
@ -319,7 +319,7 @@ class ModelContainer:
|
|||
if self.draft_config:
|
||||
self.draft_model = ExLlamaV2(self.draft_config)
|
||||
if not self.quiet:
|
||||
print("Loading draft model: " + self.draft_config.model_dir)
|
||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||
|
||||
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
|
||||
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
|
||||
|
|
@ -337,7 +337,7 @@ class ModelContainer:
|
|||
# Load model
|
||||
self.model = ExLlamaV2(self.config)
|
||||
if not self.quiet:
|
||||
print("Loading model: " + self.config.model_dir)
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
|
||||
if not self.gpu_split_auto:
|
||||
for value in self.model.load_gen(
|
||||
|
|
@ -373,7 +373,7 @@ class ModelContainer:
|
|||
self.draft_cache,
|
||||
)
|
||||
|
||||
print("Model successfully loaded.")
|
||||
logger.info("Model successfully loaded.")
|
||||
|
||||
def unload(self, loras_only: bool = False):
|
||||
"""
|
||||
|
|
@ -494,33 +494,33 @@ class ModelContainer:
|
|||
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
|
||||
gen_settings, "mirostat"
|
||||
):
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not support "
|
||||
"Mirostat sampling"
|
||||
logger.warning(
|
||||
"Mirostat sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
gen_settings, "min_p"
|
||||
):
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not "
|
||||
"support min-P sampling"
|
||||
logger.warning(
|
||||
"Min-P sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
gen_settings, "tfs"
|
||||
):
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not support "
|
||||
"tail-free sampling (TFS)"
|
||||
logger.warning(
|
||||
"Tail-free sampling (TFS) is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
|
||||
gen_settings, "temperature_last"
|
||||
):
|
||||
print(
|
||||
" !! Warning: Currently installed ExLlamaV2 does not support "
|
||||
"temperature_last"
|
||||
logger.warning(
|
||||
"Temperature last is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
# Apply settings
|
||||
|
|
@ -614,10 +614,10 @@ class ModelContainer:
|
|||
context_len = len(ids[0])
|
||||
|
||||
if context_len > self.config.max_seq_len:
|
||||
print(
|
||||
f"WARNING: The context length {context_len} is greater than "
|
||||
f"the max_seq_len {self.config.max_seq_len}.",
|
||||
"Generation is truncated and metrics may not be accurate.",
|
||||
logger.warning(
|
||||
f"Context length {context_len} is greater than max_seq_len "
|
||||
f"{self.config.max_seq_len}. Generation is truncated and "
|
||||
"metrics may not be accurate."
|
||||
)
|
||||
|
||||
prompt_tokens = ids.shape[-1]
|
||||
|
|
@ -705,7 +705,7 @@ class ModelContainer:
|
|||
extra_parts.append("<-- Not accurate (truncated)")
|
||||
|
||||
# Print output
|
||||
print(
|
||||
logger.info(
|
||||
initial_response
|
||||
+ " ("
|
||||
+ ", ".join(itemization)
|
||||
|
|
|
|||
|
|
@ -13,3 +13,4 @@ PyYAML
|
|||
progress
|
||||
uvicorn
|
||||
jinja2 >= 3.0.0
|
||||
colorlog
|
||||
|
|
@ -19,6 +19,7 @@ PyYAML
|
|||
progress
|
||||
uvicorn
|
||||
jinja2 >= 3.0.0
|
||||
colorlog
|
||||
|
||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
||||
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
|
||||
|
|
|
|||
|
|
@ -5,3 +5,4 @@ PyYAML
|
|||
progress
|
||||
uvicorn
|
||||
jinja2 >= 3.0.0
|
||||
colorlog
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ PyYAML
|
|||
progress
|
||||
uvicorn
|
||||
jinja2 >= 3.0.0
|
||||
colorlog
|
||||
|
||||
# Flash attention v2
|
||||
|
||||
|
|
|
|||
|
|
@ -2,45 +2,52 @@
|
|||
from importlib.metadata import version
|
||||
from importlib.util import find_spec
|
||||
|
||||
from logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
successful_packages = []
|
||||
errored_packages = []
|
||||
|
||||
if find_spec("flash_attn") is not None:
|
||||
print(f"Flash attention on version {version('flash_attn')} successfully imported")
|
||||
logger.info(
|
||||
f"Flash attention on version {version('flash_attn')} " "successfully imported"
|
||||
)
|
||||
successful_packages.append("flash_attn")
|
||||
else:
|
||||
print("Flash attention 2 is not found in your environment.")
|
||||
logger.error("Flash attention 2 is not found in your environment.")
|
||||
errored_packages.append("flash_attn")
|
||||
|
||||
if find_spec("exllamav2") is not None:
|
||||
print(f"Exllamav2 on version {version('exllamav2')} successfully imported")
|
||||
logger.info(f"Exllamav2 on version {version('exllamav2')} " "successfully imported")
|
||||
successful_packages.append("exllamav2")
|
||||
else:
|
||||
print("Exllamav2 is not found in your environment.")
|
||||
logger.error("Exllamav2 is not found in your environment.")
|
||||
errored_packages.append("exllamav2")
|
||||
|
||||
if find_spec("torch") is not None:
|
||||
print(f"Torch on version {version('torch')} successfully imported")
|
||||
logger.info(f"Torch on version {version('torch')} successfully imported")
|
||||
successful_packages.append("torch")
|
||||
else:
|
||||
print("Torch is not found in your environment.")
|
||||
logger.error("Torch is not found in your environment.")
|
||||
errored_packages.append("torch")
|
||||
|
||||
if find_spec("jinja2") is not None:
|
||||
print(f"Jinja2 on version {version('jinja2')} successfully imported")
|
||||
logger.info(f"Jinja2 on version {version('jinja2')} successfully imported")
|
||||
successful_packages.append("jinja2")
|
||||
else:
|
||||
print("Jinja2 is not found in your environment.")
|
||||
logger.error("Jinja2 is not found in your environment.")
|
||||
errored_packages.append("jinja2")
|
||||
|
||||
print(
|
||||
f"\nSuccessful imports: {', '.join(successful_packages)}",
|
||||
f"\nErrored imports: {''.join(errored_packages)}",
|
||||
)
|
||||
logger.info(f"\nSuccessful imports: {', '.join(successful_packages)}")
|
||||
logger.error(f"Errored imports: {''.join(errored_packages)}")
|
||||
|
||||
if len(errored_packages) > 0:
|
||||
print(
|
||||
"\nIf packages are installed, but not found on this test, please "
|
||||
"check the wheel versions for the correct python version and CUDA "
|
||||
"version (if applicable)."
|
||||
logger.warning(
|
||||
"If all packages are installed, but not found "
|
||||
"on this test, please check the wheel versions for the "
|
||||
"correct python version and CUDA version (if "
|
||||
"applicable)."
|
||||
)
|
||||
else:
|
||||
logger.info("All wheels are installed correctly.")
|
||||
|
|
|
|||
6
utils.py
6
utils.py
|
|
@ -4,6 +4,10 @@ from typing import Optional
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
"""Wrapper callback for load progress."""
|
||||
|
|
@ -32,7 +36,7 @@ def get_generator_error(message: str):
|
|||
generator_error = TabbyGeneratorError(error=error_message)
|
||||
|
||||
# Log and send the exception
|
||||
print(f"\n{generator_error.error.trace}")
|
||||
logger.error(generator_error.error.message)
|
||||
return get_sse_packet(generator_error.model_dump_json())
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue