Merge pull request #102 from DocShotgun/main
Add support for min_tokens and banned_strings
This commit is contained in:
commit
5432f523cb
3 changed files with 53 additions and 9 deletions
|
|
@ -791,6 +791,7 @@ class ExllamaV2Container:
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||||
|
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
|
||||||
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
||||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||||
logit_bias = kwargs.get("logit_bias")
|
logit_bias = kwargs.get("logit_bias")
|
||||||
|
|
@ -905,6 +906,9 @@ class ExllamaV2Container:
|
||||||
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
|
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set min_tokens to generate while keeping EOS banned
|
||||||
|
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
|
||||||
|
|
||||||
# This is an inverse of skip_special_tokens
|
# This is an inverse of skip_special_tokens
|
||||||
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
|
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
|
||||||
|
|
||||||
|
|
@ -925,26 +929,40 @@ class ExllamaV2Container:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if decode_special_tokens is supported
|
# MARK: Function signature checks. Not needed in newer ExllamaV2 versions
|
||||||
# TODO: Remove when a new version of ExllamaV2 is released
|
|
||||||
if decode_special_tokens:
|
# Check if temporary token bans are supported
|
||||||
|
if min_tokens:
|
||||||
|
stream_signature = signature(self.generator.stream_ex)
|
||||||
|
|
||||||
|
try:
|
||||||
|
_bound_vars = stream_signature.bind_partial(ban_tokens=[])
|
||||||
|
except TypeError:
|
||||||
|
logger.warning(
|
||||||
|
"min_tokens is not supported by the currently "
|
||||||
|
"installed ExLlamaV2 version."
|
||||||
|
)
|
||||||
|
min_tokens = 0
|
||||||
|
|
||||||
|
# Check if banned_strings is supported
|
||||||
|
if banned_strings:
|
||||||
begin_stream_signature = signature(self.generator.begin_stream_ex)
|
begin_stream_signature = signature(self.generator.begin_stream_ex)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_bound_vars = begin_stream_signature.bind_partial(
|
_bound_vars = begin_stream_signature.bind_partial(banned_strings=[])
|
||||||
decode_special_tokens=True
|
begin_stream_args["banned_strings"] = banned_strings
|
||||||
)
|
|
||||||
begin_stream_args["decode_special_tokens"] = decode_special_tokens
|
|
||||||
except TypeError:
|
except TypeError:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"skip_special_tokens is not supported by the currently "
|
"banned_strings is not supported by the currently "
|
||||||
"installed ExLlamaV2 version."
|
"installed ExLlamaV2 version."
|
||||||
)
|
)
|
||||||
|
banned_strings = []
|
||||||
|
|
||||||
# Log generation options to console
|
# Log generation options to console
|
||||||
# Some options are too large, so log the args instead
|
# Some options are too large, so log the args instead
|
||||||
log_generation_params(
|
log_generation_params(
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
min_tokens=min_tokens,
|
||||||
stream=kwargs.get("stream"),
|
stream=kwargs.get("stream"),
|
||||||
**gen_settings_log_dict,
|
**gen_settings_log_dict,
|
||||||
token_healing=token_healing,
|
token_healing=token_healing,
|
||||||
|
|
@ -959,6 +977,7 @@ class ExllamaV2Container:
|
||||||
logprobs=request_logprobs,
|
logprobs=request_logprobs,
|
||||||
stop_conditions=stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
banned_tokens=banned_tokens,
|
banned_tokens=banned_tokens,
|
||||||
|
banned_strings=banned_strings,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -997,7 +1016,10 @@ class ExllamaV2Container:
|
||||||
|
|
||||||
# Run dict generation
|
# Run dict generation
|
||||||
# Guarantees return of chunk, eos, and chunk_token_ids
|
# Guarantees return of chunk, eos, and chunk_token_ids
|
||||||
raw_generation = self.generator.stream_ex()
|
if generated_tokens < min_tokens:
|
||||||
|
raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens)
|
||||||
|
else:
|
||||||
|
raw_generation = self.generator.stream_ex()
|
||||||
|
|
||||||
if token_healing:
|
if token_healing:
|
||||||
# Extract healed token
|
# Extract healed token
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,11 @@ class BaseSamplerRequest(BaseModel):
|
||||||
examples=[150],
|
examples=[150],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
min_tokens: Optional[int] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
|
||||||
|
examples=[0],
|
||||||
|
)
|
||||||
|
|
||||||
generate_window: Optional[int] = Field(
|
generate_window: Optional[int] = Field(
|
||||||
default_factory=lambda: get_default_sampler_value("generate_window"),
|
default_factory=lambda: get_default_sampler_value("generate_window"),
|
||||||
examples=[512],
|
examples=[512],
|
||||||
|
|
@ -27,6 +32,10 @@ class BaseSamplerRequest(BaseModel):
|
||||||
default_factory=lambda: get_default_sampler_value("stop", [])
|
default_factory=lambda: get_default_sampler_value("stop", [])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
banned_strings: Optional[Union[str, List[str]]] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("banned_strings", [])
|
||||||
|
)
|
||||||
|
|
||||||
token_healing: Optional[bool] = Field(
|
token_healing: Optional[bool] = Field(
|
||||||
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
||||||
)
|
)
|
||||||
|
|
@ -252,6 +261,10 @@ class BaseSamplerRequest(BaseModel):
|
||||||
if self.stop and isinstance(self.stop, str):
|
if self.stop and isinstance(self.stop, str):
|
||||||
self.stop = [self.stop]
|
self.stop = [self.stop]
|
||||||
|
|
||||||
|
# Convert banned_strings to an array of strings
|
||||||
|
if self.banned_strings and isinstance(self.banned_strings, str):
|
||||||
|
self.banned_strings = [self.banned_strings]
|
||||||
|
|
||||||
# Convert string banned tokens to an integer list
|
# Convert string banned tokens to an integer list
|
||||||
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
||||||
self.banned_tokens = [
|
self.banned_tokens = [
|
||||||
|
|
@ -260,8 +273,10 @@ class BaseSamplerRequest(BaseModel):
|
||||||
|
|
||||||
gen_params = {
|
gen_params = {
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
|
"min_tokens": self.min_tokens,
|
||||||
"generate_window": self.generate_window,
|
"generate_window": self.generate_window,
|
||||||
"stop": self.stop,
|
"stop": self.stop,
|
||||||
|
"banned_strings": self.banned_strings,
|
||||||
"add_bos_token": self.add_bos_token,
|
"add_bos_token": self.add_bos_token,
|
||||||
"ban_eos_token": self.ban_eos_token,
|
"ban_eos_token": self.ban_eos_token,
|
||||||
"skip_special_tokens": self.skip_special_tokens,
|
"skip_special_tokens": self.skip_special_tokens,
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,17 @@
|
||||||
max_tokens:
|
max_tokens:
|
||||||
override: 150
|
override: 150
|
||||||
force: false
|
force: false
|
||||||
|
min_tokens:
|
||||||
|
override: 0
|
||||||
|
force: false
|
||||||
stop:
|
stop:
|
||||||
override: []
|
override: []
|
||||||
force: false
|
force: false
|
||||||
additive: false
|
additive: false
|
||||||
|
banned_strings:
|
||||||
|
override: []
|
||||||
|
force: false
|
||||||
|
additive: false
|
||||||
token_healing:
|
token_healing:
|
||||||
override: false
|
override: false
|
||||||
force: false
|
force: false
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue