fixup: some metrics
This commit is contained in:
parent
c0f268f33e
commit
b35c48da37
1 changed files with 84 additions and 34 deletions
|
|
@ -18,6 +18,7 @@ from common.concurrency import iterate_in_threadpool
|
|||
from common.gen_logging import (
|
||||
log_metrics,
|
||||
)
|
||||
from common.health import HealthManager
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
|
|
@ -436,6 +437,37 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
return finish_chunk
|
||||
|
||||
async def create_generator(self):
|
||||
"""Create and save a Exllama generator class."""
|
||||
|
||||
try:
|
||||
# Don't acquire locks unless a model is loaded
|
||||
if self.loaded:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Immediately cancel all jobs
|
||||
await self.wait_for_jobs(skip_wait=True)
|
||||
|
||||
# Create new generator
|
||||
self.generator = AsyncGenerator(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
)
|
||||
|
||||
# Update the state of the container var
|
||||
if self.max_batch_size is None:
|
||||
self.max_batch_size = self.generator.generator.max_batch_size
|
||||
finally:
|
||||
# This means the generator is being recreated
|
||||
# The load lock is already released in the load function
|
||||
if self.loaded:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
async def generate_gen(
|
||||
self,
|
||||
request_id: str,
|
||||
|
|
@ -516,42 +548,60 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
full_response = ""
|
||||
metrics_result = {}
|
||||
|
||||
async for result in job:
|
||||
chunk = unwrap(result.get("text"), "")
|
||||
if chunk:
|
||||
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
|
||||
full_response += chunk
|
||||
if isinstance(chunk_tokens, torch.Tensor):
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
generation = {
|
||||
"text": chunk,
|
||||
"prompt_tokens": context_len,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
}
|
||||
yield generation
|
||||
# Get the generation status once it's ready
|
||||
try:
|
||||
async for result in job:
|
||||
chunk = unwrap(result.get("text"), "")
|
||||
if chunk:
|
||||
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
|
||||
full_response += chunk
|
||||
if isinstance(chunk_tokens, torch.Tensor):
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
generation = {
|
||||
"text": chunk,
|
||||
"prompt_tokens": context_len,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
}
|
||||
yield generation
|
||||
|
||||
if result.get("eos"):
|
||||
generation = self.handle_finish_chunk(result, generation)
|
||||
if result.get("eos"):
|
||||
generation = self.handle_finish_chunk(result, generation)
|
||||
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = result
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = result
|
||||
|
||||
yield generation
|
||||
break
|
||||
# Assign the active job to the request ID
|
||||
self.active_job_ids[request_id] = job
|
||||
yield generation
|
||||
break
|
||||
# Assign the active job to the request ID
|
||||
self.active_job_ids[request_id] = job
|
||||
|
||||
# Log the metrics if present
|
||||
if metrics_result:
|
||||
log_metrics(
|
||||
request_id,
|
||||
metrics_result.get("time_enqueued"),
|
||||
metrics_result.get("prompt_tokens"),
|
||||
metrics_result.get("cached_tokens"),
|
||||
metrics_result.get("time_prefill"),
|
||||
metrics_result.get("new_tokens"),
|
||||
metrics_result.get("time_generate"),
|
||||
context_len,
|
||||
self.max_seq_len,
|
||||
except asyncio.CancelledError:
|
||||
await job.cancel()
|
||||
except Exception as ex:
|
||||
# Create a new generator since the current state is broken
|
||||
# No need to wait for this to finish
|
||||
logger.error(
|
||||
"FATAL ERROR with generation. "
|
||||
"Attempting to recreate the generator. "
|
||||
"If this fails, please restart the server.\n"
|
||||
)
|
||||
asyncio.ensure_future(self.create_generator())
|
||||
|
||||
await HealthManager.add_unhealthy_event(ex)
|
||||
|
||||
raise ex
|
||||
finally:
|
||||
# Log the metrics if present
|
||||
if metrics_result:
|
||||
log_metrics(
|
||||
request_id,
|
||||
metrics_result.get("time_enqueued"),
|
||||
metrics_result.get("prompt_tokens"),
|
||||
metrics_result.get("cached_tokens"),
|
||||
metrics_result.get("time_prefill"),
|
||||
metrics_result.get("new_tokens"),
|
||||
metrics_result.get("time_generate"),
|
||||
context_len,
|
||||
self.max_seq_len,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue