From efc01d947bedc61487d9ecb3561aa218e23023f9 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 13 Mar 2024 23:32:11 -0400 Subject: [PATCH] API + Model: Add speculative ngram decoding Speculative ngram decoding is like speculative decoding without the draft model. It's not as useful because it only decodes on predictable sequences, but it depends on the usecase. Signed-off-by: kingbri --- backends/exllamav2/model.py | 27 ++++++++++++++++++++++++--- common/sampling.py | 5 +++++ sampler_overrides/sample_preset.yml | 3 +++ 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c803c15..77908d8 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -20,7 +20,12 @@ from loguru import logger from typing import List, Optional, Union from backends.exllamav2.grammar import ExLlamaV2Grammar -from common.gen_logging import log_generation_params, log_metrics, log_prompt, log_response +from common.gen_logging import ( + log_generation_params, + log_metrics, + log_prompt, + log_response, +) from common.templating import ( PromptTemplate, find_template_from_model, @@ -598,7 +603,17 @@ class ExllamaV2Container: def check_unsupported_settings(self, **kwargs): """Check and warn the user if a sampler is unsupported. Meant for dev wheels!""" - pass + if unwrap(kwargs.get("speculative_ngram"), False) and not hasattr( + ExLlamaV2StreamingGenerator, "speculative_ngram" + ): + logger.warning( + "Speculative ngram is not supported by the currently " + "installed ExLlamaV2 version." + ) + + kwargs.pop("speculative_ngram") + + return kwargs # pylint: disable=too-many-locals,too-many-branches,too-many-statements def generate_gen(self, prompt: str, **kwargs): @@ -656,7 +671,7 @@ class ExllamaV2Container: gen_settings = ExLlamaV2Sampler.Settings() # Check unsupported settings for dev wheels - self.check_unsupported_settings(**kwargs) + kwargs = self.check_unsupported_settings(**kwargs) # Apply settings gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0) @@ -758,6 +773,11 @@ class ExllamaV2Container: request_logprobs = unwrap(kwargs.get("logprobs"), 0) self.generator.return_top_tokens = request_logprobs + # Speculative Ngram + self.generator.speculative_ngram = unwrap( + kwargs.get("speculative_ngram"), False + ) + # Override sampler settings for temp = 0 if gen_settings.temperature == 0: gen_settings.temperature = 1.0 @@ -775,6 +795,7 @@ class ExllamaV2Container: generate_window=generate_window, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, + speculative_ngram=self.generator.speculative_ngram, logprobs=request_logprobs, stop_conditions=stop_conditions, logit_bias=logit_bias, diff --git a/common/sampling.py b/common/sampling.py index aa7e7d0..5a4ea94 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -123,6 +123,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("grammar_string"), ) + speculative_ngram: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("speculative_ngram"), + ) + # Aliased variables typical: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("typical", 1.0), @@ -268,6 +272,7 @@ class BaseSamplerRequest(BaseModel): "negative_prompt": self.negative_prompt, "json_schema": self.json_schema, "grammar_string": self.grammar_string, + "speculative_ngram": self.speculative_ngram, } return {**gen_params, **kwargs} diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index da28dbd..88dee47 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -17,6 +17,9 @@ stop: token_healing: override: false force: false +speculative_ngram: + override: false + force: false # Commented out because the default is dynamically scaled #generate_window: