API: Don't fallback to default values on model load request
It's best to pass them down the config stack. API/User config.yml -> model config.yml -> model config.json -> fallback. Doing this allows for seamless flow and yielding control to each member in the stack. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
4452d6f665
commit
a96fa5f138
4 changed files with 19 additions and 17 deletions
|
|
@ -361,6 +361,8 @@ class ExllamaV2Container:
|
|||
self.draft_config.max_attention_size = chunk_size**2
|
||||
|
||||
def set_model_overrides(self, **kwargs):
|
||||
"""Sets overrides from a model folder's config yaml."""
|
||||
|
||||
override_config_path = self.model_dir / "tabby_config.yml"
|
||||
|
||||
if not override_config_path.exists():
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ async def unload_embedding_model():
|
|||
embeddings_container = None
|
||||
|
||||
|
||||
def get_config_default(key: str, fallback=None, model_type: str = "model"):
|
||||
def get_config_default(key: str, model_type: str = "model"):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
model_config = config.model_config()
|
||||
|
|
@ -162,14 +162,12 @@ def get_config_default(key: str, fallback=None, model_type: str = "model"):
|
|||
# Is this a draft model load parameter?
|
||||
if model_type == "draft":
|
||||
draft_config = config.draft_model_config()
|
||||
return unwrap(draft_config.get(key), fallback)
|
||||
return draft_config.get(key)
|
||||
elif model_type == "embedding":
|
||||
embeddings_config = config.embeddings_config()
|
||||
return unwrap(embeddings_config.get(key), fallback)
|
||||
return embeddings_config.get(key)
|
||||
else:
|
||||
return unwrap(model_config.get(key), fallback)
|
||||
else:
|
||||
return fallback
|
||||
return model_config.get(key)
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
|
|
|
|||
|
|
@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel):
|
|||
# Config arguments
|
||||
draft_rope_scale: Optional[float] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_scale", 1.0, model_type="draft"
|
||||
"draft_rope_scale", model_type="draft"
|
||||
)
|
||||
)
|
||||
draft_rope_alpha: Optional[float] = Field(
|
||||
description="Automatically calculated if not present",
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_alpha", None, model_type="draft"
|
||||
"draft_rope_alpha", model_type="draft"
|
||||
),
|
||||
examples=[1.0],
|
||||
)
|
||||
draft_cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_cache_mode", "FP16", model_type="draft"
|
||||
"draft_cache_mode", model_type="draft"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -97,16 +97,16 @@ class ModelLoadRequest(BaseModel):
|
|||
examples=[4096],
|
||||
)
|
||||
tensor_parallel: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("tensor_parallel", False)
|
||||
default_factory=lambda: get_config_default("tensor_parallel")
|
||||
)
|
||||
gpu_split_auto: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split_auto", True)
|
||||
default_factory=lambda: get_config_default("gpu_split_auto")
|
||||
)
|
||||
autosplit_reserve: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("autosplit_reserve", [96])
|
||||
default_factory=lambda: get_config_default("autosplit_reserve")
|
||||
)
|
||||
gpu_split: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split", []),
|
||||
default_factory=lambda: get_config_default("gpu_split"),
|
||||
examples=[[24.0, 20.0]],
|
||||
)
|
||||
rope_scale: Optional[float] = Field(
|
||||
|
|
@ -120,10 +120,10 @@ class ModelLoadRequest(BaseModel):
|
|||
examples=[1.0],
|
||||
)
|
||||
cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("cache_mode", "FP16")
|
||||
default_factory=lambda: get_config_default("cache_mode")
|
||||
)
|
||||
chunk_size: Optional[int] = Field(
|
||||
default_factory=lambda: get_config_default("chunk_size", 2048)
|
||||
default_factory=lambda: get_config_default("chunk_size")
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("prompt_template")
|
||||
|
|
@ -132,7 +132,7 @@ class ModelLoadRequest(BaseModel):
|
|||
default_factory=lambda: get_config_default("num_experts_per_token")
|
||||
)
|
||||
fasttensors: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("fasttensors", False)
|
||||
default_factory=lambda: get_config_default("fasttensors")
|
||||
)
|
||||
|
||||
# Non-config arguments
|
||||
|
|
|
|||
|
|
@ -95,8 +95,10 @@ async def stream_model_load(
|
|||
):
|
||||
"""Request generation wrapper for the loading process."""
|
||||
|
||||
# Get trimmed load data
|
||||
load_data = data.model_dump(exclude_none=True)
|
||||
|
||||
# Set the draft model path if it exists
|
||||
load_data = data.model_dump()
|
||||
if draft_model_path:
|
||||
load_data["draft"]["draft_model_dir"] = draft_model_path
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue