Model: Auto-scale max_tokens by default
If max_tokens is None, it automatically scales to fill up the context. This does not mean the generation will fill up that context since EOS stops also exist. Originally suggested by #86 Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
8cbb59d6e1
commit
09a4c79847
2 changed files with 28 additions and 20 deletions
|
|
@ -638,7 +638,6 @@ class ExllamaV2Container:
|
|||
"""
|
||||
|
||||
token_healing = unwrap(kwargs.get("token_healing"), False)
|
||||
max_tokens = unwrap(kwargs.get("max_tokens"), 150)
|
||||
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
|
||||
generate_window = max(
|
||||
unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8
|
||||
|
|
@ -761,24 +760,8 @@ class ExllamaV2Container:
|
|||
gen_settings.top_p = 0
|
||||
gen_settings.typical = 0
|
||||
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
max_tokens=max_tokens,
|
||||
**vars(gen_settings),
|
||||
token_healing=token_healing,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
generate_window=generate_window,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
speculative_ngram=self.generator.speculative_ngram,
|
||||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
|
||||
# Log prompt to console
|
||||
log_prompt(prompt, negative_prompt)
|
||||
# Store the gen settings for logging purposes
|
||||
gen_settings_log_dict = vars(gen_settings)
|
||||
|
||||
# Set logit bias
|
||||
if logit_bias:
|
||||
|
|
@ -854,6 +837,31 @@ class ExllamaV2Container:
|
|||
|
||||
prompt_tokens = ids.shape[-1]
|
||||
|
||||
# Automatically set max_tokens to fill up the context
|
||||
# This should be an OK default, but may be changed in the future
|
||||
max_tokens = unwrap(
|
||||
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
|
||||
)
|
||||
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
max_tokens=max_tokens,
|
||||
**gen_settings_log_dict,
|
||||
token_healing=token_healing,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
generate_window=generate_window,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
speculative_ngram=self.generator.speculative_ngram,
|
||||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
|
||||
# Log prompt to console
|
||||
log_prompt(prompt, negative_prompt)
|
||||
|
||||
# Begin
|
||||
generated_tokens = 0
|
||||
full_response = ""
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"""Common class for sampler params that are used in APIs"""
|
||||
|
||||
max_tokens: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens", 150),
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens"),
|
||||
examples=[150],
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue