diff --git a/common/config_models.py b/common/config_models.py index 057e216..0958a8e 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -168,7 +168,7 @@ class ModelConfig(BaseConfigModel): backend: Optional[str] = Field( None, description=( - "Backend to use for this model (default: exllamav2)\n" + "Backend to use for this model (auto-detect if not specified)\n" "Options: exllamav2, exllamav3" ), ) diff --git a/common/model.py b/common/model.py index 2c7bd65..9cdfdeb 100644 --- a/common/model.py +++ b/common/model.py @@ -17,6 +17,7 @@ from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.tabby_config import config from common.optional_dependencies import dependencies +from common.transformers_utils import HuggingFaceConfig from common.utils import unwrap # Global variables for model container @@ -56,6 +57,24 @@ def load_progress(module, modules): yield module, modules +async def detect_backend(model_path: pathlib.Path) -> str: + """Determine the appropriate backend based on model files and configuration.""" + + try: + hf_config = await HuggingFaceConfig.from_directory(model_path) + quant_method = hf_config.quant_method() + + if quant_method == "exl3": + return "exllamav3" + else: + return "exllamav2" + except Exception as exc: + raise ValueError( + "Failed to read the model's config.json. " + f"Please check your model directory at {model_path}." + ) from exc + + async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): """Sets overrides from a model folder's config yaml.""" @@ -124,7 +143,9 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): kwargs = await apply_inline_overrides(model_path, **kwargs) # Create a new container and check if the right dependencies are installed - backend_name = unwrap(kwargs.get("backend"), "exllamav2").lower() + backend_name = unwrap( + kwargs.get("backend"), await detect_backend(model_path) + ).lower() container_class = _BACKEND_REGISTRY.get(backend_name) if not container_class: @@ -140,6 +161,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): f"Available backends: {available_backends}" ) + logger.info(f"Using backend {backend_name}") new_container: BaseModelContainer = await container_class.create( model_path.resolve(), **kwargs ) diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 045312c..d1e5ac1 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,7 +1,7 @@ import aiofiles import json import pathlib -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel @@ -42,8 +42,10 @@ class HuggingFaceConfig(BaseModel): Will be expanded as needed. """ + quantization_config: Optional[Dict] = None + @classmethod - async def from_file(cls, model_directory: pathlib.Path): + async def from_directory(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json" @@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel): hf_config_dict = json.loads(contents) return cls.model_validate(hf_config_dict) + def quant_method(self): + """Wrapper method to fetch quant type""" + + if isinstance(self.quantization_config, Dict): + return self.quantization_config.get("quant_method") + else: + return None + class TokenizerConfig(BaseModel): """ diff --git a/config_sample.yml b/config_sample.yml index 045db51..ffe2605 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -74,9 +74,9 @@ model: # Example: ['max_seq_len', 'cache_mode']. use_as_default: [] - # Backend to use for the model (default: exllamav2) + # Backend to use for this model (auto-detect if not specified) # Options: exllamav2, exllamav3 - backend: exllamav2 + backend: # Max sequence length (default: Empty). # Fetched from the model's base sequence length in config.json by default.