"""The model container class for ExLlamaV2 models.""" import asyncio import gc import math import pathlib import torch from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, ExLlamaV2CacheBase, ExLlamaV2Cache, ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8, ExLlamaV2Cache_TP, ExLlamaV2Tokenizer, ExLlamaV2Lora, ExLlamaV2VisionTower, ) from exllamav2.generator import ( ExLlamaV2Sampler, ExLlamaV2DynamicGeneratorAsync, ExLlamaV2DynamicJobAsync, ) from itertools import zip_longest from loguru import logger from typing import Dict, List, Optional from backends.base_model_container import BaseModelContainer from backends.exllamav2.grammar import ( ExLlamaV2Grammar, clear_grammar_func_cache, ) from backends.exllamav2.utils import exllama_disabled_flash_attn from backends.exllamav2.vision import clear_image_embedding_cache from common.concurrency import iterate_in_threadpool from common.gen_logging import ( log_generation_params, log_metrics, log_prompt, log_response, ) from common.hardware import hardware_supports_flash_attn from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.optional_dependencies import check_package_version from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template from common.transformers_utils import HFModel from common.utils import calculate_rope_alpha, coalesce, unwrap from endpoints.core.types.model import ModelCard, ModelCardParameters class ExllamaV2Container(BaseModelContainer): """The model container class for ExLlamaV2 models.""" # Model directories model_dir: pathlib.Path = pathlib.Path("models") draft_model_dir: pathlib.Path = pathlib.Path("models") prompt_template: Optional[PromptTemplate] = None # HF model instance hf_model: HFModel # Exl2 vars config: Optional[ExLlamaV2Config] = None model: Optional[ExLlamaV2] = None cache: Optional[ExLlamaV2Cache] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None prompt_template: Optional[PromptTemplate] = None paged: bool = True # Draft model vars use_draft_model: bool = False draft_config: Optional[ExLlamaV2Config] = None draft_model: Optional[ExLlamaV2] = None draft_cache: Optional[ExLlamaV2Cache] = None # Internal config vars cache_size: int = None cache_mode: str = "FP16" draft_cache_mode: str = "FP16" max_batch_size: Optional[int] = None # GPU split vars gpu_split: List[float] = [] draft_gpu_split: List[float] = [] gpu_split_auto: bool = True autosplit_reserve: List[float] = [96 * 1024**2] use_tp: bool = False # Vision vars use_vision: bool = False vision_model: Optional[ExLlamaV2VisionTower] = None # Load synchronization active_job_ids: Dict[str, Optional[ExLlamaV2DynamicJobAsync]] = {} loaded: bool = False load_lock: asyncio.Lock = asyncio.Lock() load_condition: asyncio.Condition = asyncio.Condition() @classmethod async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs): """ Primary asynchronous initializer for model container. Kwargs are located in config_sample.yml """ # Create a new instance as a "fake self" self = cls() # Make sure ExllamaV2 is up to date check_package_version("exllamav2", "0.3.1") # Initialize config self.config = ExLlamaV2Config() self.model_dir = model_directory self.config.model_dir = str(model_directory.resolve()) self.hf_model = hf_model # Make the max seq len 4096 before preparing the config # This is a better default than 2048 self.config.max_seq_len = 4096 self.config.prepare() # Check if the model arch is compatible with various exl2 features self.config.arch_compat_overrides() # Set vision state and error if vision isn't supported on the current model self.use_vision = unwrap(kwargs.get("vision"), False) if self.use_vision and not self.config.vision_model_type: raise ValueError( "The provided model does not have vision capabilities that are " "supported by ExllamaV2. " "Please reload with vision disabled." ) # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") self.use_draft_model = draft_args and draft_model_name # Always disable draft if params are incorrectly configured if draft_args and draft_model_name is None: logger.warning( "Draft model is disabled because a model name " "wasn't provided. Please check your config.yml!" ) self.use_draft_model = False if self.use_draft_model: self.draft_config = ExLlamaV2Config() draft_model_path = pathlib.Path( unwrap(draft_args.get("draft_model_dir"), "models") ) draft_model_path = draft_model_path / draft_model_name self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), []) self.draft_model_dir = draft_model_path self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() # MARK: User configuration # Get cache mode self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") # Catch exllamav3 cache_mode if self.cache_mode != "FP16" and not self.cache_mode.startswith("Q"): logger.warning( f"Provided cache mode '{self.cache_mode}' is not a " "valid choice for exllamav2, please check your settings. " "Defaulting to FP16." ) self.cache_mode = "FP16" # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) use_tp = unwrap(kwargs.get("tensor_parallel"), False) gpu_split = unwrap(kwargs.get("gpu_split"), []) gpu_device_list = list(range(0, gpu_count)) # Set GPU split options if gpu_count == 1: self.gpu_split_auto = False logger.info("Disabling GPU split because one GPU is in use.") else: # Set tensor parallel if use_tp: self.use_tp = True # TP has its own autosplit loader self.gpu_split_auto = False # Enable manual GPU split if provided if gpu_split: self.gpu_split_auto = False self.gpu_split = gpu_split gpu_device_list = [ device_idx for device_idx, memory in enumerate(self.gpu_split) if memory > 0 ] elif gpu_split_auto and not self.use_tp: # Otherwise fallback to autosplit settings self.gpu_split_auto = gpu_split_auto autosplit_reserve_megabytes = unwrap( kwargs.get("autosplit_reserve"), [96] ) # Reserve VRAM for each GPU self.autosplit_reserve = [ int(math.ceil(value * 1024**2)) for value in autosplit_reserve_megabytes ] # Change the GPU device list only if gpu_split's list is too small # This allows for an uneven list specification if self.draft_gpu_split and len(self.draft_gpu_split) > len(self.gpu_split): gpu_device_list = [ device_idx for device_idx, memory in enumerate(self.draft_gpu_split) if memory > 0 ] # Hardcode max output length to 16 self.config.max_output_len = 16 # Grab the base model's sequence length before overrides for # rope calculations base_seq_len = hf_model.hf_config.max_position_embeddings # Set the target seq len if present target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) # Set the rope scale self.config.scale_pos_emb = unwrap( kwargs.get("rope_scale"), self.config.scale_pos_emb ) # Sets rope alpha value. # Utilize the model's max_position_embeddings as a base value # Automatically calculate if unset or defined as an "auto" literal. rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto") if rope_alpha == "auto": self.config.scale_alpha_value = calculate_rope_alpha( base_seq_len, target_seq_len ) else: self.config.scale_alpha_value = rope_alpha # Set the max seq len if specified if target_seq_len: self.config.max_seq_len = target_seq_len # Set max batch size to the config override self.max_batch_size = unwrap(kwargs.get("max_batch_size")) # Check whether the user's configuration supports flash/paged attention # Also check if exl2 has disabled flash attention if exllama_disabled_flash_attn( self.config.no_flash_attn ) or not hardware_supports_flash_attn(gpu_device_list): gpu_unsupported_message = ( "An unsupported GPU is found in this configuration. " "Switching to compatibility mode. \n" "This disables parallel batching " "and features that rely on it (ex. CFG). \n" "To disable compatability mode, all GPUs must be ampere " "(30 series) or newer. AMD GPUs are not supported." ) logger.warning(gpu_unsupported_message) self.config.no_flash_attn = True if self.draft_config: self.draft_config.no_flash_attn = True self.paged = False self.max_batch_size = 1 torch.backends.cuda.enable_flash_sdp(False) # Set k/v cache size # cache_size is only relevant when paged mode is enabled if self.paged: cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len) if cache_size < self.config.max_seq_len: logger.warning( f"The given cache_size ({cache_size}) is smaller than the " "desired context length.\n" "Overriding cache_size to max_seq_len. " ) cache_size = self.config.max_seq_len # Enforce a multiple of 256 for cache size # Overestimate to ensure that the cache isn't below max_seq_len cache_remainder = cache_size % 256 if cache_remainder != 0: rounded_cache_size = int( 256 * ((cache_size - cache_remainder) / 256 + 1) ) logger.warning( f"The given cache size ({cache_size}) is " "not a multiple of 256.\n" "Overriding cache_size with an overestimated value of " f"{rounded_cache_size} tokens." ) cache_size = rounded_cache_size # Warn user if cache size may be inadequate for CFG if cache_size < 2 * self.config.max_seq_len: logger.warning( f"The given cache_size ({cache_size}) is less than 2 * max_seq_len " "and may be too small for requests using CFG. \n" "Ignore this warning if you do not plan on using CFG." ) self.cache_size = cache_size else: self.cache_size = self.config.max_seq_len # Try to set prompt template self.prompt_template = await find_prompt_template( kwargs.get("prompt_template"), model_directory ) # Catch all for template lookup errors if self.prompt_template: logger.info( f'Using template "{self.prompt_template.name}" for chat completions.' ) else: logger.warning( "Chat completions are disabled because a prompt " "template wasn't provided or auto-detected." ) # Make sure chunk size is >= 256, keep near or below max seq len user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048) chunk_size = sorted((256, user_chunk_size, self.config.max_seq_len))[1] chunk_remainder = chunk_size % 256 if chunk_remainder != 0: rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1)) logger.warning( f"The given chunk size ({chunk_size}) is " "not a multiple of 256.\n" "Overriding chunk_size with an overestimated value of " f"{rounded_chunk_size} tokens." ) chunk_size = rounded_chunk_size self.config.max_input_len = chunk_size self.config.max_attention_size = chunk_size**2 # Set user-configured draft model values if self.use_draft_model: self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.scale_pos_emb = unwrap( draft_args.get("draft_rope_scale"), 1.0 ) # Set draft rope alpha. Follows same behavior as model rope alpha. # Use the max_position_embeddings of the model draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto") if draft_rope_alpha == "auto": self.draft_config.scale_alpha_value = calculate_rope_alpha( base_seq_len, self.draft_config.max_seq_len ) else: self.draft_config.scale_alpha_value = draft_rope_alpha # Set draft cache mode self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16") # Catch exllamav3 draft_cache_mode if self.draft_cache_mode != "FP16" and not self.draft_cache_mode.startswith( "Q" ): logger.warning( f"Provided draft cache mode '{self.draft_cache_mode}' is not a " "valid choice for exllamav2, please check your settings. " "Defaulting to FP16." ) self.draft_cache_mode = "FP16" # Edit the draft config size if chunk_size: self.draft_config.max_input_len = chunk_size self.draft_config.max_attention_size = chunk_size**2 print(self.config.max_seq_len) # Return the created instance return self def model_info(self): draft_model_card: ModelCard = None if self.draft_config: draft_model_params = ModelCardParameters( max_seq_len=self.draft_config.max_seq_len, rope_scale=self.draft_config.scale_pos_emb, rope_alpha=self.draft_config.scale_alpha_value, cache_mode=self.draft_cache_mode, ) draft_model_card = ModelCard( id=self.draft_model_dir.name, parameters=draft_model_params, ) model_params = ModelCardParameters( max_seq_len=self.config.max_seq_len, cache_size=self.cache_size, rope_scale=self.config.scale_pos_emb, rope_alpha=self.config.scale_alpha_value, max_batch_size=self.max_batch_size, cache_mode=self.cache_mode, chunk_size=self.config.max_input_len, use_vision=self.use_vision, draft=draft_model_card, ) if self.prompt_template: model_params.prompt_template = self.prompt_template.name model_params.prompt_template_content = self.prompt_template.raw_template model_card = ModelCard( id=self.model_dir.name, parameters=model_params, ) return model_card async def wait_for_jobs(self, skip_wait: bool = False): """Polling mechanism to wait for pending generation jobs.""" if not self.generator: return # Immediately abort all jobs if asked if skip_wait: logger.warning( "Immediately terminating all jobs. " "Clients will have their requests cancelled.\n" ) for job in self.active_job_ids.values(): if job: await job.cancel() while len(self.active_job_ids) > 0: await asyncio.sleep(0.01) async def load(self, progress_callback=None): """ Load model Args: progress_callback (function, optional): A function to call for each module loaded. Prototype: def progress(loaded_modules: int, total_modules: int) """ async for _ in self.load_gen(progress_callback): pass async def load_gen(self, progress_callback=None, **kwargs): """Loads a model and streams progress via a generator.""" # Indicate that model load has started # Do this operation under the load lock's context try: await self.load_lock.acquire() # Wait for existing generation jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) # Streaming gen for model load progress model_load_generator = self.load_model_sync(progress_callback) async for value in iterate_in_threadpool(model_load_generator): yield value # Create async generator await self.create_generator() # Clean up any extra vram usage from torch and cuda # (Helps reduce VRAM bottlenecking on Windows) gc.collect() torch.cuda.empty_cache() # Cleanup and update model load state self.loaded = True logger.info("Model successfully loaded.") finally: self.load_lock.release() async with self.load_condition: self.load_condition.notify_all() @torch.inference_mode() def load_model_sync(self, progress_callback=None): """ Synchronous generator for loading. Args: progress_callback (function, optional): A function to call for each module loaded. Prototype: def progress(loaded_modules: int, total_modules: int) Runs under a shared inference mode context. """ # Reset tokenizer namespace vars and create a tokenizer ExLlamaV2Tokenizer.unspecial_piece_to_id = {} ExLlamaV2Tokenizer.unspecial_id_to_piece = {} ExLlamaV2Tokenizer.extended_id_to_piece = {} ExLlamaV2Tokenizer.extended_piece_to_id = {} self.tokenizer = ExLlamaV2Tokenizer(self.config) # Calculate autosplit reserve for all GPUs gpu_count = torch.cuda.device_count() autosplit_reserve = self.autosplit_reserve + [0] * ( gpu_count - len(self.autosplit_reserve) ) # Load draft model if a config is present if self.draft_config: self.draft_model = ExLlamaV2(self.draft_config) logger.info("Loading draft model: " + self.draft_config.model_dir) # Draft uses the autosplit loader, so create a cache that reflects this draft_cache_class = self.get_cache_class(self.draft_cache_mode) if self.draft_gpu_split: logger.info("Loading with a manual GPU split (or a one GPU setup)") for value in self.draft_model.load_gen( self.draft_gpu_split, callback_gen=progress_callback, ): if value: yield value self.draft_cache = self.create_cache( cache_class=draft_cache_class, autosplit=False, use_tp=False, model=self.draft_model, ) else: logger.info("Loading with autosplit") self.draft_cache = self.create_cache( cache_class=draft_cache_class, autosplit=True, use_tp=False, model=self.draft_model, ) for value in self.draft_model.load_autosplit_gen( self.draft_cache, reserve_vram=autosplit_reserve, last_id_only=True, callback_gen=progress_callback, ): if value: yield value # Test VRAM allocation with a full-length forward pass input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True) # Load vision tower if it exists if self.use_vision: self.vision_model = ExLlamaV2VisionTower(self.config) for value in self.vision_model.load_gen(callback_gen=progress_callback): if value: yield value self.model = ExLlamaV2(self.config) logger.info("Loading model: " + self.config.model_dir) # Get class of the model cache cache_class = self.get_cache_class(self.cache_mode) # Load model with manual split # Entrypoint for single GPU users if self.use_tp: logger.info("Loading with tensor parallel") # GPU split must be None if the array is empty # Otherwise the TP loader fails for value in self.model.load_tp_gen( self.gpu_split or None, callback_gen=progress_callback, expect_cache_base=cache_class, expect_cache_tokens=self.cache_size, ): if value: yield value elif not self.gpu_split_auto: logger.info("Loading with a manual GPU split (or a one GPU setup)") for value in self.model.load_gen( self.gpu_split, callback_gen=progress_callback, ): if value: yield value # Create the model cache self.cache = self.create_cache( cache_class=cache_class, autosplit=self.gpu_split_auto, use_tp=self.use_tp, model=self.model, ) # Load model with autosplit (without TP) if self.gpu_split_auto and not self.use_tp: logger.info("Loading with autosplit") for value in self.model.load_autosplit_gen( self.cache, reserve_vram=autosplit_reserve, last_id_only=True, callback_gen=progress_callback, ): if value: yield value # Test VRAM allocation with a full-length forward pass input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) # TODO: Maybe make a wrapper class with an ID instead of a utility function def get_cache_class(self, cache_mode: str): """Utility function to get a cache class based on user preference.""" match cache_mode: case "Q4": return ExLlamaV2Cache_Q4 case "Q6": return ExLlamaV2Cache_Q6 case "Q8": return ExLlamaV2Cache_Q8 case _: return ExLlamaV2Cache def create_cache( self, cache_class: ExLlamaV2CacheBase, autosplit: bool, use_tp: bool, model: ExLlamaV2, ): """Utility function to create a model cache.""" if use_tp: return ExLlamaV2Cache_TP( model, base=cache_class, max_seq_len=self.cache_size, batch_size=1, ) else: return cache_class( model, max_seq_len=self.cache_size, lazy=autosplit, batch_size=1, ) async def create_generator(self): """Create and save a Exllama generator class.""" try: # Don't acquire locks unless a model is loaded if self.loaded: await self.load_lock.acquire() # Immediately cancel all jobs await self.wait_for_jobs(skip_wait=True) # Create new generator self.generator = ExLlamaV2DynamicGeneratorAsync( model=self.model, cache=self.cache, draft_model=self.draft_model, draft_cache=self.draft_cache, tokenizer=self.tokenizer, max_batch_size=self.max_batch_size, paged=self.paged, ) # Update the state of the container var if self.max_batch_size is None: self.max_batch_size = self.generator.generator.max_batch_size finally: # This means the generator is being recreated # The load lock is already released in the load function if self.loaded: self.load_lock.release() async with self.load_condition: self.load_condition.notify_all() def get_loras(self): """Convenience function to get all loras.""" return unwrap(self.generator.generator.current_loras, []) async def load_loras(self, lora_directory: pathlib.Path, **kwargs): """Load loras.""" loras = unwrap(kwargs.get("loras"), []) try: await self.load_lock.acquire() # Wait for existing generation jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) loras_to_load: List[ExLlamaV2Lora] = [] success: List[str] = [] failure: List[str] = [] for lora in loras: lora_name = lora.get("name") lora_scaling = unwrap(lora.get("scaling"), 1.0) if lora_name is None: logger.warning( "One of your loras does not have a name. Please check your " "config.yml! Skipping lora load." ) failure.append(lora_name) continue logger.info(f"Adding lora: {lora_name} at scaling {lora_scaling}") lora_path = lora_directory / lora_name loras_to_load.append( ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) ) logger.info(f"Lora successfully added: {lora_name}") success.append(lora_name) self.generator.generator.set_loras(loras_to_load) logger.info("All loras successfully loaded") # Return success and failure names return {"success": success, "failure": failure} finally: self.load_lock.release() async with self.load_condition: self.load_condition.notify_all() async def unload(self, loras_only: bool = False, **kwargs): """Free all VRAM resources used by the model (and loras).""" # Shutdown immediately unloads and bypasses all locks do_shutdown = kwargs.get("shutdown") try: if not do_shutdown: await self.load_lock.acquire() # Wait for other jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) # Delete references held in the grammar module clear_grammar_func_cache() # Clear the image embedding cache clear_image_embedding_cache() # Unload LoRAs if self.generator and self.generator.generator.current_loras: for lora in self.generator.generator.current_loras: lora.unload() self.generator.generator.set_loras([]) # Unload the entire model if not just unloading loras if not loras_only: if self.model: self.model.unload() self.model = None if self.vision_model: self.vision_model.unload() self.vision_model = None if self.draft_model: self.draft_model.unload() self.draft_model = None self.config = None self.cache = None self.tokenizer = None # Cleanup the generator from any pending jobs if self.generator is not None: await self.generator.close() self.generator = None # Set all model state variables to False self.loaded = False gc.collect() torch.cuda.empty_cache() logger.info("Loras unloaded." if loras_only else "Model unloaded.") finally: if not do_shutdown: self.load_lock.release() async with self.load_condition: self.load_condition.notify_all() def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string.""" mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] return ( self.tokenizer.encode( text, add_bos=unwrap( kwargs.get("add_bos_token"), self.hf_model.add_bos_token() ), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), embeddings=mm_embeddings_content, ) .flatten() .tolist() ) def decode_tokens(self, ids: List[int], **kwargs): """Wrapper to decode tokens from a list of IDs""" ids = torch.tensor([ids]) return self.tokenizer.decode( ids, decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), )[0] def get_special_tokens(self): return { "bos_token": self.tokenizer.bos_token, "eos_token": self.tokenizer.eos_token, "pad_token": self.tokenizer.pad_token, "unk_token": self.tokenizer.unk_token, } def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor): top_tokens = [ self.tokenizer.extended_id_to_piece.get( index, self.tokenizer.get_id_to_piece_list(True)[index] ) for index in token_ids.flatten().tolist() ] top_values = torch.log(token_probs).flatten().tolist() # Cannot return -inf in JSON cleaned_values = [ -1000 if value == float("-inf") else value for value in top_values ] return dict(zip_longest(top_tokens, cleaned_values)) async def generate( self, request_id: str, prompt: str, params: BaseSamplerRequest, abort_event: Optional[asyncio.Event] = None, mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, ): """Generate a response to a prompt.""" generations = [] async for generation in self.stream_generate( request_id, prompt, params, abort_event, mm_embeddings, ): generations.append(generation) joined_generation = { "text": "", "prompt_tokens": 0, "gen_tokens": 0, "tool_calls": None, "offset": [], "token_probs": {}, "logprobs": [], } if generations: # Get finish_reason first and then shift where -1 points to if "finish_reason" in generations[-1]: finish_chunk = generations.pop() joined_generation = {**joined_generation, **finish_chunk} else: joined_generation["finish_reason"] = "stop" if len(generations) > 0: for generation in generations: joined_generation["text"] += unwrap(generation.get("text"), "") joined_generation["offset"].append(unwrap(generation.get("offset"), -1)) joined_generation["token_probs"].update( unwrap(generation.get("token_probs"), {}) ) # Include empty logprob dicts for index preservation joined_generation["logprobs"].append( unwrap(generation.get("logprobs"), {}) ) joined_generation["prompt_tokens"] = unwrap( generations[-1].get("prompt_tokens"), 0 ) joined_generation["generated_tokens"] = unwrap( generations[-1].get("generated_tokens"), 0 ) return joined_generation async def stream_generate( self, request_id: str, prompt: str, params: BaseSamplerRequest, abort_event: Optional[asyncio.Event] = None, mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, ): try: # Wait for load lock to be freed before processing # Mainly used for loras and other operations where the class is available async with self.load_condition: await self.load_condition.wait_for(lambda: not self.load_lock.locked()) # If the model is being unloaded, don't accept new requests if not self.loaded: raise RuntimeError( "Model is being unloaded. Cannot process new generation requests." ) # Mark that the job is running self.active_job_ids[request_id] = None # Yield from the internal generator async for generation_chunk in self.generate_gen( request_id=request_id, prompt=prompt, params=params, abort_event=abort_event, mm_embeddings=mm_embeddings, ): yield generation_chunk finally: # Clean up and remove the job from active IDs del self.active_job_ids[request_id] def check_unsupported_settings(self, params: BaseSamplerRequest): """ Check and warn the user if a sampler is unsupported. Meant for dev wheels! """ return params def assign_gen_params( self, params: BaseSamplerRequest, gen_settings: ExLlamaV2Sampler.Settings, grammar_handler: ExLlamaV2Grammar, ): # Apply settings gen_settings.temperature = params.temperature gen_settings.temperature_last = params.temperature_last gen_settings.smoothing_factor = params.smoothing_factor gen_settings.top_k = params.top_k gen_settings.top_p = params.top_p gen_settings.top_a = params.top_a gen_settings.min_p = params.min_p gen_settings.tfs = params.tfs gen_settings.typical = params.typical gen_settings.mirostat = params.mirostat_mode == 2 gen_settings.skew = params.skew # XTC if params.xtc_probability > 0.0: gen_settings.xtc_probability = params.xtc_probability # 0.1 is the default for this value gen_settings.xtc_threshold = params.xtc_threshold # DynaTemp settings max_temp = params.max_temp min_temp = params.min_temp if params.max_temp > params.min_temp: gen_settings.max_temp = max_temp gen_settings.min_temp = min_temp gen_settings.temp_exponent = params.temp_exponent else: # Force to default values gen_settings.max_temp = 1.0 gen_settings.min_temp = 1.0 gen_settings.temp_exponent = 1.0 # Warn if max/min temp values are > 0 # and if they're less than or equal to each other if max_temp < min_temp or ( 1 not in {min_temp, max_temp} and max_temp == min_temp ): logger.warning( "Max temp is less than or equal to min temp, skipping DynaTemp." ) # Default tau and eta fallbacks don't matter if mirostat is off gen_settings.mirostat_tau = params.mirostat_tau gen_settings.mirostat_eta = params.mirostat_eta # Penalties gen_settings.token_repetition_penalty = params.repetition_penalty gen_settings.token_frequency_penalty = params.frequency_penalty gen_settings.token_presence_penalty = params.presence_penalty # Applies for all penalties despite being called token_repetition_range gen_settings.token_repetition_range = unwrap( params.penalty_range, self.config.max_seq_len ) # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed # fallback # Always default to 0 if something goes wrong if gen_settings.token_repetition_range < 0: fallback_decay = 0 else: fallback_decay = gen_settings.token_repetition_range gen_settings.token_repetition_decay = coalesce( params.repetition_decay, fallback_decay, 0 ) # DRY options dry_multiplier = params.dry_multiplier # < 0 = disabled if dry_multiplier > 0: gen_settings.dry_multiplier = dry_multiplier gen_settings.dry_allowed_length = params.dry_allowed_length gen_settings.dry_base = params.dry_base # Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range # Use max_seq_len as the fallback to stay consistent gen_settings.dry_range = unwrap(params.dry_range, self.config.max_seq_len) # Tokenize sequence breakers if params.dry_sequence_breakers: gen_settings.dry_sequence_breakers = { self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers } # Add JSON schema filter if it exists if params.json_schema: grammar_handler.add_json_schema_filter( params.json_schema, self.model, self.tokenizer ) # Add regex filter if it exists if params.regex_pattern: grammar_handler.add_regex_filter( params.regex_pattern, self.model, self.tokenizer ) # Add EBNF filter if it exists if params.grammar_string: grammar_handler.add_kbnf_filter( params.grammar_string, self.model, self.tokenizer ) # Speculative Ngram self.generator.speculative_ngram = params.speculative_ngram # Override sampler settings for temp = 0 if gen_settings.temperature == 0: gen_settings.temperature = 1.0 gen_settings.top_k = 1 gen_settings.top_p = 0 gen_settings.typical = 0 logger.warning( "Temperature is set to 0. Overriding temp, " "top_k, top_p, and typical to 1.0, 1, 0, and 0." ) # Set banned tokens if params.banned_tokens: gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens) # Set allowed tokens if params.allowed_tokens: gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens) # Set logit bias if params.logit_bias: # Create a vocab tensor if it doesn't exist for token biasing if gen_settings.token_bias is None: padding = -self.tokenizer.config.vocab_size % 32 gen_settings.token_bias = torch.zeros( (self.tokenizer.config.vocab_size + padding,), dtype=torch.float, ) # Map logits to the tensor with their biases for token_id, bias in params.logit_bias.items(): if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)): gen_settings.token_bias[token_id] = bias else: logger.warning( f"Logit bias: Token {token_id} not present " "in the model's vocab. Skipping." ) # Adds logprobs to a generation chunk def handle_logprobs(self, result: dict, generation: dict): top_tokens = unwrap( result.get("top_k_tokens"), torch.empty((1, 0, 1), dtype=torch.long), ) top_probs = unwrap( result.get("top_k_probs"), torch.empty((1, 0, 1), dtype=torch.float), ) if top_tokens.numel() > 0 and top_probs.numel() > 0: logprobs = self.get_logprobs(top_tokens, top_probs) generation["logprobs"] = logprobs # The first logprob is the selected token prob generation["token_probs"] = { token: logprobs[token] for token in list(logprobs.keys())[:1] } # Creates and returns a finish chunk def handle_finish_chunk(self, result: dict, generation: dict): eos_reason = result.get("eos_reason") stop_str = None if eos_reason == "max_new_tokens": finish_reason = "length" else: finish_reason = "stop" # Grab stop string if stop was the reason if eos_reason == "stop_token": stop_str = result.get("eos_triggering_token_str") elif eos_reason == "stop_string": stop_str = result.get("eos_triggering_string") # Prompt prompt_tokens = result.get("prompt_tokens") cached_tokens = round(result.get("cached_tokens"), 2) prompt_time = round(result.get("time_prefill"), 2) prompt_ts = ( "Indeterminate" if prompt_time == 0 else round((prompt_tokens - cached_tokens) / prompt_time, 2) ) # Generated gen_tokens = result.get("new_tokens") gen_time = result.get("time_generate") gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2) # Queue + Total queue_time = result.get("time_enqueued") total_time = round(queue_time + prompt_time + gen_time, 2) finish_chunk = { "prompt_tokens": prompt_tokens, "prompt_time": round(prompt_time, 2), "prompt_tokens_per_sec": prompt_ts, "gen_tokens": gen_tokens, "gen_time": round(gen_time, 2), "gen_tokens_per_sec": gen_ts, "total_time": total_time, "queue_time": round(queue_time, 2), "cached_tokens": cached_tokens, "finish_reason": finish_reason, "stop_str": stop_str, } return finish_chunk async def generate_gen( self, request_id: str, prompt: str, params: BaseSamplerRequest, abort_event: Optional[asyncio.Event] = None, mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, ): """ Create generator function for prompt completion. for kwargs, check common/sampling.py """ prompts = [prompt] gen_settings = ExLlamaV2Sampler.Settings() grammar_handler = ExLlamaV2Grammar() self.assign_gen_params( params, gen_settings, grammar_handler, ) # Set banned strings banned_strings = params.banned_strings if banned_strings and len(grammar_handler.filters) > 0: logger.warning( "Disabling banned_strings because " "they cannot be used with grammar filters." ) banned_strings = [] # Set CFG scale and negative prompt cfg_scale = params.cfg_scale negative_prompt = None if cfg_scale not in [None, 1.0]: if self.paged: gen_settings.cfg_scale = cfg_scale # If the negative prompt is empty, use the BOS token negative_prompt = unwrap( params.negative_prompt, self.tokenizer.bos_token ) prompts.append(negative_prompt) else: logger.warning( "CFG is currently disabled because paged mode is disabled. " "Please use an ampere (30 series) or higher GPU for CFG support." ) # Dynamically scale penalty range to output tokens # Only do this if freq/pres pen is enabled # and the repetition range is -1 auto_scale_penalty_range = ( gen_settings.token_frequency_penalty != 0 or gen_settings.token_presence_penalty != 0 ) and gen_settings.token_repetition_range == -1 stop_conditions = params.stop ban_eos_token = params.ban_eos_token # Set add_bos_token for generation add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token()) # Fetch EOS tokens from the HF model if they exist eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id] # Ban the EOS token if specified. If not, append to stop conditions # as well. # Set this below logging to avoid polluting the stop strings array if ban_eos_token: gen_settings.disallow_tokens(self.tokenizer, eos_tokens) else: stop_conditions += eos_tokens # Get multimodal embeddings if present mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] # Encode both positive and negative prompts input_ids = [ self.tokenizer.encode( prompt, add_bos=add_bos_token, encode_special_tokens=True, embeddings=mm_embeddings_content, ) for prompt in prompts ] # The first index will always be the positive prompt context_len = input_ids[0].size(dim=-1) # The second index will be the negative prompt if CFG is enabled negative_context_len = input_ids[1].size(dim=-1) if negative_prompt else 0 # Automatically set max_tokens to fill up the context # This should be an OK default, but may be changed in the future max_tokens = unwrap( params.max_tokens, self.config.max_seq_len - max(context_len, negative_context_len), ) if max_tokens < 1: logger.warning("max_tokens must be a positive integer, setting to 1.") max_tokens = 1 # Determine if the negative context or the context length is bigger context_to_check = max(negative_context_len, context_len) # Check total length of prompt against max context length if context_to_check > self.config.max_seq_len: preamble = ( "Negative prompt" if negative_context_len > context_len else "Prompt" ) raise ValueError( f"{preamble} length {context_to_check} is greater than " f"max_seq_len {self.config.max_seq_len}" ) # Check total required pages for CFG request to avoid overallocation if negative_prompt and ( sum( 256 * math.ceil((context + max_tokens) / 256) for context in (context_len, negative_context_len) ) > self.cache_size ): raise ValueError( f"Total required page size for request " f"{context_len} + {negative_context_len} + {max_tokens} * 2 " f"is greater than cache_size {self.cache_size}" ) # Log prompt to console. Add the BOS token if specified log_prompt( f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}", request_id, negative_prompt, ) # Create and add a new job # Don't use the request ID here as there can be multiple jobs per request job = ExLlamaV2DynamicJobAsync( self.generator, input_ids=input_ids, max_new_tokens=max_tokens, min_new_tokens=params.min_tokens, gen_settings=gen_settings, stop_conditions=stop_conditions, decode_special_tokens=True, filters=grammar_handler.filters, filter_prefer_eos=bool(grammar_handler.filters), return_probs=params.logprobs > 0, return_top_tokens=params.logprobs, return_logits=params.logprobs > 0, banned_strings=banned_strings, token_healing=params.token_healing, identifier=request_id, embeddings=mm_embeddings_content, ) # Assign the active job to the request ID self.active_job_ids[request_id] = job # Save generated tokens and full response # Copy over max seq len incase model is unloaded and stored jobs can complete # Full response is required for offset calculation max_seq_len = self.config.max_seq_len generated_tokens = 0 full_response = "" metrics_result = {} # Get the generation status once it's ready try: async for result in job: # Abort if the event is set while streaming if abort_event and abort_event.is_set(): await job.cancel() break stage = result.get("stage") result_id = result.get("identifier") if stage == "streaming" and result_id == request_id: chunk = unwrap(result.get("text"), "") full_response += chunk chunk_tokens = result.get("token_ids") if chunk_tokens is not None: generated_tokens += chunk_tokens.size(dim=0) generation = { "text": chunk, "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), } # Increase penalty range to generated token amount if auto_scale_penalty_range: gen_settings.token_repetition_range = generated_tokens # Handle logprobs if params.logprobs > 0: self.handle_logprobs(result, generation) yield generation # Yield a finish chunk when generation is finished if result.get("eos"): log_response(request_id, full_response) finish_chunk = self.handle_finish_chunk(result, generation) # Save the final result for metrics logging metrics_result = finish_chunk yield finish_chunk break except asyncio.CancelledError: await job.cancel() except Exception as ex: # Create a new generator since the current state is broken # No need to wait for this to finish logger.error( "FATAL ERROR with generation. " "Attempting to recreate the generator. " "If this fails, please restart the server.\n" ) asyncio.ensure_future(self.create_generator()) await HealthManager.add_unhealthy_event(ex) raise ex finally: # Log generation options to console # Some options are too large, so log the args instead log_generation_params( request_id=request_id, bos_token_id=self.tokenizer.bos_token_id, eos_token_id=eos_tokens, prompt=prompt, **params.model_dump(exclude={"prompt"}), auto_scale_penalty_range=auto_scale_penalty_range, ) # Log the metrics if present if metrics_result: log_metrics( request_id, metrics_result, context_len, max_seq_len, )