Model: Cleanup logging and remove extraneous declarations

Log the parameters passed into the generate gen function rather than
the generation settings to reduce complexity.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-15 23:31:12 -04:00
parent 436ce752da
commit 11ed3cf5ee
3 changed files with 63 additions and 45 deletions

View file

@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models."""
from functools import partial
import aiofiles
import asyncio
import gc
@ -31,7 +32,7 @@ from exllamav2.generator import (
)
from itertools import zip_longest
from loguru import logger
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
from ruamel.yaml import YAML
@ -106,6 +107,7 @@ class ExllamaV2Container:
# Load synchronization
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
active_job_ids: Dict[str, ExLlamaV2DynamicJobAsync] = {}
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
@ -887,12 +889,7 @@ class ExllamaV2Container:
self.model = None
if self.vision_model:
# TODO: Remove this with newer exl2 versions
# Required otherwise unload function won't finish
try:
self.vision_model.unload()
except AttributeError:
pass
self.vision_model.unload()
self.vision_model = None
@ -950,7 +947,6 @@ class ExllamaV2Container:
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
# TODO: Maybe support generation_config for eos_token
def get_special_tokens(
self, add_bos_token: bool = True, ban_eos_token: bool = False
):
@ -1042,13 +1038,6 @@ class ExllamaV2Container:
Meant for dev wheels!
"""
if unwrap(kwargs.get("xtc_probability"), 0.0) > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "xtc_probability"
):
logger.warning(
"XTC is not supported by the currently " "installed ExLlamaV2 version."
)
return kwargs
async def generate_gen(
@ -1082,6 +1071,7 @@ class ExllamaV2Container:
kwargs = self.check_unsupported_settings(**kwargs)
# Apply settings
partial(gen_settings.temperature, 1.0)
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0)
@ -1191,7 +1181,6 @@ class ExllamaV2Container:
if dry_multiplier > 0:
gen_settings.dry_multiplier = dry_multiplier
# TODO: Maybe set the "sane" defaults instead?
gen_settings.dry_allowed_length = unwrap(
kwargs.get("dry_allowed_length"), 0
)
@ -1261,18 +1250,10 @@ class ExllamaV2Container:
gen_settings.typical = 0
logger.warning(
"".join(
[
"Temperature is set to 0. Overriding temp, ",
"top_k, top_p, and typical to 1.0, 1, 0, and 0.",
]
)
"Temperature is set to 0. Overriding temp, "
"top_k, top_p, and typical to 1.0, 1, 0, and 0."
)
# Store the gen settings for logging purposes
# Deepcopy to save a snapshot of vars
gen_settings_log_dict = deepcopy(vars(gen_settings))
# Set banned tokens
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])
if banned_tokens:
@ -1522,26 +1503,11 @@ class ExllamaV2Container:
# Some options are too large, so log the args instead
log_generation_params(
request_id=request_id,
max_tokens=max_tokens,
min_tokens=min_tokens,
stream=kwargs.get("stream"),
**gen_settings_log_dict,
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
generate_window=generate_window,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
skip_special_tokens=not decode_special_tokens,
speculative_ngram=self.generator.speculative_ngram,
logprobs=request_logprobs,
stop_conditions=stop_conditions,
banned_tokens=banned_tokens,
allowed_tokens=allowed_tokens,
banned_strings=banned_strings,
logit_bias=logit_bias,
filters=grammar_handler.filters,
**kwargs,
generate_window=generate_window,
auto_scale_penalty_range=auto_scale_penalty_range,
)
# Log the metrics if present

View file

@ -165,7 +165,7 @@ class BaseSamplerRequest(BaseModel):
"rep_pen_range",
),
description=(
"Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range"
"Aliases: repetition_range, repetition_penalty_range, rep_pen_range"
),
)

View file

@ -1,5 +1,6 @@
"""Common utility functions"""
import inspect
from types import NoneType
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
@ -85,3 +86,54 @@ def unwrap_optional_type(type_hint) -> Type:
return arg
return type_hint
def with_defer(func):
"""
Decorator for a go-style defer
"""
def wrapper(*args, **kwargs):
deferred_calls = []
# This 'defer' function is what you'll call inside your decorated function
def defer(fn, *fn_args, **fn_kwargs):
deferred_calls.append((fn, fn_args, fn_kwargs))
try:
# Inject 'defer' into the kwargs of the original function
return func(*args, defer=defer, **kwargs)
finally:
# After the original function finishes (or raises), run deferred calls
for fn, fn_args, fn_kwargs in reversed(deferred_calls):
fn(*fn_args, **fn_kwargs)
return wrapper
def with_defer_async(func):
"""
Decorator for running async functions in go-style defer blocks
"""
async def wrapper(*args, **kwargs):
deferred_calls = []
# This 'defer' function is what you'll call inside your decorated function
def defer(fn, *fn_args, **fn_kwargs):
deferred_calls.append((fn, fn_args, fn_kwargs))
try:
# Inject 'defer' into the kwargs of the original function
return await func(*args, defer=defer, **kwargs)
finally:
# After the original function finishes (or raises), run deferred calls
for fn, fn_args, fn_kwargs in reversed(deferred_calls):
if inspect.iscoroutinefunction(fn):
await fn(*fn_args, **fn_kwargs)
elif inspect.iscoroutine(fn):
await fn
else:
fn(*fn_args, **fn_kwargs)
return wrapper