API + Model: Add speculative ngram decoding
Speculative ngram decoding is like speculative decoding without the draft model. It's not as useful because it only decodes on predictable sequences, but it depends on the usecase. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
2ebefe8258
commit
efc01d947b
3 changed files with 32 additions and 3 deletions
|
|
@ -20,7 +20,12 @@ from loguru import logger
|
|||
from typing import List, Optional, Union
|
||||
|
||||
from backends.exllamav2.grammar import ExLlamaV2Grammar
|
||||
from common.gen_logging import log_generation_params, log_metrics, log_prompt, log_response
|
||||
from common.gen_logging import (
|
||||
log_generation_params,
|
||||
log_metrics,
|
||||
log_prompt,
|
||||
log_response,
|
||||
)
|
||||
from common.templating import (
|
||||
PromptTemplate,
|
||||
find_template_from_model,
|
||||
|
|
@ -598,7 +603,17 @@ class ExllamaV2Container:
|
|||
def check_unsupported_settings(self, **kwargs):
|
||||
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
|
||||
|
||||
pass
|
||||
if unwrap(kwargs.get("speculative_ngram"), False) and not hasattr(
|
||||
ExLlamaV2StreamingGenerator, "speculative_ngram"
|
||||
):
|
||||
logger.warning(
|
||||
"Speculative ngram is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
kwargs.pop("speculative_ngram")
|
||||
|
||||
return kwargs
|
||||
|
||||
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
||||
def generate_gen(self, prompt: str, **kwargs):
|
||||
|
|
@ -656,7 +671,7 @@ class ExllamaV2Container:
|
|||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
|
||||
# Check unsupported settings for dev wheels
|
||||
self.check_unsupported_settings(**kwargs)
|
||||
kwargs = self.check_unsupported_settings(**kwargs)
|
||||
|
||||
# Apply settings
|
||||
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
||||
|
|
@ -758,6 +773,11 @@ class ExllamaV2Container:
|
|||
request_logprobs = unwrap(kwargs.get("logprobs"), 0)
|
||||
self.generator.return_top_tokens = request_logprobs
|
||||
|
||||
# Speculative Ngram
|
||||
self.generator.speculative_ngram = unwrap(
|
||||
kwargs.get("speculative_ngram"), False
|
||||
)
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
if gen_settings.temperature == 0:
|
||||
gen_settings.temperature = 1.0
|
||||
|
|
@ -775,6 +795,7 @@ class ExllamaV2Container:
|
|||
generate_window=generate_window,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
speculative_ngram=self.generator.speculative_ngram,
|
||||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
logit_bias=logit_bias,
|
||||
|
|
|
|||
|
|
@ -123,6 +123,10 @@ class BaseSamplerRequest(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("grammar_string"),
|
||||
)
|
||||
|
||||
speculative_ngram: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("speculative_ngram"),
|
||||
)
|
||||
|
||||
# Aliased variables
|
||||
typical: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
|
|
@ -268,6 +272,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"negative_prompt": self.negative_prompt,
|
||||
"json_schema": self.json_schema,
|
||||
"grammar_string": self.grammar_string,
|
||||
"speculative_ngram": self.speculative_ngram,
|
||||
}
|
||||
|
||||
return {**gen_params, **kwargs}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,9 @@ stop:
|
|||
token_healing:
|
||||
override: false
|
||||
force: false
|
||||
speculative_ngram:
|
||||
override: false
|
||||
force: false
|
||||
|
||||
# Commented out because the default is dynamically scaled
|
||||
#generate_window:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue