Api: Add ban_eos_token and add_bos_token support
Adds the ability for the client to specify whether to add the BOS token and ban the EOS token. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
8fea5391a8
commit
ea91d17a11
3 changed files with 19 additions and 4 deletions
|
|
@ -54,6 +54,8 @@ class CompletionRequest(BaseModel):
|
|||
mirostat_mode: Optional[int] = 0
|
||||
mirostat_tau: Optional[float] = 1.5
|
||||
mirostat_eta: Optional[float] = 0.1
|
||||
add_bos_token: Optional[bool] = True
|
||||
ban_eos_token: Optional[bool] = False
|
||||
|
||||
# Converts to internal generation parameters
|
||||
def to_gen_params(self):
|
||||
|
|
@ -73,6 +75,8 @@ class CompletionRequest(BaseModel):
|
|||
"prompt": self.prompt,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"token_healing": self.token_healing,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from pydantic import BaseModel
|
|||
from typing import List
|
||||
|
||||
class CommonTokenRequest(BaseModel):
|
||||
add_bos: bool = True
|
||||
add_bos_token: bool = True
|
||||
encode_special_tokens: bool = True
|
||||
decode_special_tokens: bool = True
|
||||
|
||||
def get_params(self):
|
||||
return {
|
||||
"add_bos": self.add_bos,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"encode_special_tokens": self.encode_special_tokens,
|
||||
"decode_special_tokens": self.decode_special_tokens
|
||||
}
|
||||
|
|
|
|||
15
model.py
15
model.py
|
|
@ -200,7 +200,8 @@ class ModelContainer:
|
|||
if text:
|
||||
# Assume token encoding
|
||||
return self.tokenizer.encode(
|
||||
text, add_bos = kwargs.get("add_bos", True),
|
||||
text,
|
||||
add_bos = kwargs.get("add_bos_token", True),
|
||||
encode_special_tokens = kwargs.get("encode_special_tokens", True)
|
||||
)
|
||||
if ids:
|
||||
|
|
@ -236,6 +237,8 @@ class ModelContainer:
|
|||
'repetition_decay' (int): Repetition penalty range (default: same as range)
|
||||
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
|
||||
'max_tokens' (int): Max no. tokens in response (default: 150)
|
||||
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
|
||||
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
|
||||
'stream_interval' (float): Interval in seconds between each output chunk (default: immediate)
|
||||
'generate_window' (int): Space to reserve at the end of the model's context when generating.
|
||||
Rolls context window by the same amount if context length is exceeded to allow generating past
|
||||
|
|
@ -266,6 +269,10 @@ class ModelContainer:
|
|||
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
|
||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
|
||||
|
||||
# Ban the EOS token if specified
|
||||
if kwargs.get("ban_eos_token", False):
|
||||
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
|
||||
if gen_settings.temperature == 0:
|
||||
|
|
@ -280,7 +287,11 @@ class ModelContainer:
|
|||
|
||||
# Tokenized context
|
||||
|
||||
ids = self.tokenizer.encode(prompt, encode_special_tokens = True)
|
||||
ids = self.tokenizer.encode(
|
||||
prompt,
|
||||
add_bos=kwargs.get("add_bos_token", True),
|
||||
encode_special_tokens = True
|
||||
)
|
||||
|
||||
# Begin
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue