Model: Add override base seq len
Some models (such as mistral and mixtral) set their base sequence length to 32k due to assumptions of support for sliding window attention. Therefore, add this parameter to override the base sequence length of a model which helps with auto-calculation of rope alpha. If auto-calculation of rope alpha isn't being used, the max_seq_len parameter works fine as is. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
5368ed7b64
commit
ab10b263fd
3 changed files with 19 additions and 7 deletions
|
|
@ -34,8 +34,9 @@ class DraftModelLoadRequest(BaseModel):
|
|||
class ModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
# Max seq len is defaulted when loading the model itself
|
||||
max_seq_len: Optional[int] = None
|
||||
# Max seq len is fetched from config.json of the model by default
|
||||
max_seq_len: Optional[int] = Field(description = "Leave this blank to use the model's base sequence length", default = None)
|
||||
override_base_seq_len: Optional[int] = Field(description = "Overrides the model's base sequence length. Leave blank if unsure", default = None)
|
||||
gpu_split_auto: Optional[bool] = True
|
||||
gpu_split: Optional[List[float]] = Field(default_factory=list)
|
||||
rope_scale: Optional[float] = 1.0
|
||||
|
|
|
|||
|
|
@ -37,9 +37,15 @@ model:
|
|||
|
||||
# The below parameters apply only if model_name is set
|
||||
|
||||
# Override maximum model context length (default: None)
|
||||
# Max sequence length (default: None)
|
||||
# Fetched from the model's base sequence length in config.json by default
|
||||
max_seq_len:
|
||||
|
||||
# Overrides base model context length (default: None)
|
||||
# WARNING: Don't set this unless you know what you're doing!
|
||||
# Only use this if the model's base sequence length in config.json is incorrect (ex. Mistral/Mixtral models)
|
||||
override_base_seq_len:
|
||||
|
||||
# Automatically allocate resources to GPUs (default: True)
|
||||
gpu_split_auto: True
|
||||
|
||||
|
|
|
|||
13
model.py
13
model.py
|
|
@ -85,14 +85,19 @@ class ModelContainer:
|
|||
self.config.max_seq_len = 4096
|
||||
self.config.prepare()
|
||||
|
||||
# Then override the max_seq_len if present
|
||||
override_max_seq_len = kwargs.get("max_seq_len")
|
||||
if override_max_seq_len:
|
||||
self.config.max_seq_len = kwargs.get("max_seq_len")
|
||||
# Then override the base_seq_len if present
|
||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||
if override_base_seq_len:
|
||||
self.config.max_seq_len = override_base_seq_len
|
||||
|
||||
# Grab the base model's sequence length before overrides for rope calculations
|
||||
base_seq_len = self.config.max_seq_len
|
||||
|
||||
# Set the target seq len if present
|
||||
target_max_seq_len = kwargs.get("max_seq_len")
|
||||
if target_max_seq_len:
|
||||
self.config.max_seq_len = target_max_seq_len
|
||||
|
||||
# Set the rope scale
|
||||
self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue