diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index d42e2c0..c1275df 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -6,8 +6,8 @@ import pathlib import threading import time import traceback - import torch +import uuid from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, @@ -17,7 +17,11 @@ from exllamav2 import ( ExLlamaV2Tokenizer, ExLlamaV2Lora, ) -from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler +from exllamav2.generator import ( + ExLlamaV2Sampler, + ExLlamaV2DynamicGenerator, + ExLlamaV2DynamicJob, +) from itertools import zip_longest from loguru import logger from typing import List, Optional, Union @@ -50,7 +54,7 @@ class ExllamaV2Container: cache: Optional[ExLlamaV2Cache] = None draft_cache: Optional[ExLlamaV2Cache] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None - generator: Optional[ExLlamaV2StreamingGenerator] = None + generator: Optional[ExLlamaV2DynamicGenerator] = None prompt_template: Optional[PromptTemplate] = None active_loras: List[ExLlamaV2Lora] = [] @@ -507,18 +511,29 @@ class ExllamaV2Container: yield value # Test VRAM allocation with a full-length forward pass + """ input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) + """ + + # TODO: Change these! + max_batch_size = 1 if self.config.no_flash_attn else 20 + paged = not self.config.no_flash_attn # Create generator - self.generator = ExLlamaV2StreamingGenerator( - self.model, - self.cache, - self.tokenizer, - self.draft_model, - self.draft_cache, + self.generator = ExLlamaV2DynamicGenerator( + model=self.model, + cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, + tokenizer=self.tokenizer, + max_batch_size=max_batch_size, + paged=paged, ) + # Warmup the generator + self.generator.warmup() + # Clean up any extra vram usage from torch and cuda # (Helps reduce VRAM bottlenecking on Windows) gc.collect() @@ -879,9 +894,6 @@ class ExllamaV2Container: else: stop_conditions += eos_tokens - # Stop conditions - self.generator.set_stop_conditions(stop_conditions) - # Tokenized context ids, offsets = self.tokenizer.encode( [prompt, negative_prompt] @@ -965,15 +977,75 @@ class ExllamaV2Container: log_prompt(prompt, negative_prompt) # Begin + # generated_tokens = 0 + # full_response = "" + # start_time = time.time() + # last_chunk_time = start_time + + # save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool) + # chunk_buffer = "" + # chunk_tokens = 0 + + # Create and add a new job + job_id = uuid.uuid4().hex + job = ExLlamaV2DynamicJob( + input_ids=ids, + max_new_tokens=max_tokens, + gen_settings=gen_settings, + stop_conditions=stop_conditions, + decode_special_tokens=decode_special_tokens, + return_probs=request_logprobs > 0, + return_top_tokens=request_logprobs, + return_logits=request_logprobs > 0, + banned_strings=banned_strings, + identifier=job_id, + ) + + self.generator.enqueue(job) + + # Save generated tokens generated_tokens = 0 - full_response = "" - start_time = time.time() - last_chunk_time = start_time - save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool) - chunk_buffer = "" - chunk_tokens = 0 + # Grab the next job and iterate through the results + while self.generator.num_remaining_jobs(): + results = self.generator.iterate() + for raw_generation in results: + if ( + raw_generation["stage"] == "streaming" + and raw_generation["identifier"] == job_id + ): + chunk = unwrap(raw_generation.get("text"), "") + eos = raw_generation.get("eos") + chunk_tokens = raw_generation.get("token_ids") + if chunk_tokens is not None: + generated_tokens += chunk_tokens.size(dim=0) + + generation = { + "text": chunk, + "prompt_tokens": prompt_tokens, + "generated_tokens": generated_tokens, + # "offset": len(full_response), + } + + yield generation + + # Second yield if eos is true + if eos: + log_response(raw_generation.get("full_completion")) + + eos_reason = raw_generation.get("eos_reason") + finish_reason = ( + "length" if eos_reason == "max_new_tokens" else "stop" + ) + + # Remove the token text + generation["text"] = None + generation["finish_reason"] = finish_reason + + yield generation + + """ while True: # Ingest prompt if chunk_tokens == 0: @@ -1077,3 +1149,4 @@ class ExllamaV2Container: yield generation break + """