API: Add timings to usage stats

It's useful for the client to know what the T/s and total time for
generation are per-request.

Works with both completions and chat completions.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-06-17 22:54:51 -04:00
parent 5d94d4d022
commit 2913ce29fc
6 changed files with 110 additions and 63 deletions

View file

@ -913,7 +913,7 @@ class ExllamaV2Container(BaseModelContainer):
joined_generation = { joined_generation = {
"text": "", "text": "",
"prompt_tokens": 0, "prompt_tokens": 0,
"generation_tokens": 0, "gen_tokens": 0,
"tool_calls": None, "tool_calls": None,
"offset": [], "offset": [],
"token_probs": {}, "token_probs": {},
@ -923,11 +923,8 @@ class ExllamaV2Container(BaseModelContainer):
if generations: if generations:
# Get finish_reason first and then shift where -1 points to # Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]: if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop() finish_chunk = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get( joined_generation = {**joined_generation, **finish_chunk}
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else: else:
joined_generation["finish_reason"] = "stop" joined_generation["finish_reason"] = "stop"
@ -1189,9 +1186,35 @@ class ExllamaV2Container(BaseModelContainer):
elif eos_reason == "stop_string": elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string") stop_str = result.get("eos_triggering_string")
# Prompt
prompt_tokens = result.get("prompt_tokens")
cached_tokens = round(result.get("cached_tokens"), 2)
prompt_time = round(result.get("time_prefill"), 2)
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
# Generated
gen_tokens = result.get("new_tokens")
gen_time = result.get("time_generate")
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
# Queue + Total
queue_time = result.get("time_enqueued")
total_time = round(queue_time + prompt_time + gen_time, 2)
finish_chunk = { finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"), "prompt_tokens": prompt_tokens,
"generated_tokens": generation.get("generated_tokens"), "prompt_time": round(prompt_time, 2),
"prompt_tokens_per_sec": prompt_ts,
"gen_tokens": gen_tokens,
"gen_time": round(gen_time, 2),
"gen_tokens_per_sec": gen_ts,
"total_time": total_time,
"queue_time": round(queue_time, 2),
"cached_tokens": cached_tokens,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"stop_str": stop_str, "stop_str": stop_str,
} }
@ -1413,12 +1436,12 @@ class ExllamaV2Container(BaseModelContainer):
if result.get("eos"): if result.get("eos"):
log_response(request_id, full_response) log_response(request_id, full_response)
generation = self.handle_finish_chunk(result, generation) finish_chunk = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging # Save the final result for metrics logging
metrics_result = result metrics_result = finish_chunk
yield generation yield finish_chunk
break break
except asyncio.CancelledError: except asyncio.CancelledError:
await job.cancel() await job.cancel()
@ -1451,12 +1474,7 @@ class ExllamaV2Container(BaseModelContainer):
if metrics_result: if metrics_result:
log_metrics( log_metrics(
request_id, request_id,
metrics_result.get("time_enqueued"), metrics_result,
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len, context_len,
max_seq_len, max_seq_len,
) )

View file

@ -649,11 +649,8 @@ class ExllamaV3Container(BaseModelContainer):
if generations: if generations:
# Get finish_reason first and then shift where -1 points to # Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]: if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop() finish_chunk = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get( joined_generation = {**joined_generation, **finish_chunk}
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else: else:
joined_generation["finish_reason"] = "stop" joined_generation["finish_reason"] = "stop"
@ -743,9 +740,35 @@ class ExllamaV3Container(BaseModelContainer):
elif eos_reason == "stop_string": elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string") stop_str = result.get("eos_triggering_string")
# Prompt
prompt_tokens = result.get("prompt_tokens")
cached_tokens = round(result.get("cached_tokens"), 2)
prompt_time = round(result.get("time_prefill"), 2)
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
# Generated
gen_tokens = result.get("new_tokens")
gen_time = result.get("time_generate")
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
# Queue + Total
queue_time = result.get("time_enqueued")
total_time = round(queue_time + prompt_time + gen_time, 2)
finish_chunk = { finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"), "prompt_tokens": prompt_tokens,
"generated_tokens": generation.get("generated_tokens"), "prompt_time": round(prompt_time, 2),
"prompt_tokens_per_sec": prompt_ts,
"gen_tokens": gen_tokens,
"gen_time": round(gen_time, 2),
"gen_tokens_per_sec": gen_ts,
"total_time": total_time,
"queue_time": round(queue_time, 2),
"cached_tokens": cached_tokens,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"stop_str": stop_str, "stop_str": stop_str,
} }
@ -921,12 +944,12 @@ class ExllamaV3Container(BaseModelContainer):
yield generation yield generation
if result.get("eos"): if result.get("eos"):
generation = self.handle_finish_chunk(result, generation) finish_chunk = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging # Save the final result for metrics logging
metrics_result = result metrics_result = finish_chunk
yield generation yield finish_chunk
break break
# Assign the active job to the request ID # Assign the active job to the request ID
self.active_job_ids[request_id] = job self.active_job_ids[request_id] = job
@ -962,12 +985,7 @@ class ExllamaV3Container(BaseModelContainer):
if metrics_result: if metrics_result:
log_metrics( log_metrics(
request_id, request_id,
metrics_result.get("time_enqueued"), metrics_result,
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len, context_len,
self.max_seq_len, self.max_seq_len,
) )

View file

@ -54,40 +54,29 @@ def log_response(request_id: str, response: str):
def log_metrics( def log_metrics(
request_id: str, request_id: str,
queue_time: float, metrics: dict,
prompt_tokens: int,
cached_tokens: int,
prompt_time: float,
generated_tokens: int,
generate_time: float,
context_len: Optional[int], context_len: Optional[int],
max_seq_len: int, max_seq_len: int,
): ):
initial_response = ( initial_response = (
f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in " f"Metrics (ID: {request_id}): {metrics.get('gen_tokens')} "
f"{round(queue_time + prompt_time + generate_time, 2)} seconds" f"tokens generated in {metrics.get('total_time')} seconds"
) )
itemization = [] itemization = []
extra_parts = [] extra_parts = []
itemization.append(f"Queue: {round(queue_time, 2)} s") itemization.append(f"Queue: {metrics.get('queue_time')} s")
cached_tokens = metrics.get("cached_tokens")
prompt_tokens = metrics.get("prompt_tokens")
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
itemization.append( itemization.append(
f"Process: {cached_tokens} cached tokens and " f"Process: {cached_tokens} cached tokens and "
f"{prompt_tokens - cached_tokens} new tokens at {prompt_ts} T/s" f"{prompt_tokens - cached_tokens} new tokens at "
f"{metrics.get('prompt_tokens_per_sec')} T/s"
) )
generate_ts = ( itemization.append(f"Generate: {metrics.get('gen_tokens_per_sec')} T/s")
"Indeterminate"
if generate_time == 0
else round(generated_tokens / generate_time, 2)
)
itemization.append(f"Generate: {generate_ts} T/s")
# Add context (original token count) # Add context (original token count)
if context_len: if context_len:

View file

@ -1,7 +1,7 @@
"""Common types for OAI.""" """Common types for OAI."""
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional, Union
from common.sampling import BaseSamplerRequest, get_default_sampler_value from common.sampling import BaseSamplerRequest, get_default_sampler_value
@ -10,8 +10,13 @@ class UsageStats(BaseModel):
"""Represents usage stats.""" """Represents usage stats."""
prompt_tokens: int prompt_tokens: int
prompt_time: Optional[float] = None
prompt_tokens_per_sec: Optional[Union[float, str]] = None
completion_tokens: int completion_tokens: int
completion_time: Optional[float] = None
completion_tokens_per_sec: Optional[Union[float, str]] = None
total_tokens: int total_tokens: int
total_time: Optional[float] = None
class CompletionResponseFormat(BaseModel): class CompletionResponseFormat(BaseModel):

View file

@ -38,9 +38,6 @@ def _create_response(
): ):
"""Create a chat completion response from the provided text.""" """Create a chat completion response from the provided text."""
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
choices = [] choices = []
for index, generation in enumerate(generations): for index, generation in enumerate(generations):
message = ChatCompletionMessage( message = ChatCompletionMessage(
@ -91,14 +88,23 @@ def _create_response(
choices.append(choice) choices.append(choice)
final_generation = generations[-1]
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
response = ChatCompletionResponse( response = ChatCompletionResponse(
id=f"chatcmpl-{request_id}", id=f"cmpl-{request_id}",
choices=choices, choices=choices,
model=unwrap(model_name, ""), model=model_name,
usage=UsageStats( usage=UsageStats(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
prompt_time=final_generation.get("prompt_time"),
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
completion_time=final_generation.get("gen_time"),
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
total_time=final_generation.get("total_time"),
), ),
) )
@ -119,12 +125,17 @@ def _create_stream_chunk(
if is_usage_chunk: if is_usage_chunk:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0) completion_tokens = unwrap(generation.get("gen_tokens"), 0)
usage_stats = UsageStats( usage_stats = UsageStats(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
prompt_time=generation.get("prompt_time"),
prompt_tokens_per_sec=generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
completion_time=generation.get("gen_time"),
completion_tokens_per_sec=generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
total_time=generation.get("total_time"),
) )
elif "finish_reason" in generation: elif "finish_reason" in generation:
# Get the finish reason from the generation # Get the finish reason from the generation

View file

@ -73,8 +73,9 @@ def _create_response(
choices.append(choice) choices.append(choice)
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) final_generation = generations[-1]
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
response = CompletionResponse( response = CompletionResponse(
id=f"cmpl-{request_id}", id=f"cmpl-{request_id}",
@ -82,8 +83,13 @@ def _create_response(
model=model_name, model=model_name,
usage=UsageStats( usage=UsageStats(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
prompt_time=final_generation.get("prompt_time"),
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
completion_time=final_generation.get("gen_time"),
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
total_time=final_generation.get("total_time"),
), ),
) )