OAI: Add response_prefix and fix BOS token issues in chat completions
response_prefix is used to add a prefix before generating the next message. This is used in many cases such as continuining a prompt (see #96). Also if a template has BOS token specified, add_bos_token will append two BOS tokens. Add a check which strips a starting BOS token from the prompt if it exists. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ed7cd3cb59
commit
fb1d2f34c1
4 changed files with 20 additions and 1 deletions
|
|
@ -45,6 +45,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
prompt_template: Optional[str] = None
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
template_vars: Optional[dict] = {}
|
||||
response_prefix: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from uuid import uuid4
|
|||
|
||||
from fastapi import HTTPException
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
|
|
@ -153,6 +154,22 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
data.template_vars
|
||||
)
|
||||
|
||||
# Append response prefix if present
|
||||
if data.response_prefix:
|
||||
if data.add_generation_prompt:
|
||||
prompt += data.response_prefix
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not add response prefix because "
|
||||
"add_generation_prompt is False"
|
||||
)
|
||||
|
||||
# Removes the starting BOS token if present
|
||||
# This is to prevent add_bos_token from adding multiple bos tokens
|
||||
bos_token = special_tokens_dict.get("bos_token")
|
||||
if bos_token and prompt.startswith(bos_token):
|
||||
prompt = prompt.removeprefix(bos_token)
|
||||
|
||||
# Append template stop strings
|
||||
if isinstance(data.stop, str):
|
||||
data.stop = [data.stop] + template_stop_strings
|
||||
|
|
|
|||
|
|
@ -94,8 +94,8 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
|
|||
|
||||
try:
|
||||
generation = await model.container.generate(data.prompt, **data.to_gen_params())
|
||||
|
||||
response = _create_response(generation, model_path.name)
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue