diff --git a/backends/base_model_container.py b/backends/base_model_container.py index 10eae0d..96393ab 100644 --- a/backends/base_model_container.py +++ b/backends/base_model_container.py @@ -9,11 +9,10 @@ from typing import ( List, Optional, ) - from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate -from common.transformers_utils import GenerationConfig +from common.transformers_utils import HFModel from endpoints.core.types.model import ModelCard @@ -23,7 +22,9 @@ class BaseModelContainer(abc.ABC): # Exposed model information model_dir: pathlib.Path = pathlib.Path("models") prompt_template: Optional[PromptTemplate] = None - generation_config: Optional[GenerationConfig] = None + + # HF Model instance + hf_model: HFModel # Optional features use_draft_model: bool = False @@ -41,7 +42,7 @@ class BaseModelContainer(abc.ABC): # Required methods @classmethod @abc.abstractmethod - async def create(cls, model_directory: pathlib.Path, **kwargs): + async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs): """ Asynchronously creates and initializes a model container instance. diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 9a7f30c..89fa005 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -4,7 +4,6 @@ import asyncio import gc import math import pathlib -import traceback import torch from exllamav2 import ( ExLlamaV2, @@ -47,7 +46,7 @@ from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template -from common.transformers_utils import GenerationConfig, TokenizerConfig +from common.transformers_utils import HFModel from common.utils import calculate_rope_alpha, coalesce, unwrap from endpoints.core.types.model import ModelCard, ModelCardParameters @@ -58,6 +57,10 @@ class ExllamaV2Container(BaseModelContainer): # Model directories model_dir: pathlib.Path = pathlib.Path("models") draft_model_dir: pathlib.Path = pathlib.Path("models") + prompt_template: Optional[PromptTemplate] = None + + # HF model instance + hf_model: HFModel # Exl2 vars config: Optional[ExLlamaV2Config] = None @@ -79,8 +82,6 @@ class ExllamaV2Container(BaseModelContainer): cache_mode: str = "FP16" draft_cache_mode: str = "FP16" max_batch_size: Optional[int] = None - generation_config: Optional[GenerationConfig] = None - tokenizer_config: Optional[TokenizerConfig] = None # GPU split vars gpu_split: List[float] = [] @@ -100,7 +101,7 @@ class ExllamaV2Container(BaseModelContainer): load_condition: asyncio.Condition = asyncio.Condition() @classmethod - async def create(cls, model_directory: pathlib.Path, **kwargs): + async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs): """ Primary asynchronous initializer for model container. @@ -114,6 +115,7 @@ class ExllamaV2Container(BaseModelContainer): self.config = ExLlamaV2Config() self.model_dir = model_directory self.config.model_dir = str(model_directory.resolve()) + self.hf_model = hf_model # Make the max seq len 4096 before preparing the config # This is a better default than 2048 @@ -124,30 +126,6 @@ class ExllamaV2Container(BaseModelContainer): # Check if the model arch is compatible with various exl2 features self.config.arch_compat_overrides() - # Load generation config overrides - generation_config_path = model_directory / "generation_config.json" - if generation_config_path.exists(): - try: - self.generation_config = await GenerationConfig.from_file( - model_directory - ) - except Exception: - logger.error(traceback.format_exc()) - logger.warning( - "Skipping generation config load because of an unexpected error." - ) - - # Load tokenizer config overrides - tokenizer_config_path = model_directory / "tokenizer_config.json" - if tokenizer_config_path.exists(): - try: - self.tokenizer_config = await TokenizerConfig.from_file(model_directory) - except Exception: - logger.error(traceback.format_exc()) - logger.warning( - "Skipping tokenizer config load because of an unexpected error." - ) - # Set vision state and error if vision isn't supported on the current model self.use_vision = unwrap(kwargs.get("vision"), False) if self.use_vision and not self.config.vision_model_type: @@ -864,7 +842,7 @@ class ExllamaV2Container(BaseModelContainer): self.tokenizer.encode( text, add_bos=unwrap( - kwargs.get("add_bos_token"), self.tokenizer_config.add_bos_token + kwargs.get("add_bos_token"), self.hf_model.add_bos_token() ), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), embeddings=mm_embeddings_content, @@ -1282,16 +1260,10 @@ class ExllamaV2Container(BaseModelContainer): ban_eos_token = params.ban_eos_token # Set add_bos_token for generation - add_bos_token = unwrap( - params.add_bos_token, self.tokenizer_config.add_bos_token - ) + add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token()) - # Fetch EOS tokens from generation_config if they exist - eos_tokens = ( - self.generation_config.eos_tokens() - if self.generation_config - else [self.tokenizer.eos_token_id] - ) + # Fetch EOS tokens from the HF model if they exist + eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id] # Ban the EOS token if specified. If not, append to stop conditions # as well. diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index c386a6e..52536fe 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -2,7 +2,6 @@ import asyncio import gc import pathlib import re -import traceback from typing import ( Any, AsyncIterator, @@ -35,7 +34,7 @@ from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template -from common.transformers_utils import GenerationConfig, TokenizerConfig +from common.transformers_utils import HFModel from common.utils import coalesce, unwrap from endpoints.core.types.model import ModelCard, ModelCardParameters @@ -46,7 +45,9 @@ class ExllamaV3Container(BaseModelContainer): # Exposed model information model_dir: pathlib.Path = pathlib.Path("models") prompt_template: Optional[PromptTemplate] = None - generation_config: Optional[GenerationConfig] = None + + # HF Model instance + hf_model: HFModel # Load synchronization # The bool is a master switch for accepting requests @@ -58,15 +59,14 @@ class ExllamaV3Container(BaseModelContainer): load_condition: asyncio.Condition = asyncio.Condition() # Exl3 vars - model: Optional[Model] - cache: Optional[Cache] - draft_model: Optional[Model] - draft_cache: Optional[Cache] - tokenizer: Optional[Tokenizer] - config: Optional[Config] - draft_config: Optional[Config] - generator: Optional[AsyncGenerator] - tokenizer_config: Optional[TokenizerConfig] + model: Optional[Model] = None + cache: Optional[Cache] = None + draft_model: Optional[Model] = None + draft_cache: Optional[Cache] = None + tokenizer: Optional[Tokenizer] = None + config: Optional[Config] = None + draft_config: Optional[Config] = None + generator: Optional[AsyncGenerator] = None # Class-specific vars gpu_split: List[float] | None = None @@ -82,7 +82,7 @@ class ExllamaV3Container(BaseModelContainer): # Required methods @classmethod - async def create(cls, model_directory: pathlib.Path, **kwargs): + async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs): """ Asynchronously creates and initializes a model container instance. @@ -96,50 +96,17 @@ class ExllamaV3Container(BaseModelContainer): self = cls() - self.model = None - self.cache = None - self.draft_model = None - self.draft_cache = None - self.tokenizer = None - self.config = None - self.draft_config = None - self.generator = None - self.tokenizer_config = None - logger.warning( "ExllamaV3 is currently in an alpha state. " "Please note that all config options may not work." ) self.model_dir = model_directory + self.hf_model = hf_model self.config = Config.from_directory(str(model_directory.resolve())) self.model = Model.from_config(self.config) self.tokenizer = Tokenizer.from_config(self.config) - # Load generation config overrides - generation_config_path = model_directory / "generation_config.json" - if generation_config_path.exists(): - try: - self.generation_config = await GenerationConfig.from_file( - model_directory - ) - except Exception: - logger.error(traceback.format_exc()) - logger.warning( - "Skipping generation config load because of an unexpected error." - ) - - # Load tokenizer config overrides - tokenizer_config_path = model_directory / "tokenizer_config.json" - if tokenizer_config_path.exists(): - try: - self.tokenizer_config = await TokenizerConfig.from_file(model_directory) - except Exception: - logger.error(traceback.format_exc()) - logger.warning( - "Skipping tokenizer config load because of an unexpected error." - ) - # Fallback to 4096 since exl3 can't fetch from HF's config.json self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) @@ -554,7 +521,9 @@ class ExllamaV3Container(BaseModelContainer): return ( self.tokenizer.encode( text, - add_bos=unwrap(kwargs.get("add_bos_token"), True), + add_bos=unwrap( + kwargs.get("add_bos_token"), self.hf_model.add_bos_token() + ), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), ) .flatten() @@ -822,16 +791,10 @@ class ExllamaV3Container(BaseModelContainer): prompts = [prompt] stop_conditions = params.stop - add_bos_token = unwrap( - params.add_bos_token, self.tokenizer_config.add_bos_token - ) + add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token()) # Fetch EOS tokens from generation_config if they exist - eos_tokens = ( - self.generation_config.eos_tokens() - if self.generation_config - else [self.tokenizer.eos_token_id] - ) + eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id] stop_conditions += eos_tokens diff --git a/common/model.py b/common/model.py index 9cdfdeb..0241d4a 100644 --- a/common/model.py +++ b/common/model.py @@ -17,7 +17,7 @@ from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.tabby_config import config from common.optional_dependencies import dependencies -from common.transformers_utils import HuggingFaceConfig +from common.transformers_utils import HFModel from common.utils import unwrap # Global variables for model container @@ -57,22 +57,15 @@ def load_progress(module, modules): yield module, modules -async def detect_backend(model_path: pathlib.Path) -> str: +def detect_backend(hf_model: HFModel) -> str: """Determine the appropriate backend based on model files and configuration.""" - try: - hf_config = await HuggingFaceConfig.from_directory(model_path) - quant_method = hf_config.quant_method() + quant_method = hf_model.quant_method() - if quant_method == "exl3": - return "exllamav3" - else: - return "exllamav2" - except Exception as exc: - raise ValueError( - "Failed to read the model's config.json. " - f"Please check your model directory at {model_path}." - ) from exc + if quant_method == "exl3": + return "exllamav3" + else: + return "exllamav2" async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): @@ -142,28 +135,29 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): kwargs = {**config.model_defaults, **kwargs} kwargs = await apply_inline_overrides(model_path, **kwargs) + # Fetch the extra HF configuration options + hf_model = await HFModel.from_directory(model_path) + # Create a new container and check if the right dependencies are installed - backend_name = unwrap( - kwargs.get("backend"), await detect_backend(model_path) - ).lower() - container_class = _BACKEND_REGISTRY.get(backend_name) + backend = unwrap(kwargs.get("backend"), detect_backend(hf_model)) + container_class = _BACKEND_REGISTRY.get(backend) if not container_class: available_backends = list(_BACKEND_REGISTRY.keys()) - if backend_name in available_backends: + if backend in available_backends: raise ValueError( - f"Backend '{backend_name}' selected, but required dependencies " + f"Backend '{backend}' selected, but required dependencies " "are not installed." ) else: raise ValueError( - f"Invalid backend '{backend_name}'. " + f"Invalid backend '{backend}'. " f"Available backends: {available_backends}" ) - logger.info(f"Using backend {backend_name}") + logger.info(f"Using backend {backend}") new_container: BaseModelContainer = await container_class.create( - model_path.resolve(), **kwargs + model_path.resolve(), hf_model, **kwargs ) # Add possible types of models that can be loaded diff --git a/common/transformers_utils.py b/common/transformers_utils.py index d1e5ac1..6cac3b8 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,8 +1,9 @@ import aiofiles import json import pathlib -from typing import Dict, List, Optional, Union +from loguru import logger from pydantic import BaseModel +from typing import Dict, List, Optional, Set, Union class GenerationConfig(BaseModel): @@ -14,7 +15,7 @@ class GenerationConfig(BaseModel): eos_token_id: Optional[Union[int, List[int]]] = None @classmethod - async def from_file(cls, model_directory: pathlib.Path): + async def from_directory(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" generation_config_path = model_directory / "generation_config.json" @@ -28,10 +29,12 @@ class GenerationConfig(BaseModel): def eos_tokens(self): """Wrapper method to fetch EOS tokens.""" - if isinstance(self.eos_token_id, int): + if isinstance(self.eos_token_id, list): + return self.eos_token_id + elif isinstance(self.eos_token_id, int): return [self.eos_token_id] else: - return self.eos_token_id + return [] class HuggingFaceConfig(BaseModel): @@ -42,6 +45,7 @@ class HuggingFaceConfig(BaseModel): Will be expanded as needed. """ + eos_token_id: Optional[Union[int, List[int]]] = None quantization_config: Optional[Dict] = None @classmethod @@ -64,6 +68,16 @@ class HuggingFaceConfig(BaseModel): else: return None + def eos_tokens(self): + """Wrapper method to fetch EOS tokens.""" + + if isinstance(self.eos_token_id, list): + return self.eos_token_id + elif isinstance(self.eos_token_id, int): + return [self.eos_token_id] + else: + return [] + class TokenizerConfig(BaseModel): """ @@ -73,7 +87,7 @@ class TokenizerConfig(BaseModel): add_bos_token: Optional[bool] = True @classmethod - async def from_file(cls, model_directory: pathlib.Path): + async def from_directory(cls, model_directory: pathlib.Path): """Create an instance from a tokenizer config file.""" tokenizer_config_path = model_directory / "tokenizer_config.json" @@ -83,3 +97,81 @@ class TokenizerConfig(BaseModel): contents = await tokenizer_config_json.read() tokenizer_config_dict = json.loads(contents) return cls.model_validate(tokenizer_config_dict) + + +class HFModel: + """ + Unified container for HuggingFace model configuration files. + These are abridged for hyper-specific model parameters not covered + by most backends. + + Includes: + - config.json + - generation_config.json + - tokenizer_config.json + """ + + hf_config: HuggingFaceConfig + tokenizer_config: Optional[TokenizerConfig] = None + generation_config: Optional[GenerationConfig] = None + + @classmethod + async def from_directory(cls, model_directory: pathlib.Path): + """Create an instance from a model directory""" + + self = cls() + + # A model must have an HF config + try: + self.hf_config = await HuggingFaceConfig.from_directory(model_directory) + except Exception as exc: + raise ValueError( + f"Failed to load config.json from {model_directory}" + ) from exc + + try: + self.generation_config = await GenerationConfig.from_directory( + model_directory + ) + except Exception: + logger.warning( + "Generation config file not found in model directory, skipping." + ) + + try: + self.tokenizer_config = await TokenizerConfig.from_directory( + model_directory + ) + except Exception: + logger.warning( + "Tokenizer config file not found in model directory, skipping." + ) + + return self + + def quant_method(self): + """Wrapper for quantization method""" + + return self.hf_config.quant_method() + + def eos_tokens(self): + """Combines and returns EOS tokens from various configs""" + + eos_ids: Set[int] = set() + + eos_ids.update(self.hf_config.eos_tokens()) + + if self.generation_config: + eos_ids.update(self.generation_config.eos_tokens()) + + # Convert back to a list + return list(eos_ids) + + def add_bos_token(self): + """Wrapper for tokenizer config""" + + if self.tokenizer_config: + return self.tokenizer_config.add_bos_token + + # Expected default + return True