Model: Cleanup logging and remove extraneous declarations
Log the parameters passed into the generate gen function rather than the generation settings to reduce complexity. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
436ce752da
commit
11ed3cf5ee
3 changed files with 63 additions and 45 deletions
|
|
@ -1,5 +1,6 @@
|
|||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
from functools import partial
|
||||
import aiofiles
|
||||
import asyncio
|
||||
import gc
|
||||
|
|
@ -31,7 +32,7 @@ from exllamav2.generator import (
|
|||
)
|
||||
from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
||||
|
|
@ -106,6 +107,7 @@ class ExllamaV2Container:
|
|||
# Load synchronization
|
||||
# The lock keeps load tasks sequential
|
||||
# The condition notifies any waiting tasks
|
||||
active_job_ids: Dict[str, ExLlamaV2DynamicJobAsync] = {}
|
||||
load_lock: asyncio.Lock = asyncio.Lock()
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
|
|
@ -887,12 +889,7 @@ class ExllamaV2Container:
|
|||
self.model = None
|
||||
|
||||
if self.vision_model:
|
||||
# TODO: Remove this with newer exl2 versions
|
||||
# Required otherwise unload function won't finish
|
||||
try:
|
||||
self.vision_model.unload()
|
||||
except AttributeError:
|
||||
pass
|
||||
self.vision_model.unload()
|
||||
|
||||
self.vision_model = None
|
||||
|
||||
|
|
@ -950,7 +947,6 @@ class ExllamaV2Container:
|
|||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||
)[0]
|
||||
|
||||
# TODO: Maybe support generation_config for eos_token
|
||||
def get_special_tokens(
|
||||
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
||||
):
|
||||
|
|
@ -1042,13 +1038,6 @@ class ExllamaV2Container:
|
|||
Meant for dev wheels!
|
||||
"""
|
||||
|
||||
if unwrap(kwargs.get("xtc_probability"), 0.0) > 0.0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "xtc_probability"
|
||||
):
|
||||
logger.warning(
|
||||
"XTC is not supported by the currently " "installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
async def generate_gen(
|
||||
|
|
@ -1082,6 +1071,7 @@ class ExllamaV2Container:
|
|||
kwargs = self.check_unsupported_settings(**kwargs)
|
||||
|
||||
# Apply settings
|
||||
partial(gen_settings.temperature, 1.0)
|
||||
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
||||
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
|
||||
gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0)
|
||||
|
|
@ -1191,7 +1181,6 @@ class ExllamaV2Container:
|
|||
if dry_multiplier > 0:
|
||||
gen_settings.dry_multiplier = dry_multiplier
|
||||
|
||||
# TODO: Maybe set the "sane" defaults instead?
|
||||
gen_settings.dry_allowed_length = unwrap(
|
||||
kwargs.get("dry_allowed_length"), 0
|
||||
)
|
||||
|
|
@ -1261,18 +1250,10 @@ class ExllamaV2Container:
|
|||
gen_settings.typical = 0
|
||||
|
||||
logger.warning(
|
||||
"".join(
|
||||
[
|
||||
"Temperature is set to 0. Overriding temp, ",
|
||||
"top_k, top_p, and typical to 1.0, 1, 0, and 0.",
|
||||
]
|
||||
)
|
||||
"Temperature is set to 0. Overriding temp, "
|
||||
"top_k, top_p, and typical to 1.0, 1, 0, and 0."
|
||||
)
|
||||
|
||||
# Store the gen settings for logging purposes
|
||||
# Deepcopy to save a snapshot of vars
|
||||
gen_settings_log_dict = deepcopy(vars(gen_settings))
|
||||
|
||||
# Set banned tokens
|
||||
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])
|
||||
if banned_tokens:
|
||||
|
|
@ -1522,26 +1503,11 @@ class ExllamaV2Container:
|
|||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
request_id=request_id,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
stream=kwargs.get("stream"),
|
||||
**gen_settings_log_dict,
|
||||
token_healing=token_healing,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
generate_window=generate_window,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=eos_tokens,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
skip_special_tokens=not decode_special_tokens,
|
||||
speculative_ngram=self.generator.speculative_ngram,
|
||||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
banned_tokens=banned_tokens,
|
||||
allowed_tokens=allowed_tokens,
|
||||
banned_strings=banned_strings,
|
||||
logit_bias=logit_bias,
|
||||
filters=grammar_handler.filters,
|
||||
**kwargs,
|
||||
generate_window=generate_window,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
)
|
||||
|
||||
# Log the metrics if present
|
||||
|
|
|
|||
|
|
@ -165,7 +165,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"rep_pen_range",
|
||||
),
|
||||
description=(
|
||||
"Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range"
|
||||
"Aliases: repetition_range, repetition_penalty_range, rep_pen_range"
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Common utility functions"""
|
||||
|
||||
import inspect
|
||||
from types import NoneType
|
||||
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
|
||||
|
||||
|
|
@ -85,3 +86,54 @@ def unwrap_optional_type(type_hint) -> Type:
|
|||
return arg
|
||||
|
||||
return type_hint
|
||||
|
||||
|
||||
def with_defer(func):
|
||||
"""
|
||||
Decorator for a go-style defer
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
deferred_calls = []
|
||||
|
||||
# This 'defer' function is what you'll call inside your decorated function
|
||||
def defer(fn, *fn_args, **fn_kwargs):
|
||||
deferred_calls.append((fn, fn_args, fn_kwargs))
|
||||
|
||||
try:
|
||||
# Inject 'defer' into the kwargs of the original function
|
||||
return func(*args, defer=defer, **kwargs)
|
||||
finally:
|
||||
# After the original function finishes (or raises), run deferred calls
|
||||
for fn, fn_args, fn_kwargs in reversed(deferred_calls):
|
||||
fn(*fn_args, **fn_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def with_defer_async(func):
|
||||
"""
|
||||
Decorator for running async functions in go-style defer blocks
|
||||
"""
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
deferred_calls = []
|
||||
|
||||
# This 'defer' function is what you'll call inside your decorated function
|
||||
def defer(fn, *fn_args, **fn_kwargs):
|
||||
deferred_calls.append((fn, fn_args, fn_kwargs))
|
||||
|
||||
try:
|
||||
# Inject 'defer' into the kwargs of the original function
|
||||
return await func(*args, defer=defer, **kwargs)
|
||||
finally:
|
||||
# After the original function finishes (or raises), run deferred calls
|
||||
for fn, fn_args, fn_kwargs in reversed(deferred_calls):
|
||||
if inspect.iscoroutinefunction(fn):
|
||||
await fn(*fn_args, **fn_kwargs)
|
||||
elif inspect.iscoroutine(fn):
|
||||
await fn
|
||||
else:
|
||||
fn(*fn_args, **fn_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue