From 9f93505bc1c1b738cdd29f0c93b515bf3d7df148 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 21 Apr 2024 00:37:46 -0400 Subject: [PATCH] OAI: Add skip_special_tokens parameter Allows the ability to decode special tokens if the user wishes. Signed-off-by: kingbri --- backends/exllamav2/model.py | 77 ++++++++++++++++++++----------------- common/sampling.py | 6 +++ 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 2934292..d6b86ea 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -17,6 +17,7 @@ from exllamav2 import ( ExLlamaV2Lora, ) from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler +from inspect import signature from itertools import zip_longest from loguru import logger from typing import List, Optional, Union @@ -663,16 +664,6 @@ class ExllamaV2Container: def check_unsupported_settings(self, **kwargs): """Check and warn the user if a sampler is unsupported. Meant for dev wheels!""" - if unwrap(kwargs.get("speculative_ngram"), False) and not hasattr( - ExLlamaV2StreamingGenerator, "speculative_ngram" - ): - logger.warning( - "Speculative ngram is not supported by the currently " - "installed ExLlamaV2 version." - ) - - kwargs.pop("speculative_ngram") - return kwargs async def generate_gen( @@ -907,6 +898,42 @@ class ExllamaV2Container: kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens ) + # This is an inverse of skip_special_tokens + decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) + + begin_stream_args = { + "token_healing": token_healing, + "loras": self.active_loras, + "return_probabilities": request_logprobs > 0, + "return_top_tokens": request_logprobs, + "return_logits": request_logprobs > 0, + "abort_event": abort_event, + } + + if self.use_cfg: + begin_stream_args.update( + { + "input_mask": mask, + "position_offsets": offsets, + } + ) + + # Check if decode_special_tokens is supported + # TODO: Remove when a new version of ExllamaV2 is released + if decode_special_tokens: + begin_stream_signature = signature(self.generator.begin_stream_ex) + + try: + _bound_vars = begin_stream_signature.bind_partial( + decode_special_tokens=True + ) + begin_stream_args["decode_special_tokens"] = decode_special_tokens + except TypeError: + logger.warning( + "skip_special_tokens is not supported by the currently " + "installed ExLlamaV2 version." + ) + # Log generation options to console # Some options are too large, so log the args instead log_generation_params( @@ -920,6 +947,7 @@ class ExllamaV2Container: eos_token_id=eos_tokens, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, + skip_special_tokens=not decode_special_tokens, speculative_ngram=self.generator.speculative_ngram, logprobs=request_logprobs, stop_conditions=stop_conditions, @@ -948,31 +976,10 @@ class ExllamaV2Container: active_ids = ids[:, max(0, overflow) :] chunk_tokens = self.config.max_seq_len - active_ids.shape[-1] - # Split for exllama versions that have CFG - if self.use_cfg: - self.generator.begin_stream_ex( - active_ids, - gen_settings, - token_healing=token_healing, - loras=self.active_loras, - input_mask=mask, - position_offsets=offsets, - return_probabilities=request_logprobs > 0, - return_top_tokens=request_logprobs, - return_logits=request_logprobs > 0, - abort_event=abort_event, - ) - else: - self.generator.begin_stream_ex( - active_ids, - gen_settings, - token_healing=token_healing, - loras=self.active_loras, - return_probabilities=request_logprobs > 0, - return_top_tokens=request_logprobs, - return_logits=request_logprobs > 0, - abort_event=abort_event, - ) + # Kick off the streaming generation + self.generator.begin_stream_ex( + active_ids, gen_settings, **begin_stream_args + ) # Reset offsets for subsequent passes if the context is truncated offsets = None diff --git a/common/sampling.py b/common/sampling.py index 022030c..e8be21b 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -106,6 +106,11 @@ class BaseSamplerRequest(BaseModel): examples=[False], ) + skip_special_tokens: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("ban_eos_token", True), + examples=[True], + ) + logit_bias: Optional[Dict[int, float]] = Field( default_factory=lambda: get_default_sampler_value("logit_bias"), examples=[{"1": 10, "2": 50}], @@ -246,6 +251,7 @@ class BaseSamplerRequest(BaseModel): "stop": self.stop, "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, + "skip_special_tokens": self.skip_special_tokens, "token_healing": self.token_healing, "logit_bias": self.logit_bias, "temperature": self.temperature,