Templates: Migrate to class

Having many utility functions for initialization doesn't make much sense.
Instead, handle anything regarding template creation inside the
class which reduces the amount of function imports.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-21 23:28:14 -04:00
parent 9f93505bc1
commit cab789e685
4 changed files with 122 additions and 132 deletions

View file

@ -14,11 +14,7 @@ from common.concurrency import (
generate_with_semaphore,
)
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import (
get_all_templates,
get_prompt_from_template,
get_template_from_file,
)
from common.templating import PromptTemplate, get_all_templates
from common.utils import coalesce, unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest
@ -224,8 +220,7 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message)
try:
template = get_template_from_file(data.name)
model.container.prompt_template = template
model.container.prompt_template = PromptTemplate.from_file(data.name)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
@ -402,9 +397,7 @@ async def encode_tokens(data: TokenEncodeRequest):
**special_tokens_dict,
}
text, _ = get_prompt_from_template(
model.container.prompt_template, template_vars
)
text, _ = model.container.prompt_template.render(template_vars)
raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, [])

View file

@ -15,7 +15,6 @@ from common.networking import (
handle_request_disconnect,
handle_request_error,
)
from common.templating import get_prompt_from_template
from common.utils import unwrap
from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs,
@ -150,8 +149,8 @@ def format_prompt_with_template(data: ChatCompletionRequest):
}
)
prompt, template_stop_strings = get_prompt_from_template(
model.container.prompt_template, data.template_vars
prompt, template_stop_strings = model.container.prompt_template.render(
data.template_vars
)
# Append template stop strings