diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 7d3138e..5c1151f 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -44,6 +44,7 @@ class ChatCompletionRequest(CommonCompletionRequest): messages: Union[str, List[Dict[str, str]]] prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True + template_vars: Optional[dict] = {} class ChatCompletionResponse(BaseModel): diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 4b3d39c..0ddaa94 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -141,14 +141,17 @@ def format_prompt_with_template(data: ChatCompletionRequest): unwrap(data.ban_eos_token, False), ) - template_vars = { - "messages": data.messages, - "add_generation_prompt": data.add_generation_prompt, - **special_tokens_dict, - } + # Overwrite any protected vars with their values + data.template_vars.update( + { + "messages": data.messages, + "add_generation_prompt": data.add_generation_prompt, + **special_tokens_dict, + } + ) prompt, template_stop_strings = get_prompt_from_template( - model.container.prompt_template, template_vars + model.container.prompt_template, data.template_vars ) # Append template stop strings