tabbyAPI-ollama/backends/exllamav2/model.py
kingbri 2913ce29fc API: Add timings to usage stats
It's useful for the client to know what the T/s and total time for
generation are per-request.

Works with both completions and chat completions.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
2025-06-17 22:54:51 -04:00

1480 lines
52 KiB
Python

"""The model container class for ExLlamaV2 models."""
import asyncio
import gc
import math
import pathlib
import torch
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2CacheBase,
ExLlamaV2Cache,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
ExLlamaV2Cache_Q8,
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
ExLlamaV2Lora,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2Sampler,
ExLlamaV2DynamicGeneratorAsync,
ExLlamaV2DynamicJobAsync,
)
from itertools import zip_longest
from loguru import logger
from typing import Dict, List, Optional
from backends.base_model_container import BaseModelContainer
from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
clear_grammar_func_cache,
)
from backends.exllamav2.utils import exllama_disabled_flash_attn
from backends.exllamav2.vision import clear_image_embedding_cache
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
log_metrics,
log_prompt,
log_response,
)
from common.hardware import hardware_supports_flash_attn
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.optional_dependencies import check_package_version
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import HFModel
from common.utils import calculate_rope_alpha, coalesce, unwrap
from endpoints.core.types.model import ModelCard, ModelCardParameters
class ExllamaV2Container(BaseModelContainer):
"""The model container class for ExLlamaV2 models."""
# Model directories
model_dir: pathlib.Path = pathlib.Path("models")
draft_model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
# HF model instance
hf_model: HFModel
# Exl2 vars
config: Optional[ExLlamaV2Config] = None
model: Optional[ExLlamaV2] = None
cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
prompt_template: Optional[PromptTemplate] = None
paged: bool = True
# Draft model vars
use_draft_model: bool = False
draft_config: Optional[ExLlamaV2Config] = None
draft_model: Optional[ExLlamaV2] = None
draft_cache: Optional[ExLlamaV2Cache] = None
# Internal config vars
cache_size: int = None
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
max_batch_size: Optional[int] = None
# GPU split vars
gpu_split: List[float] = []
draft_gpu_split: List[float] = []
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False
# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None
# Load synchronization
active_job_ids: Dict[str, Optional[ExLlamaV2DynamicJobAsync]] = {}
loaded: bool = False
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
@classmethod
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
"""
Primary asynchronous initializer for model container.
Kwargs are located in config_sample.yml
"""
# Create a new instance as a "fake self"
self = cls()
# Make sure ExllamaV2 is up to date
check_package_version("exllamav2", "0.3.1")
# Initialize config
self.config = ExLlamaV2Config()
self.model_dir = model_directory
self.config.model_dir = str(model_directory.resolve())
self.hf_model = hf_model
# Make the max seq len 4096 before preparing the config
# This is a better default than 2048
self.config.max_seq_len = 4096
self.config.prepare()
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()
# Set vision state and error if vision isn't supported on the current model
self.use_vision = unwrap(kwargs.get("vision"), False)
if self.use_vision and not self.config.vision_model_type:
raise ValueError(
"The provided model does not have vision capabilities that are "
"supported by ExllamaV2. "
"Please reload with vision disabled."
)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
self.use_draft_model = draft_args and draft_model_name
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
self.use_draft_model = False
if self.use_draft_model:
self.draft_config = ExLlamaV2Config()
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name
self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
self.draft_model_dir = draft_model_path
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
# MARK: User configuration
# Get cache mode
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
# Catch exllamav3 cache_mode
if self.cache_mode != "FP16" and not self.cache_mode.startswith("Q"):
logger.warning(
f"Provided cache mode '{self.cache_mode}' is not a "
"valid choice for exllamav2, please check your settings. "
"Defaulting to FP16."
)
self.cache_mode = "FP16"
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
use_tp = unwrap(kwargs.get("tensor_parallel"), False)
gpu_split = unwrap(kwargs.get("gpu_split"), [])
gpu_device_list = list(range(0, gpu_count))
# Set GPU split options
if gpu_count == 1:
self.gpu_split_auto = False
logger.info("Disabling GPU split because one GPU is in use.")
else:
# Set tensor parallel
if use_tp:
self.use_tp = True
# TP has its own autosplit loader
self.gpu_split_auto = False
# Enable manual GPU split if provided
if gpu_split:
self.gpu_split_auto = False
self.gpu_split = gpu_split
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.gpu_split)
if memory > 0
]
elif gpu_split_auto and not self.use_tp:
# Otherwise fallback to autosplit settings
self.gpu_split_auto = gpu_split_auto
autosplit_reserve_megabytes = unwrap(
kwargs.get("autosplit_reserve"), [96]
)
# Reserve VRAM for each GPU
self.autosplit_reserve = [
int(math.ceil(value * 1024**2))
for value in autosplit_reserve_megabytes
]
# Change the GPU device list only if gpu_split's list is too small
# This allows for an uneven list specification
if self.draft_gpu_split and len(self.draft_gpu_split) > len(self.gpu_split):
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.draft_gpu_split)
if memory > 0
]
# Hardcode max output length to 16
self.config.max_output_len = 16
# Grab the base model's sequence length before overrides for
# rope calculations
base_seq_len = hf_model.hf_config.max_position_embeddings
# Set the target seq len if present
target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Set the rope scale
self.config.scale_pos_emb = unwrap(
kwargs.get("rope_scale"), self.config.scale_pos_emb
)
# Sets rope alpha value.
# Utilize the model's max_position_embeddings as a base value
# Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
if rope_alpha == "auto":
self.config.scale_alpha_value = calculate_rope_alpha(
base_seq_len, target_seq_len
)
else:
self.config.scale_alpha_value = rope_alpha
# Set the max seq len if specified
if target_seq_len:
self.config.max_seq_len = target_seq_len
# Set max batch size to the config override
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))
# Check whether the user's configuration supports flash/paged attention
# Also check if exl2 has disabled flash attention
if exllama_disabled_flash_attn(
self.config.no_flash_attn
) or not hardware_supports_flash_attn(gpu_device_list):
gpu_unsupported_message = (
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
logger.warning(gpu_unsupported_message)
self.config.no_flash_attn = True
if self.draft_config:
self.draft_config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1
torch.backends.cuda.enable_flash_sdp(False)
# Set k/v cache size
# cache_size is only relevant when paged mode is enabled
if self.paged:
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
if cache_size < self.config.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is smaller than the "
"desired context length.\n"
"Overriding cache_size to max_seq_len. "
)
cache_size = self.config.max_seq_len
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
cache_remainder = cache_size % 256
if cache_remainder != 0:
rounded_cache_size = int(
256 * ((cache_size - cache_remainder) / 256 + 1)
)
logger.warning(
f"The given cache size ({cache_size}) is "
"not a multiple of 256.\n"
"Overriding cache_size with an overestimated value of "
f"{rounded_cache_size} tokens."
)
cache_size = rounded_cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.config.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
"and may be too small for requests using CFG. \n"
"Ignore this warning if you do not plan on using CFG."
)
self.cache_size = cache_size
else:
self.cache_size = self.config.max_seq_len
# Try to set prompt template
self.prompt_template = await find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Catch all for template lookup errors
if self.prompt_template:
logger.info(
f'Using template "{self.prompt_template.name}" for chat completions.'
)
else:
logger.warning(
"Chat completions are disabled because a prompt "
"template wasn't provided or auto-detected."
)
# Make sure chunk size is >= 256, keep near or below max seq len
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
chunk_size = sorted((256, user_chunk_size, self.config.max_seq_len))[1]
chunk_remainder = chunk_size % 256
if chunk_remainder != 0:
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
logger.warning(
f"The given chunk size ({chunk_size}) is "
"not a multiple of 256.\n"
"Overriding chunk_size with an overestimated value of "
f"{rounded_chunk_size} tokens."
)
chunk_size = rounded_chunk_size
self.config.max_input_len = chunk_size
self.config.max_attention_size = chunk_size**2
# Set user-configured draft model values
if self.use_draft_model:
self.draft_config.max_seq_len = self.config.max_seq_len
self.draft_config.scale_pos_emb = unwrap(
draft_args.get("draft_rope_scale"), 1.0
)
# Set draft rope alpha. Follows same behavior as model rope alpha.
# Use the max_position_embeddings of the model
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
if draft_rope_alpha == "auto":
self.draft_config.scale_alpha_value = calculate_rope_alpha(
base_seq_len, self.draft_config.max_seq_len
)
else:
self.draft_config.scale_alpha_value = draft_rope_alpha
# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
# Catch exllamav3 draft_cache_mode
if self.draft_cache_mode != "FP16" and not self.draft_cache_mode.startswith(
"Q"
):
logger.warning(
f"Provided draft cache mode '{self.draft_cache_mode}' is not a "
"valid choice for exllamav2, please check your settings. "
"Defaulting to FP16."
)
self.draft_cache_mode = "FP16"
# Edit the draft config size
if chunk_size:
self.draft_config.max_input_len = chunk_size
self.draft_config.max_attention_size = chunk_size**2
print(self.config.max_seq_len)
# Return the created instance
return self
def model_info(self):
draft_model_card: ModelCard = None
if self.draft_config:
draft_model_params = ModelCardParameters(
max_seq_len=self.draft_config.max_seq_len,
rope_scale=self.draft_config.scale_pos_emb,
rope_alpha=self.draft_config.scale_alpha_value,
cache_mode=self.draft_cache_mode,
)
draft_model_card = ModelCard(
id=self.draft_model_dir.name,
parameters=draft_model_params,
)
model_params = ModelCardParameters(
max_seq_len=self.config.max_seq_len,
cache_size=self.cache_size,
rope_scale=self.config.scale_pos_emb,
rope_alpha=self.config.scale_alpha_value,
max_batch_size=self.max_batch_size,
cache_mode=self.cache_mode,
chunk_size=self.config.max_input_len,
use_vision=self.use_vision,
draft=draft_model_card,
)
if self.prompt_template:
model_params.prompt_template = self.prompt_template.name
model_params.prompt_template_content = self.prompt_template.raw_template
model_card = ModelCard(
id=self.model_dir.name,
parameters=model_params,
)
return model_card
async def wait_for_jobs(self, skip_wait: bool = False):
"""Polling mechanism to wait for pending generation jobs."""
if not self.generator:
return
# Immediately abort all jobs if asked
if skip_wait:
logger.warning(
"Immediately terminating all jobs. "
"Clients will have their requests cancelled.\n"
)
for job in self.active_job_ids.values():
if job:
await job.cancel()
while len(self.active_job_ids) > 0:
await asyncio.sleep(0.01)
async def load(self, progress_callback=None):
"""
Load model
Args:
progress_callback (function, optional): A function to call for each
module loaded.
Prototype:
def progress(loaded_modules: int, total_modules: int)
"""
async for _ in self.load_gen(progress_callback):
pass
async def load_gen(self, progress_callback=None, **kwargs):
"""Loads a model and streams progress via a generator."""
# Indicate that model load has started
# Do this operation under the load lock's context
try:
await self.load_lock.acquire()
# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
# Streaming gen for model load progress
model_load_generator = self.load_model_sync(progress_callback)
async for value in iterate_in_threadpool(model_load_generator):
yield value
# Create async generator
await self.create_generator()
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
gc.collect()
torch.cuda.empty_cache()
# Cleanup and update model load state
self.loaded = True
logger.info("Model successfully loaded.")
finally:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
@torch.inference_mode()
def load_model_sync(self, progress_callback=None):
"""
Synchronous generator for loading.
Args:
progress_callback (function, optional): A function to call for each
module loaded.
Prototype:
def progress(loaded_modules: int, total_modules: int)
Runs under a shared inference mode context.
"""
# Reset tokenizer namespace vars and create a tokenizer
ExLlamaV2Tokenizer.unspecial_piece_to_id = {}
ExLlamaV2Tokenizer.unspecial_id_to_piece = {}
ExLlamaV2Tokenizer.extended_id_to_piece = {}
ExLlamaV2Tokenizer.extended_piece_to_id = {}
self.tokenizer = ExLlamaV2Tokenizer(self.config)
# Calculate autosplit reserve for all GPUs
gpu_count = torch.cuda.device_count()
autosplit_reserve = self.autosplit_reserve + [0] * (
gpu_count - len(self.autosplit_reserve)
)
# Load draft model if a config is present
if self.draft_config:
self.draft_model = ExLlamaV2(self.draft_config)
logger.info("Loading draft model: " + self.draft_config.model_dir)
# Draft uses the autosplit loader, so create a cache that reflects this
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
if self.draft_gpu_split:
logger.info("Loading with a manual GPU split (or a one GPU setup)")
for value in self.draft_model.load_gen(
self.draft_gpu_split,
callback_gen=progress_callback,
):
if value:
yield value
self.draft_cache = self.create_cache(
cache_class=draft_cache_class,
autosplit=False,
use_tp=False,
model=self.draft_model,
)
else:
logger.info("Loading with autosplit")
self.draft_cache = self.create_cache(
cache_class=draft_cache_class,
autosplit=True,
use_tp=False,
model=self.draft_model,
)
for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
last_id_only=True,
callback_gen=progress_callback,
):
if value:
yield value
# Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
# Load vision tower if it exists
if self.use_vision:
self.vision_model = ExLlamaV2VisionTower(self.config)
for value in self.vision_model.load_gen(callback_gen=progress_callback):
if value:
yield value
self.model = ExLlamaV2(self.config)
logger.info("Loading model: " + self.config.model_dir)
# Get class of the model cache
cache_class = self.get_cache_class(self.cache_mode)
# Load model with manual split
# Entrypoint for single GPU users
if self.use_tp:
logger.info("Loading with tensor parallel")
# GPU split must be None if the array is empty
# Otherwise the TP loader fails
for value in self.model.load_tp_gen(
self.gpu_split or None,
callback_gen=progress_callback,
expect_cache_base=cache_class,
expect_cache_tokens=self.cache_size,
):
if value:
yield value
elif not self.gpu_split_auto:
logger.info("Loading with a manual GPU split (or a one GPU setup)")
for value in self.model.load_gen(
self.gpu_split,
callback_gen=progress_callback,
):
if value:
yield value
# Create the model cache
self.cache = self.create_cache(
cache_class=cache_class,
autosplit=self.gpu_split_auto,
use_tp=self.use_tp,
model=self.model,
)
# Load model with autosplit (without TP)
if self.gpu_split_auto and not self.use_tp:
logger.info("Loading with autosplit")
for value in self.model.load_autosplit_gen(
self.cache,
reserve_vram=autosplit_reserve,
last_id_only=True,
callback_gen=progress_callback,
):
if value:
yield value
# Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
# TODO: Maybe make a wrapper class with an ID instead of a utility function
def get_cache_class(self, cache_mode: str):
"""Utility function to get a cache class based on user preference."""
match cache_mode:
case "Q4":
return ExLlamaV2Cache_Q4
case "Q6":
return ExLlamaV2Cache_Q6
case "Q8":
return ExLlamaV2Cache_Q8
case _:
return ExLlamaV2Cache
def create_cache(
self,
cache_class: ExLlamaV2CacheBase,
autosplit: bool,
use_tp: bool,
model: ExLlamaV2,
):
"""Utility function to create a model cache."""
if use_tp:
return ExLlamaV2Cache_TP(
model,
base=cache_class,
max_seq_len=self.cache_size,
batch_size=1,
)
else:
return cache_class(
model,
max_seq_len=self.cache_size,
lazy=autosplit,
batch_size=1,
)
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.loaded:
await self.load_lock.acquire()
# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)
# Create new generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
paged=self.paged,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.loaded:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
def get_loras(self):
"""Convenience function to get all loras."""
return unwrap(self.generator.generator.current_loras, [])
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""Load loras."""
loras = unwrap(kwargs.get("loras"), [])
try:
await self.load_lock.acquire()
# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
loras_to_load: List[ExLlamaV2Lora] = []
success: List[str] = []
failure: List[str] = []
for lora in loras:
lora_name = lora.get("name")
lora_scaling = unwrap(lora.get("scaling"), 1.0)
if lora_name is None:
logger.warning(
"One of your loras does not have a name. Please check your "
"config.yml! Skipping lora load."
)
failure.append(lora_name)
continue
logger.info(f"Adding lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
loras_to_load.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
logger.info(f"Lora successfully added: {lora_name}")
success.append(lora_name)
self.generator.generator.set_loras(loras_to_load)
logger.info("All loras successfully loaded")
# Return success and failure names
return {"success": success, "failure": failure}
finally:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async def unload(self, loras_only: bool = False, **kwargs):
"""Free all VRAM resources used by the model (and loras)."""
# Shutdown immediately unloads and bypasses all locks
do_shutdown = kwargs.get("shutdown")
try:
if not do_shutdown:
await self.load_lock.acquire()
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
# Delete references held in the grammar module
clear_grammar_func_cache()
# Clear the image embedding cache
clear_image_embedding_cache()
# Unload LoRAs
if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras:
lora.unload()
self.generator.generator.set_loras([])
# Unload the entire model if not just unloading loras
if not loras_only:
if self.model:
self.model.unload()
self.model = None
if self.vision_model:
self.vision_model.unload()
self.vision_model = None
if self.draft_model:
self.draft_model.unload()
self.draft_model = None
self.config = None
self.cache = None
self.tokenizer = None
# Cleanup the generator from any pending jobs
if self.generator is not None:
await self.generator.close()
self.generator = None
# Set all model state variables to False
self.loaded = False
gc.collect()
torch.cuda.empty_cache()
logger.info("Loras unloaded." if loras_only else "Model unloaded.")
finally:
if not do_shutdown:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string."""
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
return (
self.tokenizer.encode(
text,
add_bos=unwrap(
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
embeddings=mm_embeddings_content,
)
.flatten()
.tolist()
)
def decode_tokens(self, ids: List[int], **kwargs):
"""Wrapper to decode tokens from a list of IDs"""
ids = torch.tensor([ids])
return self.tokenizer.decode(
ids,
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
def get_special_tokens(self):
return {
"bos_token": self.tokenizer.bos_token,
"eos_token": self.tokenizer.eos_token,
"pad_token": self.tokenizer.pad_token,
"unk_token": self.tokenizer.unk_token,
}
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
top_tokens = [
self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.get_id_to_piece_list(True)[index]
)
for index in token_ids.flatten().tolist()
]
top_values = torch.log(token_probs).flatten().tolist()
# Cannot return -inf in JSON
cleaned_values = [
-1000 if value == float("-inf") else value for value in top_values
]
return dict(zip_longest(top_tokens, cleaned_values))
async def generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""Generate a response to a prompt."""
generations = []
async for generation in self.stream_generate(
request_id,
prompt,
params,
abort_event,
mm_embeddings,
):
generations.append(generation)
joined_generation = {
"text": "",
"prompt_tokens": 0,
"gen_tokens": 0,
"tool_calls": None,
"offset": [],
"token_probs": {},
"logprobs": [],
}
if generations:
# Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]:
finish_chunk = generations.pop()
joined_generation = {**joined_generation, **finish_chunk}
else:
joined_generation["finish_reason"] = "stop"
if len(generations) > 0:
for generation in generations:
joined_generation["text"] += unwrap(generation.get("text"), "")
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
)
# Include empty logprob dicts for index preservation
joined_generation["logprobs"].append(
unwrap(generation.get("logprobs"), {})
)
joined_generation["prompt_tokens"] = unwrap(
generations[-1].get("prompt_tokens"), 0
)
joined_generation["generated_tokens"] = unwrap(
generations[-1].get("generated_tokens"), 0
)
return joined_generation
async def stream_generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
try:
# Wait for load lock to be freed before processing
# Mainly used for loras and other operations where the class is available
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
# If the model is being unloaded, don't accept new requests
if not self.loaded:
raise RuntimeError(
"Model is being unloaded. Cannot process new generation requests."
)
# Mark that the job is running
self.active_job_ids[request_id] = None
# Yield from the internal generator
async for generation_chunk in self.generate_gen(
request_id=request_id,
prompt=prompt,
params=params,
abort_event=abort_event,
mm_embeddings=mm_embeddings,
):
yield generation_chunk
finally:
# Clean up and remove the job from active IDs
del self.active_job_ids[request_id]
def check_unsupported_settings(self, params: BaseSamplerRequest):
"""
Check and warn the user if a sampler is unsupported.
Meant for dev wheels!
"""
return params
def assign_gen_params(
self,
params: BaseSamplerRequest,
gen_settings: ExLlamaV2Sampler.Settings,
grammar_handler: ExLlamaV2Grammar,
):
# Apply settings
gen_settings.temperature = params.temperature
gen_settings.temperature_last = params.temperature_last
gen_settings.smoothing_factor = params.smoothing_factor
gen_settings.top_k = params.top_k
gen_settings.top_p = params.top_p
gen_settings.top_a = params.top_a
gen_settings.min_p = params.min_p
gen_settings.tfs = params.tfs
gen_settings.typical = params.typical
gen_settings.mirostat = params.mirostat_mode == 2
gen_settings.skew = params.skew
# XTC
if params.xtc_probability > 0.0:
gen_settings.xtc_probability = params.xtc_probability
# 0.1 is the default for this value
gen_settings.xtc_threshold = params.xtc_threshold
# DynaTemp settings
max_temp = params.max_temp
min_temp = params.min_temp
if params.max_temp > params.min_temp:
gen_settings.max_temp = max_temp
gen_settings.min_temp = min_temp
gen_settings.temp_exponent = params.temp_exponent
else:
# Force to default values
gen_settings.max_temp = 1.0
gen_settings.min_temp = 1.0
gen_settings.temp_exponent = 1.0
# Warn if max/min temp values are > 0
# and if they're less than or equal to each other
if max_temp < min_temp or (
1 not in {min_temp, max_temp} and max_temp == min_temp
):
logger.warning(
"Max temp is less than or equal to min temp, skipping DynaTemp."
)
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = params.mirostat_tau
gen_settings.mirostat_eta = params.mirostat_eta
# Penalties
gen_settings.token_repetition_penalty = params.repetition_penalty
gen_settings.token_frequency_penalty = params.frequency_penalty
gen_settings.token_presence_penalty = params.presence_penalty
# Applies for all penalties despite being called token_repetition_range
gen_settings.token_repetition_range = unwrap(
params.penalty_range, self.config.max_seq_len
)
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
# fallback
# Always default to 0 if something goes wrong
if gen_settings.token_repetition_range < 0:
fallback_decay = 0
else:
fallback_decay = gen_settings.token_repetition_range
gen_settings.token_repetition_decay = coalesce(
params.repetition_decay, fallback_decay, 0
)
# DRY options
dry_multiplier = params.dry_multiplier
# < 0 = disabled
if dry_multiplier > 0:
gen_settings.dry_multiplier = dry_multiplier
gen_settings.dry_allowed_length = params.dry_allowed_length
gen_settings.dry_base = params.dry_base
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
# Use max_seq_len as the fallback to stay consistent
gen_settings.dry_range = unwrap(params.dry_range, self.config.max_seq_len)
# Tokenize sequence breakers
if params.dry_sequence_breakers:
gen_settings.dry_sequence_breakers = {
self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers
}
# Add JSON schema filter if it exists
if params.json_schema:
grammar_handler.add_json_schema_filter(
params.json_schema, self.model, self.tokenizer
)
# Add regex filter if it exists
if params.regex_pattern:
grammar_handler.add_regex_filter(
params.regex_pattern, self.model, self.tokenizer
)
# Add EBNF filter if it exists
if params.grammar_string:
grammar_handler.add_kbnf_filter(
params.grammar_string, self.model, self.tokenizer
)
# Speculative Ngram
self.generator.speculative_ngram = params.speculative_ngram
# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
gen_settings.temperature = 1.0
gen_settings.top_k = 1
gen_settings.top_p = 0
gen_settings.typical = 0
logger.warning(
"Temperature is set to 0. Overriding temp, "
"top_k, top_p, and typical to 1.0, 1, 0, and 0."
)
# Set banned tokens
if params.banned_tokens:
gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens)
# Set allowed tokens
if params.allowed_tokens:
gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens)
# Set logit bias
if params.logit_bias:
# Create a vocab tensor if it doesn't exist for token biasing
if gen_settings.token_bias is None:
padding = -self.tokenizer.config.vocab_size % 32
gen_settings.token_bias = torch.zeros(
(self.tokenizer.config.vocab_size + padding,),
dtype=torch.float,
)
# Map logits to the tensor with their biases
for token_id, bias in params.logit_bias.items():
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
gen_settings.token_bias[token_id] = bias
else:
logger.warning(
f"Logit bias: Token {token_id} not present "
"in the model's vocab. Skipping."
)
# Adds logprobs to a generation chunk
def handle_logprobs(self, result: dict, generation: dict):
top_tokens = unwrap(
result.get("top_k_tokens"),
torch.empty((1, 0, 1), dtype=torch.long),
)
top_probs = unwrap(
result.get("top_k_probs"),
torch.empty((1, 0, 1), dtype=torch.float),
)
if top_tokens.numel() > 0 and top_probs.numel() > 0:
logprobs = self.get_logprobs(top_tokens, top_probs)
generation["logprobs"] = logprobs
# The first logprob is the selected token prob
generation["token_probs"] = {
token: logprobs[token] for token in list(logprobs.keys())[:1]
}
# Creates and returns a finish chunk
def handle_finish_chunk(self, result: dict, generation: dict):
eos_reason = result.get("eos_reason")
stop_str = None
if eos_reason == "max_new_tokens":
finish_reason = "length"
else:
finish_reason = "stop"
# Grab stop string if stop was the reason
if eos_reason == "stop_token":
stop_str = result.get("eos_triggering_token_str")
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
# Prompt
prompt_tokens = result.get("prompt_tokens")
cached_tokens = round(result.get("cached_tokens"), 2)
prompt_time = round(result.get("time_prefill"), 2)
prompt_ts = (
"Indeterminate"
if prompt_time == 0
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
)
# Generated
gen_tokens = result.get("new_tokens")
gen_time = result.get("time_generate")
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
# Queue + Total
queue_time = result.get("time_enqueued")
total_time = round(queue_time + prompt_time + gen_time, 2)
finish_chunk = {
"prompt_tokens": prompt_tokens,
"prompt_time": round(prompt_time, 2),
"prompt_tokens_per_sec": prompt_ts,
"gen_tokens": gen_tokens,
"gen_time": round(gen_time, 2),
"gen_tokens_per_sec": gen_ts,
"total_time": total_time,
"queue_time": round(queue_time, 2),
"cached_tokens": cached_tokens,
"finish_reason": finish_reason,
"stop_str": stop_str,
}
return finish_chunk
async def generate_gen(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""
Create generator function for prompt completion.
for kwargs, check common/sampling.py
"""
prompts = [prompt]
gen_settings = ExLlamaV2Sampler.Settings()
grammar_handler = ExLlamaV2Grammar()
self.assign_gen_params(
params,
gen_settings,
grammar_handler,
)
# Set banned strings
banned_strings = params.banned_strings
if banned_strings and len(grammar_handler.filters) > 0:
logger.warning(
"Disabling banned_strings because "
"they cannot be used with grammar filters."
)
banned_strings = []
# Set CFG scale and negative prompt
cfg_scale = params.cfg_scale
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.paged:
gen_settings.cfg_scale = cfg_scale
# If the negative prompt is empty, use the BOS token
negative_prompt = unwrap(
params.negative_prompt, self.tokenizer.bos_token
)
prompts.append(negative_prompt)
else:
logger.warning(
"CFG is currently disabled because paged mode is disabled. "
"Please use an ampere (30 series) or higher GPU for CFG support."
)
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
stop_conditions = params.stop
ban_eos_token = params.ban_eos_token
# Set add_bos_token for generation
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
# Fetch EOS tokens from the HF model if they exist
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
# Ban the EOS token if specified. If not, append to stop conditions
# as well.
# Set this below logging to avoid polluting the stop strings array
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, eos_tokens)
else:
stop_conditions += eos_tokens
# Get multimodal embeddings if present
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
# Encode both positive and negative prompts
input_ids = [
self.tokenizer.encode(
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
embeddings=mm_embeddings_content,
)
for prompt in prompts
]
# The first index will always be the positive prompt
context_len = input_ids[0].size(dim=-1)
# The second index will be the negative prompt if CFG is enabled
negative_context_len = input_ids[1].size(dim=-1) if negative_prompt else 0
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
params.max_tokens,
self.config.max_seq_len - max(context_len, negative_context_len),
)
if max_tokens < 1:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1
# Determine if the negative context or the context length is bigger
context_to_check = max(negative_context_len, context_len)
# Check total length of prompt against max context length
if context_to_check > self.config.max_seq_len:
preamble = (
"Negative prompt" if negative_context_len > context_len else "Prompt"
)
raise ValueError(
f"{preamble} length {context_to_check} is greater than "
f"max_seq_len {self.config.max_seq_len}"
)
# Check total required pages for CFG request to avoid overallocation
if negative_prompt and (
sum(
256 * math.ceil((context + max_tokens) / 256)
for context in (context_len, negative_context_len)
)
> self.cache_size
):
raise ValueError(
f"Total required page size for request "
f"{context_len} + {negative_context_len} + {max_tokens} * 2 "
f"is greater than cache_size {self.cache_size}"
)
# Log prompt to console. Add the BOS token if specified
log_prompt(
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
request_id,
negative_prompt,
)
# Create and add a new job
# Don't use the request ID here as there can be multiple jobs per request
job = ExLlamaV2DynamicJobAsync(
self.generator,
input_ids=input_ids,
max_new_tokens=max_tokens,
min_new_tokens=params.min_tokens,
gen_settings=gen_settings,
stop_conditions=stop_conditions,
decode_special_tokens=True,
filters=grammar_handler.filters,
filter_prefer_eos=bool(grammar_handler.filters),
return_probs=params.logprobs > 0,
return_top_tokens=params.logprobs,
return_logits=params.logprobs > 0,
banned_strings=banned_strings,
token_healing=params.token_healing,
identifier=request_id,
embeddings=mm_embeddings_content,
)
# Assign the active job to the request ID
self.active_job_ids[request_id] = job
# Save generated tokens and full response
# Copy over max seq len incase model is unloaded and stored jobs can complete
# Full response is required for offset calculation
max_seq_len = self.config.max_seq_len
generated_tokens = 0
full_response = ""
metrics_result = {}
# Get the generation status once it's ready
try:
async for result in job:
# Abort if the event is set while streaming
if abort_event and abort_event.is_set():
await job.cancel()
break
stage = result.get("stage")
result_id = result.get("identifier")
if stage == "streaming" and result_id == request_id:
chunk = unwrap(result.get("text"), "")
full_response += chunk
chunk_tokens = result.get("token_ids")
if chunk_tokens is not None:
generated_tokens += chunk_tokens.size(dim=0)
generation = {
"text": chunk,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
# Increase penalty range to generated token amount
if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens
# Handle logprobs
if params.logprobs > 0:
self.handle_logprobs(result, generation)
yield generation
# Yield a finish chunk when generation is finished
if result.get("eos"):
log_response(request_id, full_response)
finish_chunk = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = finish_chunk
yield finish_chunk
break
except asyncio.CancelledError:
await job.cancel()
except Exception as ex:
# Create a new generator since the current state is broken
# No need to wait for this to finish
logger.error(
"FATAL ERROR with generation. "
"Attempting to recreate the generator. "
"If this fails, please restart the server.\n"
)
asyncio.ensure_future(self.create_generator())
await HealthManager.add_unhealthy_event(ex)
raise ex
finally:
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
request_id=request_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
prompt=prompt,
**params.model_dump(exclude={"prompt"}),
auto_scale_penalty_range=auto_scale_penalty_range,
)
# Log the metrics if present
if metrics_result:
log_metrics(
request_id,
metrics_result,
context_len,
max_seq_len,
)