diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 7ffc9ce..64f571f 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -193,14 +193,33 @@ class ExllamaV2Container: ) # Set k/v cache size - self.cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) - if self.cache_size < self.config.max_seq_len: + cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) + + if cache_size < self.config.max_seq_len: logger.warning( - "Your specified cache_size is smaller than your " - "desired context length. \n" - "Defaulting cache_size to max_seq_len." + f"The given cache_size ({cache_size}) is smaller than the " + "desired context length.\n" + "Overriding cache_size to max_seq_len. " ) - self.cache_size = self.config.max_seq_len + + cache_size = self.config.max_seq_len + + # Enforce a multiple of 256 for cache size + # Overestimate to ensure that the cache isn't below max_seq_len + cache_remainder = cache_size % 256 + if cache_remainder != 0: + rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1)) + + logger.warning( + f"The given cache size ({cache_size}) is " + "not a multiple of 256.\n" + "Overriding cache_size with an overestimated value of " + f"{rounded_cache_size} tokens." + ) + + cache_size = rounded_cache_size + + self.cache_size = cache_size # Enable fasttensors loading if present self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)