Model: Initial dynamic generator support
Adds basic support for ExllamaV2's dynamic generator. Can generate a streaming and non-streaming completion. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c474076b22
commit
8ccd8fe5f8
1 changed files with 91 additions and 18 deletions
|
|
@ -6,8 +6,8 @@ import pathlib
|
|||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
import uuid
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
|
|
@ -17,7 +17,11 @@ from exllamav2 import (
|
|||
ExLlamaV2Tokenizer,
|
||||
ExLlamaV2Lora,
|
||||
)
|
||||
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
||||
from exllamav2.generator import (
|
||||
ExLlamaV2Sampler,
|
||||
ExLlamaV2DynamicGenerator,
|
||||
ExLlamaV2DynamicJob,
|
||||
)
|
||||
from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
|
@ -50,7 +54,7 @@ class ExllamaV2Container:
|
|||
cache: Optional[ExLlamaV2Cache] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2StreamingGenerator] = None
|
||||
generator: Optional[ExLlamaV2DynamicGenerator] = None
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
active_loras: List[ExLlamaV2Lora] = []
|
||||
|
||||
|
|
@ -507,18 +511,29 @@ class ExllamaV2Container:
|
|||
yield value
|
||||
|
||||
# Test VRAM allocation with a full-length forward pass
|
||||
"""
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
"""
|
||||
|
||||
# TODO: Change these!
|
||||
max_batch_size = 1 if self.config.no_flash_attn else 20
|
||||
paged = not self.config.no_flash_attn
|
||||
|
||||
# Create generator
|
||||
self.generator = ExLlamaV2StreamingGenerator(
|
||||
self.model,
|
||||
self.cache,
|
||||
self.tokenizer,
|
||||
self.draft_model,
|
||||
self.draft_cache,
|
||||
self.generator = ExLlamaV2DynamicGenerator(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
paged=paged,
|
||||
)
|
||||
|
||||
# Warmup the generator
|
||||
self.generator.warmup()
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
gc.collect()
|
||||
|
|
@ -879,9 +894,6 @@ class ExllamaV2Container:
|
|||
else:
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
# Stop conditions
|
||||
self.generator.set_stop_conditions(stop_conditions)
|
||||
|
||||
# Tokenized context
|
||||
ids, offsets = self.tokenizer.encode(
|
||||
[prompt, negative_prompt]
|
||||
|
|
@ -965,15 +977,75 @@ class ExllamaV2Container:
|
|||
log_prompt(prompt, negative_prompt)
|
||||
|
||||
# Begin
|
||||
# generated_tokens = 0
|
||||
# full_response = ""
|
||||
# start_time = time.time()
|
||||
# last_chunk_time = start_time
|
||||
|
||||
# save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
|
||||
# chunk_buffer = ""
|
||||
# chunk_tokens = 0
|
||||
|
||||
# Create and add a new job
|
||||
job_id = uuid.uuid4().hex
|
||||
job = ExLlamaV2DynamicJob(
|
||||
input_ids=ids,
|
||||
max_new_tokens=max_tokens,
|
||||
gen_settings=gen_settings,
|
||||
stop_conditions=stop_conditions,
|
||||
decode_special_tokens=decode_special_tokens,
|
||||
return_probs=request_logprobs > 0,
|
||||
return_top_tokens=request_logprobs,
|
||||
return_logits=request_logprobs > 0,
|
||||
banned_strings=banned_strings,
|
||||
identifier=job_id,
|
||||
)
|
||||
|
||||
self.generator.enqueue(job)
|
||||
|
||||
# Save generated tokens
|
||||
generated_tokens = 0
|
||||
full_response = ""
|
||||
start_time = time.time()
|
||||
last_chunk_time = start_time
|
||||
|
||||
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
|
||||
chunk_buffer = ""
|
||||
chunk_tokens = 0
|
||||
# Grab the next job and iterate through the results
|
||||
while self.generator.num_remaining_jobs():
|
||||
results = self.generator.iterate()
|
||||
for raw_generation in results:
|
||||
if (
|
||||
raw_generation["stage"] == "streaming"
|
||||
and raw_generation["identifier"] == job_id
|
||||
):
|
||||
chunk = unwrap(raw_generation.get("text"), "")
|
||||
eos = raw_generation.get("eos")
|
||||
|
||||
chunk_tokens = raw_generation.get("token_ids")
|
||||
if chunk_tokens is not None:
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
|
||||
generation = {
|
||||
"text": chunk,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
# "offset": len(full_response),
|
||||
}
|
||||
|
||||
yield generation
|
||||
|
||||
# Second yield if eos is true
|
||||
if eos:
|
||||
log_response(raw_generation.get("full_completion"))
|
||||
|
||||
eos_reason = raw_generation.get("eos_reason")
|
||||
finish_reason = (
|
||||
"length" if eos_reason == "max_new_tokens" else "stop"
|
||||
)
|
||||
|
||||
# Remove the token text
|
||||
generation["text"] = None
|
||||
generation["finish_reason"] = finish_reason
|
||||
|
||||
yield generation
|
||||
|
||||
"""
|
||||
while True:
|
||||
# Ingest prompt
|
||||
if chunk_tokens == 0:
|
||||
|
|
@ -1077,3 +1149,4 @@ class ExllamaV2Container:
|
|||
yield generation
|
||||
|
||||
break
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue