OAI: Add skip_special_tokens parameter

Allows the ability to decode special tokens if the user wishes.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-21 00:37:46 -04:00
parent 67f061859d
commit 9f93505bc1
2 changed files with 48 additions and 35 deletions

View file

@ -17,6 +17,7 @@ from exllamav2 import (
ExLlamaV2Lora,
)
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
from inspect import signature
from itertools import zip_longest
from loguru import logger
from typing import List, Optional, Union
@ -663,16 +664,6 @@ class ExllamaV2Container:
def check_unsupported_settings(self, **kwargs):
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
if unwrap(kwargs.get("speculative_ngram"), False) and not hasattr(
ExLlamaV2StreamingGenerator, "speculative_ngram"
):
logger.warning(
"Speculative ngram is not supported by the currently "
"installed ExLlamaV2 version."
)
kwargs.pop("speculative_ngram")
return kwargs
async def generate_gen(
@ -907,6 +898,42 @@ class ExllamaV2Container:
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
)
# This is an inverse of skip_special_tokens
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
begin_stream_args = {
"token_healing": token_healing,
"loras": self.active_loras,
"return_probabilities": request_logprobs > 0,
"return_top_tokens": request_logprobs,
"return_logits": request_logprobs > 0,
"abort_event": abort_event,
}
if self.use_cfg:
begin_stream_args.update(
{
"input_mask": mask,
"position_offsets": offsets,
}
)
# Check if decode_special_tokens is supported
# TODO: Remove when a new version of ExllamaV2 is released
if decode_special_tokens:
begin_stream_signature = signature(self.generator.begin_stream_ex)
try:
_bound_vars = begin_stream_signature.bind_partial(
decode_special_tokens=True
)
begin_stream_args["decode_special_tokens"] = decode_special_tokens
except TypeError:
logger.warning(
"skip_special_tokens is not supported by the currently "
"installed ExLlamaV2 version."
)
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
@ -920,6 +947,7 @@ class ExllamaV2Container:
eos_token_id=eos_tokens,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
skip_special_tokens=not decode_special_tokens,
speculative_ngram=self.generator.speculative_ngram,
logprobs=request_logprobs,
stop_conditions=stop_conditions,
@ -948,31 +976,10 @@ class ExllamaV2Container:
active_ids = ids[:, max(0, overflow) :]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
# Split for exllama versions that have CFG
if self.use_cfg:
self.generator.begin_stream_ex(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
input_mask=mask,
position_offsets=offsets,
return_probabilities=request_logprobs > 0,
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,
abort_event=abort_event,
)
else:
self.generator.begin_stream_ex(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
return_probabilities=request_logprobs > 0,
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,
abort_event=abort_event,
)
# Kick off the streaming generation
self.generator.begin_stream_ex(
active_ids, gen_settings, **begin_stream_args
)
# Reset offsets for subsequent passes if the context is truncated
offsets = None

View file

@ -106,6 +106,11 @@ class BaseSamplerRequest(BaseModel):
examples=[False],
)
skip_special_tokens: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", True),
examples=[True],
)
logit_bias: Optional[Dict[int, float]] = Field(
default_factory=lambda: get_default_sampler_value("logit_bias"),
examples=[{"1": 10, "2": 50}],
@ -246,6 +251,7 @@ class BaseSamplerRequest(BaseModel):
"stop": self.stop,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"skip_special_tokens": self.skip_special_tokens,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"temperature": self.temperature,