API: Add banned_strings
From exllamav2: List of strings that the generator will refuse to output. As soon as a partial match happens, a checkpoint is saved that the generator can rewind to if need be. Subsequent tokens are then held until the full string is resolved (match or no match) and either emitted or discarded, accordingly.
This commit is contained in:
parent
a1df22668b
commit
c0b631ba92
3 changed files with 31 additions and 0 deletions
|
|
@ -32,6 +32,10 @@ class BaseSamplerRequest(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("stop", [])
|
||||
)
|
||||
|
||||
banned_strings: Optional[Union[str, List[str]]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("banned_strings", [])
|
||||
)
|
||||
|
||||
token_healing: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
||||
)
|
||||
|
|
@ -257,6 +261,10 @@ class BaseSamplerRequest(BaseModel):
|
|||
if self.stop and isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
# Convert banned_strings to an array of strings
|
||||
if self.banned_strings and isinstance(self.banned_strings, str):
|
||||
self.banned_strings = [self.banned_strings]
|
||||
|
||||
# Convert string banned tokens to an integer list
|
||||
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
||||
self.banned_tokens = [
|
||||
|
|
@ -268,6 +276,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"min_tokens": self.min_tokens,
|
||||
"generate_window": self.generate_window,
|
||||
"stop": self.stop,
|
||||
"banned_strings": self.banned_strings,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"skip_special_tokens": self.skip_special_tokens,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue