OAI: Add ability to pass extra vars in jinja templates

A chat completion can now declare extra template_vars to pass when
a template is rendered, opening up the possibility of using state
outside of huggingface's parameters.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-11 00:55:32 -04:00
parent b1f3baad74
commit 2a0aaa2e8a
2 changed files with 10 additions and 6 deletions

View file

@ -44,6 +44,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
messages: Union[str, List[Dict[str, str]]]
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}
class ChatCompletionResponse(BaseModel):

View file

@ -141,14 +141,17 @@ def format_prompt_with_template(data: ChatCompletionRequest):
unwrap(data.ban_eos_token, False),
)
template_vars = {
"messages": data.messages,
"add_generation_prompt": data.add_generation_prompt,
**special_tokens_dict,
}
# Overwrite any protected vars with their values
data.template_vars.update(
{
"messages": data.messages,
"add_generation_prompt": data.add_generation_prompt,
**special_tokens_dict,
}
)
prompt, template_stop_strings = get_prompt_from_template(
model.container.prompt_template, template_vars
model.container.prompt_template, data.template_vars
)
# Append template stop strings