diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 7554ae3..004f5e4 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -50,7 +50,7 @@ from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template -from common.transformers_utils import GenerationConfig +from common.transformers_utils import GenerationConfig, TokenizerConfig from common.utils import calculate_rope_alpha, coalesce, unwrap from endpoints.core.types.model import ModelCard, ModelCardParameters @@ -80,6 +80,7 @@ class ExllamaV2Container(BaseModelContainer): draft_cache_mode: str = "FP16" max_batch_size: Optional[int] = None generation_config: Optional[GenerationConfig] = None + tokenizer_config: Optional[TokenizerConfig] = None # GPU split vars gpu_split: List[float] = [] @@ -130,7 +131,7 @@ class ExllamaV2Container(BaseModelContainer): if generation_config_path.exists(): try: self.generation_config = await GenerationConfig.from_file( - generation_config_path.parent + model_directory ) except Exception: logger.error(traceback.format_exc()) @@ -138,6 +139,19 @@ class ExllamaV2Container(BaseModelContainer): "Skipping generation config load because of an unexpected error." ) + # Load tokenizer config overrides + tokenizer_config_path = model_directory / "tokenizer_config.json" + if tokenizer_config_path.exists(): + try: + self.tokenizer_config = await TokenizerConfig.from_file( + model_directory + ) + except Exception: + logger.error(traceback.format_exc()) + logger.warning( + "Skipping tokenizer config load because of an unexpected error." + ) + # Set vision state and error if vision isn't supported on the current model self.use_vision = unwrap(kwargs.get("vision"), False) if self.use_vision and not self.config.vision_model_type: @@ -1240,9 +1254,17 @@ class ExllamaV2Container(BaseModelContainer): ) and gen_settings.token_repetition_range == -1 stop_conditions = params.stop - add_bos_token = unwrap(params.add_bos_token, True) ban_eos_token = params.ban_eos_token + + print(self.tokenizer_config.add_bos_token) + # Set add_bos_token for generation + add_bos_token = coalesce( + params.add_bos_token, self.tokenizer_config.add_bos_token, True + ) + + print(add_bos_token) + # Fetch EOS tokens from generation_config if they exist eos_tokens = ( self.generation_config.eos_tokens() diff --git a/common/templating.py b/common/templating.py index d71de11..3a9347b 100644 --- a/common/templating.py +++ b/common/templating.py @@ -239,6 +239,7 @@ async def find_prompt_template(template_name, model_dir: pathlib.Path): ] # Add lookup from prompt template name if provided + # TODO: Possibly link to the TokenizerConfig class if template_name: find_template_functions[:0] = [ lambda: PromptTemplate.from_file(pathlib.Path("templates") / template_name), diff --git a/common/transformers_utils.py b/common/transformers_utils.py index a765b9f..32ff52e 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -53,3 +53,23 @@ class HuggingFaceConfig(BaseModel): contents = await hf_config_json.read() hf_config_dict = json.loads(contents) return cls.model_validate(hf_config_dict) + + +class TokenizerConfig(BaseModel): + """ + An abridged version of HuggingFace's tokenizer config. + """ + + add_bos_token: Optional[bool] = None + + @classmethod + async def from_file(cls, model_directory: pathlib.Path): + """Create an instance from a tokenizer config file.""" + + tokenizer_config_path = model_directory / "tokenizer_config.json" + async with aiofiles.open( + tokenizer_config_path, "r", encoding="utf8" + ) as tokenizer_config_json: + contents = await tokenizer_config_json.read() + tokenizer_config_dict = json.loads(contents) + return cls.model_validate(tokenizer_config_dict)