Model: Add support for num_experts_by_token

New parameter that's safe to edit in exllamav2 v0.0.11. Only recommended
for people who know what they're doing.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-17 18:03:01 -05:00
parent 70fbee3edd
commit ad8807a830
3 changed files with 16 additions and 1 deletions

View file

@ -7,8 +7,9 @@ class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = 4096
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
prompt_template: Optional[str] = None
cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
draft: Optional['ModelCard'] = None
class ModelCard(BaseModel):
@ -40,6 +41,7 @@ class ModelLoadRequest(BaseModel):
# low_mem: Optional[bool] = False
cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
draft: Optional[DraftModelLoadRequest] = None
class ModelLoadResponse(BaseModel):

View file

@ -60,6 +60,11 @@ model:
# NOTE: Only works with chat completion message lists!
prompt_template:
# Number of experts to use per token. Loads from the model's config.json if not specified (default: None)
# WARNING: Don't set this unless you know what you're doing!
# NOTE: For MoE models (ex. Mixtral) only!
num_experts_per_token:
# Options for draft models (speculative decoding). This will use more VRAM!
draft:
# Overrides the directory to look for draft (default: models)

View file

@ -105,6 +105,14 @@ class ModelContainer:
# Set prompt template override if provided
self.prompt_template = kwargs.get("prompt_template")
# Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token")
if num_experts_override:
if hasattr(self.config, "num_experts_per_token"):
self.config.num_experts_per_token = num_experts_override
else:
print(" !! Warning: Currently installed ExLlamaV2 does not support overriding MoE experts")
chunk_size = min(unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len)
self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2