diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index fcf4f3c..09da9a2 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -533,8 +533,7 @@ class ExllamaV2Container(BaseModelContainer): # Load draft model if a config is present if self.draft_config: self.draft_model = ExLlamaV2(self.draft_config) - if not self.quiet: - logger.info("Loading draft model: " + self.draft_config.model_dir) + logger.info("Loading draft model: " + self.draft_config.model_dir) # Draft uses the autosplit loader, so create a cache that reflects this draft_cache_class = self.get_cache_class(self.draft_cache_mode) @@ -587,8 +586,7 @@ class ExllamaV2Container(BaseModelContainer): yield value self.model = ExLlamaV2(self.config) - if not self.quiet: - logger.info("Loading model: " + self.config.model_dir) + logger.info("Loading model: " + self.config.model_dir) # Get class of the model cache cache_class = self.get_cache_class(self.cache_mode) diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index a9781ec..fc4f198 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -16,12 +16,12 @@ from backends.base_model_container import BaseModelContainer from common.concurrency import iterate_in_threadpool from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest -from common.templating import PromptTemplate +from common.templating import PromptTemplate, find_prompt_template from common.transformers_utils import GenerationConfig from common.utils import unwrap from endpoints.core.types.model import ModelCard -from exllamav3 import Config, Model, Cache, Tokenizer +from exllamav3 import AsyncGenerator, AsyncJob, Config, Model, Cache, Tokenizer class ExllamaV3Container(BaseModelContainer): @@ -46,6 +46,8 @@ class ExllamaV3Container(BaseModelContainer): cache: Cache tokenizer: Tokenizer config: Config + gpu_split: List[float] = [] + max_seq_len: int = 2048 # Required methods @classmethod @@ -74,6 +76,16 @@ class ExllamaV3Container(BaseModelContainer): max_seq_len = kwargs.get("max_seq_len") self.cache = Cache(self.model, max_num_tokens=max_seq_len) + gpu_split = unwrap(kwargs.get("gpu_split"), []) + + # Set GPU split options + # Enable manual GPU split if provided + if gpu_split: + self.gpu_split = gpu_split + # Try to set prompt template + self.prompt_template = await find_prompt_template( + kwargs.get("prompt_template"), model_directory + ) return self @@ -128,7 +140,10 @@ class ExllamaV3Container(BaseModelContainer): # TODO: Add draft loading @torch.inference_mode() def load_model_sync(self, progress_callback=None): - for value in self.model.load_gen(callback=progress_callback): + for value in self.model.load_gen( + use_per_device=self.gpu_split, + callback=progress_callback + ): if value: yield value @@ -263,7 +278,58 @@ class ExllamaV3Container(BaseModelContainer): A dictionary containing the generation info """ - pass + generations = [] + async for generation in self.stream_generate( + request_id, + prompt, + params, + abort_event, + mm_embeddings, + ): + generations.append(generation) + + joined_generation = { + "text": "", + "prompt_tokens": 0, + "generation_tokens": 0, + "tool_calls": None, + "offset": [], + "token_probs": {}, + "logprobs": [], + } + + if generations: + # Get finish_reason first and then shift where -1 points to + if "finish_reason" in generations[-1]: + finish_reason_gen = generations.pop() + joined_generation["finish_reason"] = finish_reason_gen.get( + "finish_reason" + ) + joined_generation["stop_str"] = finish_reason_gen.get("stop_str") + else: + joined_generation["finish_reason"] = "stop" + + if len(generations) > 0: + for generation in generations: + joined_generation["text"] += unwrap(generation.get("text"), "") + joined_generation["offset"].append(unwrap(generation.get("offset"), -1)) + joined_generation["token_probs"].update( + unwrap(generation.get("token_probs"), {}) + ) + + # Include empty logprob dicts for index preservation + joined_generation["logprobs"].append( + unwrap(generation.get("logprobs"), {}) + ) + + joined_generation["prompt_tokens"] = unwrap( + generations[-1].get("prompt_tokens"), 0 + ) + joined_generation["generated_tokens"] = unwrap( + generations[-1].get("generated_tokens"), 0 + ) + + return joined_generation async def stream_generate( self, @@ -287,5 +353,152 @@ class ExllamaV3Container(BaseModelContainer): Generation chunks """ - if False: - yield + try: + # Wait for load lock to be freed before processing + # Mainly used for loras and other operations where the class is available + async with self.load_condition: + await self.load_condition.wait_for(lambda: not self.load_lock.locked()) + + # If the model is being unloaded, don't accept new requests + if not self.loaded: + raise RuntimeError( + "Model is being unloaded. Cannot process new generation requests." + ) + + # Mark that the job is running + self.active_job_ids[request_id] = None + + # Yield from the internal generator + async for generation_chunk in self.generate_gen( + request_id=request_id, + prompt=prompt, + params=params, + abort_event=abort_event, + mm_embeddings=mm_embeddings, + ): + yield generation_chunk + finally: + # Clean up and remove the job from active IDs + del self.active_job_ids[request_id] + + def handle_finish_chunk(self, result: dict, generation: dict): + eos_reason = result.get("eos_reason") + + stop_str = None + if eos_reason == "max_new_tokens": + finish_reason = "length" + else: + finish_reason = "stop" + # Grab stop string if stop was the reason + if eos_reason == "stop_token": + stop_str = result.get("eos_triggering_token_str") + elif eos_reason == "stop_string": + stop_str = result.get("eos_triggering_string") + + finish_chunk = { + "prompt_tokens": generation.get("prompt_tokens"), + "generated_tokens": generation.get("generated_tokens"), + "finish_reason": finish_reason, + "stop_str": stop_str, + } + + return finish_chunk + + async def generate_gen( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ): + """ + Create generator function for prompt completion. + + for kwargs, check common/sampling.py + """ + chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor] + + prompts = [prompt] + stop_conditions = params.stop + add_bos_token = params.add_bos_token + + # Fetch EOS tokens from generation_config if they exist + eos_tokens = ( + self.generation_config.eos_tokens() + if self.generation_config + else [self.tokenizer.eos_token_id] + ) + + stop_conditions += eos_tokens + + input_ids = [ + self.tokenizer.encode( + prompt, + add_bos=add_bos_token, + encode_special_tokens=True, + ) + for prompt in prompts + ] + + # The first index will always be the positive prompt + context_len = input_ids[0].size(dim=-1) + + # Automatically set max_tokens to fill up the context + # This should be an OK default, but may be changed in the future + max_tokens = unwrap( + params.max_tokens, + self.max_seq_len - context_len, + ) + if max_tokens < 1: + logger.warning("max_tokens must be a positive integer, setting to 1.") + max_tokens = 1 + + # Determine if the negative context or the context length is bigger + context_to_check = context_len + + # Check total length of prompt against max context length + if context_to_check > self.max_seq_len: + preamble = "Prompt" + + raise ValueError( + f"{preamble} length {context_to_check} is greater than " + f"max_seq_len {self.max_seq_len}" + ) + + self.generator = AsyncGenerator( + model=self.model, + cache=self.cache, + tokenizer=self.tokenizer, + ) + + generation = {} + print(max_tokens) + job = AsyncJob( + self.generator, + input_ids=self.tokenizer.encode(prompt, add_bos=False), + max_new_tokens=max_tokens, + stop_conditions=stop_conditions, + ) + generated_tokens = 0 + full_response = "" + async for result in job: + chunk = unwrap(result.get("text"), "") + if chunk: + chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) + full_response += chunk + if isinstance(chunk_tokens, torch.Tensor): + generated_tokens += chunk_tokens.size(dim=0) + generation = { + "text": chunk, + "prompt_tokens": context_len, + "generated_tokens": generated_tokens, + "offset": len(full_response), + } + yield generation + + if result.get("eos"): + generation = self.handle_finish_chunk(result, generation) + yield generation + # Assign the active job to the request ID + self.active_job_ids[request_id] = job