Model: Auto-scale max_tokens by default

If max_tokens is None, it automatically scales to fill up the context.
This does not mean the generation will fill up that context since
EOS stops also exist.

Originally suggested by #86

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-18 22:54:59 -04:00
parent 8cbb59d6e1
commit 09a4c79847
2 changed files with 28 additions and 20 deletions

View file

@ -638,7 +638,6 @@ class ExllamaV2Container:
"""
token_healing = unwrap(kwargs.get("token_healing"), False)
max_tokens = unwrap(kwargs.get("max_tokens"), 150)
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
generate_window = max(
unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8
@ -761,24 +760,8 @@ class ExllamaV2Container:
gen_settings.top_p = 0
gen_settings.typical = 0
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
max_tokens=max_tokens,
**vars(gen_settings),
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
generate_window=generate_window,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
speculative_ngram=self.generator.speculative_ngram,
logprobs=request_logprobs,
stop_conditions=stop_conditions,
logit_bias=logit_bias,
)
# Log prompt to console
log_prompt(prompt, negative_prompt)
# Store the gen settings for logging purposes
gen_settings_log_dict = vars(gen_settings)
# Set logit bias
if logit_bias:
@ -854,6 +837,31 @@ class ExllamaV2Container:
prompt_tokens = ids.shape[-1]
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
)
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
max_tokens=max_tokens,
**gen_settings_log_dict,
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
generate_window=generate_window,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
speculative_ngram=self.generator.speculative_ngram,
logprobs=request_logprobs,
stop_conditions=stop_conditions,
logit_bias=logit_bias,
)
# Log prompt to console
log_prompt(prompt, negative_prompt)
# Begin
generated_tokens = 0
full_response = ""

View file

@ -14,7 +14,7 @@ class BaseSamplerRequest(BaseModel):
"""Common class for sampler params that are used in APIs"""
max_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens", 150),
default_factory=lambda: get_default_sampler_value("max_tokens"),
examples=[150],
)