Model: Use true async jobs and add logprobs

The new async dynamic job allows for native async support without the
need of threading. Also add logprobs and metrics back to responses.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-23 21:37:50 -04:00 committed by Brian Dashore
parent 32ae62feac
commit 06ff47e2b4
4 changed files with 102 additions and 217 deletions

View file

@ -19,8 +19,8 @@ from exllamav2 import (
)
from exllamav2.generator import (
ExLlamaV2Sampler,
ExLlamaV2DynamicGenerator,
ExLlamaV2DynamicJob,
ExLlamaV2DynamicGeneratorAsync,
ExLlamaV2DynamicJobAsync,
)
from itertools import zip_longest
from loguru import logger
@ -54,7 +54,7 @@ class ExllamaV2Container:
cache: Optional[ExLlamaV2Cache] = None
draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2DynamicGenerator] = None
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
prompt_template: Optional[PromptTemplate] = None
active_loras: List[ExLlamaV2Lora] = []
@ -410,14 +410,44 @@ class ExllamaV2Container:
return {"success": success, "failure": failure}
async def load_gen(self, progress_callback=None):
"""Basic async wrapper around the loading generator"""
"""Loads a model and streams progress via a generator."""
load_generator = self.load_gen_sync(progress_callback)
async for value in iterate_in_threadpool(load_generator):
# Indicate that model load has started
self.model_is_loading = True
# Streaming gen for model load progress
model_load_generator = self.load_model_sync(progress_callback)
async for value in iterate_in_threadpool(model_load_generator):
yield value
# TODO: Change these!
# Set the max batch size and check if paged support is available
max_batch_size = 1 if self.config.no_flash_attn else 20
paged = not self.config.no_flash_attn
# Create async generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
paged=paged,
)
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
gc.collect()
torch.cuda.empty_cache()
# Cleanup and update model load state
self.model_is_loading = False
self.model_loaded = True
logger.info("Model successfully loaded.")
@torch.inference_mode()
def load_gen_sync(self, progress_callback=None):
def load_model_sync(self, progress_callback=None):
"""
Load model, generator function
@ -429,9 +459,6 @@ class ExllamaV2Container:
Runs under a shared inference mode context.
"""
# Notify that the model is being loaded
self.model_is_loading = True
# Reset tokenizer namespace vars and create a tokenizer
ExLlamaV2Tokenizer.unspecial_piece_to_id = {}
ExLlamaV2Tokenizer.unspecial_id_to_piece = {}
@ -511,38 +538,8 @@ class ExllamaV2Container:
yield value
# Test VRAM allocation with a full-length forward pass
"""
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
"""
# TODO: Change these!
max_batch_size = 1 if self.config.no_flash_attn else 20
paged = not self.config.no_flash_attn
# Create generator
self.generator = ExLlamaV2DynamicGenerator(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
paged=paged,
)
# Warmup the generator
self.generator.warmup()
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
gc.collect()
torch.cuda.empty_cache()
# Update model load state
self.model_is_loading = False
self.model_loaded = True
logger.info("Model successfully loaded.")
def unload(self, loras_only: bool = False):
"""
@ -682,19 +679,7 @@ class ExllamaV2Container:
return kwargs
async def generate_gen(
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
):
"""Basic async wrapper for completion generator"""
sync_generator = self.generate_gen_sync(prompt, abort_event, **kwargs)
async for value in iterate_in_threadpool(sync_generator):
yield value
@torch.inference_mode()
def generate_gen_sync(
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
):
async def generate_gen(self, prompt: str, **kwargs):
"""
Create generator function for prompt completion.
@ -702,7 +687,6 @@ class ExllamaV2Container:
"""
token_healing = unwrap(kwargs.get("token_healing"), False)
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
generate_window = max(
unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8
)
@ -926,25 +910,6 @@ class ExllamaV2Container:
# This is an inverse of skip_special_tokens
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
begin_stream_args = {
"token_healing": token_healing,
"loras": self.active_loras,
"return_probabilities": request_logprobs > 0,
"return_top_tokens": request_logprobs,
"return_logits": request_logprobs > 0,
"abort_event": abort_event,
"banned_strings": banned_strings,
"decode_special_tokens": decode_special_tokens,
}
if self.use_cfg:
begin_stream_args.update(
{
"input_mask": mask,
"position_offsets": offsets,
}
)
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
@ -972,19 +937,10 @@ class ExllamaV2Container:
# Log prompt to console
log_prompt(prompt, negative_prompt)
# Begin
# generated_tokens = 0
# full_response = ""
# start_time = time.time()
# last_chunk_time = start_time
# save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
# chunk_buffer = ""
# chunk_tokens = 0
# Create and add a new job
job_id = uuid.uuid4().hex
job = ExLlamaV2DynamicJob(
job = ExLlamaV2DynamicJobAsync(
self.generator,
input_ids=ids,
max_new_tokens=max_tokens,
gen_settings=gen_settings,
@ -996,108 +952,30 @@ class ExllamaV2Container:
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,
banned_strings=banned_strings,
token_healing=token_healing,
identifier=job_id,
)
self.generator.enqueue(job)
# Save generated tokens
# Save generated tokens and full response
# Full response is required for offset calculation
generated_tokens = 0
full_response = ""
# Grab the next job and iterate through the results
while self.generator.num_remaining_jobs():
results = self.generator.iterate()
for raw_generation in results:
if (
raw_generation["stage"] == "streaming"
and raw_generation["identifier"] == job_id
):
chunk = unwrap(raw_generation.get("text"), "")
eos = raw_generation.get("eos")
# Get the generation status once it's ready
async for result in job:
stage = result.get("stage")
result_id = result.get("identifier")
chunk_tokens = raw_generation.get("token_ids")
if chunk_tokens is not None:
generated_tokens += chunk_tokens.size(dim=0)
if stage == "streaming" and result_id == job_id:
chunk = unwrap(result.get("text"), "")
full_response += chunk
generation = {
"text": chunk,
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
# "offset": len(full_response),
}
chunk_tokens = result.get("token_ids")
if chunk_tokens is not None:
generated_tokens += chunk_tokens.size(dim=0)
yield generation
# Second yield if eos is true
if eos:
log_response(raw_generation.get("full_completion"))
eos_reason = raw_generation.get("eos_reason")
finish_reason = (
"length" if eos_reason == "max_new_tokens" else "stop"
)
# Remove the token text
generation["text"] = None
generation["finish_reason"] = finish_reason
yield generation
"""
while True:
# Ingest prompt
if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1)
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
active_ids = ids[:, max(0, overflow) :]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
# Kick off the streaming generation
self.generator.begin_stream_ex(
active_ids, gen_settings, **begin_stream_args
)
# Reset offsets for subsequent passes if the context is truncated
offsets = None
if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens
# Run dict generation
# Guarantees return of chunk, eos, and chunk_token_ids
if generated_tokens < min_tokens:
raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens)
else:
raw_generation = self.generator.stream_ex()
if token_healing:
# Extract healed token
ids[:, -1] = self.generator.sequence_ids[:, -2]
token_healing = False
# Get parameters that will always exist
chunk = raw_generation["chunk"]
eos = raw_generation["eos"]
tokens = raw_generation["chunk_token_ids"]
save_tokens = torch.cat(
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
)
chunk_buffer += chunk
generated_tokens += 1
chunk_tokens -= 1
# Yield output
now = time.time()
elapsed = now - last_chunk_time
if chunk_buffer != "" and (
elapsed > stream_interval or eos or generated_tokens == max_tokens
):
generation = {
"text": chunk_buffer,
"text": chunk,
"prompt_tokens": prompt_tokens,
"generated_tokens": generated_tokens,
"offset": len(full_response),
@ -1106,12 +984,12 @@ class ExllamaV2Container:
if request_logprobs > 0:
# Get top tokens and probs
top_tokens = unwrap(
raw_generation.get("top_tokens"),
result.get("top_k_tokens"),
torch.empty((1, 0, 1), dtype=torch.long),
)
top_probs = unwrap(
raw_generation.get("top_probs"),
result.get("top_k_probs"),
torch.empty((1, 0, 1), dtype=torch.float),
)
@ -1126,25 +1004,32 @@ class ExllamaV2Container:
}
yield generation
full_response += chunk_buffer
chunk_buffer = ""
last_chunk_time = now
if eos or generated_tokens == max_tokens:
# Print response
log_response(full_response)
# Second yield if eos is true
if result.get("eos"):
log_response(full_response)
# Print metrics
elapsed_time = last_chunk_time - start_time
context_len = None if ids is None else context_len
eos_reason = result.get("eos_reason")
finish_reason = (
"length" if eos_reason == "max_new_tokens" else "stop"
)
log_metrics(
generated_tokens, elapsed_time, context_len, self.config.max_seq_len
)
log_metrics(
result.get("time_enqueued"),
result.get("prompt_tokens"),
result.get("time_prefill"),
result.get("new_tokens"),
result.get("time_generate"),
context_len,
self.config.max_seq_len,
)
finish_reason = "length" if generated_tokens == max_tokens else "stop"
generation = {"finish_reason": finish_reason}
yield generation
# Remove the token text
generation = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
}
break
"""
yield generation
break

View file

@ -70,29 +70,38 @@ def log_response(response: str):
def log_metrics(
queue_time: float,
prompt_tokens: int,
prompt_time: float,
generated_tokens: int,
elapsed_time: float,
generate_time: float,
context_len: Optional[int],
max_seq_len: int,
):
initial_response = (
f"Metrics: {generated_tokens} tokens generated in "
f"{round(elapsed_time, 2)} seconds"
f"{round(queue_time + prompt_time + generate_time, 2)} seconds"
)
itemization = []
extra_parts = []
# Add tokens per second
tokens_per_second = (
"Indeterminate"
if elapsed_time == 0
else round(generated_tokens / elapsed_time, 2)
itemization.append(f"Queue: {round(queue_time, 2)} s")
prompt_ts = (
"Indeterminate" if prompt_time == 0 else round(prompt_tokens / prompt_time, 2)
)
itemization.append(f"{tokens_per_second} T/s")
itemization.append(f"Process: {prompt_ts} T/s")
generate_ts = (
"Indeterminate"
if generate_time == 0
else round(generated_tokens / generate_time, 2)
)
itemization.append(f"Generate: {generate_ts} T/s")
# Add context (original token count)
if context_len:
itemization.append(f"context {context_len} tokens")
itemization.append(f"Context: {context_len} tokens")
if context_len > max_seq_len:
extra_parts.append("<-- Not accurate (truncated)")

View file

@ -1,8 +1,7 @@
"""Chat completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib
import threading
from asyncio import CancelledError
from typing import Optional
from uuid import uuid4
@ -198,11 +197,8 @@ async def stream_generate_chat_completion(
"""Generator for the generation process."""
try:
const_id = f"chatcmpl-{uuid4().hex}"
abort_event = threading.Event()
new_generation = model.container.generate_gen(
prompt, abort_event, **data.to_gen_params()
)
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
async for generation in new_generation:
response = _create_stream_chunk(const_id, generation, model_path.name)
@ -214,7 +210,6 @@ async def stream_generate_chat_completion(
except CancelledError:
# Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception:
yield get_generator_error(

View file

@ -2,7 +2,6 @@
import pathlib
from asyncio import CancelledError
import threading
from fastapi import HTTPException
from typing import Optional
@ -65,10 +64,8 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
"""Streaming generation for completions."""
try:
abort_event = threading.Event()
new_generation = model.container.generate_gen(
data.prompt, abort_event, **data.to_gen_params()
data.prompt, **data.to_gen_params()
)
async for generation in new_generation:
response = _create_response(generation, model_path.name)
@ -81,7 +78,6 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
except CancelledError:
# Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
except Exception:
yield get_generator_error(