diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index fc10a3d..f04c218 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -161,9 +161,7 @@ class ExllamaV3Container(BaseModelContainer): self.draft_model_dir = draft_model_path self.draft_config = Config.from_directory(str(draft_model_path.resolve())) self.draft_model = Model.from_config(self.draft_config) - logger.info( - f'Using draft model: {str(draft_model_path.resolve())}' - ) + logger.info(f"Using draft model: {str(draft_model_path.resolve())}") else: self.draft_model = None self.craft_cache = None @@ -223,7 +221,7 @@ class ExllamaV3Container(BaseModelContainer): # Draft cache if self.use_draft_model: - self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size) + self.draft_cache = Cache(self.draft_model, max_num_tokens=self.cache_size) # Max batch size self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256) diff --git a/common/config_models.py b/common/config_models.py index 9d88eb1..de1f803 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -166,7 +166,7 @@ class ModelConfig(BaseConfigModel): backend: Optional[str] = Field( None, description=( - "Backend to use for this model (default: exllamav2)\n" + "Backend to use for this model (auto-detect if not specified)\n" "Options: exllamav2, exllamav3" ), ) diff --git a/common/model.py b/common/model.py index 2c7bd65..161c7dc 100644 --- a/common/model.py +++ b/common/model.py @@ -17,6 +17,7 @@ from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.tabby_config import config from common.optional_dependencies import dependencies +from common.transformers_utils import HuggingFaceConfig from common.utils import unwrap # Global variables for model container @@ -123,8 +124,24 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): kwargs = {**config.model_defaults, **kwargs} kwargs = await apply_inline_overrides(model_path, **kwargs) + # Read config.json and detect the quant method + hf_config_path = model_path / "config.json" + if hf_config_path.exists(): + try: + hf_config = await HuggingFaceConfig.from_file(model_path) + except Exception as exc: + raise ValueError( + "Failed to read the model's config.json. " + f"Please check your model directory at {model_path}." + ) from exc + quant_method = hf_config.quant_method() + if quant_method == "exl3": + backend_name = "exllamav3" + else: + backend_name = "exllamav2" + # Create a new container and check if the right dependencies are installed - backend_name = unwrap(kwargs.get("backend"), "exllamav2").lower() + backend_name = unwrap(kwargs.get("backend"), backend_name).lower() container_class = _BACKEND_REGISTRY.get(backend_name) if not container_class: diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 045312c..cd79f00 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,7 +1,7 @@ import aiofiles import json import pathlib -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel @@ -42,6 +42,8 @@ class HuggingFaceConfig(BaseModel): Will be expanded as needed. """ + quantization_config: Optional[Dict] = None + @classmethod async def from_file(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" @@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel): hf_config_dict = json.loads(contents) return cls.model_validate(hf_config_dict) + def quant_method(self): + """Wrapper method to fetch quant type""" + + if isinstance(self.quantization_config, Dict): + return self.quantization_config.get("quant_method") + else: + return None + class TokenizerConfig(BaseModel): """ diff --git a/config_sample.yml b/config_sample.yml index a743c2c..025d2ee 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -74,9 +74,9 @@ model: # Example: ['max_seq_len', 'cache_mode']. use_as_default: [] - # Backend to use for the model (default: exllamav2) + # Backend to use for this model (auto-detect if not specified) # Options: exllamav2, exllamav3 - backend: exllamav2 + backend: # Max sequence length (default: Empty). # Fetched from the model's base sequence length in config.json by default. diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 51695c2..fb73eb9 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -84,7 +84,7 @@ class ChatCompletionRequest(CommonCompletionRequest): # Chat completions requests do not have a BOS token preference. Backend # respects the tokenization config for the individual model. - add_bos_token: Optional[bool] = Field(default = None) + add_bos_token: Optional[bool] = Field(default=None) @field_validator("add_bos_token", mode="after") def force_bos_token(cls, v):