Merge pull request #358 from theroyallab/breaking

Breaking changes for configuration
This commit is contained in:
Brian 2025-06-17 23:10:16 -04:00 committed by GitHub
commit e362319a4d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 152 additions and 104 deletions

View file

@ -235,11 +235,10 @@ class ExllamaV2Container(BaseModelContainer):
# Grab the base model's sequence length before overrides for # Grab the base model's sequence length before overrides for
# rope calculations # rope calculations
base_seq_len = self.config.max_seq_len base_seq_len = hf_model.hf_config.max_position_embeddings
# Set the target seq len if present # Set the target seq len if present
# Fallback to base_seq_len if not provided target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
target_seq_len = unwrap(kwargs.get("max_seq_len"), base_seq_len)
# Set the rope scale # Set the rope scale
self.config.scale_pos_emb = unwrap( self.config.scale_pos_emb = unwrap(
@ -247,6 +246,7 @@ class ExllamaV2Container(BaseModelContainer):
) )
# Sets rope alpha value. # Sets rope alpha value.
# Utilize the model's max_position_embeddings as a base value
# Automatically calculate if unset or defined as an "auto" literal. # Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto") rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
if rope_alpha == "auto": if rope_alpha == "auto":
@ -371,7 +371,7 @@ class ExllamaV2Container(BaseModelContainer):
) )
# Set draft rope alpha. Follows same behavior as model rope alpha. # Set draft rope alpha. Follows same behavior as model rope alpha.
# Use the base sequence length of the model # Use the max_position_embeddings of the model
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto") draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
if draft_rope_alpha == "auto": if draft_rope_alpha == "auto":
self.draft_config.scale_alpha_value = calculate_rope_alpha( self.draft_config.scale_alpha_value = calculate_rope_alpha(
@ -911,7 +911,7 @@ class ExllamaV2Container(BaseModelContainer):
joined_generation = { joined_generation = {
"text": "", "text": "",
"prompt_tokens": 0, "prompt_tokens": 0,
"generation_tokens": 0, "gen_tokens": 0,
"tool_calls": None, "tool_calls": None,
"offset": [], "offset": [],
"token_probs": {}, "token_probs": {},
@ -921,11 +921,8 @@ class ExllamaV2Container(BaseModelContainer):
if generations: if generations:
# Get finish_reason first and then shift where -1 points to # Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]: if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop() finish_chunk = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get( joined_generation = {**joined_generation, **finish_chunk}
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else: else:
joined_generation["finish_reason"] = "stop" joined_generation["finish_reason"] = "stop"
@ -1187,9 +1184,35 @@ class ExllamaV2Container(BaseModelContainer):
elif eos_reason == "stop_string": elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string") stop_str = result.get("eos_triggering_string")
# Prompt
prompt_tokens = result.get("prompt_tokens")
cached_tokens = round(result.get("cached_tokens"), 2)
prompt_time = round(result.get("time_prefill"), 2)
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
# Generated
gen_tokens = result.get("new_tokens")
gen_time = result.get("time_generate")
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
# Queue + Total
queue_time = result.get("time_enqueued")
total_time = round(queue_time + prompt_time + gen_time, 2)
finish_chunk = { finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"), "prompt_tokens": prompt_tokens,
"generated_tokens": generation.get("generated_tokens"), "prompt_time": round(prompt_time, 2),
"prompt_tokens_per_sec": prompt_ts,
"gen_tokens": gen_tokens,
"gen_time": round(gen_time, 2),
"gen_tokens_per_sec": gen_ts,
"total_time": total_time,
"queue_time": round(queue_time, 2),
"cached_tokens": cached_tokens,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"stop_str": stop_str, "stop_str": stop_str,
} }
@ -1411,12 +1434,12 @@ class ExllamaV2Container(BaseModelContainer):
if result.get("eos"): if result.get("eos"):
log_response(request_id, full_response) log_response(request_id, full_response)
generation = self.handle_finish_chunk(result, generation) finish_chunk = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging # Save the final result for metrics logging
metrics_result = result metrics_result = finish_chunk
yield generation yield finish_chunk
break break
except asyncio.CancelledError: except asyncio.CancelledError:
await job.cancel() await job.cancel()
@ -1449,12 +1472,7 @@ class ExllamaV2Container(BaseModelContainer):
if metrics_result: if metrics_result:
log_metrics( log_metrics(
request_id, request_id,
metrics_result.get("time_enqueued"), metrics_result,
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len, context_len,
max_seq_len, max_seq_len,
) )

View file

@ -649,11 +649,8 @@ class ExllamaV3Container(BaseModelContainer):
if generations: if generations:
# Get finish_reason first and then shift where -1 points to # Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]: if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop() finish_chunk = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get( joined_generation = {**joined_generation, **finish_chunk}
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else: else:
joined_generation["finish_reason"] = "stop" joined_generation["finish_reason"] = "stop"
@ -743,9 +740,35 @@ class ExllamaV3Container(BaseModelContainer):
elif eos_reason == "stop_string": elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string") stop_str = result.get("eos_triggering_string")
# Prompt
prompt_tokens = result.get("prompt_tokens")
cached_tokens = round(result.get("cached_tokens"), 2)
prompt_time = round(result.get("time_prefill"), 2)
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
# Generated
gen_tokens = result.get("new_tokens")
gen_time = result.get("time_generate")
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
# Queue + Total
queue_time = result.get("time_enqueued")
total_time = round(queue_time + prompt_time + gen_time, 2)
finish_chunk = { finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"), "prompt_tokens": prompt_tokens,
"generated_tokens": generation.get("generated_tokens"), "prompt_time": round(prompt_time, 2),
"prompt_tokens_per_sec": prompt_ts,
"gen_tokens": gen_tokens,
"gen_time": round(gen_time, 2),
"gen_tokens_per_sec": gen_ts,
"total_time": total_time,
"queue_time": round(queue_time, 2),
"cached_tokens": cached_tokens,
"finish_reason": finish_reason, "finish_reason": finish_reason,
"stop_str": stop_str, "stop_str": stop_str,
} }
@ -921,12 +944,12 @@ class ExllamaV3Container(BaseModelContainer):
yield generation yield generation
if result.get("eos"): if result.get("eos"):
generation = self.handle_finish_chunk(result, generation) finish_chunk = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging # Save the final result for metrics logging
metrics_result = result metrics_result = finish_chunk
yield generation yield finish_chunk
break break
# Assign the active job to the request ID # Assign the active job to the request ID
self.active_job_ids[request_id] = job self.active_job_ids[request_id] = job
@ -962,12 +985,7 @@ class ExllamaV3Container(BaseModelContainer):
if metrics_result: if metrics_result:
log_metrics( log_metrics(
request_id, request_id,
metrics_result.get("time_enqueued"), metrics_result,
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len, context_len,
self.max_seq_len, self.max_seq_len,
) )

View file

@ -175,10 +175,10 @@ class ModelConfig(BaseConfigModel):
max_seq_len: Optional[int] = Field( max_seq_len: Optional[int] = Field(
None, None,
description=( description=(
"Max sequence length (default: Empty).\n" "Max sequence length (default: 4096).\n"
"Fetched from the model's base sequence length in config.json by default." "Set to -1 to fetch from the model's config.json"
), ),
ge=0, ge=-1,
) )
tensor_parallel: Optional[bool] = Field( tensor_parallel: Optional[bool] = Field(
False, False,

View file

@ -54,40 +54,29 @@ def log_response(request_id: str, response: str):
def log_metrics( def log_metrics(
request_id: str, request_id: str,
queue_time: float, metrics: dict,
prompt_tokens: int,
cached_tokens: int,
prompt_time: float,
generated_tokens: int,
generate_time: float,
context_len: Optional[int], context_len: Optional[int],
max_seq_len: int, max_seq_len: int,
): ):
initial_response = ( initial_response = (
f"Metrics (ID: {request_id}): {generated_tokens} tokens generated in " f"Metrics (ID: {request_id}): {metrics.get('gen_tokens')} "
f"{round(queue_time + prompt_time + generate_time, 2)} seconds" f"tokens generated in {metrics.get('total_time')} seconds"
) )
itemization = [] itemization = []
extra_parts = [] extra_parts = []
itemization.append(f"Queue: {round(queue_time, 2)} s") itemization.append(f"Queue: {metrics.get('queue_time')} s")
cached_tokens = metrics.get("cached_tokens")
prompt_tokens = metrics.get("prompt_tokens")
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
itemization.append( itemization.append(
f"Process: {cached_tokens} cached tokens and " f"Process: {cached_tokens} cached tokens and "
f"{prompt_tokens - cached_tokens} new tokens at {prompt_ts} T/s" f"{prompt_tokens - cached_tokens} new tokens at "
f"{metrics.get('prompt_tokens_per_sec')} T/s"
) )
generate_ts = ( itemization.append(f"Generate: {metrics.get('gen_tokens_per_sec')} T/s")
"Indeterminate"
if generate_time == 0
else round(generated_tokens / generate_time, 2)
)
itemization.append(f"Generate: {generate_ts} T/s")
# Add context (original token count) # Add context (original token count)
if context_len: if context_len:

View file

@ -76,6 +76,9 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
if not override_config_path.exists(): if not override_config_path.exists():
return kwargs return kwargs
# Initialize overrides dict
overrides = {}
async with aiofiles.open( async with aiofiles.open(
override_config_path, "r", encoding="utf8" override_config_path, "r", encoding="utf8"
) as override_config_file: ) as override_config_file:
@ -83,18 +86,25 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
# Create a temporary YAML parser # Create a temporary YAML parser
yaml = YAML(typ="safe") yaml = YAML(typ="safe")
override_args = unwrap(yaml.load(contents), {}) inline_config = unwrap(yaml.load(contents), {})
# Check for inline model overrides
model_inline_config = unwrap(inline_config.get("model"), {})
if model_inline_config:
overrides = {**model_inline_config}
else:
logger.warning(
"Cannot find inline model overrides. "
'Make sure they are nested under a "model:" key'
)
# Merge draft overrides beforehand # Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft_model"), {}) draft_inline_config = unwrap(inline_config.get("draft_model"), {})
if draft_override_args: if draft_inline_config:
kwargs["draft_model"] = { overrides["draft_model"] = {**draft_inline_config}
**draft_override_args,
**unwrap(kwargs.get("draft_model"), {}),
}
# Merge the override and model kwargs # Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs} merged_kwargs = {**overrides, **kwargs}
return merged_kwargs return merged_kwargs
@ -138,6 +148,13 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Fetch the extra HF configuration options # Fetch the extra HF configuration options
hf_model = await HFModel.from_directory(model_path) hf_model = await HFModel.from_directory(model_path)
# Override the max sequence length based on user
max_seq_len = kwargs.get("max_seq_len")
if max_seq_len == -1:
kwargs["max_seq_len"] = hf_model.hf_config.max_position_embeddings
elif max_seq_len is None:
kwargs["max_seq_len"] = 4096
# Create a new container and check if the right dependencies are installed # Create a new container and check if the right dependencies are installed
backend = unwrap(kwargs.get("backend"), detect_backend(hf_model)) backend = unwrap(kwargs.get("backend"), detect_backend(hf_model))
container_class = _BACKEND_REGISTRY.get(backend) container_class = _BACKEND_REGISTRY.get(backend)

View file

@ -39,12 +39,11 @@ class GenerationConfig(BaseModel):
class HuggingFaceConfig(BaseModel): class HuggingFaceConfig(BaseModel):
""" """
DEPRECATED: Currently a stub and doesn't do anything.
An abridged version of HuggingFace's model config. An abridged version of HuggingFace's model config.
Will be expanded as needed. Will be expanded as needed.
""" """
max_position_embeddings: int = 4096
eos_token_id: Optional[Union[int, List[int]]] = None eos_token_id: Optional[Union[int, List[int]]] = None
quantization_config: Optional[Dict] = None quantization_config: Optional[Dict] = None

View file

@ -78,8 +78,8 @@ model:
# Options: exllamav2, exllamav3 # Options: exllamav2, exllamav3
backend: backend:
# Max sequence length (default: Empty). # Max sequence length (default: 4096).
# Fetched from the model's base sequence length in config.json by default. # Set to -1 to fetch from the model's config.json
max_seq_len: max_seq_len:
# Load model with tensor parallelism. # Load model with tensor parallelism.

View file

@ -1,7 +1,7 @@
"""Common types for OAI.""" """Common types for OAI."""
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional from typing import Optional, Union
from common.sampling import BaseSamplerRequest, get_default_sampler_value from common.sampling import BaseSamplerRequest, get_default_sampler_value
@ -10,8 +10,13 @@ class UsageStats(BaseModel):
"""Represents usage stats.""" """Represents usage stats."""
prompt_tokens: int prompt_tokens: int
prompt_time: Optional[float] = None
prompt_tokens_per_sec: Optional[Union[float, str]] = None
completion_tokens: int completion_tokens: int
completion_time: Optional[float] = None
completion_tokens_per_sec: Optional[Union[float, str]] = None
total_tokens: int total_tokens: int
total_time: Optional[float] = None
class CompletionResponseFormat(BaseModel): class CompletionResponseFormat(BaseModel):

View file

@ -38,9 +38,6 @@ def _create_response(
): ):
"""Create a chat completion response from the provided text.""" """Create a chat completion response from the provided text."""
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
choices = [] choices = []
for index, generation in enumerate(generations): for index, generation in enumerate(generations):
message = ChatCompletionMessage( message = ChatCompletionMessage(
@ -91,14 +88,23 @@ def _create_response(
choices.append(choice) choices.append(choice)
final_generation = generations[-1]
prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
response = ChatCompletionResponse( response = ChatCompletionResponse(
id=f"chatcmpl-{request_id}", id=f"cmpl-{request_id}",
choices=choices, choices=choices,
model=unwrap(model_name, ""), model=model_name,
usage=UsageStats( usage=UsageStats(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
prompt_time=final_generation.get("prompt_time"),
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
completion_time=final_generation.get("gen_time"),
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
total_time=final_generation.get("total_time"),
), ),
) )
@ -119,12 +125,17 @@ def _create_stream_chunk(
if is_usage_chunk: if is_usage_chunk:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0) completion_tokens = unwrap(generation.get("gen_tokens"), 0)
usage_stats = UsageStats( usage_stats = UsageStats(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
prompt_time=generation.get("prompt_time"),
prompt_tokens_per_sec=generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
completion_time=generation.get("gen_time"),
completion_tokens_per_sec=generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
total_time=generation.get("total_time"),
) )
elif "finish_reason" in generation: elif "finish_reason" in generation:
# Get the finish reason from the generation # Get the finish reason from the generation

View file

@ -73,8 +73,9 @@ def _create_response(
choices.append(choice) choices.append(choice)
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) final_generation = generations[-1]
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) prompt_tokens = unwrap(final_generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(final_generation.get("gen_tokens"), 0)
response = CompletionResponse( response = CompletionResponse(
id=f"cmpl-{request_id}", id=f"cmpl-{request_id}",
@ -82,8 +83,13 @@ def _create_response(
model=model_name, model=model_name,
usage=UsageStats( usage=UsageStats(
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
prompt_time=final_generation.get("prompt_time"),
prompt_tokens_per_sec=final_generation.get("prompt_tokens_per_sec"),
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
completion_time=final_generation.get("gen_time"),
completion_tokens_per_sec=final_generation.get("gen_tokens_per_sec"),
total_tokens=prompt_tokens + completion_tokens, total_tokens=prompt_tokens + completion_tokens,
total_time=final_generation.get("total_time"),
), ),
) )

View file

@ -1,6 +1,6 @@
"""Contains model card types.""" """Contains model card types."""
from pydantic import AliasChoices, BaseModel, Field, ConfigDict from pydantic import BaseModel, Field, ConfigDict
from time import time from time import time
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
@ -50,10 +50,7 @@ class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request.""" """Represents a draft model load request."""
# Required # Required
draft_model_name: str = Field( draft_model_name: str
alias=AliasChoices("draft_model_name", "name"),
description="Aliases: name",
)
# Config arguments # Config arguments
draft_rope_scale: Optional[float] = None draft_rope_scale: Optional[float] = None
@ -75,10 +72,7 @@ class ModelLoadRequest(BaseModel):
model_config = ConfigDict(protected_namespaces=[]) model_config = ConfigDict(protected_namespaces=[])
# Required # Required
model_name: str = Field( model_name: str
alias=AliasChoices("model_name", "name"),
description="Aliases: name",
)
# Config arguments # Config arguments
backend: Optional[str] = Field( backend: Optional[str] = Field(
@ -118,18 +112,12 @@ class ModelLoadRequest(BaseModel):
vision: Optional[bool] = None vision: Optional[bool] = None
# Non-config arguments # Non-config arguments
draft_model: Optional[DraftModelLoadRequest] = Field( draft_model: Optional[DraftModelLoadRequest] = None
default=None,
alias=AliasChoices("draft_model", "draft"),
)
skip_queue: Optional[bool] = False skip_queue: Optional[bool] = False
class EmbeddingModelLoadRequest(BaseModel): class EmbeddingModelLoadRequest(BaseModel):
embedding_model_name: str = Field( embedding_model_name: str
alias=AliasChoices("embedding_model_name", "name"),
description="Aliases: name",
)
# Set default from the config # Set default from the config
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device) embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)

View file

@ -1,4 +1,4 @@
from pydantic import AliasChoices, BaseModel, Field from pydantic import BaseModel, Field
from typing import List from typing import List
@ -12,7 +12,4 @@ class TemplateList(BaseModel):
class TemplateSwitchRequest(BaseModel): class TemplateSwitchRequest(BaseModel):
"""Request to switch a template.""" """Request to switch a template."""
prompt_template_name: str = Field( prompt_template_name: str
alias=AliasChoices("prompt_template_name", "name"),
description="Aliases: name",
)