API: Add banned_strings

From exllamav2: List of strings that the generator will refuse to output. As soon as a partial match happens, a checkpoint is saved that the generator can rewind to if need be. Subsequent tokens are then held until the full string is resolved (match or no match) and either emitted or discarded, accordingly.
This commit is contained in:
DocShotgun 2024-05-10 13:53:55 -07:00
parent a1df22668b
commit c0b631ba92
3 changed files with 31 additions and 0 deletions

View file

@ -791,6 +791,7 @@ class ExllamaV2Container:
) )
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
add_bos_token = unwrap(kwargs.get("add_bos_token"), True) add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
logit_bias = kwargs.get("logit_bias") logit_bias = kwargs.get("logit_bias")
@ -960,6 +961,22 @@ class ExllamaV2Container:
) )
min_tokens = 0 min_tokens = 0
# Check if banned_strings is supported
# TODO: Remove when a new version of ExllamaV2 is released
if banned_strings:
begin_stream_signature = signature(self.generator.begin_stream_ex)
try:
_bound_vars = begin_stream_signature.bind_partial(
banned_strings=[]
)
begin_stream_args["banned_strings"] = banned_strings
except TypeError:
logger.warning(
"banned_strings is not supported by the currently "
"installed ExLlamaV2 version."
)
# Log generation options to console # Log generation options to console
# Some options are too large, so log the args instead # Some options are too large, so log the args instead
log_generation_params( log_generation_params(
@ -979,6 +996,7 @@ class ExllamaV2Container:
logprobs=request_logprobs, logprobs=request_logprobs,
stop_conditions=stop_conditions, stop_conditions=stop_conditions,
banned_tokens=banned_tokens, banned_tokens=banned_tokens,
banned_strings=banned_strings,
logit_bias=logit_bias, logit_bias=logit_bias,
) )

View file

@ -32,6 +32,10 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("stop", []) default_factory=lambda: get_default_sampler_value("stop", [])
) )
banned_strings: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("banned_strings", [])
)
token_healing: Optional[bool] = Field( token_healing: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False) default_factory=lambda: get_default_sampler_value("token_healing", False)
) )
@ -257,6 +261,10 @@ class BaseSamplerRequest(BaseModel):
if self.stop and isinstance(self.stop, str): if self.stop and isinstance(self.stop, str):
self.stop = [self.stop] self.stop = [self.stop]
# Convert banned_strings to an array of strings
if self.banned_strings and isinstance(self.banned_strings, str):
self.banned_strings = [self.banned_strings]
# Convert string banned tokens to an integer list # Convert string banned tokens to an integer list
if self.banned_tokens and isinstance(self.banned_tokens, str): if self.banned_tokens and isinstance(self.banned_tokens, str):
self.banned_tokens = [ self.banned_tokens = [
@ -268,6 +276,7 @@ class BaseSamplerRequest(BaseModel):
"min_tokens": self.min_tokens, "min_tokens": self.min_tokens,
"generate_window": self.generate_window, "generate_window": self.generate_window,
"stop": self.stop, "stop": self.stop,
"banned_strings": self.banned_strings,
"add_bos_token": self.add_bos_token, "add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token, "ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens, "skip_special_tokens": self.skip_special_tokens,

View file

@ -18,6 +18,10 @@ stop:
override: [] override: []
force: false force: false
additive: false additive: false
banned_strings:
override: []
force: false
additive: false
token_healing: token_healing:
override: false override: false
force: false force: false