Merge pull request #102 from DocShotgun/main

Add support for min_tokens and banned_strings
This commit is contained in:
Brian Dashore 2024-05-10 21:21:57 -04:00 committed by GitHub
commit 5432f523cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 53 additions and 9 deletions

View file

@ -791,6 +791,7 @@ class ExllamaV2Container:
) )
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
add_bos_token = unwrap(kwargs.get("add_bos_token"), True) add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
logit_bias = kwargs.get("logit_bias") logit_bias = kwargs.get("logit_bias")
@ -905,6 +906,9 @@ class ExllamaV2Container:
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
) )
# Set min_tokens to generate while keeping EOS banned
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
# This is an inverse of skip_special_tokens # This is an inverse of skip_special_tokens
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
@ -925,26 +929,40 @@ class ExllamaV2Container:
} }
) )
# Check if decode_special_tokens is supported # MARK: Function signature checks. Not needed in newer ExllamaV2 versions
# TODO: Remove when a new version of ExllamaV2 is released
if decode_special_tokens: # Check if temporary token bans are supported
if min_tokens:
stream_signature = signature(self.generator.stream_ex)
try:
_bound_vars = stream_signature.bind_partial(ban_tokens=[])
except TypeError:
logger.warning(
"min_tokens is not supported by the currently "
"installed ExLlamaV2 version."
)
min_tokens = 0
# Check if banned_strings is supported
if banned_strings:
begin_stream_signature = signature(self.generator.begin_stream_ex) begin_stream_signature = signature(self.generator.begin_stream_ex)
try: try:
_bound_vars = begin_stream_signature.bind_partial( _bound_vars = begin_stream_signature.bind_partial(banned_strings=[])
decode_special_tokens=True begin_stream_args["banned_strings"] = banned_strings
)
begin_stream_args["decode_special_tokens"] = decode_special_tokens
except TypeError: except TypeError:
logger.warning( logger.warning(
"skip_special_tokens is not supported by the currently " "banned_strings is not supported by the currently "
"installed ExLlamaV2 version." "installed ExLlamaV2 version."
) )
banned_strings = []
# Log generation options to console # Log generation options to console
# Some options are too large, so log the args instead # Some options are too large, so log the args instead
log_generation_params( log_generation_params(
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=min_tokens,
stream=kwargs.get("stream"), stream=kwargs.get("stream"),
**gen_settings_log_dict, **gen_settings_log_dict,
token_healing=token_healing, token_healing=token_healing,
@ -959,6 +977,7 @@ class ExllamaV2Container:
logprobs=request_logprobs, logprobs=request_logprobs,
stop_conditions=stop_conditions, stop_conditions=stop_conditions,
banned_tokens=banned_tokens, banned_tokens=banned_tokens,
banned_strings=banned_strings,
logit_bias=logit_bias, logit_bias=logit_bias,
) )
@ -997,7 +1016,10 @@ class ExllamaV2Container:
# Run dict generation # Run dict generation
# Guarantees return of chunk, eos, and chunk_token_ids # Guarantees return of chunk, eos, and chunk_token_ids
raw_generation = self.generator.stream_ex() if generated_tokens < min_tokens:
raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens)
else:
raw_generation = self.generator.stream_ex()
if token_healing: if token_healing:
# Extract healed token # Extract healed token

View file

@ -18,6 +18,11 @@ class BaseSamplerRequest(BaseModel):
examples=[150], examples=[150],
) )
min_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
examples=[0],
)
generate_window: Optional[int] = Field( generate_window: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("generate_window"), default_factory=lambda: get_default_sampler_value("generate_window"),
examples=[512], examples=[512],
@ -27,6 +32,10 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("stop", []) default_factory=lambda: get_default_sampler_value("stop", [])
) )
banned_strings: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("banned_strings", [])
)
token_healing: Optional[bool] = Field( token_healing: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False) default_factory=lambda: get_default_sampler_value("token_healing", False)
) )
@ -252,6 +261,10 @@ class BaseSamplerRequest(BaseModel):
if self.stop and isinstance(self.stop, str): if self.stop and isinstance(self.stop, str):
self.stop = [self.stop] self.stop = [self.stop]
# Convert banned_strings to an array of strings
if self.banned_strings and isinstance(self.banned_strings, str):
self.banned_strings = [self.banned_strings]
# Convert string banned tokens to an integer list # Convert string banned tokens to an integer list
if self.banned_tokens and isinstance(self.banned_tokens, str): if self.banned_tokens and isinstance(self.banned_tokens, str):
self.banned_tokens = [ self.banned_tokens = [
@ -260,8 +273,10 @@ class BaseSamplerRequest(BaseModel):
gen_params = { gen_params = {
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"generate_window": self.generate_window, "generate_window": self.generate_window,
"stop": self.stop, "stop": self.stop,
"banned_strings": self.banned_strings,
"add_bos_token": self.add_bos_token, "add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token, "ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens, "skip_special_tokens": self.skip_special_tokens,

View file

@ -11,10 +11,17 @@
max_tokens: max_tokens:
override: 150 override: 150
force: false force: false
min_tokens:
override: 0
force: false
stop: stop:
override: [] override: []
force: false force: false
additive: false additive: false
banned_strings:
override: []
force: false
additive: false
token_healing: token_healing:
override: false override: false
force: false force: false