diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 3cf4400..fc0ba0a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -878,6 +878,7 @@ class ExllamaV2Container: encode_special_tokens=True, return_offsets=True, ) + print(ids) mask = ( self.tokenizer.padding_mask(ids) if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0] diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 5c1151f..92265a7 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -45,6 +45,7 @@ class ChatCompletionRequest(CommonCompletionRequest): prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} + response_prefix: Optional[str] = None class ChatCompletionResponse(BaseModel): diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 155f806..df93447 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -8,6 +8,7 @@ from uuid import uuid4 from fastapi import HTTPException from jinja2 import TemplateError +from loguru import logger from common import model from common.networking import ( @@ -153,6 +154,22 @@ def format_prompt_with_template(data: ChatCompletionRequest): data.template_vars ) + # Append response prefix if present + if data.response_prefix: + if data.add_generation_prompt: + prompt += data.response_prefix + else: + logger.warning( + "Could not add response prefix because " + "add_generation_prompt is False" + ) + + # Removes the starting BOS token if present + # This is to prevent add_bos_token from adding multiple bos tokens + bos_token = special_tokens_dict.get("bos_token") + if bos_token and prompt.startswith(bos_token): + prompt = prompt.removeprefix(bos_token) + # Append template stop strings if isinstance(data.stop, str): data.stop = [data.stop] + template_stop_strings diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index f7e50af..24a3d12 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -94,8 +94,8 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path) try: generation = await model.container.generate(data.prompt, **data.to_gen_params()) - response = _create_response(generation, model_path.name) + return response except Exception as exc: error_message = handle_request_error(