Model: Add override base seq len

Some models (such as mistral and mixtral) set their base sequence
length to 32k due to assumptions of support for sliding window
attention.

Therefore, add this parameter to override the base sequence length
of a model which helps with auto-calculation of rope alpha.

If auto-calculation of rope alpha isn't being used, the max_seq_len
parameter works fine as is.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-20 00:43:19 -05:00
parent 5368ed7b64
commit ab10b263fd
3 changed files with 19 additions and 7 deletions

View file

@ -34,8 +34,9 @@ class DraftModelLoadRequest(BaseModel):
class ModelLoadRequest(BaseModel):
name: str
# Max seq len is defaulted when loading the model itself
max_seq_len: Optional[int] = None
# Max seq len is fetched from config.json of the model by default
max_seq_len: Optional[int] = Field(description = "Leave this blank to use the model's base sequence length", default = None)
override_base_seq_len: Optional[int] = Field(description = "Overrides the model's base sequence length. Leave blank if unsure", default = None)
gpu_split_auto: Optional[bool] = True
gpu_split: Optional[List[float]] = Field(default_factory=list)
rope_scale: Optional[float] = 1.0

View file

@ -37,9 +37,15 @@ model:
# The below parameters apply only if model_name is set
# Override maximum model context length (default: None)
# Max sequence length (default: None)
# Fetched from the model's base sequence length in config.json by default
max_seq_len:
# Overrides base model context length (default: None)
# WARNING: Don't set this unless you know what you're doing!
# Only use this if the model's base sequence length in config.json is incorrect (ex. Mistral/Mixtral models)
override_base_seq_len:
# Automatically allocate resources to GPUs (default: True)
gpu_split_auto: True

View file

@ -85,14 +85,19 @@ class ModelContainer:
self.config.max_seq_len = 4096
self.config.prepare()
# Then override the max_seq_len if present
override_max_seq_len = kwargs.get("max_seq_len")
if override_max_seq_len:
self.config.max_seq_len = kwargs.get("max_seq_len")
# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
self.config.max_seq_len = override_base_seq_len
# Grab the base model's sequence length before overrides for rope calculations
base_seq_len = self.config.max_seq_len
# Set the target seq len if present
target_max_seq_len = kwargs.get("max_seq_len")
if target_max_seq_len:
self.config.max_seq_len = target_max_seq_len
# Set the rope scale
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)