Model: Fix inline loading and draft key (#225)
* Model: Fix inline loading and draft key There was a lack of foresight between the new config.yml and how it was structured. The "draft" key became "draft_model" without updating both the API request and inline loading keys. For the API requests, still support "draft" as legacy, but the "draft_model" key is preferred. Signed-off-by: kingbri <bdashore3@proton.me> * OAI: Add draft model dir to inline load Was not pushed before and caused errors of the kwargs being None. Signed-off-by: kingbri <bdashore3@proton.me> * Model: Fix draft args application Draft model args weren't applying since there was a reset due to how the old override behavior worked. Signed-off-by: kingbri <bdashore3@proton.me> * OAI: Change embedding model load params Use embedding_model_name to be inline with the config. Signed-off-by: kingbri <bdashore3@proton.me> * API: Fix parameter for draft model load Alias name to draft_model_name. Signed-off-by: kingbri <bdashore3@proton.me> * API: Fix parameter for template switch Add prompt_template_name to be more descriptive. Signed-off-by: kingbri <bdashore3@proton.me> * API: Fix parameter for model load Alias name to model_name for config parity. Signed-off-by: kingbri <bdashore3@proton.me> * API: Add alias documentation Signed-off-by: kingbri <bdashore3@proton.me> --------- Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
f20857cb34
commit
6e48bb420a
7 changed files with 68 additions and 46 deletions
|
|
@ -129,8 +129,27 @@ class ExllamaV2Container:
|
|||
# Check if the model arch is compatible with various exl2 features
|
||||
self.config.arch_compat_overrides()
|
||||
|
||||
# Create the hf_config
|
||||
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = await GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Apply a model's config overrides while respecting user settings
|
||||
kwargs = await self.set_model_overrides(**kwargs)
|
||||
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
|
||||
|
|
@ -154,25 +173,6 @@ class ExllamaV2Container:
|
|||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
# Create the hf_config
|
||||
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = await GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Apply a model's config overrides while respecting user settings
|
||||
kwargs = await self.set_model_overrides(**kwargs)
|
||||
|
||||
# MARK: User configuration
|
||||
|
||||
# Get cache mode
|
||||
|
|
@ -338,9 +338,6 @@ class ExllamaV2Container:
|
|||
|
||||
# Set user-configured draft model values
|
||||
if enable_draft:
|
||||
# Fetch from the updated kwargs
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
self.draft_config.scale_pos_emb = unwrap(
|
||||
|
|
@ -384,9 +381,12 @@ class ExllamaV2Container:
|
|||
override_args = unwrap(yaml.load(contents), {})
|
||||
|
||||
# Merge draft overrides beforehand
|
||||
draft_override_args = unwrap(override_args.get("draft"), {})
|
||||
if self.draft_config and draft_override_args:
|
||||
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
|
||||
draft_override_args = unwrap(override_args.get("draft_model"), {})
|
||||
if draft_override_args:
|
||||
kwargs["draft_model"] = {
|
||||
**draft_override_args,
|
||||
**unwrap(kwargs.get("draft_model"), {}),
|
||||
}
|
||||
|
||||
# Merge the override and model kwargs
|
||||
merged_kwargs = {**override_args, **kwargs}
|
||||
|
|
|
|||
|
|
@ -149,8 +149,11 @@ async def load_inline_model(model_name: str, request: Request):
|
|||
|
||||
return
|
||||
|
||||
# Load the model
|
||||
await model.load_model(model_path)
|
||||
# Load the model and also add draft dir
|
||||
await model.load_model(
|
||||
model_path,
|
||||
draft_model=config.draft_model.model_dump(include={"draft_model_dir"}),
|
||||
)
|
||||
|
||||
|
||||
async def stream_generate_completion(
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||
"""Loads a model into the model container. This returns an SSE stream."""
|
||||
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
if not data.model_name:
|
||||
error_message = handle_request_error(
|
||||
"A model name was not provided for load.",
|
||||
exc_info=False,
|
||||
|
|
@ -132,11 +132,11 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
|||
raise HTTPException(400, error_message)
|
||||
|
||||
model_path = pathlib.Path(config.model.model_dir)
|
||||
model_path = model_path / data.name
|
||||
model_path = model_path / data.model_name
|
||||
|
||||
draft_model_path = None
|
||||
if data.draft:
|
||||
if not data.draft.draft_model_name:
|
||||
if data.draft_model:
|
||||
if not data.draft_model.draft_model_name:
|
||||
error_message = handle_request_error(
|
||||
"Could not find the draft model name for model load.",
|
||||
exc_info=False,
|
||||
|
|
@ -301,7 +301,7 @@ async def load_embedding_model(
|
|||
request: Request, data: EmbeddingModelLoadRequest
|
||||
) -> ModelLoadResponse:
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
if not data.embedding_model_name:
|
||||
error_message = handle_request_error(
|
||||
"A model name was not provided for load.",
|
||||
exc_info=False,
|
||||
|
|
@ -310,7 +310,7 @@ async def load_embedding_model(
|
|||
raise HTTPException(400, error_message)
|
||||
|
||||
embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
|
||||
embedding_model_path = embedding_model_dir / data.name
|
||||
embedding_model_path = embedding_model_dir / data.embedding_model_name
|
||||
|
||||
if not embedding_model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
|
|
@ -441,7 +441,7 @@ async def list_templates(request: Request) -> TemplateList:
|
|||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template."""
|
||||
|
||||
if not data.name:
|
||||
if not data.prompt_template_name:
|
||||
error_message = handle_request_error(
|
||||
"New template name not found.",
|
||||
exc_info=False,
|
||||
|
|
@ -450,11 +450,12 @@ async def switch_template(data: TemplateSwitchRequest):
|
|||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
template_path = pathlib.Path("templates") / data.name
|
||||
template_path = pathlib.Path("templates") / data.prompt_template_name
|
||||
model.container.prompt_template = await PromptTemplate.from_file(template_path)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
f"The template name {data.prompt_template_name} doesn't exist. "
|
||||
+ "Check the spelling?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Contains model card types."""
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from pydantic import AliasChoices, BaseModel, Field, ConfigDict
|
||||
from time import time
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
|
|
@ -48,7 +48,10 @@ class DraftModelLoadRequest(BaseModel):
|
|||
"""Represents a draft model load request."""
|
||||
|
||||
# Required
|
||||
draft_model_name: str
|
||||
draft_model_name: str = Field(
|
||||
alias=AliasChoices("draft_model_name", "name"),
|
||||
description="Aliases: name",
|
||||
)
|
||||
|
||||
# Config arguments
|
||||
draft_rope_scale: Optional[float] = None
|
||||
|
|
@ -63,8 +66,14 @@ class DraftModelLoadRequest(BaseModel):
|
|||
class ModelLoadRequest(BaseModel):
|
||||
"""Represents a model load request."""
|
||||
|
||||
# Avoids pydantic namespace warning
|
||||
model_config = ConfigDict(protected_namespaces=[])
|
||||
|
||||
# Required
|
||||
name: str
|
||||
model_name: str = Field(
|
||||
alias=AliasChoices("model_name", "name"),
|
||||
description="Aliases: name",
|
||||
)
|
||||
|
||||
# Config arguments
|
||||
|
||||
|
|
@ -108,12 +117,18 @@ class ModelLoadRequest(BaseModel):
|
|||
num_experts_per_token: Optional[int] = None
|
||||
|
||||
# Non-config arguments
|
||||
draft: Optional[DraftModelLoadRequest] = None
|
||||
draft_model: Optional[DraftModelLoadRequest] = Field(
|
||||
default=None,
|
||||
alias=AliasChoices("draft_model", "draft"),
|
||||
)
|
||||
skip_queue: Optional[bool] = False
|
||||
|
||||
|
||||
class EmbeddingModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
embedding_model_name: str = Field(
|
||||
alias=AliasChoices("embedding_model_name", "name"),
|
||||
description="Aliases: name",
|
||||
)
|
||||
|
||||
# Set default from the config
|
||||
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import AliasChoices, BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
|
||||
|
|
@ -12,4 +12,7 @@ class TemplateList(BaseModel):
|
|||
class TemplateSwitchRequest(BaseModel):
|
||||
"""Request to switch a template."""
|
||||
|
||||
name: str
|
||||
prompt_template_name: str = Field(
|
||||
alias=AliasChoices("prompt_template_name", "name"),
|
||||
description="Aliases: name",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ async def stream_model_load(
|
|||
|
||||
# Set the draft model path if it exists
|
||||
if draft_model_path:
|
||||
load_data["draft"]["draft_model_dir"] = draft_model_path
|
||||
load_data["draft_model"]["draft_model_dir"] = draft_model_path
|
||||
|
||||
load_status = model.load_model_gen(
|
||||
model_path, skip_wait=data.skip_queue, **load_data
|
||||
|
|
|
|||
2
main.py
2
main.py
|
|
@ -70,7 +70,7 @@ async def entrypoint_async():
|
|||
await model.load_model(
|
||||
model_path.resolve(),
|
||||
**config.model.model_dump(exclude_none=True),
|
||||
draft=config.draft_model.model_dump(exclude_none=True),
|
||||
draft_model=config.draft_model.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
# Load loras after loading the model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue