fixup: some metrics

This commit is contained in:
randoentity 2025-04-30 11:56:24 +02:00 committed by kingbri
parent c0f268f33e
commit b35c48da37

View file

@ -18,6 +18,7 @@ from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_metrics,
)
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
@ -436,6 +437,37 @@ class ExllamaV3Container(BaseModelContainer):
return finish_chunk
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.loaded:
await self.load_lock.acquire()
# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)
# Create new generator
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.loaded:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async def generate_gen(
self,
request_id: str,
@ -516,42 +548,60 @@ class ExllamaV3Container(BaseModelContainer):
full_response = ""
metrics_result = {}
async for result in job:
chunk = unwrap(result.get("text"), "")
if chunk:
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
full_response += chunk
if isinstance(chunk_tokens, torch.Tensor):
generated_tokens += chunk_tokens.size(dim=0)
generation = {
"text": chunk,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
yield generation
# Get the generation status once it's ready
try:
async for result in job:
chunk = unwrap(result.get("text"), "")
if chunk:
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
full_response += chunk
if isinstance(chunk_tokens, torch.Tensor):
generated_tokens += chunk_tokens.size(dim=0)
generation = {
"text": chunk,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
yield generation
if result.get("eos"):
generation = self.handle_finish_chunk(result, generation)
if result.get("eos"):
generation = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = result
# Save the final result for metrics logging
metrics_result = result
yield generation
break
# Assign the active job to the request ID
self.active_job_ids[request_id] = job
yield generation
break
# Assign the active job to the request ID
self.active_job_ids[request_id] = job
# Log the metrics if present
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len,
self.max_seq_len,
except asyncio.CancelledError:
await job.cancel()
except Exception as ex:
# Create a new generator since the current state is broken
# No need to wait for this to finish
logger.error(
"FATAL ERROR with generation. "
"Attempting to recreate the generator. "
"If this fails, please restart the server.\n"
)
asyncio.ensure_future(self.create_generator())
await HealthManager.add_unhealthy_event(ex)
raise ex
finally:
# Log the metrics if present
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len,
self.max_seq_len,
)