Model: Move calculate_rope_alpha from backend

Makes more sense to use as a utility function. Also clarify how the
vars are set.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-20 18:20:19 -04:00
parent 027ffce05d
commit 8e238fa8f6
3 changed files with 33 additions and 27 deletions

View file

@ -1,6 +1,5 @@
"""The model container class for ExLlamaV2 models."""
import aiofiles
import asyncio
import gc
import math
@ -29,11 +28,7 @@ from itertools import zip_longest
from loguru import logger
from typing import Dict, List, Optional
from ruamel.yaml import YAML
from backends.base_model_container import BaseModelContainer
from common.health import HealthManager
from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
clear_grammar_func_cache,
@ -51,6 +46,7 @@ from common.gen_logging import (
log_prompt,
log_response,
)
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import (
@ -59,7 +55,7 @@ from common.templating import (
find_template_from_model,
)
from common.transformers_utils import GenerationConfig
from common.utils import coalesce, unwrap
from common.utils import calculate_rope_alpha, coalesce, unwrap
class ExllamaV2Container(BaseModelContainer):
@ -244,9 +240,7 @@ class ExllamaV2Container(BaseModelContainer):
base_seq_len = self.config.max_seq_len
# Set the target seq len if present
target_max_seq_len = kwargs.get("max_seq_len")
if target_max_seq_len:
self.config.max_seq_len = target_max_seq_len
target_seq_len = kwargs.get("max_seq_len")
# Set the rope scale
self.config.scale_pos_emb = unwrap(
@ -257,10 +251,16 @@ class ExllamaV2Container(BaseModelContainer):
# Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
if rope_alpha == "auto":
self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len)
self.config.scale_alpha_value = calculate_rope_alpha(
base_seq_len, target_seq_len
)
else:
self.config.scale_alpha_value = rope_alpha
# Set the max seq len if specified
if target_seq_len:
self.config.max_seq_len = target_seq_len
# Set max batch size to the config override
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))
@ -363,10 +363,11 @@ class ExllamaV2Container(BaseModelContainer):
)
# Set draft rope alpha. Follows same behavior as model rope alpha.
# Use the base sequence length of the model
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
if draft_rope_alpha == "auto":
self.draft_config.scale_alpha_value = self.calculate_rope_alpha(
self.draft_config.max_seq_len
self.draft_config.scale_alpha_value = calculate_rope_alpha(
base_seq_len, self.draft_config.max_seq_len
)
else:
self.draft_config.scale_alpha_value = draft_rope_alpha
@ -438,19 +439,6 @@ class ExllamaV2Container(BaseModelContainer):
)
continue
def calculate_rope_alpha(self, base_seq_len):
"""Calculate the rope alpha value for a given sequence length."""
ratio = self.config.max_seq_len / base_seq_len
# Default to a 1 alpha if the sequence length is ever less
# than or equal to 1
if ratio <= 1.0:
alpha = 1
else:
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
return alpha
def get_model_parameters(self):
model_params = {
"name": self.model_dir.name,

View file

@ -112,8 +112,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
kwargs = {**config.model_defaults, **kwargs}
kwargs = await apply_inline_overrides(model_path, **kwargs)
print(kwargs)
# Create a new container
new_container = await ExllamaV2Container.create(
model_path.resolve(), False, **kwargs

View file

@ -87,3 +87,23 @@ def unwrap_optional_type(type_hint) -> Type:
return type_hint
def calculate_rope_alpha(base_seq_len: int, target_seq_len: int):
"""
Converts a given max sequence length to a rope alpha value.
Args:
base_seq_len: The model's configured sequence length.
target_seq_len: The user-specified max sequence length.
"""
# Get the ratio of the model's max sequence length to the target
ratio = base_seq_len / target_seq_len
# Default to a 1 alpha if the sequence length is ever less
# than or equal to 1
if ratio <= 1.0:
alpha = 1
else:
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
return alpha