Model: Auto detect model backend from config

* Use exllamav3 for exl3 models, exllamav2 otherwise
This commit is contained in:
DocShotgun 2025-05-06 18:51:58 -07:00
parent bc0a84241a
commit f8070e7707
6 changed files with 35 additions and 10 deletions

View file

@ -161,9 +161,7 @@ class ExllamaV3Container(BaseModelContainer):
self.draft_model_dir = draft_model_path
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
self.draft_model = Model.from_config(self.draft_config)
logger.info(
f'Using draft model: {str(draft_model_path.resolve())}'
)
logger.info(f"Using draft model: {str(draft_model_path.resolve())}")
else:
self.draft_model = None
self.craft_cache = None
@ -223,7 +221,7 @@ class ExllamaV3Container(BaseModelContainer):
# Draft cache
if self.use_draft_model:
self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size)
self.draft_cache = Cache(self.draft_model, max_num_tokens=self.cache_size)
# Max batch size
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)

View file

@ -166,7 +166,7 @@ class ModelConfig(BaseConfigModel):
backend: Optional[str] = Field(
None,
description=(
"Backend to use for this model (default: exllamav2)\n"
"Backend to use for this model (auto-detect if not specified)\n"
"Options: exllamav2, exllamav3"
),
)

View file

@ -17,6 +17,7 @@ from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.optional_dependencies import dependencies
from common.transformers_utils import HuggingFaceConfig
from common.utils import unwrap
# Global variables for model container
@ -123,8 +124,24 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
kwargs = {**config.model_defaults, **kwargs}
kwargs = await apply_inline_overrides(model_path, **kwargs)
# Read config.json and detect the quant method
hf_config_path = model_path / "config.json"
if hf_config_path.exists():
try:
hf_config = await HuggingFaceConfig.from_file(model_path)
except Exception as exc:
raise ValueError(
"Failed to read the model's config.json. "
f"Please check your model directory at {model_path}."
) from exc
quant_method = hf_config.quant_method()
if quant_method == "exl3":
backend_name = "exllamav3"
else:
backend_name = "exllamav2"
# Create a new container and check if the right dependencies are installed
backend_name = unwrap(kwargs.get("backend"), "exllamav2").lower()
backend_name = unwrap(kwargs.get("backend"), backend_name).lower()
container_class = _BACKEND_REGISTRY.get(backend_name)
if not container_class:

View file

@ -1,7 +1,7 @@
import aiofiles
import json
import pathlib
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
from pydantic import BaseModel
@ -42,6 +42,8 @@ class HuggingFaceConfig(BaseModel):
Will be expanded as needed.
"""
quantization_config: Optional[Dict] = None
@classmethod
async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel):
hf_config_dict = json.loads(contents)
return cls.model_validate(hf_config_dict)
def quant_method(self):
"""Wrapper method to fetch quant type"""
if isinstance(self.quantization_config, Dict):
return self.quantization_config.get("quant_method")
else:
return None
class TokenizerConfig(BaseModel):
"""

View file

@ -74,9 +74,9 @@ model:
# Example: ['max_seq_len', 'cache_mode'].
use_as_default: []
# Backend to use for the model (default: exllamav2)
# Backend to use for this model (auto-detect if not specified)
# Options: exllamav2, exllamav3
backend: exllamav2
backend:
# Max sequence length (default: Empty).
# Fetched from the model's base sequence length in config.json by default.

View file

@ -84,7 +84,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
# Chat completions requests do not have a BOS token preference. Backend
# respects the tokenization config for the individual model.
add_bos_token: Optional[bool] = Field(default = None)
add_bos_token: Optional[bool] = Field(default=None)
@field_validator("add_bos_token", mode="after")
def force_bos_token(cls, v):