Model: Remove backwards compatability hacks

Now that exllamav2 is required to be the latest, don't add attribute
checks unless the feature is not in the release build.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-02-02 23:40:10 -05:00
parent 6eeb62b82c
commit f1ea15d77e

View file

@ -139,23 +139,10 @@ class ExllamaV2Container:
)
# Enable CFG if present
use_cfg = unwrap(kwargs.get("use_cfg"), False)
if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
self.use_cfg = use_cfg
elif use_cfg:
logger.warning(
"CFG is not supported by the currently installed ExLlamaV2 version."
)
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
# Enable fasttensors loading if present
use_fasttensors = unwrap(kwargs.get("fasttensors"), False)
if hasattr(ExLlamaV2Config, "fasttensors"):
self.config.fasttensors = use_fasttensors
elif use_fasttensors:
logger.warning(
"fasttensors is not supported by "
"the currently installed ExllamaV2 version."
)
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
# Turn off flash attention if CFG is on
# Workaround until batched FA2 is fixed in exllamav2 upstream
@ -189,13 +176,7 @@ class ExllamaV2Container:
# Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token")
if num_experts_override:
if hasattr(self.config, "num_experts_per_token"):
self.config.num_experts_per_token = num_experts_override
else:
logger.warning(
"MoE experts per token override is not "
"supported by the current ExLlamaV2 version."
)
self.config.num_experts_per_token = kwargs.get("num_experts_per_token")
chunk_size = min(
unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len
@ -410,6 +391,10 @@ class ExllamaV2Container:
self.draft_cache,
)
# Always return logprobs and logits
self.generator.return_probabilities = True
self.generator.return_logits = True
logger.info("Model successfully loaded.")
def unload(self, loras_only: bool = False):
@ -470,21 +455,7 @@ class ExllamaV2Container:
def check_unsupported_settings(self, **kwargs):
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
if kwargs.get("max_temp") > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "max_temp"
):
logger.warning(
"DynaTemp parameters are not supported by the currently "
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("smoothing_factor"), 0.0)) > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "smoothing_factor"
):
logger.warning(
"Smoothing factor is not supported by the currently "
"installed ExLlamaV2 version."
)
pass
def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
@ -566,20 +537,27 @@ class ExllamaV2Container:
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
# DynaTemp settings
if hasattr(gen_settings, "max_temp"):
max_temp = unwrap(kwargs.get("max_temp"), 0.0)
min_temp = unwrap(kwargs.get("min_temp"), 0.0)
if max_temp < min_temp or (
0 not in {min_temp, max_temp} and max_temp == min_temp
):
logger.warning(
"Max temp is less than or equal to min temp, skipping DynaTemp."
)
max_temp = unwrap(kwargs.get("max_temp"), 0.0)
min_temp = unwrap(kwargs.get("min_temp"), 0.0)
if max_temp > min_temp:
gen_settings.max_temp = max_temp
gen_settings.min_temp = min_temp
gen_settings.temp_exponent = kwargs.get("temp_exponent")
gen_settings.temp_exponent = unwrap(kwargs.get("temp_exponent"), 1.0)
else:
# Force to default values
gen_settings.max_temp = 0.0
gen_settings.min_temp = 0.0
gen_settings.temp_exponent = 1.0
# Warn if max/min temp values are > 0
# and if they're less than or equal to each other
if max_temp < min_temp or (
0 not in {min_temp, max_temp} and max_temp == min_temp
):
logger.warning(
"Max temp is less than or equal to min temp, skipping DynaTemp."
)
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
@ -760,7 +738,7 @@ class ExllamaV2Container:
gen_settings.token_repetition_range = generated_tokens
# Generate
chunk, eos, tokens, _, *extra_parts = self.generator.stream()
chunk, eos, tokens, _, _ = self.generator.stream()
if token_healing:
# Extract healed token