Model: Fix inline loading and draft key (#225)

* Model: Fix inline loading and draft key

There was a lack of foresight between the new config.yml and how
it was structured. The "draft" key became "draft_model" without updating
both the API request and inline loading keys.

For the API requests, still support "draft" as legacy, but the "draft_model"
key is preferred.

Signed-off-by: kingbri <bdashore3@proton.me>

* OAI: Add draft model dir to inline load

Was not pushed before and caused errors of the kwargs being None.

Signed-off-by: kingbri <bdashore3@proton.me>

* Model: Fix draft args application

Draft model args weren't applying since there was a reset due to how
the old override behavior worked.

Signed-off-by: kingbri <bdashore3@proton.me>

* OAI: Change embedding model load params

Use embedding_model_name to be inline with the config.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Fix parameter for draft model load

Alias name to draft_model_name.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Fix parameter for template switch

Add prompt_template_name to be more descriptive.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Fix parameter for model load

Alias name to model_name for config parity.

Signed-off-by: kingbri <bdashore3@proton.me>

* API: Add alias documentation

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
Brian Dashore 2024-10-24 23:35:05 -04:00 committed by GitHub
parent f20857cb34
commit 6e48bb420a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 68 additions and 46 deletions

View file

@ -129,8 +129,27 @@ class ExllamaV2Container:
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()
# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft"), {})
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
@ -154,25 +173,6 @@ class ExllamaV2Container:
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)
# MARK: User configuration
# Get cache mode
@ -338,9 +338,6 @@ class ExllamaV2Container:
# Set user-configured draft model values
if enable_draft:
# Fetch from the updated kwargs
draft_args = unwrap(kwargs.get("draft"), {})
self.draft_config.max_seq_len = self.config.max_seq_len
self.draft_config.scale_pos_emb = unwrap(
@ -384,9 +381,12 @@ class ExllamaV2Container:
override_args = unwrap(yaml.load(contents), {})
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
if self.draft_config and draft_override_args:
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
draft_override_args = unwrap(override_args.get("draft_model"), {})
if draft_override_args:
kwargs["draft_model"] = {
**draft_override_args,
**unwrap(kwargs.get("draft_model"), {}),
}
# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}

View file

@ -149,8 +149,11 @@ async def load_inline_model(model_name: str, request: Request):
return
# Load the model
await model.load_model(model_path)
# Load the model and also add draft dir
await model.load_model(
model_path,
draft_model=config.draft_model.model_dump(include={"draft_model_dir"}),
)
async def stream_generate_completion(

View file

@ -123,7 +123,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""
# Verify request parameters
if not data.name:
if not data.model_name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
@ -132,11 +132,11 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / data.name
model_path = model_path / data.model_name
draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
if data.draft_model:
if not data.draft_model.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
@ -301,7 +301,7 @@ async def load_embedding_model(
request: Request, data: EmbeddingModelLoadRequest
) -> ModelLoadResponse:
# Verify request parameters
if not data.name:
if not data.embedding_model_name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
@ -310,7 +310,7 @@ async def load_embedding_model(
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
embedding_model_path = embedding_model_dir / data.name
embedding_model_path = embedding_model_dir / data.embedding_model_name
if not embedding_model_path.exists():
error_message = handle_request_error(
@ -441,7 +441,7 @@ async def list_templates(request: Request) -> TemplateList:
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template."""
if not data.name:
if not data.prompt_template_name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
@ -450,11 +450,12 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message)
try:
template_path = pathlib.Path("templates") / data.name
template_path = pathlib.Path("templates") / data.prompt_template_name
model.container.prompt_template = await PromptTemplate.from_file(template_path)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
f"The template name {data.prompt_template_name} doesn't exist. "
+ "Check the spelling?",
exc_info=False,
).error.message

View file

@ -1,6 +1,6 @@
"""Contains model card types."""
from pydantic import BaseModel, Field, ConfigDict
from pydantic import AliasChoices, BaseModel, Field, ConfigDict
from time import time
from typing import List, Literal, Optional, Union
@ -48,7 +48,10 @@ class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""
# Required
draft_model_name: str
draft_model_name: str = Field(
alias=AliasChoices("draft_model_name", "name"),
description="Aliases: name",
)
# Config arguments
draft_rope_scale: Optional[float] = None
@ -63,8 +66,14 @@ class DraftModelLoadRequest(BaseModel):
class ModelLoadRequest(BaseModel):
"""Represents a model load request."""
# Avoids pydantic namespace warning
model_config = ConfigDict(protected_namespaces=[])
# Required
name: str
model_name: str = Field(
alias=AliasChoices("model_name", "name"),
description="Aliases: name",
)
# Config arguments
@ -108,12 +117,18 @@ class ModelLoadRequest(BaseModel):
num_experts_per_token: Optional[int] = None
# Non-config arguments
draft: Optional[DraftModelLoadRequest] = None
draft_model: Optional[DraftModelLoadRequest] = Field(
default=None,
alias=AliasChoices("draft_model", "draft"),
)
skip_queue: Optional[bool] = False
class EmbeddingModelLoadRequest(BaseModel):
name: str
embedding_model_name: str = Field(
alias=AliasChoices("embedding_model_name", "name"),
description="Aliases: name",
)
# Set default from the config
embeddings_device: Optional[str] = Field(config.embeddings.embeddings_device)

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import AliasChoices, BaseModel, Field
from typing import List
@ -12,4 +12,7 @@ class TemplateList(BaseModel):
class TemplateSwitchRequest(BaseModel):
"""Request to switch a template."""
name: str
prompt_template_name: str = Field(
alias=AliasChoices("prompt_template_name", "name"),
description="Aliases: name",
)

View file

@ -104,7 +104,7 @@ async def stream_model_load(
# Set the draft model path if it exists
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path
load_data["draft_model"]["draft_model_dir"] = draft_model_path
load_status = model.load_model_gen(
model_path, skip_wait=data.skip_queue, **load_data

View file

@ -70,7 +70,7 @@ async def entrypoint_async():
await model.load_model(
model_path.resolve(),
**config.model.model_dump(exclude_none=True),
draft=config.draft_model.model_dump(exclude_none=True),
draft_model=config.draft_model.model_dump(exclude_none=True),
)
# Load loras after loading the model