OAI: Add skip_special_tokens parameter
Allows the ability to decode special tokens if the user wishes. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
67f061859d
commit
9f93505bc1
2 changed files with 48 additions and 35 deletions
|
|
@ -17,6 +17,7 @@ from exllamav2 import (
|
|||
ExLlamaV2Lora,
|
||||
)
|
||||
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
||||
from inspect import signature
|
||||
from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
|
@ -663,16 +664,6 @@ class ExllamaV2Container:
|
|||
def check_unsupported_settings(self, **kwargs):
|
||||
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
|
||||
|
||||
if unwrap(kwargs.get("speculative_ngram"), False) and not hasattr(
|
||||
ExLlamaV2StreamingGenerator, "speculative_ngram"
|
||||
):
|
||||
logger.warning(
|
||||
"Speculative ngram is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
kwargs.pop("speculative_ngram")
|
||||
|
||||
return kwargs
|
||||
|
||||
async def generate_gen(
|
||||
|
|
@ -907,6 +898,42 @@ class ExllamaV2Container:
|
|||
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
|
||||
)
|
||||
|
||||
# This is an inverse of skip_special_tokens
|
||||
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
|
||||
|
||||
begin_stream_args = {
|
||||
"token_healing": token_healing,
|
||||
"loras": self.active_loras,
|
||||
"return_probabilities": request_logprobs > 0,
|
||||
"return_top_tokens": request_logprobs,
|
||||
"return_logits": request_logprobs > 0,
|
||||
"abort_event": abort_event,
|
||||
}
|
||||
|
||||
if self.use_cfg:
|
||||
begin_stream_args.update(
|
||||
{
|
||||
"input_mask": mask,
|
||||
"position_offsets": offsets,
|
||||
}
|
||||
)
|
||||
|
||||
# Check if decode_special_tokens is supported
|
||||
# TODO: Remove when a new version of ExllamaV2 is released
|
||||
if decode_special_tokens:
|
||||
begin_stream_signature = signature(self.generator.begin_stream_ex)
|
||||
|
||||
try:
|
||||
_bound_vars = begin_stream_signature.bind_partial(
|
||||
decode_special_tokens=True
|
||||
)
|
||||
begin_stream_args["decode_special_tokens"] = decode_special_tokens
|
||||
except TypeError:
|
||||
logger.warning(
|
||||
"skip_special_tokens is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
|
|
@ -920,6 +947,7 @@ class ExllamaV2Container:
|
|||
eos_token_id=eos_tokens,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
skip_special_tokens=not decode_special_tokens,
|
||||
speculative_ngram=self.generator.speculative_ngram,
|
||||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
|
|
@ -948,31 +976,10 @@ class ExllamaV2Container:
|
|||
active_ids = ids[:, max(0, overflow) :]
|
||||
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
|
||||
|
||||
# Split for exllama versions that have CFG
|
||||
if self.use_cfg:
|
||||
self.generator.begin_stream_ex(
|
||||
active_ids,
|
||||
gen_settings,
|
||||
token_healing=token_healing,
|
||||
loras=self.active_loras,
|
||||
input_mask=mask,
|
||||
position_offsets=offsets,
|
||||
return_probabilities=request_logprobs > 0,
|
||||
return_top_tokens=request_logprobs,
|
||||
return_logits=request_logprobs > 0,
|
||||
abort_event=abort_event,
|
||||
)
|
||||
else:
|
||||
self.generator.begin_stream_ex(
|
||||
active_ids,
|
||||
gen_settings,
|
||||
token_healing=token_healing,
|
||||
loras=self.active_loras,
|
||||
return_probabilities=request_logprobs > 0,
|
||||
return_top_tokens=request_logprobs,
|
||||
return_logits=request_logprobs > 0,
|
||||
abort_event=abort_event,
|
||||
)
|
||||
# Kick off the streaming generation
|
||||
self.generator.begin_stream_ex(
|
||||
active_ids, gen_settings, **begin_stream_args
|
||||
)
|
||||
|
||||
# Reset offsets for subsequent passes if the context is truncated
|
||||
offsets = None
|
||||
|
|
|
|||
|
|
@ -106,6 +106,11 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[False],
|
||||
)
|
||||
|
||||
skip_special_tokens: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("ban_eos_token", True),
|
||||
examples=[True],
|
||||
)
|
||||
|
||||
logit_bias: Optional[Dict[int, float]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("logit_bias"),
|
||||
examples=[{"1": 10, "2": 50}],
|
||||
|
|
@ -246,6 +251,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"stop": self.stop,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"skip_special_tokens": self.skip_special_tokens,
|
||||
"token_healing": self.token_healing,
|
||||
"logit_bias": self.logit_bias,
|
||||
"temperature": self.temperature,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue