Model: Create universal HFModel class

The HFModel class serves to coalesce all config files that contain
random keys which are required for model usage.

Adding this base class allows us to expand as HuggingFace randomly
changes their JSON schemas over time, reducing the brunt that backend
devs need to feel when their next model isn't supported.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-13 18:12:38 -04:00
parent 7900b72848
commit 390daeb92f
5 changed files with 149 additions and 127 deletions

View file

@ -9,11 +9,10 @@ from typing import (
List,
Optional,
)
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate
from common.transformers_utils import GenerationConfig
from common.transformers_utils import HFModel
from endpoints.core.types.model import ModelCard
@ -23,7 +22,9 @@ class BaseModelContainer(abc.ABC):
# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
generation_config: Optional[GenerationConfig] = None
# HF Model instance
hf_model: HFModel
# Optional features
use_draft_model: bool = False
@ -41,7 +42,7 @@ class BaseModelContainer(abc.ABC):
# Required methods
@classmethod
@abc.abstractmethod
async def create(cls, model_directory: pathlib.Path, **kwargs):
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
"""
Asynchronously creates and initializes a model container instance.

View file

@ -4,7 +4,6 @@ import asyncio
import gc
import math
import pathlib
import traceback
import torch
from exllamav2 import (
ExLlamaV2,
@ -47,7 +46,7 @@ from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig, TokenizerConfig
from common.transformers_utils import HFModel
from common.utils import calculate_rope_alpha, coalesce, unwrap
from endpoints.core.types.model import ModelCard, ModelCardParameters
@ -58,6 +57,10 @@ class ExllamaV2Container(BaseModelContainer):
# Model directories
model_dir: pathlib.Path = pathlib.Path("models")
draft_model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
# HF model instance
hf_model: HFModel
# Exl2 vars
config: Optional[ExLlamaV2Config] = None
@ -79,8 +82,6 @@ class ExllamaV2Container(BaseModelContainer):
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
max_batch_size: Optional[int] = None
generation_config: Optional[GenerationConfig] = None
tokenizer_config: Optional[TokenizerConfig] = None
# GPU split vars
gpu_split: List[float] = []
@ -100,7 +101,7 @@ class ExllamaV2Container(BaseModelContainer):
load_condition: asyncio.Condition = asyncio.Condition()
@classmethod
async def create(cls, model_directory: pathlib.Path, **kwargs):
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
"""
Primary asynchronous initializer for model container.
@ -114,6 +115,7 @@ class ExllamaV2Container(BaseModelContainer):
self.config = ExLlamaV2Config()
self.model_dir = model_directory
self.config.model_dir = str(model_directory.resolve())
self.hf_model = hf_model
# Make the max seq len 4096 before preparing the config
# This is a better default than 2048
@ -124,30 +126,6 @@ class ExllamaV2Container(BaseModelContainer):
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
model_directory
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Load tokenizer config overrides
tokenizer_config_path = model_directory / "tokenizer_config.json"
if tokenizer_config_path.exists():
try:
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping tokenizer config load because of an unexpected error."
)
# Set vision state and error if vision isn't supported on the current model
self.use_vision = unwrap(kwargs.get("vision"), False)
if self.use_vision and not self.config.vision_model_type:
@ -864,7 +842,7 @@ class ExllamaV2Container(BaseModelContainer):
self.tokenizer.encode(
text,
add_bos=unwrap(
kwargs.get("add_bos_token"), self.tokenizer_config.add_bos_token
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
embeddings=mm_embeddings_content,
@ -1282,16 +1260,10 @@ class ExllamaV2Container(BaseModelContainer):
ban_eos_token = params.ban_eos_token
# Set add_bos_token for generation
add_bos_token = unwrap(
params.add_bos_token, self.tokenizer_config.add_bos_token
)
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()
if self.generation_config
else [self.tokenizer.eos_token_id]
)
# Fetch EOS tokens from the HF model if they exist
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
# Ban the EOS token if specified. If not, append to stop conditions
# as well.

View file

@ -2,7 +2,6 @@ import asyncio
import gc
import pathlib
import re
import traceback
from typing import (
Any,
AsyncIterator,
@ -35,7 +34,7 @@ from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig, TokenizerConfig
from common.transformers_utils import HFModel
from common.utils import coalesce, unwrap
from endpoints.core.types.model import ModelCard, ModelCardParameters
@ -46,7 +45,9 @@ class ExllamaV3Container(BaseModelContainer):
# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
generation_config: Optional[GenerationConfig] = None
# HF Model instance
hf_model: HFModel
# Load synchronization
# The bool is a master switch for accepting requests
@ -58,15 +59,14 @@ class ExllamaV3Container(BaseModelContainer):
load_condition: asyncio.Condition = asyncio.Condition()
# Exl3 vars
model: Optional[Model]
cache: Optional[Cache]
draft_model: Optional[Model]
draft_cache: Optional[Cache]
tokenizer: Optional[Tokenizer]
config: Optional[Config]
draft_config: Optional[Config]
generator: Optional[AsyncGenerator]
tokenizer_config: Optional[TokenizerConfig]
model: Optional[Model] = None
cache: Optional[Cache] = None
draft_model: Optional[Model] = None
draft_cache: Optional[Cache] = None
tokenizer: Optional[Tokenizer] = None
config: Optional[Config] = None
draft_config: Optional[Config] = None
generator: Optional[AsyncGenerator] = None
# Class-specific vars
gpu_split: List[float] | None = None
@ -82,7 +82,7 @@ class ExllamaV3Container(BaseModelContainer):
# Required methods
@classmethod
async def create(cls, model_directory: pathlib.Path, **kwargs):
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
"""
Asynchronously creates and initializes a model container instance.
@ -96,50 +96,17 @@ class ExllamaV3Container(BaseModelContainer):
self = cls()
self.model = None
self.cache = None
self.draft_model = None
self.draft_cache = None
self.tokenizer = None
self.config = None
self.draft_config = None
self.generator = None
self.tokenizer_config = None
logger.warning(
"ExllamaV3 is currently in an alpha state. "
"Please note that all config options may not work."
)
self.model_dir = model_directory
self.hf_model = hf_model
self.config = Config.from_directory(str(model_directory.resolve()))
self.model = Model.from_config(self.config)
self.tokenizer = Tokenizer.from_config(self.config)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
model_directory
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Load tokenizer config overrides
tokenizer_config_path = model_directory / "tokenizer_config.json"
if tokenizer_config_path.exists():
try:
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping tokenizer config load because of an unexpected error."
)
# Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
@ -554,7 +521,9 @@ class ExllamaV3Container(BaseModelContainer):
return (
self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
add_bos=unwrap(
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
)
.flatten()
@ -822,16 +791,10 @@ class ExllamaV3Container(BaseModelContainer):
prompts = [prompt]
stop_conditions = params.stop
add_bos_token = unwrap(
params.add_bos_token, self.tokenizer_config.add_bos_token
)
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()
if self.generation_config
else [self.tokenizer.eos_token_id]
)
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
stop_conditions += eos_tokens