Model: Use true async jobs and add logprobs
The new async dynamic job allows for native async support without the need of threading. Also add logprobs and metrics back to responses. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
32ae62feac
commit
06ff47e2b4
4 changed files with 102 additions and 217 deletions
|
|
@ -19,8 +19,8 @@ from exllamav2 import (
|
|||
)
|
||||
from exllamav2.generator import (
|
||||
ExLlamaV2Sampler,
|
||||
ExLlamaV2DynamicGenerator,
|
||||
ExLlamaV2DynamicJob,
|
||||
ExLlamaV2DynamicGeneratorAsync,
|
||||
ExLlamaV2DynamicJobAsync,
|
||||
)
|
||||
from itertools import zip_longest
|
||||
from loguru import logger
|
||||
|
|
@ -54,7 +54,7 @@ class ExllamaV2Container:
|
|||
cache: Optional[ExLlamaV2Cache] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2DynamicGenerator] = None
|
||||
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
active_loras: List[ExLlamaV2Lora] = []
|
||||
|
||||
|
|
@ -410,14 +410,44 @@ class ExllamaV2Container:
|
|||
return {"success": success, "failure": failure}
|
||||
|
||||
async def load_gen(self, progress_callback=None):
|
||||
"""Basic async wrapper around the loading generator"""
|
||||
"""Loads a model and streams progress via a generator."""
|
||||
|
||||
load_generator = self.load_gen_sync(progress_callback)
|
||||
async for value in iterate_in_threadpool(load_generator):
|
||||
# Indicate that model load has started
|
||||
self.model_is_loading = True
|
||||
|
||||
# Streaming gen for model load progress
|
||||
model_load_generator = self.load_model_sync(progress_callback)
|
||||
async for value in iterate_in_threadpool(model_load_generator):
|
||||
yield value
|
||||
|
||||
# TODO: Change these!
|
||||
# Set the max batch size and check if paged support is available
|
||||
max_batch_size = 1 if self.config.no_flash_attn else 20
|
||||
paged = not self.config.no_flash_attn
|
||||
|
||||
# Create async generator
|
||||
self.generator = ExLlamaV2DynamicGeneratorAsync(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
paged=paged,
|
||||
)
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Cleanup and update model load state
|
||||
self.model_is_loading = False
|
||||
self.model_loaded = True
|
||||
logger.info("Model successfully loaded.")
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_gen_sync(self, progress_callback=None):
|
||||
def load_model_sync(self, progress_callback=None):
|
||||
"""
|
||||
Load model, generator function
|
||||
|
||||
|
|
@ -429,9 +459,6 @@ class ExllamaV2Container:
|
|||
Runs under a shared inference mode context.
|
||||
"""
|
||||
|
||||
# Notify that the model is being loaded
|
||||
self.model_is_loading = True
|
||||
|
||||
# Reset tokenizer namespace vars and create a tokenizer
|
||||
ExLlamaV2Tokenizer.unspecial_piece_to_id = {}
|
||||
ExLlamaV2Tokenizer.unspecial_id_to_piece = {}
|
||||
|
|
@ -511,38 +538,8 @@ class ExllamaV2Container:
|
|||
yield value
|
||||
|
||||
# Test VRAM allocation with a full-length forward pass
|
||||
"""
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
"""
|
||||
|
||||
# TODO: Change these!
|
||||
max_batch_size = 1 if self.config.no_flash_attn else 20
|
||||
paged = not self.config.no_flash_attn
|
||||
|
||||
# Create generator
|
||||
self.generator = ExLlamaV2DynamicGenerator(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
paged=paged,
|
||||
)
|
||||
|
||||
# Warmup the generator
|
||||
self.generator.warmup()
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Update model load state
|
||||
self.model_is_loading = False
|
||||
self.model_loaded = True
|
||||
logger.info("Model successfully loaded.")
|
||||
|
||||
def unload(self, loras_only: bool = False):
|
||||
"""
|
||||
|
|
@ -682,19 +679,7 @@ class ExllamaV2Container:
|
|||
|
||||
return kwargs
|
||||
|
||||
async def generate_gen(
|
||||
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
|
||||
):
|
||||
"""Basic async wrapper for completion generator"""
|
||||
|
||||
sync_generator = self.generate_gen_sync(prompt, abort_event, **kwargs)
|
||||
async for value in iterate_in_threadpool(sync_generator):
|
||||
yield value
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_gen_sync(
|
||||
self, prompt: str, abort_event: Optional[threading.Event] = None, **kwargs
|
||||
):
|
||||
async def generate_gen(self, prompt: str, **kwargs):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
|
||||
|
|
@ -702,7 +687,6 @@ class ExllamaV2Container:
|
|||
"""
|
||||
|
||||
token_healing = unwrap(kwargs.get("token_healing"), False)
|
||||
stream_interval = unwrap(kwargs.get("stream_interval"), 0)
|
||||
generate_window = max(
|
||||
unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8
|
||||
)
|
||||
|
|
@ -926,25 +910,6 @@ class ExllamaV2Container:
|
|||
# This is an inverse of skip_special_tokens
|
||||
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
|
||||
|
||||
begin_stream_args = {
|
||||
"token_healing": token_healing,
|
||||
"loras": self.active_loras,
|
||||
"return_probabilities": request_logprobs > 0,
|
||||
"return_top_tokens": request_logprobs,
|
||||
"return_logits": request_logprobs > 0,
|
||||
"abort_event": abort_event,
|
||||
"banned_strings": banned_strings,
|
||||
"decode_special_tokens": decode_special_tokens,
|
||||
}
|
||||
|
||||
if self.use_cfg:
|
||||
begin_stream_args.update(
|
||||
{
|
||||
"input_mask": mask,
|
||||
"position_offsets": offsets,
|
||||
}
|
||||
)
|
||||
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
|
|
@ -972,19 +937,10 @@ class ExllamaV2Container:
|
|||
# Log prompt to console
|
||||
log_prompt(prompt, negative_prompt)
|
||||
|
||||
# Begin
|
||||
# generated_tokens = 0
|
||||
# full_response = ""
|
||||
# start_time = time.time()
|
||||
# last_chunk_time = start_time
|
||||
|
||||
# save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
|
||||
# chunk_buffer = ""
|
||||
# chunk_tokens = 0
|
||||
|
||||
# Create and add a new job
|
||||
job_id = uuid.uuid4().hex
|
||||
job = ExLlamaV2DynamicJob(
|
||||
job = ExLlamaV2DynamicJobAsync(
|
||||
self.generator,
|
||||
input_ids=ids,
|
||||
max_new_tokens=max_tokens,
|
||||
gen_settings=gen_settings,
|
||||
|
|
@ -996,108 +952,30 @@ class ExllamaV2Container:
|
|||
return_top_tokens=request_logprobs,
|
||||
return_logits=request_logprobs > 0,
|
||||
banned_strings=banned_strings,
|
||||
token_healing=token_healing,
|
||||
identifier=job_id,
|
||||
)
|
||||
|
||||
self.generator.enqueue(job)
|
||||
|
||||
# Save generated tokens
|
||||
# Save generated tokens and full response
|
||||
# Full response is required for offset calculation
|
||||
generated_tokens = 0
|
||||
full_response = ""
|
||||
|
||||
# Grab the next job and iterate through the results
|
||||
while self.generator.num_remaining_jobs():
|
||||
results = self.generator.iterate()
|
||||
for raw_generation in results:
|
||||
if (
|
||||
raw_generation["stage"] == "streaming"
|
||||
and raw_generation["identifier"] == job_id
|
||||
):
|
||||
chunk = unwrap(raw_generation.get("text"), "")
|
||||
eos = raw_generation.get("eos")
|
||||
# Get the generation status once it's ready
|
||||
async for result in job:
|
||||
stage = result.get("stage")
|
||||
result_id = result.get("identifier")
|
||||
|
||||
chunk_tokens = raw_generation.get("token_ids")
|
||||
if chunk_tokens is not None:
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
if stage == "streaming" and result_id == job_id:
|
||||
chunk = unwrap(result.get("text"), "")
|
||||
full_response += chunk
|
||||
|
||||
generation = {
|
||||
"text": chunk,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
# "offset": len(full_response),
|
||||
}
|
||||
chunk_tokens = result.get("token_ids")
|
||||
if chunk_tokens is not None:
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
|
||||
yield generation
|
||||
|
||||
# Second yield if eos is true
|
||||
if eos:
|
||||
log_response(raw_generation.get("full_completion"))
|
||||
|
||||
eos_reason = raw_generation.get("eos_reason")
|
||||
finish_reason = (
|
||||
"length" if eos_reason == "max_new_tokens" else "stop"
|
||||
)
|
||||
|
||||
# Remove the token text
|
||||
generation["text"] = None
|
||||
generation["finish_reason"] = finish_reason
|
||||
|
||||
yield generation
|
||||
|
||||
"""
|
||||
while True:
|
||||
# Ingest prompt
|
||||
if chunk_tokens == 0:
|
||||
ids = torch.cat((ids, save_tokens), dim=-1)
|
||||
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
|
||||
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
|
||||
active_ids = ids[:, max(0, overflow) :]
|
||||
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
|
||||
|
||||
# Kick off the streaming generation
|
||||
self.generator.begin_stream_ex(
|
||||
active_ids, gen_settings, **begin_stream_args
|
||||
)
|
||||
|
||||
# Reset offsets for subsequent passes if the context is truncated
|
||||
offsets = None
|
||||
|
||||
if auto_scale_penalty_range:
|
||||
gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
# Run dict generation
|
||||
# Guarantees return of chunk, eos, and chunk_token_ids
|
||||
if generated_tokens < min_tokens:
|
||||
raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens)
|
||||
else:
|
||||
raw_generation = self.generator.stream_ex()
|
||||
|
||||
if token_healing:
|
||||
# Extract healed token
|
||||
ids[:, -1] = self.generator.sequence_ids[:, -2]
|
||||
token_healing = False
|
||||
|
||||
# Get parameters that will always exist
|
||||
chunk = raw_generation["chunk"]
|
||||
eos = raw_generation["eos"]
|
||||
tokens = raw_generation["chunk_token_ids"]
|
||||
|
||||
save_tokens = torch.cat(
|
||||
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
|
||||
)
|
||||
chunk_buffer += chunk
|
||||
|
||||
generated_tokens += 1
|
||||
chunk_tokens -= 1
|
||||
|
||||
# Yield output
|
||||
now = time.time()
|
||||
elapsed = now - last_chunk_time
|
||||
|
||||
if chunk_buffer != "" and (
|
||||
elapsed > stream_interval or eos or generated_tokens == max_tokens
|
||||
):
|
||||
generation = {
|
||||
"text": chunk_buffer,
|
||||
"text": chunk,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
|
|
@ -1106,12 +984,12 @@ class ExllamaV2Container:
|
|||
if request_logprobs > 0:
|
||||
# Get top tokens and probs
|
||||
top_tokens = unwrap(
|
||||
raw_generation.get("top_tokens"),
|
||||
result.get("top_k_tokens"),
|
||||
torch.empty((1, 0, 1), dtype=torch.long),
|
||||
)
|
||||
|
||||
top_probs = unwrap(
|
||||
raw_generation.get("top_probs"),
|
||||
result.get("top_k_probs"),
|
||||
torch.empty((1, 0, 1), dtype=torch.float),
|
||||
)
|
||||
|
||||
|
|
@ -1126,25 +1004,32 @@ class ExllamaV2Container:
|
|||
}
|
||||
|
||||
yield generation
|
||||
full_response += chunk_buffer
|
||||
chunk_buffer = ""
|
||||
last_chunk_time = now
|
||||
|
||||
if eos or generated_tokens == max_tokens:
|
||||
# Print response
|
||||
log_response(full_response)
|
||||
# Second yield if eos is true
|
||||
if result.get("eos"):
|
||||
log_response(full_response)
|
||||
|
||||
# Print metrics
|
||||
elapsed_time = last_chunk_time - start_time
|
||||
context_len = None if ids is None else context_len
|
||||
eos_reason = result.get("eos_reason")
|
||||
finish_reason = (
|
||||
"length" if eos_reason == "max_new_tokens" else "stop"
|
||||
)
|
||||
|
||||
log_metrics(
|
||||
generated_tokens, elapsed_time, context_len, self.config.max_seq_len
|
||||
)
|
||||
log_metrics(
|
||||
result.get("time_enqueued"),
|
||||
result.get("prompt_tokens"),
|
||||
result.get("time_prefill"),
|
||||
result.get("new_tokens"),
|
||||
result.get("time_generate"),
|
||||
context_len,
|
||||
self.config.max_seq_len,
|
||||
)
|
||||
|
||||
finish_reason = "length" if generated_tokens == max_tokens else "stop"
|
||||
generation = {"finish_reason": finish_reason}
|
||||
yield generation
|
||||
# Remove the token text
|
||||
generation = {
|
||||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
break
|
||||
"""
|
||||
yield generation
|
||||
break
|
||||
|
|
|
|||
|
|
@ -70,29 +70,38 @@ def log_response(response: str):
|
|||
|
||||
|
||||
def log_metrics(
|
||||
queue_time: float,
|
||||
prompt_tokens: int,
|
||||
prompt_time: float,
|
||||
generated_tokens: int,
|
||||
elapsed_time: float,
|
||||
generate_time: float,
|
||||
context_len: Optional[int],
|
||||
max_seq_len: int,
|
||||
):
|
||||
initial_response = (
|
||||
f"Metrics: {generated_tokens} tokens generated in "
|
||||
f"{round(elapsed_time, 2)} seconds"
|
||||
f"{round(queue_time + prompt_time + generate_time, 2)} seconds"
|
||||
)
|
||||
itemization = []
|
||||
extra_parts = []
|
||||
|
||||
# Add tokens per second
|
||||
tokens_per_second = (
|
||||
"Indeterminate"
|
||||
if elapsed_time == 0
|
||||
else round(generated_tokens / elapsed_time, 2)
|
||||
itemization.append(f"Queue: {round(queue_time, 2)} s")
|
||||
|
||||
prompt_ts = (
|
||||
"Indeterminate" if prompt_time == 0 else round(prompt_tokens / prompt_time, 2)
|
||||
)
|
||||
itemization.append(f"{tokens_per_second} T/s")
|
||||
itemization.append(f"Process: {prompt_ts} T/s")
|
||||
|
||||
generate_ts = (
|
||||
"Indeterminate"
|
||||
if generate_time == 0
|
||||
else round(generated_tokens / generate_time, 2)
|
||||
)
|
||||
itemization.append(f"Generate: {generate_ts} T/s")
|
||||
|
||||
# Add context (original token count)
|
||||
if context_len:
|
||||
itemization.append(f"context {context_len} tokens")
|
||||
itemization.append(f"Context: {context_len} tokens")
|
||||
|
||||
if context_len > max_seq_len:
|
||||
extra_parts.append("<-- Not accurate (truncated)")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
"""Chat completion utilities for OAI server."""
|
||||
|
||||
from asyncio import CancelledError
|
||||
import pathlib
|
||||
import threading
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
|
|
@ -198,11 +197,8 @@ async def stream_generate_chat_completion(
|
|||
"""Generator for the generation process."""
|
||||
try:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
abort_event = threading.Event()
|
||||
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
|
||||
async for generation in new_generation:
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
|
||||
|
|
@ -214,7 +210,6 @@ async def stream_generate_chat_completion(
|
|||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Chat completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
import threading
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -65,10 +64,8 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
|
|||
"""Streaming generation for completions."""
|
||||
|
||||
try:
|
||||
abort_event = threading.Event()
|
||||
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, abort_event, **data.to_gen_params()
|
||||
data.prompt, **data.to_gen_params()
|
||||
)
|
||||
async for generation in new_generation:
|
||||
response = _create_response(generation, model_path.name)
|
||||
|
|
@ -81,7 +78,6 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
|
|||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue