fixup: add sampler logs
Also passing sampler to job with this, no idea if this is correct
This commit is contained in:
parent
b35c48da37
commit
c744790f14
1 changed files with 52 additions and 1 deletions
|
|
@ -10,12 +10,21 @@ from typing import (
|
|||
)
|
||||
|
||||
import torch
|
||||
from exllamav3 import AsyncGenerator, AsyncJob, Cache, Config, Model, Tokenizer
|
||||
from exllamav3 import (
|
||||
AsyncGenerator,
|
||||
AsyncJob,
|
||||
Cache,
|
||||
ComboSampler,
|
||||
Config,
|
||||
Model,
|
||||
Tokenizer,
|
||||
)
|
||||
from loguru import logger
|
||||
|
||||
from backends.base_model_container import BaseModelContainer
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.gen_logging import (
|
||||
log_generation_params,
|
||||
log_metrics,
|
||||
)
|
||||
from common.health import HealthManager
|
||||
|
|
@ -483,6 +492,30 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
"""
|
||||
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
# FIXME: this is probably not right
|
||||
base_sampler = BaseSamplerRequest()
|
||||
sampler = ComboSampler(
|
||||
rep_p=base_sampler.repetition_penalty,
|
||||
pres_p=base_sampler.presence_penalty,
|
||||
freq_p=base_sampler.frequency_penalty,
|
||||
rep_sustain_range=base_sampler.penalty_range,
|
||||
rep_decay_range=base_sampler.penalty_range,
|
||||
temperature=base_sampler.temperature,
|
||||
min_p=base_sampler.min_p,
|
||||
top_k=base_sampler.top_k,
|
||||
top_p=base_sampler.top_p,
|
||||
temp_last=base_sampler.temperature_last,
|
||||
)
|
||||
|
||||
# Dynamically scale penalty range to output tokens
|
||||
# Only do this if freq/pres pen is enabled
|
||||
# and the repetition range is -1
|
||||
# TODO:
|
||||
# auto_scale_penalty_range = (
|
||||
# gen_settings.token_frequency_penalty != 0
|
||||
# or gen_settings.token_presence_penalty != 0
|
||||
# ) and gen_settings.token_repetition_range == -1
|
||||
|
||||
prompts = [prompt]
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = params.add_bos_token
|
||||
|
|
@ -539,6 +572,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
generation = {}
|
||||
job = AsyncJob(
|
||||
self.generator,
|
||||
sampler=sampler,
|
||||
input_ids=self.tokenizer.encode(prompt, add_bos=False),
|
||||
max_new_tokens=max_tokens,
|
||||
stop_conditions=stop_conditions,
|
||||
|
|
@ -557,6 +591,12 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
full_response += chunk
|
||||
if isinstance(chunk_tokens, torch.Tensor):
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
|
||||
# Increase penalty range to generated token amount
|
||||
# TODO:
|
||||
# if auto_scale_penalty_range:
|
||||
# gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
generation = {
|
||||
"text": chunk,
|
||||
"prompt_tokens": context_len,
|
||||
|
|
@ -592,6 +632,17 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
raise ex
|
||||
finally:
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
request_id=request_id,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=eos_tokens,
|
||||
prompt=prompt,
|
||||
**params.model_dump(exclude={"prompt"}),
|
||||
# auto_scale_penalty_range=auto_scale_penalty_range, # TODO
|
||||
)
|
||||
|
||||
# Log the metrics if present
|
||||
if metrics_result:
|
||||
log_metrics(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue