Model: Create universal HFModel class

The HFModel class serves to coalesce all config files that contain
random keys which are required for model usage.

Adding this base class allows us to expand as HuggingFace randomly
changes their JSON schemas over time, reducing the brunt that backend
devs need to feel when their next model isn't supported.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-13 18:12:38 -04:00
parent 7900b72848
commit 390daeb92f
5 changed files with 149 additions and 127 deletions

View file

@ -17,7 +17,7 @@ from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.optional_dependencies import dependencies
from common.transformers_utils import HuggingFaceConfig
from common.transformers_utils import HFModel
from common.utils import unwrap
# Global variables for model container
@ -57,22 +57,15 @@ def load_progress(module, modules):
yield module, modules
async def detect_backend(model_path: pathlib.Path) -> str:
def detect_backend(hf_model: HFModel) -> str:
"""Determine the appropriate backend based on model files and configuration."""
try:
hf_config = await HuggingFaceConfig.from_directory(model_path)
quant_method = hf_config.quant_method()
quant_method = hf_model.quant_method()
if quant_method == "exl3":
return "exllamav3"
else:
return "exllamav2"
except Exception as exc:
raise ValueError(
"Failed to read the model's config.json. "
f"Please check your model directory at {model_path}."
) from exc
if quant_method == "exl3":
return "exllamav3"
else:
return "exllamav2"
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
@ -142,28 +135,29 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
kwargs = {**config.model_defaults, **kwargs}
kwargs = await apply_inline_overrides(model_path, **kwargs)
# Fetch the extra HF configuration options
hf_model = await HFModel.from_directory(model_path)
# Create a new container and check if the right dependencies are installed
backend_name = unwrap(
kwargs.get("backend"), await detect_backend(model_path)
).lower()
container_class = _BACKEND_REGISTRY.get(backend_name)
backend = unwrap(kwargs.get("backend"), detect_backend(hf_model))
container_class = _BACKEND_REGISTRY.get(backend)
if not container_class:
available_backends = list(_BACKEND_REGISTRY.keys())
if backend_name in available_backends:
if backend in available_backends:
raise ValueError(
f"Backend '{backend_name}' selected, but required dependencies "
f"Backend '{backend}' selected, but required dependencies "
"are not installed."
)
else:
raise ValueError(
f"Invalid backend '{backend_name}'. "
f"Invalid backend '{backend}'. "
f"Available backends: {available_backends}"
)
logger.info(f"Using backend {backend_name}")
logger.info(f"Using backend {backend}")
new_container: BaseModelContainer = await container_class.create(
model_path.resolve(), **kwargs
model_path.resolve(), hf_model, **kwargs
)
# Add possible types of models that can be loaded

View file

@ -1,8 +1,9 @@
import aiofiles
import json
import pathlib
from typing import Dict, List, Optional, Union
from loguru import logger
from pydantic import BaseModel
from typing import Dict, List, Optional, Set, Union
class GenerationConfig(BaseModel):
@ -14,7 +15,7 @@ class GenerationConfig(BaseModel):
eos_token_id: Optional[Union[int, List[int]]] = None
@classmethod
async def from_file(cls, model_directory: pathlib.Path):
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
generation_config_path = model_directory / "generation_config.json"
@ -28,10 +29,12 @@ class GenerationConfig(BaseModel):
def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""
if isinstance(self.eos_token_id, int):
if isinstance(self.eos_token_id, list):
return self.eos_token_id
elif isinstance(self.eos_token_id, int):
return [self.eos_token_id]
else:
return self.eos_token_id
return []
class HuggingFaceConfig(BaseModel):
@ -42,6 +45,7 @@ class HuggingFaceConfig(BaseModel):
Will be expanded as needed.
"""
eos_token_id: Optional[Union[int, List[int]]] = None
quantization_config: Optional[Dict] = None
@classmethod
@ -64,6 +68,16 @@ class HuggingFaceConfig(BaseModel):
else:
return None
def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""
if isinstance(self.eos_token_id, list):
return self.eos_token_id
elif isinstance(self.eos_token_id, int):
return [self.eos_token_id]
else:
return []
class TokenizerConfig(BaseModel):
"""
@ -73,7 +87,7 @@ class TokenizerConfig(BaseModel):
add_bos_token: Optional[bool] = True
@classmethod
async def from_file(cls, model_directory: pathlib.Path):
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a tokenizer config file."""
tokenizer_config_path = model_directory / "tokenizer_config.json"
@ -83,3 +97,81 @@ class TokenizerConfig(BaseModel):
contents = await tokenizer_config_json.read()
tokenizer_config_dict = json.loads(contents)
return cls.model_validate(tokenizer_config_dict)
class HFModel:
"""
Unified container for HuggingFace model configuration files.
These are abridged for hyper-specific model parameters not covered
by most backends.
Includes:
- config.json
- generation_config.json
- tokenizer_config.json
"""
hf_config: HuggingFaceConfig
tokenizer_config: Optional[TokenizerConfig] = None
generation_config: Optional[GenerationConfig] = None
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a model directory"""
self = cls()
# A model must have an HF config
try:
self.hf_config = await HuggingFaceConfig.from_directory(model_directory)
except Exception as exc:
raise ValueError(
f"Failed to load config.json from {model_directory}"
) from exc
try:
self.generation_config = await GenerationConfig.from_directory(
model_directory
)
except Exception:
logger.warning(
"Generation config file not found in model directory, skipping."
)
try:
self.tokenizer_config = await TokenizerConfig.from_directory(
model_directory
)
except Exception:
logger.warning(
"Tokenizer config file not found in model directory, skipping."
)
return self
def quant_method(self):
"""Wrapper for quantization method"""
return self.hf_config.quant_method()
def eos_tokens(self):
"""Combines and returns EOS tokens from various configs"""
eos_ids: Set[int] = set()
eos_ids.update(self.hf_config.eos_tokens())
if self.generation_config:
eos_ids.update(self.generation_config.eos_tokens())
# Convert back to a list
return list(eos_ids)
def add_bos_token(self):
"""Wrapper for tokenizer config"""
if self.tokenizer_config:
return self.tokenizer_config.add_bos_token
# Expected default
return True