Merge pull request #336 from DocShotgun/backend-detect
Automatically select model backend based on config.json
This commit is contained in:
commit
f26ca23f1a
4 changed files with 38 additions and 6 deletions
|
|
@ -168,7 +168,7 @@ class ModelConfig(BaseConfigModel):
|
|||
backend: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Backend to use for this model (default: exllamav2)\n"
|
||||
"Backend to use for this model (auto-detect if not specified)\n"
|
||||
"Options: exllamav2, exllamav3"
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from common.logger import get_loading_progress_bar
|
|||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from common.optional_dependencies import dependencies
|
||||
from common.transformers_utils import HuggingFaceConfig
|
||||
from common.utils import unwrap
|
||||
|
||||
# Global variables for model container
|
||||
|
|
@ -56,6 +57,24 @@ def load_progress(module, modules):
|
|||
yield module, modules
|
||||
|
||||
|
||||
async def detect_backend(model_path: pathlib.Path) -> str:
|
||||
"""Determine the appropriate backend based on model files and configuration."""
|
||||
|
||||
try:
|
||||
hf_config = await HuggingFaceConfig.from_directory(model_path)
|
||||
quant_method = hf_config.quant_method()
|
||||
|
||||
if quant_method == "exl3":
|
||||
return "exllamav3"
|
||||
else:
|
||||
return "exllamav2"
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Failed to read the model's config.json. "
|
||||
f"Please check your model directory at {model_path}."
|
||||
) from exc
|
||||
|
||||
|
||||
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||
"""Sets overrides from a model folder's config yaml."""
|
||||
|
||||
|
|
@ -124,7 +143,9 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
||||
|
||||
# Create a new container and check if the right dependencies are installed
|
||||
backend_name = unwrap(kwargs.get("backend"), "exllamav2").lower()
|
||||
backend_name = unwrap(
|
||||
kwargs.get("backend"), await detect_backend(model_path)
|
||||
).lower()
|
||||
container_class = _BACKEND_REGISTRY.get(backend_name)
|
||||
|
||||
if not container_class:
|
||||
|
|
@ -140,6 +161,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
f"Available backends: {available_backends}"
|
||||
)
|
||||
|
||||
logger.info(f"Using backend {backend_name}")
|
||||
new_container: BaseModelContainer = await container_class.create(
|
||||
model_path.resolve(), **kwargs
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -42,8 +42,10 @@ class HuggingFaceConfig(BaseModel):
|
|||
Will be expanded as needed.
|
||||
"""
|
||||
|
||||
quantization_config: Optional[Dict] = None
|
||||
|
||||
@classmethod
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
async def from_directory(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
hf_config_path = model_directory / "config.json"
|
||||
|
|
@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel):
|
|||
hf_config_dict = json.loads(contents)
|
||||
return cls.model_validate(hf_config_dict)
|
||||
|
||||
def quant_method(self):
|
||||
"""Wrapper method to fetch quant type"""
|
||||
|
||||
if isinstance(self.quantization_config, Dict):
|
||||
return self.quantization_config.get("quant_method")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class TokenizerConfig(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -74,9 +74,9 @@ model:
|
|||
# Example: ['max_seq_len', 'cache_mode'].
|
||||
use_as_default: []
|
||||
|
||||
# Backend to use for the model (default: exllamav2)
|
||||
# Backend to use for this model (auto-detect if not specified)
|
||||
# Options: exllamav2, exllamav3
|
||||
backend: exllamav2
|
||||
backend:
|
||||
|
||||
# Max sequence length (default: Empty).
|
||||
# Fetched from the model's base sequence length in config.json by default.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue