API: Add timings to usage stats
It's useful for the client to know what the T/s and total time for generation are per-request. Works with both completions and chat completions. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
5d94d4d022
commit
2913ce29fc
6 changed files with 110 additions and 63 deletions
|
|
@ -1,7 +1,7 @@
|
|||
"""Common types for OAI."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||
|
||||
|
|
@ -10,8 +10,13 @@ class UsageStats(BaseModel):
|
|||
"""Represents usage stats."""
|
||||
|
||||
prompt_tokens: int
|
||||
prompt_time: Optional[float] = None
|
||||
prompt_tokens_per_sec: Optional[Union[float, str]] = None
|
||||
completion_tokens: int
|
||||
completion_time: Optional[float] = None
|
||||
completion_tokens_per_sec: Optional[Union[float, str]] = None
|
||||
total_tokens: int
|
||||
total_time: Optional[float] = None
|
||||
|
||||
|
||||
class CompletionResponseFormat(BaseModel):
|
||||
|
|
|
|||
|
|
@ -38,9 +38,6 @@ def _create_response(
|
|||
):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
choices = []
|
||||
for index, generation in enumerate(generations):
|
||||
message = ChatCompletionMessage(
|
||||
|
|
@ -91,14 +88,23 @@ def _create_response(
|
|||
|
||||
choices.append(choice)
|
||||
|
||||
final_generation = generations[-1]
|
||||
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{request_id}",
|
||||
id=f"cmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
model=model_name,
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_time=final_generation.get("prompt_time"),
|
||||
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_time=final_generation.get("gen_time"),
|
||||
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_time=final_generation.get("total_time"),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -119,12 +125,17 @@ def _create_stream_chunk(
|
|||
|
||||
if is_usage_chunk:
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("gen_tokens"), 0)
|
||||
|
||||
usage_stats = UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_time=generation.get("prompt_time"),
|
||||
prompt_tokens_per_sec=generation.get("prompt_tokens_per_sec"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_time=generation.get("gen_time"),
|
||||
completion_tokens_per_sec=generation.get("gen_tokens_per_sec"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_time=generation.get("total_time"),
|
||||
)
|
||||
elif "finish_reason" in generation:
|
||||
# Get the finish reason from the generation
|
||||
|
|
|
|||
|
|
@ -73,8 +73,9 @@ def _create_response(
|
|||
|
||||
choices.append(choice)
|
||||
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
final_generation = generations[-1]
|
||||
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
|
||||
|
||||
response = CompletionResponse(
|
||||
id=f"cmpl-{request_id}",
|
||||
|
|
@ -82,8 +83,13 @@ def _create_response(
|
|||
model=model_name,
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_time=final_generation.get("prompt_time"),
|
||||
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
|
||||
completion_tokens=completion_tokens,
|
||||
completion_time=final_generation.get("gen_time"),
|
||||
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
total_time=final_generation.get("total_time"),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue