Templates: Migrate to class
Having many utility functions for initialization doesn't make much sense. Instead, handle anything regarding template creation inside the class which reduces the amount of function imports. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
9f93505bc1
commit
cab789e685
4 changed files with 122 additions and 132 deletions
|
|
@ -14,11 +14,7 @@ from common.concurrency import (
|
|||
generate_with_semaphore,
|
||||
)
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
get_prompt_from_template,
|
||||
get_template_from_file,
|
||||
)
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.OAI.types.auth import AuthPermissionResponse
|
||||
from endpoints.OAI.types.completion import CompletionRequest
|
||||
|
|
@ -224,8 +220,7 @@ async def switch_template(data: TemplateSwitchRequest):
|
|||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
template = get_template_from_file(data.name)
|
||||
model.container.prompt_template = template
|
||||
model.container.prompt_template = PromptTemplate.from_file(data.name)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
|
|
@ -402,9 +397,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text, _ = get_prompt_from_template(
|
||||
model.container.prompt_template, template_vars
|
||||
)
|
||||
text, _ = model.container.prompt_template.render(template_vars)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from common.networking import (
|
|||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
)
|
||||
from common.templating import get_prompt_from_template
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionLogprobs,
|
||||
|
|
@ -150,8 +149,8 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
}
|
||||
)
|
||||
|
||||
prompt, template_stop_strings = get_prompt_from_template(
|
||||
model.container.prompt_template, data.template_vars
|
||||
prompt, template_stop_strings = model.container.prompt_template.render(
|
||||
data.template_vars
|
||||
)
|
||||
|
||||
# Append template stop strings
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue