Api: Add ban_eos_token and add_bos_token support

Adds the ability for the client to specify whether to add the BOS
token and ban the EOS token.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-14 23:05:47 -05:00
parent 8fea5391a8
commit ea91d17a11
3 changed files with 19 additions and 4 deletions

View file

@ -54,6 +54,8 @@ class CompletionRequest(BaseModel):
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
# Converts to internal generation parameters
def to_gen_params(self):
@ -73,6 +75,8 @@ class CompletionRequest(BaseModel):
"prompt": self.prompt,
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"temperature": self.temperature,
"top_k": self.top_k,

View file

@ -2,13 +2,13 @@ from pydantic import BaseModel
from typing import List
class CommonTokenRequest(BaseModel):
add_bos: bool = True
add_bos_token: bool = True
encode_special_tokens: bool = True
decode_special_tokens: bool = True
def get_params(self):
return {
"add_bos": self.add_bos,
"add_bos_token": self.add_bos_token,
"encode_special_tokens": self.encode_special_tokens,
"decode_special_tokens": self.decode_special_tokens
}

View file

@ -200,7 +200,8 @@ class ModelContainer:
if text:
# Assume token encoding
return self.tokenizer.encode(
text, add_bos = kwargs.get("add_bos", True),
text,
add_bos = kwargs.get("add_bos_token", True),
encode_special_tokens = kwargs.get("encode_special_tokens", True)
)
if ids:
@ -236,6 +237,8 @@ class ModelContainer:
'repetition_decay' (int): Repetition penalty range (default: same as range)
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
'max_tokens' (int): Max no. tokens in response (default: 150)
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
'generate_window' (int): Space to reserve at the end of the model's context when generating.
Rolls context window by the same amount if context length is exceeded to allow generating past
@ -266,6 +269,10 @@ class ModelContainer:
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
# Ban the EOS token if specified
if kwargs.get("ban_eos_token", False):
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
@ -280,7 +287,11 @@ class ModelContainer:
# Tokenized context
ids = self.tokenizer.encode(prompt, encode_special_tokens = True)
ids = self.tokenizer.encode(
prompt,
add_bos=kwargs.get("add_bos_token", True),
encode_special_tokens = True
)
# Begin