Model: Auto detect model backend from config
* Use exllamav3 for exl3 models, exllamav2 otherwise
This commit is contained in:
parent
bc0a84241a
commit
f8070e7707
6 changed files with 35 additions and 10 deletions
|
|
@ -161,9 +161,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
self.draft_model_dir = draft_model_path
|
self.draft_model_dir = draft_model_path
|
||||||
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
|
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
|
||||||
self.draft_model = Model.from_config(self.draft_config)
|
self.draft_model = Model.from_config(self.draft_config)
|
||||||
logger.info(
|
logger.info(f"Using draft model: {str(draft_model_path.resolve())}")
|
||||||
f'Using draft model: {str(draft_model_path.resolve())}'
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
self.draft_model = None
|
self.draft_model = None
|
||||||
self.craft_cache = None
|
self.craft_cache = None
|
||||||
|
|
@ -223,7 +221,7 @@ class ExllamaV3Container(BaseModelContainer):
|
||||||
|
|
||||||
# Draft cache
|
# Draft cache
|
||||||
if self.use_draft_model:
|
if self.use_draft_model:
|
||||||
self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size)
|
self.draft_cache = Cache(self.draft_model, max_num_tokens=self.cache_size)
|
||||||
|
|
||||||
# Max batch size
|
# Max batch size
|
||||||
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
|
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
|
||||||
|
|
|
||||||
|
|
@ -166,7 +166,7 @@ class ModelConfig(BaseConfigModel):
|
||||||
backend: Optional[str] = Field(
|
backend: Optional[str] = Field(
|
||||||
None,
|
None,
|
||||||
description=(
|
description=(
|
||||||
"Backend to use for this model (default: exllamav2)\n"
|
"Backend to use for this model (auto-detect if not specified)\n"
|
||||||
"Options: exllamav2, exllamav3"
|
"Options: exllamav2, exllamav3"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ from common.logger import get_loading_progress_bar
|
||||||
from common.networking import handle_request_error
|
from common.networking import handle_request_error
|
||||||
from common.tabby_config import config
|
from common.tabby_config import config
|
||||||
from common.optional_dependencies import dependencies
|
from common.optional_dependencies import dependencies
|
||||||
|
from common.transformers_utils import HuggingFaceConfig
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
|
|
||||||
# Global variables for model container
|
# Global variables for model container
|
||||||
|
|
@ -123,8 +124,24 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||||
kwargs = {**config.model_defaults, **kwargs}
|
kwargs = {**config.model_defaults, **kwargs}
|
||||||
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
||||||
|
|
||||||
|
# Read config.json and detect the quant method
|
||||||
|
hf_config_path = model_path / "config.json"
|
||||||
|
if hf_config_path.exists():
|
||||||
|
try:
|
||||||
|
hf_config = await HuggingFaceConfig.from_file(model_path)
|
||||||
|
except Exception as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Failed to read the model's config.json. "
|
||||||
|
f"Please check your model directory at {model_path}."
|
||||||
|
) from exc
|
||||||
|
quant_method = hf_config.quant_method()
|
||||||
|
if quant_method == "exl3":
|
||||||
|
backend_name = "exllamav3"
|
||||||
|
else:
|
||||||
|
backend_name = "exllamav2"
|
||||||
|
|
||||||
# Create a new container and check if the right dependencies are installed
|
# Create a new container and check if the right dependencies are installed
|
||||||
backend_name = unwrap(kwargs.get("backend"), "exllamav2").lower()
|
backend_name = unwrap(kwargs.get("backend"), backend_name).lower()
|
||||||
container_class = _BACKEND_REGISTRY.get(backend_name)
|
container_class = _BACKEND_REGISTRY.get(backend_name)
|
||||||
|
|
||||||
if not container_class:
|
if not container_class:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -42,6 +42,8 @@ class HuggingFaceConfig(BaseModel):
|
||||||
Will be expanded as needed.
|
Will be expanded as needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
quantization_config: Optional[Dict] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_file(cls, model_directory: pathlib.Path):
|
async def from_file(cls, model_directory: pathlib.Path):
|
||||||
"""Create an instance from a generation config file."""
|
"""Create an instance from a generation config file."""
|
||||||
|
|
@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel):
|
||||||
hf_config_dict = json.loads(contents)
|
hf_config_dict = json.loads(contents)
|
||||||
return cls.model_validate(hf_config_dict)
|
return cls.model_validate(hf_config_dict)
|
||||||
|
|
||||||
|
def quant_method(self):
|
||||||
|
"""Wrapper method to fetch quant type"""
|
||||||
|
|
||||||
|
if isinstance(self.quantization_config, Dict):
|
||||||
|
return self.quantization_config.get("quant_method")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class TokenizerConfig(BaseModel):
|
class TokenizerConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -74,9 +74,9 @@ model:
|
||||||
# Example: ['max_seq_len', 'cache_mode'].
|
# Example: ['max_seq_len', 'cache_mode'].
|
||||||
use_as_default: []
|
use_as_default: []
|
||||||
|
|
||||||
# Backend to use for the model (default: exllamav2)
|
# Backend to use for this model (auto-detect if not specified)
|
||||||
# Options: exllamav2, exllamav3
|
# Options: exllamav2, exllamav3
|
||||||
backend: exllamav2
|
backend:
|
||||||
|
|
||||||
# Max sequence length (default: Empty).
|
# Max sequence length (default: Empty).
|
||||||
# Fetched from the model's base sequence length in config.json by default.
|
# Fetched from the model's base sequence length in config.json by default.
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
||||||
|
|
||||||
# Chat completions requests do not have a BOS token preference. Backend
|
# Chat completions requests do not have a BOS token preference. Backend
|
||||||
# respects the tokenization config for the individual model.
|
# respects the tokenization config for the individual model.
|
||||||
add_bos_token: Optional[bool] = Field(default = None)
|
add_bos_token: Optional[bool] = Field(default=None)
|
||||||
|
|
||||||
@field_validator("add_bos_token", mode="after")
|
@field_validator("add_bos_token", mode="after")
|
||||||
def force_bos_token(cls, v):
|
def force_bos_token(cls, v):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue