Model: Create universal HFModel class

The HFModel class serves to coalesce all config files that contain
random keys which are required for model usage.

Adding this base class allows us to expand as HuggingFace randomly
changes their JSON schemas over time, reducing the brunt that backend
devs need to feel when their next model isn't supported.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-13 18:12:38 -04:00
parent 7900b72848
commit 390daeb92f
5 changed files with 149 additions and 127 deletions

View file

@ -9,11 +9,10 @@ from typing import (
List, List,
Optional, Optional,
) )
from common.multimodal import MultimodalEmbeddingWrapper from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate from common.templating import PromptTemplate
from common.transformers_utils import GenerationConfig from common.transformers_utils import HFModel
from endpoints.core.types.model import ModelCard from endpoints.core.types.model import ModelCard
@ -23,7 +22,9 @@ class BaseModelContainer(abc.ABC):
# Exposed model information # Exposed model information
model_dir: pathlib.Path = pathlib.Path("models") model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None prompt_template: Optional[PromptTemplate] = None
generation_config: Optional[GenerationConfig] = None
# HF Model instance
hf_model: HFModel
# Optional features # Optional features
use_draft_model: bool = False use_draft_model: bool = False
@ -41,7 +42,7 @@ class BaseModelContainer(abc.ABC):
# Required methods # Required methods
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
async def create(cls, model_directory: pathlib.Path, **kwargs): async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
""" """
Asynchronously creates and initializes a model container instance. Asynchronously creates and initializes a model container instance.

View file

@ -4,7 +4,6 @@ import asyncio
import gc import gc
import math import math
import pathlib import pathlib
import traceback
import torch import torch
from exllamav2 import ( from exllamav2 import (
ExLlamaV2, ExLlamaV2,
@ -47,7 +46,7 @@ from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig, TokenizerConfig from common.transformers_utils import HFModel
from common.utils import calculate_rope_alpha, coalesce, unwrap from common.utils import calculate_rope_alpha, coalesce, unwrap
from endpoints.core.types.model import ModelCard, ModelCardParameters from endpoints.core.types.model import ModelCard, ModelCardParameters
@ -58,6 +57,10 @@ class ExllamaV2Container(BaseModelContainer):
# Model directories # Model directories
model_dir: pathlib.Path = pathlib.Path("models") model_dir: pathlib.Path = pathlib.Path("models")
draft_model_dir: pathlib.Path = pathlib.Path("models") draft_model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
# HF model instance
hf_model: HFModel
# Exl2 vars # Exl2 vars
config: Optional[ExLlamaV2Config] = None config: Optional[ExLlamaV2Config] = None
@ -79,8 +82,6 @@ class ExllamaV2Container(BaseModelContainer):
cache_mode: str = "FP16" cache_mode: str = "FP16"
draft_cache_mode: str = "FP16" draft_cache_mode: str = "FP16"
max_batch_size: Optional[int] = None max_batch_size: Optional[int] = None
generation_config: Optional[GenerationConfig] = None
tokenizer_config: Optional[TokenizerConfig] = None
# GPU split vars # GPU split vars
gpu_split: List[float] = [] gpu_split: List[float] = []
@ -100,7 +101,7 @@ class ExllamaV2Container(BaseModelContainer):
load_condition: asyncio.Condition = asyncio.Condition() load_condition: asyncio.Condition = asyncio.Condition()
@classmethod @classmethod
async def create(cls, model_directory: pathlib.Path, **kwargs): async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
""" """
Primary asynchronous initializer for model container. Primary asynchronous initializer for model container.
@ -114,6 +115,7 @@ class ExllamaV2Container(BaseModelContainer):
self.config = ExLlamaV2Config() self.config = ExLlamaV2Config()
self.model_dir = model_directory self.model_dir = model_directory
self.config.model_dir = str(model_directory.resolve()) self.config.model_dir = str(model_directory.resolve())
self.hf_model = hf_model
# Make the max seq len 4096 before preparing the config # Make the max seq len 4096 before preparing the config
# This is a better default than 2048 # This is a better default than 2048
@ -124,30 +126,6 @@ class ExllamaV2Container(BaseModelContainer):
# Check if the model arch is compatible with various exl2 features # Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides() self.config.arch_compat_overrides()
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
model_directory
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Load tokenizer config overrides
tokenizer_config_path = model_directory / "tokenizer_config.json"
if tokenizer_config_path.exists():
try:
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping tokenizer config load because of an unexpected error."
)
# Set vision state and error if vision isn't supported on the current model # Set vision state and error if vision isn't supported on the current model
self.use_vision = unwrap(kwargs.get("vision"), False) self.use_vision = unwrap(kwargs.get("vision"), False)
if self.use_vision and not self.config.vision_model_type: if self.use_vision and not self.config.vision_model_type:
@ -864,7 +842,7 @@ class ExllamaV2Container(BaseModelContainer):
self.tokenizer.encode( self.tokenizer.encode(
text, text,
add_bos=unwrap( add_bos=unwrap(
kwargs.get("add_bos_token"), self.tokenizer_config.add_bos_token kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
), ),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
embeddings=mm_embeddings_content, embeddings=mm_embeddings_content,
@ -1282,16 +1260,10 @@ class ExllamaV2Container(BaseModelContainer):
ban_eos_token = params.ban_eos_token ban_eos_token = params.ban_eos_token
# Set add_bos_token for generation # Set add_bos_token for generation
add_bos_token = unwrap( add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
params.add_bos_token, self.tokenizer_config.add_bos_token
)
# Fetch EOS tokens from generation_config if they exist # Fetch EOS tokens from the HF model if they exist
eos_tokens = ( eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
self.generation_config.eos_tokens()
if self.generation_config
else [self.tokenizer.eos_token_id]
)
# Ban the EOS token if specified. If not, append to stop conditions # Ban the EOS token if specified. If not, append to stop conditions
# as well. # as well.

View file

@ -2,7 +2,6 @@ import asyncio
import gc import gc
import pathlib import pathlib
import re import re
import traceback
from typing import ( from typing import (
Any, Any,
AsyncIterator, AsyncIterator,
@ -35,7 +34,7 @@ from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig, TokenizerConfig from common.transformers_utils import HFModel
from common.utils import coalesce, unwrap from common.utils import coalesce, unwrap
from endpoints.core.types.model import ModelCard, ModelCardParameters from endpoints.core.types.model import ModelCard, ModelCardParameters
@ -46,7 +45,9 @@ class ExllamaV3Container(BaseModelContainer):
# Exposed model information # Exposed model information
model_dir: pathlib.Path = pathlib.Path("models") model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None prompt_template: Optional[PromptTemplate] = None
generation_config: Optional[GenerationConfig] = None
# HF Model instance
hf_model: HFModel
# Load synchronization # Load synchronization
# The bool is a master switch for accepting requests # The bool is a master switch for accepting requests
@ -58,15 +59,14 @@ class ExllamaV3Container(BaseModelContainer):
load_condition: asyncio.Condition = asyncio.Condition() load_condition: asyncio.Condition = asyncio.Condition()
# Exl3 vars # Exl3 vars
model: Optional[Model] model: Optional[Model] = None
cache: Optional[Cache] cache: Optional[Cache] = None
draft_model: Optional[Model] draft_model: Optional[Model] = None
draft_cache: Optional[Cache] draft_cache: Optional[Cache] = None
tokenizer: Optional[Tokenizer] tokenizer: Optional[Tokenizer] = None
config: Optional[Config] config: Optional[Config] = None
draft_config: Optional[Config] draft_config: Optional[Config] = None
generator: Optional[AsyncGenerator] generator: Optional[AsyncGenerator] = None
tokenizer_config: Optional[TokenizerConfig]
# Class-specific vars # Class-specific vars
gpu_split: List[float] | None = None gpu_split: List[float] | None = None
@ -82,7 +82,7 @@ class ExllamaV3Container(BaseModelContainer):
# Required methods # Required methods
@classmethod @classmethod
async def create(cls, model_directory: pathlib.Path, **kwargs): async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
""" """
Asynchronously creates and initializes a model container instance. Asynchronously creates and initializes a model container instance.
@ -96,50 +96,17 @@ class ExllamaV3Container(BaseModelContainer):
self = cls() self = cls()
self.model = None
self.cache = None
self.draft_model = None
self.draft_cache = None
self.tokenizer = None
self.config = None
self.draft_config = None
self.generator = None
self.tokenizer_config = None
logger.warning( logger.warning(
"ExllamaV3 is currently in an alpha state. " "ExllamaV3 is currently in an alpha state. "
"Please note that all config options may not work." "Please note that all config options may not work."
) )
self.model_dir = model_directory self.model_dir = model_directory
self.hf_model = hf_model
self.config = Config.from_directory(str(model_directory.resolve())) self.config = Config.from_directory(str(model_directory.resolve()))
self.model = Model.from_config(self.config) self.model = Model.from_config(self.config)
self.tokenizer = Tokenizer.from_config(self.config) self.tokenizer = Tokenizer.from_config(self.config)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
model_directory
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Load tokenizer config overrides
tokenizer_config_path = model_directory / "tokenizer_config.json"
if tokenizer_config_path.exists():
try:
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping tokenizer config load because of an unexpected error."
)
# Fallback to 4096 since exl3 can't fetch from HF's config.json # Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
@ -554,7 +521,9 @@ class ExllamaV3Container(BaseModelContainer):
return ( return (
self.tokenizer.encode( self.tokenizer.encode(
text, text,
add_bos=unwrap(kwargs.get("add_bos_token"), True), add_bos=unwrap(
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
) )
.flatten() .flatten()
@ -822,16 +791,10 @@ class ExllamaV3Container(BaseModelContainer):
prompts = [prompt] prompts = [prompt]
stop_conditions = params.stop stop_conditions = params.stop
add_bos_token = unwrap( add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
params.add_bos_token, self.tokenizer_config.add_bos_token
)
# Fetch EOS tokens from generation_config if they exist # Fetch EOS tokens from generation_config if they exist
eos_tokens = ( eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
self.generation_config.eos_tokens()
if self.generation_config
else [self.tokenizer.eos_token_id]
)
stop_conditions += eos_tokens stop_conditions += eos_tokens

View file

@ -17,7 +17,7 @@ from common.logger import get_loading_progress_bar
from common.networking import handle_request_error from common.networking import handle_request_error
from common.tabby_config import config from common.tabby_config import config
from common.optional_dependencies import dependencies from common.optional_dependencies import dependencies
from common.transformers_utils import HuggingFaceConfig from common.transformers_utils import HFModel
from common.utils import unwrap from common.utils import unwrap
# Global variables for model container # Global variables for model container
@ -57,22 +57,15 @@ def load_progress(module, modules):
yield module, modules yield module, modules
async def detect_backend(model_path: pathlib.Path) -> str: def detect_backend(hf_model: HFModel) -> str:
"""Determine the appropriate backend based on model files and configuration.""" """Determine the appropriate backend based on model files and configuration."""
try: quant_method = hf_model.quant_method()
hf_config = await HuggingFaceConfig.from_directory(model_path)
quant_method = hf_config.quant_method()
if quant_method == "exl3": if quant_method == "exl3":
return "exllamav3" return "exllamav3"
else: else:
return "exllamav2" return "exllamav2"
except Exception as exc:
raise ValueError(
"Failed to read the model's config.json. "
f"Please check your model directory at {model_path}."
) from exc
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
@ -142,28 +135,29 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
kwargs = {**config.model_defaults, **kwargs} kwargs = {**config.model_defaults, **kwargs}
kwargs = await apply_inline_overrides(model_path, **kwargs) kwargs = await apply_inline_overrides(model_path, **kwargs)
# Fetch the extra HF configuration options
hf_model = await HFModel.from_directory(model_path)
# Create a new container and check if the right dependencies are installed # Create a new container and check if the right dependencies are installed
backend_name = unwrap( backend = unwrap(kwargs.get("backend"), detect_backend(hf_model))
kwargs.get("backend"), await detect_backend(model_path) container_class = _BACKEND_REGISTRY.get(backend)
).lower()
container_class = _BACKEND_REGISTRY.get(backend_name)
if not container_class: if not container_class:
available_backends = list(_BACKEND_REGISTRY.keys()) available_backends = list(_BACKEND_REGISTRY.keys())
if backend_name in available_backends: if backend in available_backends:
raise ValueError( raise ValueError(
f"Backend '{backend_name}' selected, but required dependencies " f"Backend '{backend}' selected, but required dependencies "
"are not installed." "are not installed."
) )
else: else:
raise ValueError( raise ValueError(
f"Invalid backend '{backend_name}'. " f"Invalid backend '{backend}'. "
f"Available backends: {available_backends}" f"Available backends: {available_backends}"
) )
logger.info(f"Using backend {backend_name}") logger.info(f"Using backend {backend}")
new_container: BaseModelContainer = await container_class.create( new_container: BaseModelContainer = await container_class.create(
model_path.resolve(), **kwargs model_path.resolve(), hf_model, **kwargs
) )
# Add possible types of models that can be loaded # Add possible types of models that can be loaded

View file

@ -1,8 +1,9 @@
import aiofiles import aiofiles
import json import json
import pathlib import pathlib
from typing import Dict, List, Optional, Union from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, List, Optional, Set, Union
class GenerationConfig(BaseModel): class GenerationConfig(BaseModel):
@ -14,7 +15,7 @@ class GenerationConfig(BaseModel):
eos_token_id: Optional[Union[int, List[int]]] = None eos_token_id: Optional[Union[int, List[int]]] = None
@classmethod @classmethod
async def from_file(cls, model_directory: pathlib.Path): async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file.""" """Create an instance from a generation config file."""
generation_config_path = model_directory / "generation_config.json" generation_config_path = model_directory / "generation_config.json"
@ -28,10 +29,12 @@ class GenerationConfig(BaseModel):
def eos_tokens(self): def eos_tokens(self):
"""Wrapper method to fetch EOS tokens.""" """Wrapper method to fetch EOS tokens."""
if isinstance(self.eos_token_id, int): if isinstance(self.eos_token_id, list):
return self.eos_token_id
elif isinstance(self.eos_token_id, int):
return [self.eos_token_id] return [self.eos_token_id]
else: else:
return self.eos_token_id return []
class HuggingFaceConfig(BaseModel): class HuggingFaceConfig(BaseModel):
@ -42,6 +45,7 @@ class HuggingFaceConfig(BaseModel):
Will be expanded as needed. Will be expanded as needed.
""" """
eos_token_id: Optional[Union[int, List[int]]] = None
quantization_config: Optional[Dict] = None quantization_config: Optional[Dict] = None
@classmethod @classmethod
@ -64,6 +68,16 @@ class HuggingFaceConfig(BaseModel):
else: else:
return None return None
def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""
if isinstance(self.eos_token_id, list):
return self.eos_token_id
elif isinstance(self.eos_token_id, int):
return [self.eos_token_id]
else:
return []
class TokenizerConfig(BaseModel): class TokenizerConfig(BaseModel):
""" """
@ -73,7 +87,7 @@ class TokenizerConfig(BaseModel):
add_bos_token: Optional[bool] = True add_bos_token: Optional[bool] = True
@classmethod @classmethod
async def from_file(cls, model_directory: pathlib.Path): async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a tokenizer config file.""" """Create an instance from a tokenizer config file."""
tokenizer_config_path = model_directory / "tokenizer_config.json" tokenizer_config_path = model_directory / "tokenizer_config.json"
@ -83,3 +97,81 @@ class TokenizerConfig(BaseModel):
contents = await tokenizer_config_json.read() contents = await tokenizer_config_json.read()
tokenizer_config_dict = json.loads(contents) tokenizer_config_dict = json.loads(contents)
return cls.model_validate(tokenizer_config_dict) return cls.model_validate(tokenizer_config_dict)
class HFModel:
"""
Unified container for HuggingFace model configuration files.
These are abridged for hyper-specific model parameters not covered
by most backends.
Includes:
- config.json
- generation_config.json
- tokenizer_config.json
"""
hf_config: HuggingFaceConfig
tokenizer_config: Optional[TokenizerConfig] = None
generation_config: Optional[GenerationConfig] = None
@classmethod
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a model directory"""
self = cls()
# A model must have an HF config
try:
self.hf_config = await HuggingFaceConfig.from_directory(model_directory)
except Exception as exc:
raise ValueError(
f"Failed to load config.json from {model_directory}"
) from exc
try:
self.generation_config = await GenerationConfig.from_directory(
model_directory
)
except Exception:
logger.warning(
"Generation config file not found in model directory, skipping."
)
try:
self.tokenizer_config = await TokenizerConfig.from_directory(
model_directory
)
except Exception:
logger.warning(
"Tokenizer config file not found in model directory, skipping."
)
return self
def quant_method(self):
"""Wrapper for quantization method"""
return self.hf_config.quant_method()
def eos_tokens(self):
"""Combines and returns EOS tokens from various configs"""
eos_ids: Set[int] = set()
eos_ids.update(self.hf_config.eos_tokens())
if self.generation_config:
eos_ids.update(self.generation_config.eos_tokens())
# Convert back to a list
return list(eos_ids)
def add_bos_token(self):
"""Wrapper for tokenizer config"""
if self.tokenizer_config:
return self.tokenizer_config.add_bos_token
# Expected default
return True