diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index cb12a15..3408c2f 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,6 +1,5 @@ """The model container class for ExLlamaV2 models.""" -import aiofiles import asyncio import gc import math @@ -29,11 +28,7 @@ from itertools import zip_longest from loguru import logger from typing import Dict, List, Optional -from ruamel.yaml import YAML - from backends.base_model_container import BaseModelContainer -from common.health import HealthManager - from backends.exllamav2.grammar import ( ExLlamaV2Grammar, clear_grammar_func_cache, @@ -51,6 +46,7 @@ from common.gen_logging import ( log_prompt, log_response, ) +from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import ( @@ -59,7 +55,7 @@ from common.templating import ( find_template_from_model, ) from common.transformers_utils import GenerationConfig -from common.utils import coalesce, unwrap +from common.utils import calculate_rope_alpha, coalesce, unwrap class ExllamaV2Container(BaseModelContainer): @@ -244,9 +240,7 @@ class ExllamaV2Container(BaseModelContainer): base_seq_len = self.config.max_seq_len # Set the target seq len if present - target_max_seq_len = kwargs.get("max_seq_len") - if target_max_seq_len: - self.config.max_seq_len = target_max_seq_len + target_seq_len = kwargs.get("max_seq_len") # Set the rope scale self.config.scale_pos_emb = unwrap( @@ -257,10 +251,16 @@ class ExllamaV2Container(BaseModelContainer): # Automatically calculate if unset or defined as an "auto" literal. rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto") if rope_alpha == "auto": - self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len) + self.config.scale_alpha_value = calculate_rope_alpha( + base_seq_len, target_seq_len + ) else: self.config.scale_alpha_value = rope_alpha + # Set the max seq len if specified + if target_seq_len: + self.config.max_seq_len = target_seq_len + # Set max batch size to the config override self.max_batch_size = unwrap(kwargs.get("max_batch_size")) @@ -363,10 +363,11 @@ class ExllamaV2Container(BaseModelContainer): ) # Set draft rope alpha. Follows same behavior as model rope alpha. + # Use the base sequence length of the model draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto") if draft_rope_alpha == "auto": - self.draft_config.scale_alpha_value = self.calculate_rope_alpha( - self.draft_config.max_seq_len + self.draft_config.scale_alpha_value = calculate_rope_alpha( + base_seq_len, self.draft_config.max_seq_len ) else: self.draft_config.scale_alpha_value = draft_rope_alpha @@ -438,19 +439,6 @@ class ExllamaV2Container(BaseModelContainer): ) continue - def calculate_rope_alpha(self, base_seq_len): - """Calculate the rope alpha value for a given sequence length.""" - - ratio = self.config.max_seq_len / base_seq_len - - # Default to a 1 alpha if the sequence length is ever less - # than or equal to 1 - if ratio <= 1.0: - alpha = 1 - else: - alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2 - return alpha - def get_model_parameters(self): model_params = { "name": self.model_dir.name, diff --git a/common/model.py b/common/model.py index 2908fc3..a5a6272 100644 --- a/common/model.py +++ b/common/model.py @@ -112,8 +112,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): kwargs = {**config.model_defaults, **kwargs} kwargs = await apply_inline_overrides(model_path, **kwargs) - print(kwargs) - # Create a new container new_container = await ExllamaV2Container.create( model_path.resolve(), False, **kwargs diff --git a/common/utils.py b/common/utils.py index 52077dd..c56966d 100644 --- a/common/utils.py +++ b/common/utils.py @@ -87,3 +87,23 @@ def unwrap_optional_type(type_hint) -> Type: return type_hint + +def calculate_rope_alpha(base_seq_len: int, target_seq_len: int): + """ + Converts a given max sequence length to a rope alpha value. + + Args: + base_seq_len: The model's configured sequence length. + target_seq_len: The user-specified max sequence length. + """ + + # Get the ratio of the model's max sequence length to the target + ratio = base_seq_len / target_seq_len + + # Default to a 1 alpha if the sequence length is ever less + # than or equal to 1 + if ratio <= 1.0: + alpha = 1 + else: + alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2 + return alpha