Model: Fix max seq len handling

Previously, the max sequence length was overriden by the user's
config and never took the model's config.json into account.

Now, set the default to 4096, but include config.prepare when
selecting the max sequence length. The yaml and API request
now serve as overrides rather than parameters.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-19 23:37:52 -05:00
parent d3246747c0
commit ce2602df9a
3 changed files with 17 additions and 6 deletions

View file

@ -4,7 +4,8 @@ from typing import List, Optional
from gen_logging import LogConfig
class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = 4096
# Safe to do this since it's guaranteed to fetch a max seq len from model_container
max_seq_len: Optional[int] = None
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
cache_mode: Optional[str] = "FP16"
@ -32,7 +33,9 @@ class DraftModelLoadRequest(BaseModel):
# TODO: Unify this with ModelCardParams
class ModelLoadRequest(BaseModel):
name: str
max_seq_len: Optional[int] = 4096
# Max seq len is defaulted when loading the model itself
max_seq_len: Optional[int] = None
gpu_split_auto: Optional[bool] = True
gpu_split: Optional[List[float]] = Field(default_factory=list)
rope_scale: Optional[float] = 1.0

View file

@ -37,8 +37,8 @@ model:
# The below parameters apply only if model_name is set
# Maximum model context length (default: 4096)
max_seq_len: 4096
# Override maximum model context length (default: None)
max_seq_len:
# Automatically allocate resources to GPUs (default: True)
gpu_split_auto: True

View file

@ -79,13 +79,21 @@ class ModelContainer:
self.config = ExLlamaV2Config()
self.config.model_dir = str(model_directory.resolve())
# Make the max seq len 4096 before preparing the config
# This is a better default than 2038
self.config.max_seq_len = 4096
self.config.prepare()
# Then override the max_seq_len if present
override_max_seq_len = kwargs.get("max_seq_len")
if override_max_seq_len:
self.config.max_seq_len = kwargs.get("max_seq_len")
# Grab the base model's sequence length before overrides for rope calculations
base_seq_len = self.config.max_seq_len
# Then override the max_seq_len if present
self.config.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Set the rope scale
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)
# Automatically calculate rope alpha