Model: Create universal HFModel class
The HFModel class serves to coalesce all config files that contain random keys which are required for model usage. Adding this base class allows us to expand as HuggingFace randomly changes their JSON schemas over time, reducing the brunt that backend devs need to feel when their next model isn't supported. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
7900b72848
commit
390daeb92f
5 changed files with 149 additions and 127 deletions
|
|
@ -9,11 +9,10 @@ from typing import (
|
|||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.transformers_utils import HFModel
|
||||
from endpoints.core.types.model import ModelCard
|
||||
|
||||
|
||||
|
|
@ -23,7 +22,9 @@ class BaseModelContainer(abc.ABC):
|
|||
# Exposed model information
|
||||
model_dir: pathlib.Path = pathlib.Path("models")
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# HF Model instance
|
||||
hf_model: HFModel
|
||||
|
||||
# Optional features
|
||||
use_draft_model: bool = False
|
||||
|
|
@ -41,7 +42,7 @@ class BaseModelContainer(abc.ABC):
|
|||
# Required methods
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
|
||||
"""
|
||||
Asynchronously creates and initializes a model container instance.
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import asyncio
|
|||
import gc
|
||||
import math
|
||||
import pathlib
|
||||
import traceback
|
||||
import torch
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
|
|
@ -47,7 +46,7 @@ from common.health import HealthManager
|
|||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig, TokenizerConfig
|
||||
from common.transformers_utils import HFModel
|
||||
from common.utils import calculate_rope_alpha, coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||
|
||||
|
|
@ -58,6 +57,10 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Model directories
|
||||
model_dir: pathlib.Path = pathlib.Path("models")
|
||||
draft_model_dir: pathlib.Path = pathlib.Path("models")
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
|
||||
# HF model instance
|
||||
hf_model: HFModel
|
||||
|
||||
# Exl2 vars
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
|
|
@ -79,8 +82,6 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
cache_mode: str = "FP16"
|
||||
draft_cache_mode: str = "FP16"
|
||||
max_batch_size: Optional[int] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
tokenizer_config: Optional[TokenizerConfig] = None
|
||||
|
||||
# GPU split vars
|
||||
gpu_split: List[float] = []
|
||||
|
|
@ -100,7 +101,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
|
||||
"""
|
||||
Primary asynchronous initializer for model container.
|
||||
|
||||
|
|
@ -114,6 +115,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
self.config = ExLlamaV2Config()
|
||||
self.model_dir = model_directory
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
self.hf_model = hf_model
|
||||
|
||||
# Make the max seq len 4096 before preparing the config
|
||||
# This is a better default than 2048
|
||||
|
|
@ -124,30 +126,6 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Check if the model arch is compatible with various exl2 features
|
||||
self.config.arch_compat_overrides()
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = await GenerationConfig.from_file(
|
||||
model_directory
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Load tokenizer config overrides
|
||||
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
||||
if tokenizer_config_path.exists():
|
||||
try:
|
||||
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping tokenizer config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Set vision state and error if vision isn't supported on the current model
|
||||
self.use_vision = unwrap(kwargs.get("vision"), False)
|
||||
if self.use_vision and not self.config.vision_model_type:
|
||||
|
|
@ -864,7 +842,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(
|
||||
kwargs.get("add_bos_token"), self.tokenizer_config.add_bos_token
|
||||
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
|
||||
),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
embeddings=mm_embeddings_content,
|
||||
|
|
@ -1282,16 +1260,10 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
ban_eos_token = params.ban_eos_token
|
||||
|
||||
# Set add_bos_token for generation
|
||||
add_bos_token = unwrap(
|
||||
params.add_bos_token, self.tokenizer_config.add_bos_token
|
||||
)
|
||||
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
if self.generation_config
|
||||
else [self.tokenizer.eos_token_id]
|
||||
)
|
||||
# Fetch EOS tokens from the HF model if they exist
|
||||
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions
|
||||
# as well.
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import asyncio
|
|||
import gc
|
||||
import pathlib
|
||||
import re
|
||||
import traceback
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
|
|
@ -35,7 +34,7 @@ from common.health import HealthManager
|
|||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig, TokenizerConfig
|
||||
from common.transformers_utils import HFModel
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||
|
||||
|
|
@ -46,7 +45,9 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
# Exposed model information
|
||||
model_dir: pathlib.Path = pathlib.Path("models")
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# HF Model instance
|
||||
hf_model: HFModel
|
||||
|
||||
# Load synchronization
|
||||
# The bool is a master switch for accepting requests
|
||||
|
|
@ -58,15 +59,14 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
# Exl3 vars
|
||||
model: Optional[Model]
|
||||
cache: Optional[Cache]
|
||||
draft_model: Optional[Model]
|
||||
draft_cache: Optional[Cache]
|
||||
tokenizer: Optional[Tokenizer]
|
||||
config: Optional[Config]
|
||||
draft_config: Optional[Config]
|
||||
generator: Optional[AsyncGenerator]
|
||||
tokenizer_config: Optional[TokenizerConfig]
|
||||
model: Optional[Model] = None
|
||||
cache: Optional[Cache] = None
|
||||
draft_model: Optional[Model] = None
|
||||
draft_cache: Optional[Cache] = None
|
||||
tokenizer: Optional[Tokenizer] = None
|
||||
config: Optional[Config] = None
|
||||
draft_config: Optional[Config] = None
|
||||
generator: Optional[AsyncGenerator] = None
|
||||
|
||||
# Class-specific vars
|
||||
gpu_split: List[float] | None = None
|
||||
|
|
@ -82,7 +82,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
# Required methods
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
|
||||
"""
|
||||
Asynchronously creates and initializes a model container instance.
|
||||
|
||||
|
|
@ -96,50 +96,17 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
self = cls()
|
||||
|
||||
self.model = None
|
||||
self.cache = None
|
||||
self.draft_model = None
|
||||
self.draft_cache = None
|
||||
self.tokenizer = None
|
||||
self.config = None
|
||||
self.draft_config = None
|
||||
self.generator = None
|
||||
self.tokenizer_config = None
|
||||
|
||||
logger.warning(
|
||||
"ExllamaV3 is currently in an alpha state. "
|
||||
"Please note that all config options may not work."
|
||||
)
|
||||
|
||||
self.model_dir = model_directory
|
||||
self.hf_model = hf_model
|
||||
self.config = Config.from_directory(str(model_directory.resolve()))
|
||||
self.model = Model.from_config(self.config)
|
||||
self.tokenizer = Tokenizer.from_config(self.config)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = await GenerationConfig.from_file(
|
||||
model_directory
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Load tokenizer config overrides
|
||||
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
||||
if tokenizer_config_path.exists():
|
||||
try:
|
||||
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping tokenizer config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
||||
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
|
||||
|
|
@ -554,7 +521,9 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
return (
|
||||
self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
add_bos=unwrap(
|
||||
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
|
||||
),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
)
|
||||
.flatten()
|
||||
|
|
@ -822,16 +791,10 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
prompts = [prompt]
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = unwrap(
|
||||
params.add_bos_token, self.tokenizer_config.add_bos_token
|
||||
)
|
||||
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
if self.generation_config
|
||||
else [self.tokenizer.eos_token_id]
|
||||
)
|
||||
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
|
||||
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from common.logger import get_loading_progress_bar
|
|||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from common.optional_dependencies import dependencies
|
||||
from common.transformers_utils import HuggingFaceConfig
|
||||
from common.transformers_utils import HFModel
|
||||
from common.utils import unwrap
|
||||
|
||||
# Global variables for model container
|
||||
|
|
@ -57,22 +57,15 @@ def load_progress(module, modules):
|
|||
yield module, modules
|
||||
|
||||
|
||||
async def detect_backend(model_path: pathlib.Path) -> str:
|
||||
def detect_backend(hf_model: HFModel) -> str:
|
||||
"""Determine the appropriate backend based on model files and configuration."""
|
||||
|
||||
try:
|
||||
hf_config = await HuggingFaceConfig.from_directory(model_path)
|
||||
quant_method = hf_config.quant_method()
|
||||
quant_method = hf_model.quant_method()
|
||||
|
||||
if quant_method == "exl3":
|
||||
return "exllamav3"
|
||||
else:
|
||||
return "exllamav2"
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Failed to read the model's config.json. "
|
||||
f"Please check your model directory at {model_path}."
|
||||
) from exc
|
||||
if quant_method == "exl3":
|
||||
return "exllamav3"
|
||||
else:
|
||||
return "exllamav2"
|
||||
|
||||
|
||||
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||
|
|
@ -142,28 +135,29 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
kwargs = {**config.model_defaults, **kwargs}
|
||||
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
||||
|
||||
# Fetch the extra HF configuration options
|
||||
hf_model = await HFModel.from_directory(model_path)
|
||||
|
||||
# Create a new container and check if the right dependencies are installed
|
||||
backend_name = unwrap(
|
||||
kwargs.get("backend"), await detect_backend(model_path)
|
||||
).lower()
|
||||
container_class = _BACKEND_REGISTRY.get(backend_name)
|
||||
backend = unwrap(kwargs.get("backend"), detect_backend(hf_model))
|
||||
container_class = _BACKEND_REGISTRY.get(backend)
|
||||
|
||||
if not container_class:
|
||||
available_backends = list(_BACKEND_REGISTRY.keys())
|
||||
if backend_name in available_backends:
|
||||
if backend in available_backends:
|
||||
raise ValueError(
|
||||
f"Backend '{backend_name}' selected, but required dependencies "
|
||||
f"Backend '{backend}' selected, but required dependencies "
|
||||
"are not installed."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid backend '{backend_name}'. "
|
||||
f"Invalid backend '{backend}'. "
|
||||
f"Available backends: {available_backends}"
|
||||
)
|
||||
|
||||
logger.info(f"Using backend {backend_name}")
|
||||
logger.info(f"Using backend {backend}")
|
||||
new_container: BaseModelContainer = await container_class.create(
|
||||
model_path.resolve(), **kwargs
|
||||
model_path.resolve(), hf_model, **kwargs
|
||||
)
|
||||
|
||||
# Add possible types of models that can be loaded
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from typing import Dict, List, Optional, Union
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
|
||||
class GenerationConfig(BaseModel):
|
||||
|
|
@ -14,7 +15,7 @@ class GenerationConfig(BaseModel):
|
|||
eos_token_id: Optional[Union[int, List[int]]] = None
|
||||
|
||||
@classmethod
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
async def from_directory(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
|
|
@ -28,10 +29,12 @@ class GenerationConfig(BaseModel):
|
|||
def eos_tokens(self):
|
||||
"""Wrapper method to fetch EOS tokens."""
|
||||
|
||||
if isinstance(self.eos_token_id, int):
|
||||
if isinstance(self.eos_token_id, list):
|
||||
return self.eos_token_id
|
||||
elif isinstance(self.eos_token_id, int):
|
||||
return [self.eos_token_id]
|
||||
else:
|
||||
return self.eos_token_id
|
||||
return []
|
||||
|
||||
|
||||
class HuggingFaceConfig(BaseModel):
|
||||
|
|
@ -42,6 +45,7 @@ class HuggingFaceConfig(BaseModel):
|
|||
Will be expanded as needed.
|
||||
"""
|
||||
|
||||
eos_token_id: Optional[Union[int, List[int]]] = None
|
||||
quantization_config: Optional[Dict] = None
|
||||
|
||||
@classmethod
|
||||
|
|
@ -64,6 +68,16 @@ class HuggingFaceConfig(BaseModel):
|
|||
else:
|
||||
return None
|
||||
|
||||
def eos_tokens(self):
|
||||
"""Wrapper method to fetch EOS tokens."""
|
||||
|
||||
if isinstance(self.eos_token_id, list):
|
||||
return self.eos_token_id
|
||||
elif isinstance(self.eos_token_id, int):
|
||||
return [self.eos_token_id]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
class TokenizerConfig(BaseModel):
|
||||
"""
|
||||
|
|
@ -73,7 +87,7 @@ class TokenizerConfig(BaseModel):
|
|||
add_bos_token: Optional[bool] = True
|
||||
|
||||
@classmethod
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
async def from_directory(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a tokenizer config file."""
|
||||
|
||||
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
||||
|
|
@ -83,3 +97,81 @@ class TokenizerConfig(BaseModel):
|
|||
contents = await tokenizer_config_json.read()
|
||||
tokenizer_config_dict = json.loads(contents)
|
||||
return cls.model_validate(tokenizer_config_dict)
|
||||
|
||||
|
||||
class HFModel:
|
||||
"""
|
||||
Unified container for HuggingFace model configuration files.
|
||||
These are abridged for hyper-specific model parameters not covered
|
||||
by most backends.
|
||||
|
||||
Includes:
|
||||
- config.json
|
||||
- generation_config.json
|
||||
- tokenizer_config.json
|
||||
"""
|
||||
|
||||
hf_config: HuggingFaceConfig
|
||||
tokenizer_config: Optional[TokenizerConfig] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
@classmethod
|
||||
async def from_directory(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a model directory"""
|
||||
|
||||
self = cls()
|
||||
|
||||
# A model must have an HF config
|
||||
try:
|
||||
self.hf_config = await HuggingFaceConfig.from_directory(model_directory)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Failed to load config.json from {model_directory}"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
self.generation_config = await GenerationConfig.from_directory(
|
||||
model_directory
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Generation config file not found in model directory, skipping."
|
||||
)
|
||||
|
||||
try:
|
||||
self.tokenizer_config = await TokenizerConfig.from_directory(
|
||||
model_directory
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Tokenizer config file not found in model directory, skipping."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def quant_method(self):
|
||||
"""Wrapper for quantization method"""
|
||||
|
||||
return self.hf_config.quant_method()
|
||||
|
||||
def eos_tokens(self):
|
||||
"""Combines and returns EOS tokens from various configs"""
|
||||
|
||||
eos_ids: Set[int] = set()
|
||||
|
||||
eos_ids.update(self.hf_config.eos_tokens())
|
||||
|
||||
if self.generation_config:
|
||||
eos_ids.update(self.generation_config.eos_tokens())
|
||||
|
||||
# Convert back to a list
|
||||
return list(eos_ids)
|
||||
|
||||
def add_bos_token(self):
|
||||
"""Wrapper for tokenizer config"""
|
||||
|
||||
if self.tokenizer_config:
|
||||
return self.tokenizer_config.add_bos_token
|
||||
|
||||
# Expected default
|
||||
return True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue