API: Add timings to usage stats

It's useful for the client to know what the T/s and total time for
generation are per-request.

Works with both completions and chat completions.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-06-17 22:54:51 -04:00
parent 5d94d4d022
commit 2913ce29fc
6 changed files with 110 additions and 63 deletions

View file

@ -913,7 +913,7 @@ class ExllamaV2Container(BaseModelContainer):
joined_generation = {
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"gen_tokens": 0,
"tool_calls": None,
"offset": [],
"token_probs": {},
@ -923,11 +923,8 @@ class ExllamaV2Container(BaseModelContainer):
if generations:
# Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
finish_chunk = generations.pop()
joined_generation = {**joined_generation, **finish_chunk}
else:
joined_generation["finish_reason"] = "stop"
@ -1189,9 +1186,35 @@ class ExllamaV2Container(BaseModelContainer):
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
# Prompt
prompt_tokens = result.get("prompt_tokens")
cached_tokens = round(result.get("cached_tokens"), 2)
prompt_time = round(result.get("time_prefill"), 2)
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
# Generated
gen_tokens = result.get("new_tokens")
gen_time = result.get("time_generate")
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
# Queue + Total
queue_time = result.get("time_enqueued")
total_time = round(queue_time + prompt_time + gen_time, 2)
finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"prompt_tokens": prompt_tokens,
"prompt_time": round(prompt_time, 2),
"prompt_tokens_per_sec": prompt_ts,
"gen_tokens": gen_tokens,
"gen_time": round(gen_time, 2),
"gen_tokens_per_sec": gen_ts,
"total_time": total_time,
"queue_time": round(queue_time, 2),
"cached_tokens": cached_tokens,
"finish_reason": finish_reason,
"stop_str": stop_str,
}
@ -1413,12 +1436,12 @@ class ExllamaV2Container(BaseModelContainer):
if result.get("eos"):
log_response(request_id, full_response)
generation = self.handle_finish_chunk(result, generation)
finish_chunk = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = result
metrics_result = finish_chunk
yield generation
yield finish_chunk
break
except asyncio.CancelledError:
await job.cancel()
@ -1451,12 +1474,7 @@ class ExllamaV2Container(BaseModelContainer):
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
metrics_result,
context_len,
max_seq_len,
)

View file

@ -649,11 +649,8 @@ class ExllamaV3Container(BaseModelContainer):
if generations:
# Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
finish_chunk = generations.pop()
joined_generation = {**joined_generation, **finish_chunk}
else:
joined_generation["finish_reason"] = "stop"
@ -743,9 +740,35 @@ class ExllamaV3Container(BaseModelContainer):
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
# Prompt
prompt_tokens = result.get("prompt_tokens")
cached_tokens = round(result.get("cached_tokens"), 2)
prompt_time = round(result.get("time_prefill"), 2)
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
# Generated
gen_tokens = result.get("new_tokens")
gen_time = result.get("time_generate")
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
# Queue + Total
queue_time = result.get("time_enqueued")
total_time = round(queue_time + prompt_time + gen_time, 2)
finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"prompt_tokens": prompt_tokens,
"prompt_time": round(prompt_time, 2),
"prompt_tokens_per_sec": prompt_ts,
"gen_tokens": gen_tokens,
"gen_time": round(gen_time, 2),
"gen_tokens_per_sec": gen_ts,
"total_time": total_time,
"queue_time": round(queue_time, 2),
"cached_tokens": cached_tokens,
"finish_reason": finish_reason,
"stop_str": stop_str,
}
@ -921,12 +944,12 @@ class ExllamaV3Container(BaseModelContainer):
yield generation
if result.get("eos"):
generation = self.handle_finish_chunk(result, generation)
finish_chunk = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = result
metrics_result = finish_chunk
yield generation
yield finish_chunk
break
# Assign the active job to the request ID
self.active_job_ids[request_id] = job
@ -962,12 +985,7 @@ class ExllamaV3Container(BaseModelContainer):
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
metrics_result,
context_len,
self.max_seq_len,
)

View file

@ -54,40 +54,29 @@ def log_response(request_id: str, response: str):
def log_metrics(
request_id: str,
queue_time: float,
prompt_tokens: int,
cached_tokens: int,
prompt_time: float,
generated_tokens: int,
generate_time: float,
metrics: dict,
context_len: Optional[int],
max_seq_len: int,
):
initial_response = (
f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in "
f"{round(queue_time + prompt_time + generate_time, 2)} seconds"
f"Metrics (ID: {request_id}): {metrics.get('gen_tokens')} "
f"tokens generated in {metrics.get('total_time')} seconds"
)
itemization = []
extra_parts = []
itemization.append(f"Queue: {round(queue_time, 2)} s")
itemization.append(f"Queue: {metrics.get('queue_time')} s")
cached_tokens = metrics.get("cached_tokens")
prompt_tokens = metrics.get("prompt_tokens")
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
itemization.append(
f"Process: {cached_tokens} cached tokens and "
f"{prompt_tokens - cached_tokens} new tokens at {prompt_ts} T/s"
f"{prompt_tokens - cached_tokens} new tokens at "
f"{metrics.get('prompt_tokens_per_sec')} T/s"
)
generate_ts = (
"Indeterminate"
if generate_time == 0
else round(generated_tokens / generate_time, 2)
)
itemization.append(f"Generate: {generate_ts} T/s")
itemization.append(f"Generate: {metrics.get('gen_tokens_per_sec')} T/s")
# Add context (original token count)
if context_len:

View file

@ -1,7 +1,7 @@
"""Common types for OAI."""
from pydantic import BaseModel, Field
from typing import Optional
from typing import Optional, Union
from common.sampling import BaseSamplerRequest, get_default_sampler_value
@ -10,8 +10,13 @@ class UsageStats(BaseModel):
"""Represents usage stats."""
prompt_tokens: int
prompt_time: Optional[float] = None
prompt_tokens_per_sec: Optional[Union[float, str]] = None
completion_tokens: int
completion_time: Optional[float] = None
completion_tokens_per_sec: Optional[Union[float, str]] = None
total_tokens: int
total_time: Optional[float] = None
class CompletionResponseFormat(BaseModel):

View file

@ -38,9 +38,6 @@ def _create_response(
):
"""Create a chat completion response from the provided text."""
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
choices = []
for index, generation in enumerate(generations):
message = ChatCompletionMessage(
@ -91,14 +88,23 @@ def _create_response(
choices.append(choice)
final_generation = generations[-1]
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
response = ChatCompletionResponse(
id=f"chatcmpl-{request_id}",
id=f"cmpl-{request_id}",
choices=choices,
model=unwrap(model_name, ""),
model=model_name,
usage=UsageStats(
prompt_tokens=prompt_tokens,
prompt_time=final_generation.get("prompt_time"),
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens,
completion_time=final_generation.get("gen_time"),
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens,
total_time=final_generation.get("total_time"),
),
)
@ -119,12 +125,17 @@ def _create_stream_chunk(
if is_usage_chunk:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
completion_tokens = unwrap(generation.get("gen_tokens"), 0)
usage_stats = UsageStats(
prompt_tokens=prompt_tokens,
prompt_time=generation.get("prompt_time"),
prompt_tokens_per_sec=generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens,
completion_time=generation.get("gen_time"),
completion_tokens_per_sec=generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens,
total_time=generation.get("total_time"),
)
elif "finish_reason" in generation:
# Get the finish reason from the generation

View file

@ -73,8 +73,9 @@ def _create_response(
choices.append(choice)
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
final_generation = generations[-1]
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
response = CompletionResponse(
id=f"cmpl-{request_id}",
@ -82,8 +83,13 @@ def _create_response(
model=model_name,
usage=UsageStats(
prompt_tokens=prompt_tokens,
prompt_time=final_generation.get("prompt_time"),
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens,
completion_time=final_generation.get("gen_time"),
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens,
total_time=final_generation.get("total_time"),
),
)