Merge pull request #358 from theroyallab/breaking
Breaking changes for configuration
This commit is contained in:
commit
e362319a4d
12 changed files with 152 additions and 104 deletions
|
|
@ -235,11 +235,10 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
|
|
||||||
# Grab the base model's sequence length before overrides for
|
# Grab the base model's sequence length before overrides for
|
||||||
# rope calculations
|
# rope calculations
|
||||||
base_seq_len = self.config.max_seq_len
|
base_seq_len = hf_model.hf_config.max_position_embeddings
|
||||||
|
|
||||||
# Set the target seq len if present
|
# Set the target seq len if present
|
||||||
# Fallback to base_seq_len if not provided
|
target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||||
target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len)
|
|
||||||
|
|
||||||
# Set the rope scale
|
# Set the rope scale
|
||||||
self.config.scale_pos_emb = unwrap(
|
self.config.scale_pos_emb = unwrap(
|
||||||
|
|
@ -247,6 +246,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sets rope alpha value.
|
# Sets rope alpha value.
|
||||||
|
# Utilize the model's max_position_embeddings as a base value
|
||||||
# Automatically calculate if unset or defined as an "auto" literal.
|
# Automatically calculate if unset or defined as an "auto" literal.
|
||||||
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
|
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
|
||||||
if rope_alpha == "auto":
|
if rope_alpha == "auto":
|
||||||
|
|
@ -371,7 +371,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Set draft rope alpha. Follows same behavior as model rope alpha.
|
# Set draft rope alpha. Follows same behavior as model rope alpha.
|
||||||
# Use the base sequence length of the model
|
# Use the max_position_embeddings of the model
|
||||||
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
|
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
|
||||||
if draft_rope_alpha == "auto":
|
if draft_rope_alpha == "auto":
|
||||||
self.draft_config.scale_alpha_value = calculate_rope_alpha(
|
self.draft_config.scale_alpha_value = calculate_rope_alpha(
|
||||||
|
|
@ -911,7 +911,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
joined_generation = {
|
joined_generation = {
|
||||||
"text": "",
|
"text": "",
|
||||||
"prompt_tokens": 0,
|
"prompt_tokens": 0,
|
||||||
"generation_tokens": 0,
|
"gen_tokens": 0,
|
||||||
"tool_calls": None,
|
"tool_calls": None,
|
||||||
"offset": [],
|
"offset": [],
|
||||||
"token_probs": {},
|
"token_probs": {},
|
||||||
|
|
@ -921,11 +921,8 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
if generations:
|
if generations:
|
||||||
# Get finish_reason first and then shift where -1 points to
|
# Get finish_reason first and then shift where -1 points to
|
||||||
if "finish_reason" in generations[-1]:
|
if "finish_reason" in generations[-1]:
|
||||||
finish_reason_gen = generations.pop()
|
finish_chunk = generations.pop()
|
||||||
joined_generation["finish_reason"] = finish_reason_gen.get(
|
joined_generation = {**joined_generation, **finish_chunk}
|
||||||
"finish_reason"
|
|
||||||
)
|
|
||||||
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
|
|
||||||
else:
|
else:
|
||||||
joined_generation["finish_reason"] = "stop"
|
joined_generation["finish_reason"] = "stop"
|
||||||
|
|
||||||
|
|
@ -1187,9 +1184,35 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
elif eos_reason == "stop_string":
|
elif eos_reason == "stop_string":
|
||||||
stop_str = result.get("eos_triggering_string")
|
stop_str = result.get("eos_triggering_string")
|
||||||
|
|
||||||
|
# Prompt
|
||||||
|
prompt_tokens = result.get("prompt_tokens")
|
||||||
|
cached_tokens = round(result.get("cached_tokens"), 2)
|
||||||
|
prompt_time = round(result.get("time_prefill"), 2)
|
||||||
|
prompt_ts = (
|
||||||
|
"Indeterminate"
|
||||||
|
if prompt_time == 0
|
||||||
|
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generated
|
||||||
|
gen_tokens = result.get("new_tokens")
|
||||||
|
gen_time = result.get("time_generate")
|
||||||
|
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
|
||||||
|
|
||||||
|
# Queue + Total
|
||||||
|
queue_time = result.get("time_enqueued")
|
||||||
|
total_time = round(queue_time + prompt_time + gen_time, 2)
|
||||||
|
|
||||||
finish_chunk = {
|
finish_chunk = {
|
||||||
"prompt_tokens": generation.get("prompt_tokens"),
|
"prompt_tokens": prompt_tokens,
|
||||||
"generated_tokens": generation.get("generated_tokens"),
|
"prompt_time": round(prompt_time, 2),
|
||||||
|
"prompt_tokens_per_sec": prompt_ts,
|
||||||
|
"gen_tokens": gen_tokens,
|
||||||
|
"gen_time": round(gen_time, 2),
|
||||||
|
"gen_tokens_per_sec": gen_ts,
|
||||||
|
"total_time": total_time,
|
||||||
|
"queue_time": round(queue_time, 2),
|
||||||
|
"cached_tokens": cached_tokens,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
"stop_str": stop_str,
|
"stop_str": stop_str,
|
||||||
}
|
}
|
||||||
|
|
@ -1411,12 +1434,12 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
if result.get("eos"):
|
if result.get("eos"):
|
||||||
log_response(request_id, full_response)
|
log_response(request_id, full_response)
|
||||||
|
|
||||||
generation = self.handle_finish_chunk(result, generation)
|
finish_chunk = self.handle_finish_chunk(result, generation)
|
||||||
|
|
||||||
# Save the final result for metrics logging
|
# Save the final result for metrics logging
|
||||||
metrics_result = result
|
metrics_result = finish_chunk
|
||||||
|
|
||||||
yield generation
|
yield finish_chunk
|
||||||
break
|
break
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
await job.cancel()
|
await job.cancel()
|
||||||
|
|
@ -1449,12 +1472,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
if metrics_result:
|
if metrics_result:
|
||||||
log_metrics(
|
log_metrics(
|
||||||
request_id,
|
request_id,
|
||||||
metrics_result.get("time_enqueued"),
|
metrics_result,
|
||||||
metrics_result.get("prompt_tokens"),
|
|
||||||
metrics_result.get("cached_tokens"),
|
|
||||||
metrics_result.get("time_prefill"),
|
|
||||||
metrics_result.get("new_tokens"),
|
|
||||||
metrics_result.get("time_generate"),
|
|
||||||
context_len,
|
context_len,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -649,11 +649,8 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
if generations:
|
if generations:
|
||||||
# Get finish_reason first and then shift where -1 points to
|
# Get finish_reason first and then shift where -1 points to
|
||||||
if "finish_reason" in generations[-1]:
|
if "finish_reason" in generations[-1]:
|
||||||
finish_reason_gen = generations.pop()
|
finish_chunk = generations.pop()
|
||||||
joined_generation["finish_reason"] = finish_reason_gen.get(
|
joined_generation = {**joined_generation, **finish_chunk}
|
||||||
"finish_reason"
|
|
||||||
)
|
|
||||||
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
|
|
||||||
else:
|
else:
|
||||||
joined_generation["finish_reason"] = "stop"
|
joined_generation["finish_reason"] = "stop"
|
||||||
|
|
||||||
|
|
@ -743,9 +740,35 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
elif eos_reason == "stop_string":
|
elif eos_reason == "stop_string":
|
||||||
stop_str = result.get("eos_triggering_string")
|
stop_str = result.get("eos_triggering_string")
|
||||||
|
|
||||||
|
# Prompt
|
||||||
|
prompt_tokens = result.get("prompt_tokens")
|
||||||
|
cached_tokens = round(result.get("cached_tokens"), 2)
|
||||||
|
prompt_time = round(result.get("time_prefill"), 2)
|
||||||
|
prompt_ts = (
|
||||||
|
"Indeterminate"
|
||||||
|
if prompt_time == 0
|
||||||
|
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generated
|
||||||
|
gen_tokens = result.get("new_tokens")
|
||||||
|
gen_time = result.get("time_generate")
|
||||||
|
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
|
||||||
|
|
||||||
|
# Queue + Total
|
||||||
|
queue_time = result.get("time_enqueued")
|
||||||
|
total_time = round(queue_time + prompt_time + gen_time, 2)
|
||||||
|
|
||||||
finish_chunk = {
|
finish_chunk = {
|
||||||
"prompt_tokens": generation.get("prompt_tokens"),
|
"prompt_tokens": prompt_tokens,
|
||||||
"generated_tokens": generation.get("generated_tokens"),
|
"prompt_time": round(prompt_time, 2),
|
||||||
|
"prompt_tokens_per_sec": prompt_ts,
|
||||||
|
"gen_tokens": gen_tokens,
|
||||||
|
"gen_time": round(gen_time, 2),
|
||||||
|
"gen_tokens_per_sec": gen_ts,
|
||||||
|
"total_time": total_time,
|
||||||
|
"queue_time": round(queue_time, 2),
|
||||||
|
"cached_tokens": cached_tokens,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
"stop_str": stop_str,
|
"stop_str": stop_str,
|
||||||
}
|
}
|
||||||
|
|
@ -921,12 +944,12 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
yield generation
|
yield generation
|
||||||
|
|
||||||
if result.get("eos"):
|
if result.get("eos"):
|
||||||
generation = self.handle_finish_chunk(result, generation)
|
finish_chunk = self.handle_finish_chunk(result, generation)
|
||||||
|
|
||||||
# Save the final result for metrics logging
|
# Save the final result for metrics logging
|
||||||
metrics_result = result
|
metrics_result = finish_chunk
|
||||||
|
|
||||||
yield generation
|
yield finish_chunk
|
||||||
break
|
break
|
||||||
# Assign the active job to the request ID
|
# Assign the active job to the request ID
|
||||||
self.active_job_ids[request_id] = job
|
self.active_job_ids[request_id] = job
|
||||||
|
|
@ -962,12 +985,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
if metrics_result:
|
if metrics_result:
|
||||||
log_metrics(
|
log_metrics(
|
||||||
request_id,
|
request_id,
|
||||||
metrics_result.get("time_enqueued"),
|
metrics_result,
|
||||||
metrics_result.get("prompt_tokens"),
|
|
||||||
metrics_result.get("cached_tokens"),
|
|
||||||
metrics_result.get("time_prefill"),
|
|
||||||
metrics_result.get("new_tokens"),
|
|
||||||
metrics_result.get("time_generate"),
|
|
||||||
context_len,
|
context_len,
|
||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -175,10 +175,10 @@ class ModelConfig(BaseConfigModel):
|
||||||
max_seq_len: Optional[int] = Field(
|
max_seq_len: Optional[int] = Field(
|
||||||
None,
|
None,
|
||||||
description=(
|
description=(
|
||||||
"Max sequence length (default: Empty).\n"
|
"Max sequence length (default: 4096).\n"
|
||||||
"Fetched from the model's base sequence length in config.json by default."
|
"Set to -1 to fetch from the model's config.json"
|
||||||
),
|
),
|
||||||
ge=0,
|
ge=-1,
|
||||||
)
|
)
|
||||||
tensor_parallel: Optional[bool] = Field(
|
tensor_parallel: Optional[bool] = Field(
|
||||||
False,
|
False,
|
||||||
|
|
|
||||||
|
|
@ -54,40 +54,29 @@ def log_response(request_id: str, response: str):
|
||||||
|
|
||||||
def log_metrics(
|
def log_metrics(
|
||||||
request_id: str,
|
request_id: str,
|
||||||
queue_time: float,
|
metrics: dict,
|
||||||
prompt_tokens: int,
|
|
||||||
cached_tokens: int,
|
|
||||||
prompt_time: float,
|
|
||||||
generated_tokens: int,
|
|
||||||
generate_time: float,
|
|
||||||
context_len: Optional[int],
|
context_len: Optional[int],
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
):
|
):
|
||||||
initial_response = (
|
initial_response = (
|
||||||
f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in "
|
f"Metrics (ID: {request_id}): {metrics.get('gen_tokens')} "
|
||||||
f"{round(queue_time + prompt_time + generate_time, 2)} seconds"
|
f"tokens generated in {metrics.get('total_time')} seconds"
|
||||||
)
|
)
|
||||||
itemization = []
|
itemization = []
|
||||||
extra_parts = []
|
extra_parts = []
|
||||||
|
|
||||||
itemization.append(f"Queue: {round(queue_time, 2)} s")
|
itemization.append(f"Queue: {metrics.get('queue_time')} s")
|
||||||
|
|
||||||
|
cached_tokens = metrics.get("cached_tokens")
|
||||||
|
prompt_tokens = metrics.get("prompt_tokens")
|
||||||
|
|
||||||
prompt_ts = (
|
|
||||||
"Indeterminate"
|
|
||||||
if prompt_time == 0
|
|
||||||
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
|
|
||||||
)
|
|
||||||
itemization.append(
|
itemization.append(
|
||||||
f"Process: {cached_tokens} cached tokens and "
|
f"Process: {cached_tokens} cached tokens and "
|
||||||
f"{prompt_tokens - cached_tokens} new tokens at {prompt_ts} T/s"
|
f"{prompt_tokens - cached_tokens} new tokens at "
|
||||||
|
f"{metrics.get('prompt_tokens_per_sec')} T/s"
|
||||||
)
|
)
|
||||||
|
|
||||||
generate_ts = (
|
itemization.append(f"Generate: {metrics.get('gen_tokens_per_sec')} T/s")
|
||||||
"Indeterminate"
|
|
||||||
if generate_time == 0
|
|
||||||
else round(generated_tokens / generate_time, 2)
|
|
||||||
)
|
|
||||||
itemization.append(f"Generate: {generate_ts} T/s")
|
|
||||||
|
|
||||||
# Add context (original token count)
|
# Add context (original token count)
|
||||||
if context_len:
|
if context_len:
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,9 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||||
if not override_config_path.exists():
|
if not override_config_path.exists():
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
# Initialize overrides dict
|
||||||
|
overrides = {}
|
||||||
|
|
||||||
async with aiofiles.open(
|
async with aiofiles.open(
|
||||||
override_config_path, "r", encoding="utf8"
|
override_config_path, "r", encoding="utf8"
|
||||||
) as override_config_file:
|
) as override_config_file:
|
||||||
|
|
@ -83,18 +86,25 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||||
|
|
||||||
# Create a temporary YAML parser
|
# Create a temporary YAML parser
|
||||||
yaml = YAML(typ="safe")
|
yaml = YAML(typ="safe")
|
||||||
override_args = unwrap(yaml.load(contents), {})
|
inline_config = unwrap(yaml.load(contents), {})
|
||||||
|
|
||||||
|
# Check for inline model overrides
|
||||||
|
model_inline_config = unwrap(inline_config.get("model"), {})
|
||||||
|
if model_inline_config:
|
||||||
|
overrides = {**model_inline_config}
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Cannot find inline model overrides. "
|
||||||
|
'Make sure they are nested under a "model:" key'
|
||||||
|
)
|
||||||
|
|
||||||
# Merge draft overrides beforehand
|
# Merge draft overrides beforehand
|
||||||
draft_override_args = unwrap(override_args.get("draft_model"), {})
|
draft_inline_config = unwrap(inline_config.get("draft_model"), {})
|
||||||
if draft_override_args:
|
if draft_inline_config:
|
||||||
kwargs["draft_model"] = {
|
overrides["draft_model"] = {**draft_inline_config}
|
||||||
**draft_override_args,
|
|
||||||
**unwrap(kwargs.get("draft_model"), {}),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Merge the override and model kwargs
|
# Merge the override and model kwargs
|
||||||
merged_kwargs = {**override_args, **kwargs}
|
merged_kwargs = {**overrides, **kwargs}
|
||||||
return merged_kwargs
|
return merged_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -138,6 +148,13 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||||
# Fetch the extra HF configuration options
|
# Fetch the extra HF configuration options
|
||||||
hf_model = await HFModel.from_directory(model_path)
|
hf_model = await HFModel.from_directory(model_path)
|
||||||
|
|
||||||
|
# Override the max sequence length based on user
|
||||||
|
max_seq_len = kwargs.get("max_seq_len")
|
||||||
|
if max_seq_len == -1:
|
||||||
|
kwargs["max_seq_len"] = hf_model.hf_config.max_position_embeddings
|
||||||
|
elif max_seq_len is None:
|
||||||
|
kwargs["max_seq_len"] = 4096
|
||||||
|
|
||||||
# Create a new container and check if the right dependencies are installed
|
# Create a new container and check if the right dependencies are installed
|
||||||
backend = unwrap(kwargs.get("backend"), detect_backend(hf_model))
|
backend = unwrap(kwargs.get("backend"), detect_backend(hf_model))
|
||||||
container_class = _BACKEND_REGISTRY.get(backend)
|
container_class = _BACKEND_REGISTRY.get(backend)
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,11 @@ class GenerationConfig(BaseModel):
|
||||||
|
|
||||||
class HuggingFaceConfig(BaseModel):
|
class HuggingFaceConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
DEPRECATED: Currently a stub and doesn't do anything.
|
|
||||||
|
|
||||||
An abridged version of HuggingFace's model config.
|
An abridged version of HuggingFace's model config.
|
||||||
Will be expanded as needed.
|
Will be expanded as needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
max_position_embeddings: int = 4096
|
||||||
eos_token_id: Optional[Union[int, List[int]]] = None
|
eos_token_id: Optional[Union[int, List[int]]] = None
|
||||||
quantization_config: Optional[Dict] = None
|
quantization_config: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,8 +78,8 @@ model:
|
||||||
# Options: exllamav2, exllamav3
|
# Options: exllamav2, exllamav3
|
||||||
backend:
|
backend:
|
||||||
|
|
||||||
# Max sequence length (default: Empty).
|
# Max sequence length (default: 4096).
|
||||||
# Fetched from the model's base sequence length in config.json by default.
|
# Set to -1 to fetch from the model's config.json
|
||||||
max_seq_len:
|
max_seq_len:
|
||||||
|
|
||||||
# Load model with tensor parallelism.
|
# Load model with tensor parallelism.
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Common types for OAI."""
|
"""Common types for OAI."""
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||||
|
|
||||||
|
|
@ -10,8 +10,13 @@ class UsageStats(BaseModel):
|
||||||
"""Represents usage stats."""
|
"""Represents usage stats."""
|
||||||
|
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
|
prompt_time: Optional[float] = None
|
||||||
|
prompt_tokens_per_sec: Optional[Union[float, str]] = None
|
||||||
completion_tokens: int
|
completion_tokens: int
|
||||||
|
completion_time: Optional[float] = None
|
||||||
|
completion_tokens_per_sec: Optional[Union[float, str]] = None
|
||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
total_time: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseFormat(BaseModel):
|
class CompletionResponseFormat(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -38,9 +38,6 @@ def _create_response(
|
||||||
):
|
):
|
||||||
"""Create a chat completion response from the provided text."""
|
"""Create a chat completion response from the provided text."""
|
||||||
|
|
||||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
|
||||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
|
||||||
|
|
||||||
choices = []
|
choices = []
|
||||||
for index, generation in enumerate(generations):
|
for index, generation in enumerate(generations):
|
||||||
message = ChatCompletionMessage(
|
message = ChatCompletionMessage(
|
||||||
|
|
@ -91,14 +88,23 @@ def _create_response(
|
||||||
|
|
||||||
choices.append(choice)
|
choices.append(choice)
|
||||||
|
|
||||||
|
final_generation = generations[-1]
|
||||||
|
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
|
||||||
|
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
|
||||||
|
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
id=f"chatcmpl-{request_id}",
|
id=f"cmpl-{request_id}",
|
||||||
choices=choices,
|
choices=choices,
|
||||||
model=unwrap(model_name, ""),
|
model=model_name,
|
||||||
usage=UsageStats(
|
usage=UsageStats(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
|
prompt_time=final_generation.get("prompt_time"),
|
||||||
|
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
|
completion_time=final_generation.get("gen_time"),
|
||||||
|
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
total_time=final_generation.get("total_time"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -119,12 +125,17 @@ def _create_stream_chunk(
|
||||||
|
|
||||||
if is_usage_chunk:
|
if is_usage_chunk:
|
||||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
completion_tokens = unwrap(generation.get("gen_tokens"), 0)
|
||||||
|
|
||||||
usage_stats = UsageStats(
|
usage_stats = UsageStats(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
|
prompt_time=generation.get("prompt_time"),
|
||||||
|
prompt_tokens_per_sec=generation.get("prompt_tokens_per_sec"),
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
|
completion_time=generation.get("gen_time"),
|
||||||
|
completion_tokens_per_sec=generation.get("gen_tokens_per_sec"),
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
total_time=generation.get("total_time"),
|
||||||
)
|
)
|
||||||
elif "finish_reason" in generation:
|
elif "finish_reason" in generation:
|
||||||
# Get the finish reason from the generation
|
# Get the finish reason from the generation
|
||||||
|
|
|
||||||
|
|
@ -73,8 +73,9 @@ def _create_response(
|
||||||
|
|
||||||
choices.append(choice)
|
choices.append(choice)
|
||||||
|
|
||||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
final_generation = generations[-1]
|
||||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
|
||||||
|
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
|
||||||
|
|
||||||
response = CompletionResponse(
|
response = CompletionResponse(
|
||||||
id=f"cmpl-{request_id}",
|
id=f"cmpl-{request_id}",
|
||||||
|
|
@ -82,8 +83,13 @@ def _create_response(
|
||||||
model=model_name,
|
model=model_name,
|
||||||
usage=UsageStats(
|
usage=UsageStats(
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
|
prompt_time=final_generation.get("prompt_time"),
|
||||||
|
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
|
completion_time=final_generation.get("gen_time"),
|
||||||
|
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
total_time=final_generation.get("total_time"),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""Contains model card types."""
|
"""Contains model card types."""
|
||||||
|
|
||||||
from pydantic import AliasChoices, BaseModel, Field, ConfigDict
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from time import time
|
from time import time
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
|
@ -50,10 +50,7 @@ class DraftModelLoadRequest(BaseModel):
|
||||||
"""Represents a draft model load request."""
|
"""Represents a draft model load request."""
|
||||||
|
|
||||||
# Required
|
# Required
|
||||||
draft_model_name: str = Field(
|
draft_model_name: str
|
||||||
alias=AliasChoices("draft_model_name", "name"),
|
|
||||||
description="Aliases: name",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Config arguments
|
# Config arguments
|
||||||
draft_rope_scale: Optional[float] = None
|
draft_rope_scale: Optional[float] = None
|
||||||
|
|
@ -75,10 +72,7 @@ class ModelLoadRequest(BaseModel):
|
||||||
model_config = ConfigDict(protected_namespaces=[])
|
model_config = ConfigDict(protected_namespaces=[])
|
||||||
|
|
||||||
# Required
|
# Required
|
||||||
model_name: str = Field(
|
model_name: str
|
||||||
alias=AliasChoices("model_name", "name"),
|
|
||||||
description="Aliases: name",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Config arguments
|
# Config arguments
|
||||||
backend: Optional[str] = Field(
|
backend: Optional[str] = Field(
|
||||||
|
|
@ -118,18 +112,12 @@ class ModelLoadRequest(BaseModel):
|
||||||
vision: Optional[bool] = None
|
vision: Optional[bool] = None
|
||||||
|
|
||||||
# Non-config arguments
|
# Non-config arguments
|
||||||
draft_model: Optional[DraftModelLoadRequest] = Field(
|
draft_model: Optional[DraftModelLoadRequest] = None
|
||||||
default=None,
|
|
||||||
alias=AliasChoices("draft_model", "draft"),
|
|
||||||
)
|
|
||||||
skip_queue: Optional[bool] = False
|
skip_queue: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModelLoadRequest(BaseModel):
|
class EmbeddingModelLoadRequest(BaseModel):
|
||||||
embedding_model_name: str = Field(
|
embedding_model_name: str
|
||||||
alias=AliasChoices("embedding_model_name", "name"),
|
|
||||||
description="Aliases: name",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set default from the config
|
# Set default from the config
|
||||||
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)
|
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from pydantic import AliasChoices, BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -12,7 +12,4 @@ class TemplateList(BaseModel):
|
||||||
class TemplateSwitchRequest(BaseModel):
|
class TemplateSwitchRequest(BaseModel):
|
||||||
"""Request to switch a template."""
|
"""Request to switch a template."""
|
||||||
|
|
||||||
prompt_template_name: str = Field(
|
prompt_template_name: str
|
||||||
alias=AliasChoices("prompt_template_name", "name"),
|
|
||||||
description="Aliases: name",
|
|
||||||
)
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue