fixup: add sampler logs

Also passing sampler to job with this, no idea if this is correct
This commit is contained in:
randoentity 2025-04-30 13:14:34 +02:00 committed by kingbri
parent b35c48da37
commit c744790f14

View file

@ -10,12 +10,21 @@ from typing import (
)
import torch
from exllamav3 import AsyncGenerator, AsyncJob, Cache, Config, Model, Tokenizer
from exllamav3 import (
AsyncGenerator,
AsyncJob,
Cache,
ComboSampler,
Config,
Model,
Tokenizer,
)
from loguru import logger
from backends.base_model_container import BaseModelContainer
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
log_metrics,
)
from common.health import HealthManager
@ -483,6 +492,30 @@ class ExllamaV3Container(BaseModelContainer):
"""
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
# FIXME: this is probably not right
base_sampler = BaseSamplerRequest()
sampler = ComboSampler(
rep_p=base_sampler.repetition_penalty,
pres_p=base_sampler.presence_penalty,
freq_p=base_sampler.frequency_penalty,
rep_sustain_range=base_sampler.penalty_range,
rep_decay_range=base_sampler.penalty_range,
temperature=base_sampler.temperature,
min_p=base_sampler.min_p,
top_k=base_sampler.top_k,
top_p=base_sampler.top_p,
temp_last=base_sampler.temperature_last,
)
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
# TODO:
# auto_scale_penalty_range = (
# gen_settings.token_frequency_penalty != 0
# or gen_settings.token_presence_penalty != 0
# ) and gen_settings.token_repetition_range == -1
prompts = [prompt]
stop_conditions = params.stop
add_bos_token = params.add_bos_token
@ -539,6 +572,7 @@ class ExllamaV3Container(BaseModelContainer):
generation = {}
job = AsyncJob(
self.generator,
sampler=sampler,
input_ids=self.tokenizer.encode(prompt, add_bos=False),
max_new_tokens=max_tokens,
stop_conditions=stop_conditions,
@ -557,6 +591,12 @@ class ExllamaV3Container(BaseModelContainer):
full_response += chunk
if isinstance(chunk_tokens, torch.Tensor):
generated_tokens += chunk_tokens.size(dim=0)
# Increase penalty range to generated token amount
# TODO:
# if auto_scale_penalty_range:
# gen_settings.token_repetition_range = generated_tokens
generation = {
"text": chunk,
"prompt_tokens": context_len,
@ -592,6 +632,17 @@ class ExllamaV3Container(BaseModelContainer):
raise ex
finally:
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
request_id=request_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
prompt=prompt,
**params.model_dump(exclude={"prompt"}),
# auto_scale_penalty_range=auto_scale_penalty_range, # TODO
)
# Log the metrics if present
if metrics_result:
log_metrics(