Model: Add support for Q4 cache

Add this in addition to 8bit cache and 16bit cache. Passing "Q4" with
the cache_mode request parameter will set this on model load.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-06 00:57:41 -05:00
parent 0b25c208d6
commit 9a007c4707
2 changed files with 25 additions and 4 deletions

View file

@ -27,6 +27,14 @@ from common.templating import (
from common.utils import coalesce, unwrap
from common.logger import init_logger
# Optional imports for dependencies
try:
from exllamav2 import ExLlamaV2Cache_Q4
_exllamav2_has_int4 = True
except ImportError:
_exllamav2_has_int4 = False
logger = init_logger(__name__)
@ -46,7 +54,7 @@ class ExllamaV2Container:
active_loras: List[ExLlamaV2Lora] = []
# Internal config vars
cache_fp8: bool = False
cache_mode: str = "FP16"
use_cfg: bool = False
# GPU split vars
@ -109,7 +117,15 @@ class ExllamaV2Container:
self.quiet = quiet
self.cache_fp8 = "cache_mode" in kwargs and kwargs["cache_mode"] == "FP8"
cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
if cache_mode == "Q4" and not _exllamav2_has_int4:
logger.warning(
"Q4 cache is not available "
"in the currently installed ExllamaV2 version. Using FP16."
)
cache_mode = "FP16"
self.cache_mode = cache_mode
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
@ -398,7 +414,12 @@ class ExllamaV2Container:
yield value
batch_size = 2 if self.use_cfg else 1
if self.cache_fp8:
if self.cache_mode == "Q4" and _exllamav2_has_int4:
self.cache = ExLlamaV2Cache_Q4(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)
elif self.cache_mode == "FP8":
self.cache = ExLlamaV2Cache_8bit(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)

View file

@ -149,7 +149,7 @@ async def get_current_model():
rope_scale=MODEL_CONTAINER.config.scale_pos_emb,
rope_alpha=MODEL_CONTAINER.config.scale_alpha_value,
max_seq_len=MODEL_CONTAINER.config.max_seq_len,
cache_mode="FP8" if MODEL_CONTAINER.cache_fp8 else "FP16",
cache_mode=MODEL_CONTAINER.cache_mode,
prompt_template=prompt_template.name if prompt_template else None,
num_experts_per_token=MODEL_CONTAINER.config.num_experts_per_token,
use_cfg=MODEL_CONTAINER.use_cfg,