Model: Initial dynamic generator support

Adds basic support for ExllamaV2's dynamic generator. Can generate
a streaming and non-streaming completion.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-23 00:13:31 -04:00 committed by Brian Dashore
parent c474076b22
commit 8ccd8fe5f8

View file

@ -6,8 +6,8 @@ import pathlib
import threading
import time
import traceback
import torch
import uuid
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
@ -17,7 +17,11 @@ from exllamav2 import (
ExLlamaV2Tokenizer,
ExLlamaV2Lora,
)
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
from exllamav2.generator import (
ExLlamaV2Sampler,
ExLlamaV2DynamicGenerator,
ExLlamaV2DynamicJob,
)
from itertools import zip_longest
from loguru import logger
from typing import List, Optional, Union
@ -50,7 +54,7 @@ class ExllamaV2Container:
cache: Optional[ExLlamaV2Cache] = None
draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2StreamingGenerator] = None
generator: Optional[ExLlamaV2DynamicGenerator] = None
prompt_template: Optional[PromptTemplate] = None
active_loras: List[ExLlamaV2Lora] = []
@ -507,18 +511,29 @@ class ExllamaV2Container:
yield value
# Test VRAM allocation with a full-length forward pass
"""
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
"""
# TODO: Change these!
max_batch_size = 1 if self.config.no_flash_attn else 20
paged = not self.config.no_flash_attn
# Create generator
self.generator = ExLlamaV2StreamingGenerator(
self.model,
self.cache,
self.tokenizer,
self.draft_model,
self.draft_cache,
self.generator = ExLlamaV2DynamicGenerator(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
paged=paged,
)
# Warmup the generator
self.generator.warmup()
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
gc.collect()
@ -879,9 +894,6 @@ class ExllamaV2Container:
else:
stop_conditions += eos_tokens
# Stop conditions
self.generator.set_stop_conditions(stop_conditions)
# Tokenized context
ids, offsets = self.tokenizer.encode(
[prompt, negative_prompt]
@ -965,15 +977,75 @@ class ExllamaV2Container:
log_prompt(prompt, negative_prompt)
# Begin
# generated_tokens = 0
# full_response = ""
# start_time = time.time()
# last_chunk_time = start_time
# save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
# chunk_buffer = ""
# chunk_tokens = 0
# Create and add a new job
job_id = uuid.uuid4().hex
job = ExLlamaV2DynamicJob(
input_ids=ids,
max_new_tokens=max_tokens,
gen_settings=gen_settings,
stop_conditions=stop_conditions,
decode_special_tokens=decode_special_tokens,
return_probs=request_logprobs > 0,
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,
banned_strings=banned_strings,
identifier=job_id,
)
self.generator.enqueue(job)
# Save generated tokens
generated_tokens = 0
full_response = ""
start_time = time.time()
last_chunk_time = start_time
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
chunk_buffer = ""
chunk_tokens = 0
# Grab the next job and iterate through the results
while self.generator.num_remaining_jobs():
results = self.generator.iterate()
for raw_generation in results:
if (
raw_generation["stage"] == "streaming"
and raw_generation["identifier"] == job_id
):
chunk = unwrap(raw_generation.get("text"), "")
eos = raw_generation.get("eos")
chunk_tokens = raw_generation.get("token_ids")
if chunk_tokens is not None:
generated_tokens += chunk_tokens.size(dim=0)
generation = {
"text": chunk,
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
# "offset": len(full_response),
}
yield generation
# Second yield if eos is true
if eos:
log_response(raw_generation.get("full_completion"))
eos_reason = raw_generation.get("eos_reason")
finish_reason = (
"length" if eos_reason == "max_new_tokens" else "stop"
)
# Remove the token text
generation["text"] = None
generation["finish_reason"] = finish_reason
yield generation
"""
while True:
# Ingest prompt
if chunk_tokens == 0:
@ -1077,3 +1149,4 @@ class ExllamaV2Container:
yield generation
break
"""