API: Add timings to usage stats
It's useful for the client to know what the T/s and total time for generation are per-request. Works with both completions and chat completions. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
5d94d4d022
commit
2913ce29fc
6 changed files with 110 additions and 63 deletions
|
|
@ -913,7 +913,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
joined_generation = {
|
||||
"text": "",
|
||||
"prompt_tokens": 0,
|
||||
"generation_tokens": 0,
|
||||
"gen_tokens": 0,
|
||||
"tool_calls": None,
|
||||
"offset": [],
|
||||
"token_probs": {},
|
||||
|
|
@ -923,11 +923,8 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
if generations:
|
||||
# Get finish_reason first and then shift where -1 points to
|
||||
if "finish_reason" in generations[-1]:
|
||||
finish_reason_gen = generations.pop()
|
||||
joined_generation["finish_reason"] = finish_reason_gen.get(
|
||||
"finish_reason"
|
||||
)
|
||||
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
|
||||
finish_chunk = generations.pop()
|
||||
joined_generation = {**joined_generation, **finish_chunk}
|
||||
else:
|
||||
joined_generation["finish_reason"] = "stop"
|
||||
|
||||
|
|
@ -1189,9 +1186,35 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
elif eos_reason == "stop_string":
|
||||
stop_str = result.get("eos_triggering_string")
|
||||
|
||||
# Prompt
|
||||
prompt_tokens = result.get("prompt_tokens")
|
||||
cached_tokens = round(result.get("cached_tokens"), 2)
|
||||
prompt_time = round(result.get("time_prefill"), 2)
|
||||
prompt_ts = (
|
||||
"Indeterminate"
|
||||
if prompt_time == 0
|
||||
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
|
||||
)
|
||||
|
||||
# Generated
|
||||
gen_tokens = result.get("new_tokens")
|
||||
gen_time = result.get("time_generate")
|
||||
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
|
||||
|
||||
# Queue + Total
|
||||
queue_time = result.get("time_enqueued")
|
||||
total_time = round(queue_time + prompt_time + gen_time, 2)
|
||||
|
||||
finish_chunk = {
|
||||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_time": round(prompt_time, 2),
|
||||
"prompt_tokens_per_sec": prompt_ts,
|
||||
"gen_tokens": gen_tokens,
|
||||
"gen_time": round(gen_time, 2),
|
||||
"gen_tokens_per_sec": gen_ts,
|
||||
"total_time": total_time,
|
||||
"queue_time": round(queue_time, 2),
|
||||
"cached_tokens": cached_tokens,
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
}
|
||||
|
|
@ -1413,12 +1436,12 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
if result.get("eos"):
|
||||
log_response(request_id, full_response)
|
||||
|
||||
generation = self.handle_finish_chunk(result, generation)
|
||||
finish_chunk = self.handle_finish_chunk(result, generation)
|
||||
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = result
|
||||
metrics_result = finish_chunk
|
||||
|
||||
yield generation
|
||||
yield finish_chunk
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
await job.cancel()
|
||||
|
|
@ -1451,12 +1474,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
if metrics_result:
|
||||
log_metrics(
|
||||
request_id,
|
||||
metrics_result.get("time_enqueued"),
|
||||
metrics_result.get("prompt_tokens"),
|
||||
metrics_result.get("cached_tokens"),
|
||||
metrics_result.get("time_prefill"),
|
||||
metrics_result.get("new_tokens"),
|
||||
metrics_result.get("time_generate"),
|
||||
metrics_result,
|
||||
context_len,
|
||||
max_seq_len,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -649,11 +649,8 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
if generations:
|
||||
# Get finish_reason first and then shift where -1 points to
|
||||
if "finish_reason" in generations[-1]:
|
||||
finish_reason_gen = generations.pop()
|
||||
joined_generation["finish_reason"] = finish_reason_gen.get(
|
||||
"finish_reason"
|
||||
)
|
||||
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
|
||||
finish_chunk = generations.pop()
|
||||
joined_generation = {**joined_generation, **finish_chunk}
|
||||
else:
|
||||
joined_generation["finish_reason"] = "stop"
|
||||
|
||||
|
|
@ -743,9 +740,35 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
elif eos_reason == "stop_string":
|
||||
stop_str = result.get("eos_triggering_string")
|
||||
|
||||
# Prompt
|
||||
prompt_tokens = result.get("prompt_tokens")
|
||||
cached_tokens = round(result.get("cached_tokens"), 2)
|
||||
prompt_time = round(result.get("time_prefill"), 2)
|
||||
prompt_ts = (
|
||||
"Indeterminate"
|
||||
if prompt_time == 0
|
||||
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
|
||||
)
|
||||
|
||||
# Generated
|
||||
gen_tokens = result.get("new_tokens")
|
||||
gen_time = result.get("time_generate")
|
||||
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
|
||||
|
||||
# Queue + Total
|
||||
queue_time = result.get("time_enqueued")
|
||||
total_time = round(queue_time + prompt_time + gen_time, 2)
|
||||
|
||||
finish_chunk = {
|
||||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_time": round(prompt_time, 2),
|
||||
"prompt_tokens_per_sec": prompt_ts,
|
||||
"gen_tokens": gen_tokens,
|
||||
"gen_time": round(gen_time, 2),
|
||||
"gen_tokens_per_sec": gen_ts,
|
||||
"total_time": total_time,
|
||||
"queue_time": round(queue_time, 2),
|
||||
"cached_tokens": cached_tokens,
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
}
|
||||
|
|
@ -921,12 +944,12 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
yield generation
|
||||
|
||||
if result.get("eos"):
|
||||
generation = self.handle_finish_chunk(result, generation)
|
||||
finish_chunk = self.handle_finish_chunk(result, generation)
|
||||
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = result
|
||||
metrics_result = finish_chunk
|
||||
|
||||
yield generation
|
||||
yield finish_chunk
|
||||
break
|
||||
# Assign the active job to the request ID
|
||||
self.active_job_ids[request_id] = job
|
||||
|
|
@ -962,12 +985,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
if metrics_result:
|
||||
log_metrics(
|
||||
request_id,
|
||||
metrics_result.get("time_enqueued"),
|
||||
metrics_result.get("prompt_tokens"),
|
||||
metrics_result.get("cached_tokens"),
|
||||
metrics_result.get("time_prefill"),
|
||||
metrics_result.get("new_tokens"),
|
||||
metrics_result.get("time_generate"),
|
||||
metrics_result,
|
||||
context_len,
|
||||
self.max_seq_len,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -54,40 +54,29 @@ def log_response(request_id: str, response: str):
|
|||
|
||||
def log_metrics(
|
||||
request_id: str,
|
||||
queue_time: float,
|
||||
prompt_tokens: int,
|
||||
cached_tokens: int,
|
||||
prompt_time: float,
|
||||
generated_tokens: int,
|
||||
generate_time: float,
|
||||
metrics: dict,
|
||||
context_len: Optional[int],
|
||||
max_seq_len: int,
|
||||
):
|
||||
initial_response = (
|
||||
f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in "
|
||||
f"{round(queue_time + prompt_time + generate_time, 2)} seconds"
|
||||
f"Metrics (ID: {request_id}): {metrics.get('gen_tokens')} "
|
||||
f"tokens generated in {metrics.get('total_time')} seconds"
|
||||
)
|
||||
itemization = []
|
||||
extra_parts = []
|
||||
|
||||
itemization.append(f"Queue: {round(queue_time, 2)} s")
|
||||
itemization.append(f"Queue: {metrics.get('queue_time')} s")
|
||||
|
||||
cached_tokens = metrics.get("cached_tokens")
|
||||
prompt_tokens = metrics.get("prompt_tokens")
|
||||
|
||||
prompt_ts = (
|
||||
"Indeterminate"
|
||||
if prompt_time == 0
|
||||
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
|
||||
)
|
||||
itemization.append(
|
||||
f"Process: {cached_tokens} cached tokens and "
|
||||
f"{prompt_tokens - cached_tokens} new tokens at {prompt_ts} T/s"
|
||||
f"{prompt_tokens - cached_tokens} new tokens at "
|
||||
f"{metrics.get('prompt_tokens_per_sec')} T/s"
|
||||
)
|
||||
|
||||
generate_ts = (
|
||||
"Indeterminate"
|
||||
if generate_time == 0
|
||||
else round(generated_tokens / generate_time, 2)
|
||||
)
|
||||
itemization.append(f"Generate: {generate_ts} T/s")
|
||||
itemization.append(f"Generate: {metrics.get('gen_tokens_per_sec')} T/s")
|
||||
|
||||
# Add context (original token count)
|
||||
if context_len:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Common types for OAI."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||
|
||||
|
|
@ -10,8 +10,13 @@ class UsageStats(BaseModel):
|
|||
"""Represents usage stats."""
|
||||
|
||||
prompt_tokens: int
|
||||
prompt_time: Optional[float] = None
|
||||
prompt_tokens_per_sec: Optional[Union[float, str]] = None
|
||||
completion_tokens: int
|
||||
completion_time: Optional[float] = None
|
||||
completion_tokens_per_sec: Optional[Union[float, str]] = None
|
||||
total_tokens: int
|
||||
total_time: Optional[float] = None
|
||||
|
||||
|
||||
class CompletionResponseFormat(BaseModel):
|
||||
|
|
|
|||
|
|
@ -38,9 +38,6 @@ def _create_response(
|
|||
):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
choices = []
|
||||
for index, generation in enumerate(generations):
|
||||
message = ChatCompletionMessage(
|
||||
|
|
@ -91,14 +88,23 @@ def _create_response(
|
|||
|
||||
choices.append(choice)
|
||||
|
||||
final_generation = generations[-1]
|
||||
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{request_id}",
|
||||
id=f"cmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
model=model_name,
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_time=final_generation.get("prompt_time"),
|
||||
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_time=final_generation.get("gen_time"),
|
||||
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_time=final_generation.get("total_time"),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -119,12 +125,17 @@ def _create_stream_chunk(
|
|||
|
||||
if is_usage_chunk:
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("gen_tokens"), 0)
|
||||
|
||||
usage_stats = UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_time=generation.get("prompt_time"),
|
||||
prompt_tokens_per_sec=generation.get("prompt_tokens_per_sec"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_time=generation.get("gen_time"),
|
||||
completion_tokens_per_sec=generation.get("gen_tokens_per_sec"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_time=generation.get("total_time"),
|
||||
)
|
||||
elif "finish_reason" in generation:
|
||||
# Get the finish reason from the generation
|
||||
|
|
|
|||
|
|
@ -73,8 +73,9 @@ def _create_response(
|
|||
|
||||
choices.append(choice)
|
||||
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
final_generation = generations[-1]
|
||||
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
|
||||
|
||||
response = CompletionResponse(
|
||||
id=f"cmpl-{request_id}",
|
||||
|
|
@ -82,8 +83,13 @@ def _create_response(
|
|||
model=model_name,
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_time=final_generation.get("prompt_time"),
|
||||
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_time=final_generation.get("gen_time"),
|
||||
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_time=final_generation.get("total_time"),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue