Merge pull request #358 from theroyallab/breaking

Breaking changes for configuration
This commit is contained in:
Brian 2025-06-17 23:10:16 -04:00 committed by GitHub
commit e362319a4d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 152 additions and 104 deletions

View file

@ -235,11 +235,10 @@ class ExllamaV2Container(BaseModelContainer):
# Grab the base model's sequence length before overrides for
# rope calculations
base_seq_len = self.config.max_seq_len
base_seq_len = hf_model.hf_config.max_position_embeddings
# Set the target seq len if present
# Fallback to base_seq_len if not provided
target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len)
target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Set the rope scale
self.config.scale_pos_emb = unwrap(
@ -247,6 +246,7 @@ class ExllamaV2Container(BaseModelContainer):
)
# Sets rope alpha value.
# Utilize the model's max_position_embeddings as a base value
# Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
if rope_alpha == "auto":
@ -371,7 +371,7 @@ class ExllamaV2Container(BaseModelContainer):
)
# Set draft rope alpha. Follows same behavior as model rope alpha.
# Use the base sequence length of the model
# Use the max_position_embeddings of the model
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
if draft_rope_alpha == "auto":
self.draft_config.scale_alpha_value = calculate_rope_alpha(
@ -911,7 +911,7 @@ class ExllamaV2Container(BaseModelContainer):
joined_generation = {
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"gen_tokens": 0,
"tool_calls": None,
"offset": [],
"token_probs": {},
@ -921,11 +921,8 @@ class ExllamaV2Container(BaseModelContainer):
if generations:
# Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
finish_chunk = generations.pop()
joined_generation = {**joined_generation, **finish_chunk}
else:
joined_generation["finish_reason"] = "stop"
@ -1187,9 +1184,35 @@ class ExllamaV2Container(BaseModelContainer):
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
# Prompt
prompt_tokens = result.get("prompt_tokens")
cached_tokens = round(result.get("cached_tokens"), 2)
prompt_time = round(result.get("time_prefill"), 2)
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
# Generated
gen_tokens = result.get("new_tokens")
gen_time = result.get("time_generate")
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
# Queue + Total
queue_time = result.get("time_enqueued")
total_time = round(queue_time + prompt_time + gen_time, 2)
finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"prompt_tokens": prompt_tokens,
"prompt_time": round(prompt_time, 2),
"prompt_tokens_per_sec": prompt_ts,
"gen_tokens": gen_tokens,
"gen_time": round(gen_time, 2),
"gen_tokens_per_sec": gen_ts,
"total_time": total_time,
"queue_time": round(queue_time, 2),
"cached_tokens": cached_tokens,
"finish_reason": finish_reason,
"stop_str": stop_str,
}
@ -1411,12 +1434,12 @@ class ExllamaV2Container(BaseModelContainer):
if result.get("eos"):
log_response(request_id, full_response)
generation = self.handle_finish_chunk(result, generation)
finish_chunk = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = result
metrics_result = finish_chunk
yield generation
yield finish_chunk
break
except asyncio.CancelledError:
await job.cancel()
@ -1449,12 +1472,7 @@ class ExllamaV2Container(BaseModelContainer):
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
metrics_result,
context_len,
max_seq_len,
)

View file

@ -649,11 +649,8 @@ class ExllamaV3Container(BaseModelContainer):
if generations:
# Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
finish_chunk = generations.pop()
joined_generation = {**joined_generation, **finish_chunk}
else:
joined_generation["finish_reason"] = "stop"
@ -743,9 +740,35 @@ class ExllamaV3Container(BaseModelContainer):
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
# Prompt
prompt_tokens = result.get("prompt_tokens")
cached_tokens = round(result.get("cached_tokens"), 2)
prompt_time = round(result.get("time_prefill"), 2)
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
# Generated
gen_tokens = result.get("new_tokens")
gen_time = result.get("time_generate")
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
# Queue + Total
queue_time = result.get("time_enqueued")
total_time = round(queue_time + prompt_time + gen_time, 2)
finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"prompt_tokens": prompt_tokens,
"prompt_time": round(prompt_time, 2),
"prompt_tokens_per_sec": prompt_ts,
"gen_tokens": gen_tokens,
"gen_time": round(gen_time, 2),
"gen_tokens_per_sec": gen_ts,
"total_time": total_time,
"queue_time": round(queue_time, 2),
"cached_tokens": cached_tokens,
"finish_reason": finish_reason,
"stop_str": stop_str,
}
@ -921,12 +944,12 @@ class ExllamaV3Container(BaseModelContainer):
yield generation
if result.get("eos"):
generation = self.handle_finish_chunk(result, generation)
finish_chunk = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = result
metrics_result = finish_chunk
yield generation
yield finish_chunk
break
# Assign the active job to the request ID
self.active_job_ids[request_id] = job
@ -962,12 +985,7 @@ class ExllamaV3Container(BaseModelContainer):
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
metrics_result,
context_len,
self.max_seq_len,
)

View file

@ -175,10 +175,10 @@ class ModelConfig(BaseConfigModel):
max_seq_len: Optional[int] = Field(
None,
description=(
"Max sequence length (default: Empty).\n"
"Fetched from the model's base sequence length in config.json by default."
"Max sequence length (default: 4096).\n"
"Set to -1 to fetch from the model's config.json"
),
ge=0,
ge=-1,
)
tensor_parallel: Optional[bool] = Field(
False,

View file

@ -54,40 +54,29 @@ def log_response(request_id: str, response: str):
def log_metrics(
request_id: str,
queue_time: float,
prompt_tokens: int,
cached_tokens: int,
prompt_time: float,
generated_tokens: int,
generate_time: float,
metrics: dict,
context_len: Optional[int],
max_seq_len: int,
):
initial_response = (
f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in "
f"{round(queue_time + prompt_time + generate_time, 2)} seconds"
f"Metrics (ID: {request_id}): {metrics.get('gen_tokens')} "
f"tokens generated in {metrics.get('total_time')} seconds"
)
itemization = []
extra_parts = []
itemization.append(f"Queue: {round(queue_time, 2)} s")
itemization.append(f"Queue: {metrics.get('queue_time')} s")
cached_tokens = metrics.get("cached_tokens")
prompt_tokens = metrics.get("prompt_tokens")
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
itemization.append(
f"Process: {cached_tokens} cached tokens and "
f"{prompt_tokens - cached_tokens} new tokens at {prompt_ts} T/s"
f"{prompt_tokens - cached_tokens} new tokens at "
f"{metrics.get('prompt_tokens_per_sec')} T/s"
)
generate_ts = (
"Indeterminate"
if generate_time == 0
else round(generated_tokens / generate_time, 2)
)
itemization.append(f"Generate: {generate_ts} T/s")
itemization.append(f"Generate: {metrics.get('gen_tokens_per_sec')} T/s")
# Add context (original token count)
if context_len:

View file

@ -76,6 +76,9 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
if not override_config_path.exists():
return kwargs
# Initialize overrides dict
overrides = {}
async with aiofiles.open(
override_config_path, "r", encoding="utf8"
) as override_config_file:
@ -83,18 +86,25 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
# Create a temporary YAML parser
yaml = YAML(typ="safe")
override_args = unwrap(yaml.load(contents), {})
inline_config = unwrap(yaml.load(contents), {})
# Check for inline model overrides
model_inline_config = unwrap(inline_config.get("model"), {})
if model_inline_config:
overrides = {**model_inline_config}
else:
logger.warning(
"Cannot find inline model overrides. "
'Make sure they are nested under a "model:" key'
)
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft_model"), {})
if draft_override_args:
kwargs["draft_model"] = {
**draft_override_args,
**unwrap(kwargs.get("draft_model"), {}),
}
draft_inline_config = unwrap(inline_config.get("draft_model"), {})
if draft_inline_config:
overrides["draft_model"] = {**draft_inline_config}
# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}
merged_kwargs = {**overrides, **kwargs}
return merged_kwargs
@ -138,6 +148,13 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Fetch the extra HF configuration options
hf_model = await HFModel.from_directory(model_path)
# Override the max sequence length based on user
max_seq_len = kwargs.get("max_seq_len")
if max_seq_len == -1:
kwargs["max_seq_len"] = hf_model.hf_config.max_position_embeddings
elif max_seq_len is None:
kwargs["max_seq_len"] = 4096
# Create a new container and check if the right dependencies are installed
backend = unwrap(kwargs.get("backend"), detect_backend(hf_model))
container_class = _BACKEND_REGISTRY.get(backend)

View file

@ -39,12 +39,11 @@ class GenerationConfig(BaseModel):
class HuggingFaceConfig(BaseModel):
"""
DEPRECATED: Currently a stub and doesn't do anything.
An abridged version of HuggingFace's model config.
Will be expanded as needed.
"""
max_position_embeddings: int = 4096
eos_token_id: Optional[Union[int, List[int]]] = None
quantization_config: Optional[Dict] = None

View file

@ -78,8 +78,8 @@ model:
# Options: exllamav2, exllamav3
backend:
# Max sequence length (default: Empty).
# Fetched from the model's base sequence length in config.json by default.
# Max sequence length (default: 4096).
# Set to -1 to fetch from the model's config.json
max_seq_len:
# Load model with tensor parallelism.

View file

@ -1,7 +1,7 @@
"""Common types for OAI."""
from pydantic import BaseModel, Field
from typing import Optional
from typing import Optional, Union
from common.sampling import BaseSamplerRequest, get_default_sampler_value
@ -10,8 +10,13 @@ class UsageStats(BaseModel):
"""Represents usage stats."""
prompt_tokens: int
prompt_time: Optional[float] = None
prompt_tokens_per_sec: Optional[Union[float, str]] = None
completion_tokens: int
completion_time: Optional[float] = None
completion_tokens_per_sec: Optional[Union[float, str]] = None
total_tokens: int
total_time: Optional[float] = None
class CompletionResponseFormat(BaseModel):

View file

@ -38,9 +38,6 @@ def _create_response(
):
"""Create a chat completion response from the provided text."""
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
choices = []
for index, generation in enumerate(generations):
message = ChatCompletionMessage(
@ -91,14 +88,23 @@ def _create_response(
choices.append(choice)
final_generation = generations[-1]
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
response = ChatCompletionResponse(
id=f"chatcmpl-{request_id}",
id=f"cmpl-{request_id}",
choices=choices,
model=unwrap(model_name, ""),
model=model_name,
usage=UsageStats(
prompt_tokens=prompt_tokens,
prompt_time=final_generation.get("prompt_time"),
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens,
completion_time=final_generation.get("gen_time"),
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens,
total_time=final_generation.get("total_time"),
),
)
@ -119,12 +125,17 @@ def _create_stream_chunk(
if is_usage_chunk:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
completion_tokens = unwrap(generation.get("gen_tokens"), 0)
usage_stats = UsageStats(
prompt_tokens=prompt_tokens,
prompt_time=generation.get("prompt_time"),
prompt_tokens_per_sec=generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens,
completion_time=generation.get("gen_time"),
completion_tokens_per_sec=generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens,
total_time=generation.get("total_time"),
)
elif "finish_reason" in generation:
# Get the finish reason from the generation

View file

@ -73,8 +73,9 @@ def _create_response(
choices.append(choice)
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
final_generation = generations[-1]
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
response = CompletionResponse(
id=f"cmpl-{request_id}",
@ -82,8 +83,13 @@ def _create_response(
model=model_name,
usage=UsageStats(
prompt_tokens=prompt_tokens,
prompt_time=final_generation.get("prompt_time"),
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens,
completion_time=final_generation.get("gen_time"),
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens,
total_time=final_generation.get("total_time"),
),
)

View file

@ -1,6 +1,6 @@
"""Contains model card types."""
from pydantic import AliasChoices, BaseModel, Field, ConfigDict
from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Literal, Optional, Union
@ -50,10 +50,7 @@ class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""
# Required
draft_model_name: str = Field(
alias=AliasChoices("draft_model_name", "name"),
description="Aliases: name",
)
draft_model_name: str
# Config arguments
draft_rope_scale: Optional[float] = None
@ -75,10 +72,7 @@ class ModelLoadRequest(BaseModel):
model_config = ConfigDict(protected_namespaces=[])
# Required
model_name: str = Field(
alias=AliasChoices("model_name", "name"),
description="Aliases: name",
)
model_name: str
# Config arguments
backend: Optional[str] = Field(
@ -118,18 +112,12 @@ class ModelLoadRequest(BaseModel):
vision: Optional[bool] = None
# Non-config arguments
draft_model: Optional[DraftModelLoadRequest] = Field(
default=None,
alias=AliasChoices("draft_model", "draft"),
)
draft_model: Optional[DraftModelLoadRequest] = None
skip_queue: Optional[bool] = False
class EmbeddingModelLoadRequest(BaseModel):
embedding_model_name: str = Field(
alias=AliasChoices("embedding_model_name", "name"),
description="Aliases: name",
)
embedding_model_name: str
# Set default from the config
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)

View file

@ -1,4 +1,4 @@
from pydantic import AliasChoices, BaseModel, Field
from pydantic import BaseModel, Field
from typing import List
@ -12,7 +12,4 @@ class TemplateList(BaseModel):
class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""
prompt_template_name: str = Field(
alias=AliasChoices("prompt_template_name", "name"),
description="Aliases: name",
)
prompt_template_name: str