Model: Auto-round cache size on init
Cache size must be a multiple of 256 to work properly in ExllamaV2. Take the config value and set the cache size to one multiple above the remainder of the cache size divided by 256. This is because cache size can never be lower than max_seq_len. If max_seq_len isn't a multiple of 256, this method will never yield a number that's lower than max_seq_len since it's no longer a source of truth. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
3dcae8b023
commit
116cf56c87
1 changed files with 25 additions and 6 deletions
|
|
@ -193,14 +193,33 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Set k/v cache size
|
||||
self.cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
|
||||
if self.cache_size < self.config.max_seq_len:
|
||||
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
|
||||
|
||||
if cache_size < self.config.max_seq_len:
|
||||
logger.warning(
|
||||
"Your specified cache_size is smaller than your "
|
||||
"desired context length. \n"
|
||||
"Defaulting cache_size to max_seq_len."
|
||||
f"The given cache_size ({cache_size}) is smaller than the "
|
||||
"desired context length.\n"
|
||||
"Overriding cache_size to max_seq_len. "
|
||||
)
|
||||
self.cache_size = self.config.max_seq_len
|
||||
|
||||
cache_size = self.config.max_seq_len
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given cache size ({cache_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding cache_size with an overestimated value of "
|
||||
f"{rounded_cache_size} tokens."
|
||||
)
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
self.cache_size = cache_size
|
||||
|
||||
# Enable fasttensors loading if present
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue