diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index b4f6fdc..06d3b29 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -14,7 +14,6 @@ from exllamav3 import ( AsyncGenerator, AsyncJob, Cache, - ComboSampler, Config, Model, Tokenizer, @@ -22,6 +21,7 @@ from exllamav3 import ( from loguru import logger from backends.base_model_container import BaseModelContainer +from backends.exllamav3.sampler import ExllamaV3SamplerBuilder from common.concurrency import iterate_in_threadpool from common.gen_logging import ( log_generation_params, @@ -32,7 +32,7 @@ from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template from common.transformers_utils import GenerationConfig -from common.utils import unwrap +from common.utils import coalesce, unwrap from endpoints.core.types.model import ModelCard @@ -58,6 +58,7 @@ class ExllamaV3Container(BaseModelContainer): cache: Cache tokenizer: Tokenizer config: Config + generator: Optional[AsyncGenerator] = None gpu_split: List[float] | None = None gpu_split_auto: bool = True autosplit_reserve: List[float] = [96 / 1024] @@ -123,13 +124,47 @@ class ExllamaV3Container(BaseModelContainer): # Reserve VRAM for each GPU self.autosplit_reserve = [ - value/1024 - for value in autosplit_reserve_megabytes + value / 1024 for value in autosplit_reserve_megabytes ] # TODO: speculative decoding return self + def model_info(self) -> ModelCard: + """ + Returns a dictionary of the current model's configuration parameters. + + Returns: + Model parameters provided by the backend + """ + + pass + + async def wait_for_jobs(self, skip_wait: bool = False): + """ + Polling to wait for any active generation jobs to complete. + + Args: + skip_wait: If True, cancel jobs immediately instead of waiting. + """ + + if not self.generator: + return + + # Immediately abort all jobs if asked + if skip_wait: + logger.warning( + "Immediately terminating all jobs. " + "Clients will have their requests cancelled.\n" + ) + + for job in self.active_job_ids.values(): + if job: + await job.cancel() + + while len(self.active_job_ids) > 0: + await asyncio.sleep(0.01) + async def load(self, progress_callback=None, **kwargs): """ Loads the model into memory. @@ -161,8 +196,8 @@ class ExllamaV3Container(BaseModelContainer): await self.wait_for_jobs(kwargs.get("skip_wait")) generator = self.load_model_sync(progress_callback) - async for module, modules in iterate_in_threadpool(generator): - yield module, modules + async for value in iterate_in_threadpool(generator): + yield value # Clean up any extra vram usage from torch and cuda # (Helps reduce VRAM bottlenecking on Windows) @@ -184,11 +219,42 @@ class ExllamaV3Container(BaseModelContainer): for value in self.model.load_gen( reserve_per_device=self.autosplit_reserve, use_per_device=self.gpu_split, - callback=progress_callback + callback=progress_callback, ): if value: yield value + async def create_generator(self): + """Create and save a Exllama generator class.""" + + try: + # Don't acquire locks unless a model is loaded + if self.loaded: + await self.load_lock.acquire() + + # Immediately cancel all jobs + await self.wait_for_jobs(skip_wait=True) + + # Create new generator + self.generator = AsyncGenerator( + model=self.model, + cache=self.cache, + tokenizer=self.tokenizer, + max_batch_size=self.max_batch_size, + ) + + # Update the state of the container var + if self.max_batch_size is None: + self.max_batch_size = self.generator.generator.max_batch_size + finally: + # This means the generator is being recreated + # The load lock is already released in the load function + if self.loaded: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + async def unload(self, loras_only: bool = False, **kwargs): """ Unloads the model and associated resources from memory. @@ -198,11 +264,15 @@ class ExllamaV3Container(BaseModelContainer): **kwargs: Additional unloading options (e.g., shutdown). """ - try: - await self.load_lock.acquire() + # Used when shutting down the server + do_shutdown = kwargs.get("shutdown") - # Wait for other jobs to finish - await self.wait_for_jobs(kwargs.get("skip_wait")) + try: + if not do_shutdown: + await self.load_lock.acquire() + + # Wait for other jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) self.model.unload() self.model = None @@ -211,15 +281,21 @@ class ExllamaV3Container(BaseModelContainer): self.cache = None self.tokenizer = None + # Cleanup the generator from any pending jobs + if self.generator is not None: + await self.generator.close() + self.generator = None + gc.collect() torch.cuda.empty_cache() logger.info("Model unloaded.") finally: - self.load_lock.release() + if not do_shutdown: + self.load_lock.release() - async with self.load_condition: - self.load_condition.notify_all() + async with self.load_condition: + self.load_condition.notify_all() def encode_tokens(self, text: str, **kwargs) -> List[int]: """ @@ -233,11 +309,15 @@ class ExllamaV3Container(BaseModelContainer): A list of integer token IDs. """ - return self.tokenizer.encode( - text, - add_bos=unwrap(kwargs.get("add_bos_token"), True), - encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), - ).flatten().tolist() + return ( + self.tokenizer.encode( + text, + add_bos=unwrap(kwargs.get("add_bos_token"), True), + encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), + ) + .flatten() + .tolist() + ) def decode_tokens(self, ids: List[int], **kwargs) -> str: """ @@ -278,26 +358,6 @@ class ExllamaV3Container(BaseModelContainer): "unk_token": self.tokenizer.unk_token, } - def model_info(self) -> ModelCard: - """ - Returns a dictionary of the current model's configuration parameters. - - Returns: - Model parameters provided by the backend - """ - - pass - - async def wait_for_jobs(self, skip_wait: bool = False): - """ - Waits for any active generation jobs to complete. - - Args: - skip_wait: If True, cancel jobs immediately instead of waiting. - """ - - pass - async def generate( self, request_id: str, @@ -446,37 +506,6 @@ class ExllamaV3Container(BaseModelContainer): return finish_chunk - async def create_generator(self): - """Create and save a Exllama generator class.""" - - try: - # Don't acquire locks unless a model is loaded - if self.loaded: - await self.load_lock.acquire() - - # Immediately cancel all jobs - await self.wait_for_jobs(skip_wait=True) - - # Create new generator - self.generator = AsyncGenerator( - model=self.model, - cache=self.cache, - tokenizer=self.tokenizer, - max_batch_size=self.max_batch_size, - ) - - # Update the state of the container var - if self.max_batch_size is None: - self.max_batch_size = self.generator.generator.max_batch_size - finally: - # This means the generator is being recreated - # The load lock is already released in the load function - if self.loaded: - self.load_lock.release() - - async with self.load_condition: - self.load_condition.notify_all() - async def generate_gen( self, request_id: str, @@ -492,25 +521,58 @@ class ExllamaV3Container(BaseModelContainer): """ chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor] - # FIXME: this is probably not right - base_sampler = BaseSamplerRequest() - sampler = ComboSampler( - rep_p=base_sampler.repetition_penalty, - pres_p=base_sampler.presence_penalty, - freq_p=base_sampler.frequency_penalty, - rep_sustain_range=base_sampler.penalty_range, - rep_decay_range=base_sampler.penalty_range, - temperature=base_sampler.temperature, - min_p=base_sampler.min_p, - top_k=base_sampler.top_k, - top_p=base_sampler.top_p, - temp_last=base_sampler.temperature_last, + sampler_builder = ExllamaV3SamplerBuilder() + + # Penalties + + # Set penalty range + penalty_range = unwrap(params.penalty_range, self.max_seq_len) + + # Exl3's version of including the entire context + if penalty_range < 0: + penalty_range = 10e7 + + # Always make sure the fallback is 0 if range < 0 + # It's technically fine to use -1, but this just validates the passed + # fallback + # Always default to 0 if something goes wrong + if params.penalty_range < 0: + fallback_decay = 0 + else: + fallback_decay = params.penalty_range + + repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0) + + # Apply penalties to builder + sampler_builder.penalties( + params.repetition_penalty, + params.frequency_penalty, + params.presence_penalty, + penalty_range, + repetition_decay, ) + # Apply temperature first to builder + if not params.temperature_last: + sampler_builder.temperature(params.temperature) + + # Apply alphabet samplers to builder + sampler_builder.top_k(params.top_k) + sampler_builder.top_p(params.top_p) + sampler_builder.min_p(params.min_p) + + # Apply temperature last to builder + if params.temperature_last: + sampler_builder.temperature(params.temperature) + + # Build the sampler + # Set greedy if temperature is 0 + sampler = sampler_builder.build(params.temperature == 0) + # Dynamically scale penalty range to output tokens # Only do this if freq/pres pen is enabled # and the repetition range is -1 - # TODO: + # TODO: This currently does not work in exl3 # auto_scale_penalty_range = ( # gen_settings.token_frequency_penalty != 0 # or gen_settings.token_presence_penalty != 0 @@ -576,6 +638,7 @@ class ExllamaV3Container(BaseModelContainer): input_ids=self.tokenizer.encode(prompt, add_bos=False), max_new_tokens=max_tokens, stop_conditions=stop_conditions, + banned_strings=params.banned_strings, ) generated_tokens = 0 @@ -585,6 +648,11 @@ class ExllamaV3Container(BaseModelContainer): # Get the generation status once it's ready try: async for result in job: + # Abort if the event is set while streaming + if abort_event and abort_event.is_set(): + await job.cancel() + break + chunk = unwrap(result.get("text"), "") if chunk: chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) diff --git a/common/config_models.py b/common/config_models.py index d2af39e..2965fe0 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -163,8 +163,10 @@ class ModelConfig(BaseConfigModel): "Example: ['max_seq_len', 'cache_mode']." ), ) + + # Defaults to exllamav2 in common/model.py backend: Optional[str] = Field( - "exllamav2", + None, description=( "Backend to use for this model (default: exllamav2)\n" "Options: exllamav2, exllamav3",