Model: Auto detect model backend from config

* Use exllamav3 for exl3 models, exllamav2 otherwise
This commit is contained in:
DocShotgun 2025-05-06 18:51:58 -07:00
parent bc0a84241a
commit f8070e7707
6 changed files with 35 additions and 10 deletions

View file

@ -161,9 +161,7 @@ class ExllamaV3Container(BaseModelContainer):
self.draft_model_dir = draft_model_path self.draft_model_dir = draft_model_path
self.draft_config = Config.from_directory(str(draft_model_path.resolve())) self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
self.draft_model = Model.from_config(self.draft_config) self.draft_model = Model.from_config(self.draft_config)
logger.info( logger.info(f"Using draft model: {str(draft_model_path.resolve())}")
f'Using draft model: {str(draft_model_path.resolve())}'
)
else: else:
self.draft_model = None self.draft_model = None
self.craft_cache = None self.craft_cache = None
@ -223,7 +221,7 @@ class ExllamaV3Container(BaseModelContainer):
# Draft cache # Draft cache
if self.use_draft_model: if self.use_draft_model:
self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size) self.draft_cache = Cache(self.draft_model, max_num_tokens=self.cache_size)
# Max batch size # Max batch size
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256) self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)

View file

@ -166,7 +166,7 @@ class ModelConfig(BaseConfigModel):
backend: Optional[str] = Field( backend: Optional[str] = Field(
None, None,
description=( description=(
"Backend to use for this model (default: exllamav2)\n" "Backend to use for this model (auto-detect if not specified)\n"
"Options: exllamav2, exllamav3" "Options: exllamav2, exllamav3"
), ),
) )

View file

@ -17,6 +17,7 @@ from common.logger import get_loading_progress_bar
from common.networking import handle_request_error from common.networking import handle_request_error
from common.tabby_config import config from common.tabby_config import config
from common.optional_dependencies import dependencies from common.optional_dependencies import dependencies
from common.transformers_utils import HuggingFaceConfig
from common.utils import unwrap from common.utils import unwrap
# Global variables for model container # Global variables for model container
@ -123,8 +124,24 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
kwargs = {**config.model_defaults, **kwargs} kwargs = {**config.model_defaults, **kwargs}
kwargs = await apply_inline_overrides(model_path, **kwargs) kwargs = await apply_inline_overrides(model_path, **kwargs)
# Read config.json and detect the quant method
hf_config_path = model_path / "config.json"
if hf_config_path.exists():
try:
hf_config = await HuggingFaceConfig.from_file(model_path)
except Exception as exc:
raise ValueError(
"Failed to read the model's config.json. "
f"Please check your model directory at {model_path}."
) from exc
quant_method = hf_config.quant_method()
if quant_method == "exl3":
backend_name = "exllamav3"
else:
backend_name = "exllamav2"
# Create a new container and check if the right dependencies are installed # Create a new container and check if the right dependencies are installed
backend_name = unwrap(kwargs.get("backend"), "exllamav2").lower() backend_name = unwrap(kwargs.get("backend"), backend_name).lower()
container_class = _BACKEND_REGISTRY.get(backend_name) container_class = _BACKEND_REGISTRY.get(backend_name)
if not container_class: if not container_class:

View file

@ -1,7 +1,7 @@
import aiofiles import aiofiles
import json import json
import pathlib import pathlib
from typing import List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -42,6 +42,8 @@ class HuggingFaceConfig(BaseModel):
Will be expanded as needed. Will be expanded as needed.
""" """
quantization_config: Optional[Dict] = None
@classmethod @classmethod
async def from_file(cls, model_directory: pathlib.Path): async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file.""" """Create an instance from a generation config file."""
@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel):
hf_config_dict = json.loads(contents) hf_config_dict = json.loads(contents)
return cls.model_validate(hf_config_dict) return cls.model_validate(hf_config_dict)
def quant_method(self):
"""Wrapper method to fetch quant type"""
if isinstance(self.quantization_config, Dict):
return self.quantization_config.get("quant_method")
else:
return None
class TokenizerConfig(BaseModel): class TokenizerConfig(BaseModel):
""" """

View file

@ -74,9 +74,9 @@ model:
# Example: ['max_seq_len', 'cache_mode']. # Example: ['max_seq_len', 'cache_mode'].
use_as_default: [] use_as_default: []
# Backend to use for the model (default: exllamav2) # Backend to use for this model (auto-detect if not specified)
# Options: exllamav2, exllamav3 # Options: exllamav2, exllamav3
backend: exllamav2 backend:
# Max sequence length (default: Empty). # Max sequence length (default: Empty).
# Fetched from the model's base sequence length in config.json by default. # Fetched from the model's base sequence length in config.json by default.

View file

@ -84,7 +84,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
# Chat completions requests do not have a BOS token preference. Backend # Chat completions requests do not have a BOS token preference. Backend
# respects the tokenization config for the individual model. # respects the tokenization config for the individual model.
add_bos_token: Optional[bool] = Field(default = None) add_bos_token: Optional[bool] = Field(default=None)
@field_validator("add_bos_token", mode="after") @field_validator("add_bos_token", mode="after")
def force_bos_token(cls, v): def force_bos_token(cls, v):