Merge pull request #336 from DocShotgun/backend-detect

Automatically select model backend based on config.json
This commit is contained in:
Brian 2025-05-09 01:56:44 -04:00 committed by GitHub
commit f26ca23f1a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 38 additions and 6 deletions

View file

@ -168,7 +168,7 @@ class ModelConfig(BaseConfigModel):
backend: Optional[str] = Field(
None,
description=(
"Backend to use for this model (default: exllamav2)\n"
"Backend to use for this model (auto-detect if not specified)\n"
"Options: exllamav2, exllamav3"
),
)

View file

@ -17,6 +17,7 @@ from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.optional_dependencies import dependencies
from common.transformers_utils import HuggingFaceConfig
from common.utils import unwrap
# Global variables for model container
@ -56,6 +57,24 @@ def load_progress(module, modules):
yield module, modules
async def detect_backend(model_path: pathlib.Path) -> str:
"""Determine the appropriate backend based on model files and configuration."""
try:
hf_config = await HuggingFaceConfig.from_directory(model_path)
quant_method = hf_config.quant_method()
if quant_method == "exl3":
return "exllamav3"
else:
return "exllamav2"
except Exception as exc:
raise ValueError(
"Failed to read the model's config.json. "
f"Please check your model directory at {model_path}."
) from exc
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
"""Sets overrides from a model folder's config yaml."""
@ -124,7 +143,9 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
kwargs = await apply_inline_overrides(model_path, **kwargs)
# Create a new container and check if the right dependencies are installed
backend_name = unwrap(kwargs.get("backend"), "exllamav2").lower()
backend_name = unwrap(
kwargs.get("backend"), await detect_backend(model_path)
).lower()
container_class = _BACKEND_REGISTRY.get(backend_name)
if not container_class:
@ -140,6 +161,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
f"Available backends: {available_backends}"
)
logger.info(f"Using backend {backend_name}")
new_container: BaseModelContainer = await container_class.create(
model_path.resolve(), **kwargs
)

View file

@ -1,7 +1,7 @@
import aiofiles
import json
import pathlib
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
from pydantic import BaseModel
@ -42,8 +42,10 @@ class HuggingFaceConfig(BaseModel):
Will be expanded as needed.
"""
quantization_config: Optional[Dict] = None
@classmethod
async def from_file(cls, model_directory: pathlib.Path):
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel):
hf_config_dict = json.loads(contents)
return cls.model_validate(hf_config_dict)
def quant_method(self):
"""Wrapper method to fetch quant type"""
if isinstance(self.quantization_config, Dict):
return self.quantization_config.get("quant_method")
else:
return None
class TokenizerConfig(BaseModel):
"""

View file

@ -74,9 +74,9 @@ model:
# Example: ['max_seq_len', 'cache_mode'].
use_as_default: []
# Backend to use for the model (default: exllamav2)
# Backend to use for this model (auto-detect if not specified)
# Options: exllamav2, exllamav3
backend: exllamav2
backend:
# Max sequence length (default: Empty).
# Fetched from the model's base sequence length in config.json by default.