Model: Fetch from generation_config and tokenizer_config in Exl3

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-02 00:16:11 -04:00
parent 59d081fe83
commit e8f00412f6
2 changed files with 31 additions and 3 deletions

View file

@ -120,7 +120,6 @@ class ExllamaV2Container(BaseModelContainer):
self.config.max_seq_len = 4096
self.config.prepare()
print(self.config.max_seq_len)
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()

View file

@ -1,6 +1,7 @@
import asyncio
import gc
import pathlib
import traceback
from typing import (
Any,
AsyncIterator,
@ -32,7 +33,7 @@ from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig
from common.transformers_utils import GenerationConfig, TokenizerConfig
from common.utils import coalesce, unwrap
from endpoints.core.types.model import ModelCard, ModelCardParameters
@ -60,6 +61,8 @@ class ExllamaV3Container(BaseModelContainer):
tokenizer: Tokenizer
config: Config
generator: Optional[AsyncGenerator] = None
generation_config: Optional[GenerationConfig] = None
tokenizer_config: Optional[TokenizerConfig] = None
# Class-specific vars
gpu_split: List[float] | None = None
@ -96,6 +99,30 @@ class ExllamaV3Container(BaseModelContainer):
self.model = Model.from_config(self.config)
self.tokenizer = Tokenizer.from_config(self.config)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
model_directory
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Load tokenizer config overrides
tokenizer_config_path = model_directory / "tokenizer_config.json"
if tokenizer_config_path.exists():
try:
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping tokenizer config load because of an unexpected error."
)
# Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
@ -689,7 +716,9 @@ class ExllamaV3Container(BaseModelContainer):
prompts = [prompt]
stop_conditions = params.stop
add_bos_token = params.add_bos_token
add_bos_token = unwrap(
params.add_bos_token, self.tokenizer_config.add_bos_token, True
)
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (