diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3408c2f..0cf1076 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -49,11 +49,7 @@ from common.gen_logging import ( from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest -from common.templating import ( - PromptTemplate, - TemplateLoadError, - find_template_from_model, -) +from common.templating import PromptTemplate, find_prompt_template from common.transformers_utils import GenerationConfig from common.utils import calculate_rope_alpha, coalesce, unwrap @@ -322,7 +318,7 @@ class ExllamaV2Container(BaseModelContainer): self.cache_size = self.config.max_seq_len # Try to set prompt template - self.prompt_template = await self.find_prompt_template( + self.prompt_template = await find_prompt_template( kwargs.get("prompt_template"), model_directory ) @@ -383,62 +379,6 @@ class ExllamaV2Container(BaseModelContainer): # Return the created instance return self - async def find_prompt_template(self, prompt_template_name, model_directory): - """Tries to find a prompt template using various methods.""" - - logger.info("Attempting to load a prompt template if present.") - - find_template_functions = [ - lambda: PromptTemplate.from_model_json( - pathlib.Path(self.config.model_dir) / "chat_template.json", - key="chat_template", - ), - lambda: PromptTemplate.from_model_json( - pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - key="chat_template", - ), - lambda: PromptTemplate.from_file(find_template_from_model(model_directory)), - ] - - # Find the template in the model directory if it exists - model_dir_template_path = ( - pathlib.Path(self.config.model_dir) / "tabby_template.jinja" - ) - if model_dir_template_path.exists(): - find_template_functions[:0] = [ - lambda: PromptTemplate.from_file(model_dir_template_path) - ] - - # Add lookup from prompt template name if provided - if prompt_template_name: - find_template_functions[:0] = [ - lambda: PromptTemplate.from_file( - pathlib.Path("templates") / prompt_template_name - ), - lambda: PromptTemplate.from_model_json( - pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - key="chat_template", - name=prompt_template_name, - ), - ] - - # Continue on exception since functions are tried as they fail - for template_func in find_template_functions: - try: - prompt_template = await template_func() - if prompt_template is not None: - return prompt_template - except TemplateLoadError as e: - logger.warning(f"TemplateLoadError: {str(e)}") - continue - except Exception: - logger.error(traceback.format_exc()) - logger.warning( - "An unexpected error happened when trying to load the template. " - "Trying other methods." - ) - continue - def get_model_parameters(self): model_params = { "name": self.model_dir.name, diff --git a/common/templating.py b/common/templating.py index 8ea4fd1..ff21300 100644 --- a/common/templating.py +++ b/common/templating.py @@ -1,5 +1,6 @@ """Small replication of AutoTokenizer's chat template system for efficiency""" +import traceback import aiofiles import json import pathlib @@ -211,3 +212,58 @@ def find_template_from_model(model_path: pathlib.Path): return template_name else: raise TemplateLoadError("Could not find template from model name.") + + +async def find_prompt_template(template_name, model_dir: pathlib.Path): + """Tries to find a prompt template using various methods.""" + + logger.info("Attempting to load a prompt template if present.") + + find_template_functions = [ + lambda: PromptTemplate.from_model_json( + model_dir / "chat_template.json", + key="chat_template", + ), + lambda: PromptTemplate.from_model_json( + model_dir / "tokenizer_config.json", + key="chat_template", + ), + lambda: PromptTemplate.from_file(find_template_from_model(model_dir)), + ] + + # Find the template in the model directory if it exists + model_dir_template_path = model_dir / "tabby_template.jinja" + if model_dir_template_path.exists(): + find_template_functions[:0] = [ + lambda: PromptTemplate.from_file(model_dir_template_path) + ] + + # Add lookup from prompt template name if provided + if template_name: + find_template_functions[:0] = [ + lambda: PromptTemplate.from_file( + pathlib.Path("templates") / template_name + ), + lambda: PromptTemplate.from_model_json( + model_dir / "tokenizer_config.json", + key="chat_template", + name=template_name, + ), + ] + + # Continue on exception since functions are tried as they fail + for template_func in find_template_functions: + try: + prompt_template = await template_func() + if prompt_template is not None: + return prompt_template + except TemplateLoadError as e: + logger.warning(f"TemplateLoadError: {str(e)}") + continue + except Exception: + logger.error(traceback.format_exc()) + logger.warning( + "An unexpected error happened when trying to load the template. " + "Trying other methods." + ) + continue