Model: Remove backwards compatability hacks
Now that exllamav2 is required to be the latest, don't add attribute checks unless the feature is not in the release build. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
6eeb62b82c
commit
f1ea15d77e
1 changed files with 27 additions and 49 deletions
|
|
@ -139,23 +139,10 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Enable CFG if present
|
||||
use_cfg = unwrap(kwargs.get("use_cfg"), False)
|
||||
if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
|
||||
self.use_cfg = use_cfg
|
||||
elif use_cfg:
|
||||
logger.warning(
|
||||
"CFG is not supported by the currently installed ExLlamaV2 version."
|
||||
)
|
||||
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
|
||||
|
||||
# Enable fasttensors loading if present
|
||||
use_fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
if hasattr(ExLlamaV2Config, "fasttensors"):
|
||||
self.config.fasttensors = use_fasttensors
|
||||
elif use_fasttensors:
|
||||
logger.warning(
|
||||
"fasttensors is not supported by "
|
||||
"the currently installed ExllamaV2 version."
|
||||
)
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
||||
# Turn off flash attention if CFG is on
|
||||
# Workaround until batched FA2 is fixed in exllamav2 upstream
|
||||
|
|
@ -189,13 +176,7 @@ class ExllamaV2Container:
|
|||
# Set num of experts per token if provided
|
||||
num_experts_override = kwargs.get("num_experts_per_token")
|
||||
if num_experts_override:
|
||||
if hasattr(self.config, "num_experts_per_token"):
|
||||
self.config.num_experts_per_token = num_experts_override
|
||||
else:
|
||||
logger.warning(
|
||||
"MoE experts per token override is not "
|
||||
"supported by the current ExLlamaV2 version."
|
||||
)
|
||||
self.config.num_experts_per_token = kwargs.get("num_experts_per_token")
|
||||
|
||||
chunk_size = min(
|
||||
unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len
|
||||
|
|
@ -410,6 +391,10 @@ class ExllamaV2Container:
|
|||
self.draft_cache,
|
||||
)
|
||||
|
||||
# Always return logprobs and logits
|
||||
self.generator.return_probabilities = True
|
||||
self.generator.return_logits = True
|
||||
|
||||
logger.info("Model successfully loaded.")
|
||||
|
||||
def unload(self, loras_only: bool = False):
|
||||
|
|
@ -470,21 +455,7 @@ class ExllamaV2Container:
|
|||
def check_unsupported_settings(self, **kwargs):
|
||||
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
|
||||
|
||||
if kwargs.get("max_temp") > 0.0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "max_temp"
|
||||
):
|
||||
logger.warning(
|
||||
"DynaTemp parameters are not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("smoothing_factor"), 0.0)) > 0.0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "smoothing_factor"
|
||||
):
|
||||
logger.warning(
|
||||
"Smoothing factor is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
pass
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
|
|
@ -566,20 +537,27 @@ class ExllamaV2Container:
|
|||
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
|
||||
|
||||
# DynaTemp settings
|
||||
if hasattr(gen_settings, "max_temp"):
|
||||
max_temp = unwrap(kwargs.get("max_temp"), 0.0)
|
||||
min_temp = unwrap(kwargs.get("min_temp"), 0.0)
|
||||
|
||||
if max_temp < min_temp or (
|
||||
0 not in {min_temp, max_temp} and max_temp == min_temp
|
||||
):
|
||||
logger.warning(
|
||||
"Max temp is less than or equal to min temp, skipping DynaTemp."
|
||||
)
|
||||
max_temp = unwrap(kwargs.get("max_temp"), 0.0)
|
||||
min_temp = unwrap(kwargs.get("min_temp"), 0.0)
|
||||
|
||||
if max_temp > min_temp:
|
||||
gen_settings.max_temp = max_temp
|
||||
gen_settings.min_temp = min_temp
|
||||
gen_settings.temp_exponent = kwargs.get("temp_exponent")
|
||||
gen_settings.temp_exponent = unwrap(kwargs.get("temp_exponent"), 1.0)
|
||||
else:
|
||||
# Force to default values
|
||||
gen_settings.max_temp = 0.0
|
||||
gen_settings.min_temp = 0.0
|
||||
gen_settings.temp_exponent = 1.0
|
||||
|
||||
# Warn if max/min temp values are > 0
|
||||
# and if they're less than or equal to each other
|
||||
if max_temp < min_temp or (
|
||||
0 not in {min_temp, max_temp} and max_temp == min_temp
|
||||
):
|
||||
logger.warning(
|
||||
"Max temp is less than or equal to min temp, skipping DynaTemp."
|
||||
)
|
||||
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
||||
|
|
@ -760,7 +738,7 @@ class ExllamaV2Container:
|
|||
gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
# Generate
|
||||
chunk, eos, tokens, _, *extra_parts = self.generator.stream()
|
||||
chunk, eos, tokens, _, _ = self.generator.stream()
|
||||
|
||||
if token_healing:
|
||||
# Extract healed token
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue