API: Add banned_strings
From exllamav2: List of strings that the generator will refuse to output. As soon as a partial match happens, a checkpoint is saved that the generator can rewind to if need be. Subsequent tokens are then held until the full string is resolved (match or no match) and either emitted or discarded, accordingly.
This commit is contained in:
parent
a1df22668b
commit
c0b631ba92
3 changed files with 31 additions and 0 deletions
|
|
@ -791,6 +791,7 @@ class ExllamaV2Container:
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||||
|
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
|
||||||
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
||||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||||
logit_bias = kwargs.get("logit_bias")
|
logit_bias = kwargs.get("logit_bias")
|
||||||
|
|
@ -960,6 +961,22 @@ class ExllamaV2Container:
|
||||||
)
|
)
|
||||||
min_tokens = 0
|
min_tokens = 0
|
||||||
|
|
||||||
|
# Check if banned_strings is supported
|
||||||
|
# TODO: Remove when a new version of ExllamaV2 is released
|
||||||
|
if banned_strings:
|
||||||
|
begin_stream_signature = signature(self.generator.begin_stream_ex)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_bound_vars = begin_stream_signature.bind_partial(
|
||||||
|
banned_strings=[]
|
||||||
|
)
|
||||||
|
begin_stream_args["banned_strings"] = banned_strings
|
||||||
|
except TypeError:
|
||||||
|
logger.warning(
|
||||||
|
"banned_strings is not supported by the currently "
|
||||||
|
"installed ExLlamaV2 version."
|
||||||
|
)
|
||||||
|
|
||||||
# Log generation options to console
|
# Log generation options to console
|
||||||
# Some options are too large, so log the args instead
|
# Some options are too large, so log the args instead
|
||||||
log_generation_params(
|
log_generation_params(
|
||||||
|
|
@ -979,6 +996,7 @@ class ExllamaV2Container:
|
||||||
logprobs=request_logprobs,
|
logprobs=request_logprobs,
|
||||||
stop_conditions=stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
banned_tokens=banned_tokens,
|
banned_tokens=banned_tokens,
|
||||||
|
banned_strings=banned_strings,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,10 @@ class BaseSamplerRequest(BaseModel):
|
||||||
default_factory=lambda: get_default_sampler_value("stop", [])
|
default_factory=lambda: get_default_sampler_value("stop", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
banned_strings: Optional[Union[str, List[str]]] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("banned_strings", [])
|
||||||
|
)
|
||||||
|
|
||||||
token_healing: Optional[bool] = Field(
|
token_healing: Optional[bool] = Field(
|
||||||
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
||||||
)
|
)
|
||||||
|
|
@ -257,6 +261,10 @@ class BaseSamplerRequest(BaseModel):
|
||||||
if self.stop and isinstance(self.stop, str):
|
if self.stop and isinstance(self.stop, str):
|
||||||
self.stop = [self.stop]
|
self.stop = [self.stop]
|
||||||
|
|
||||||
|
# Convert banned_strings to an array of strings
|
||||||
|
if self.banned_strings and isinstance(self.banned_strings, str):
|
||||||
|
self.banned_strings = [self.banned_strings]
|
||||||
|
|
||||||
# Convert string banned tokens to an integer list
|
# Convert string banned tokens to an integer list
|
||||||
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
||||||
self.banned_tokens = [
|
self.banned_tokens = [
|
||||||
|
|
@ -268,6 +276,7 @@ class BaseSamplerRequest(BaseModel):
|
||||||
"min_tokens": self.min_tokens,
|
"min_tokens": self.min_tokens,
|
||||||
"generate_window": self.generate_window,
|
"generate_window": self.generate_window,
|
||||||
"stop": self.stop,
|
"stop": self.stop,
|
||||||
|
"banned_strings": self.banned_strings,
|
||||||
"add_bos_token": self.add_bos_token,
|
"add_bos_token": self.add_bos_token,
|
||||||
"ban_eos_token": self.ban_eos_token,
|
"ban_eos_token": self.ban_eos_token,
|
||||||
"skip_special_tokens": self.skip_special_tokens,
|
"skip_special_tokens": self.skip_special_tokens,
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,10 @@ stop:
|
||||||
override: []
|
override: []
|
||||||
force: false
|
force: false
|
||||||
additive: false
|
additive: false
|
||||||
|
banned_strings:
|
||||||
|
override: []
|
||||||
|
force: false
|
||||||
|
additive: false
|
||||||
token_healing:
|
token_healing:
|
||||||
override: false
|
override: false
|
||||||
force: false
|
force: false
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue