Tree: Switch to async generators

Async generation helps remove many roadblocks to managing tasks
using threads. It should allow for abortables and modern-day paradigms.

NOTE: Exllamav2 itself is not an asynchronous library. It's just
been added into tabby's async nature to allow for a fast and concurrent
API server. It's still being debated to run stream_ex in a separate
thread or manually manage it using asyncio.sleep(0)

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-14 10:27:39 -04:00 committed by Brian Dashore
parent 33e2df50b7
commit 7fded4f183
10 changed files with 84 additions and 88 deletions

View file

@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models."""
import asyncio
import gc
from itertools import zip_longest
import pathlib
@ -325,7 +326,7 @@ class ExllamaV2Container:
return model_params
def load(self, progress_callback=None):
async def load(self, progress_callback=None):
"""
Load model
@ -338,7 +339,7 @@ class ExllamaV2Container:
for _ in self.load_gen(progress_callback):
pass
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""
@ -361,7 +362,7 @@ class ExllamaV2Container:
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
# FIXME(alpin): Does self.model need to be passed here?
self.active_loras.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
@ -371,7 +372,7 @@ class ExllamaV2Container:
# Return success and failure names
return {"success": success, "failure": failure}
def load_gen(self, progress_callback=None):
async def load_gen(self, progress_callback=None):
"""
Load model, generator function
@ -400,12 +401,16 @@ class ExllamaV2Container:
logger.info("Loading draft model: " + self.draft_config.model_dir)
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
yield from self.draft_model.load_autosplit_gen(
for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
last_id_only=True,
callback_gen=progress_callback,
)
):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
# Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
@ -424,6 +429,8 @@ class ExllamaV2Container:
self.gpu_split,
callback_gen=progress_callback,
):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
@ -452,6 +459,8 @@ class ExllamaV2Container:
last_id_only=True,
callback_gen=progress_callback,
):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
@ -565,9 +574,11 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values))
def generate(self, prompt: str, **kwargs):
async def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generations = list(self.generate_gen(prompt, **kwargs))
generations = []
async for generation in self.generate_gen(prompt, **kwargs):
generations.append(generation)
joined_generation = {
"text": "",
@ -615,8 +626,7 @@ class ExllamaV2Container:
return kwargs
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def generate_gen(self, prompt: str, **kwargs):
async def generate_gen(self, prompt: str, **kwargs):
"""
Create generator function for prompt completion
@ -889,6 +899,9 @@ class ExllamaV2Container:
chunk_tokens = 0
while True:
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
# Ingest prompt
if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1)

View file

@ -1,4 +1,4 @@
"""Generator handling"""
"""Concurrency handling"""
import asyncio
import inspect

View file

@ -52,7 +52,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
progress.start()
try:
for module, modules in load_status:
async for module, modules in load_status:
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
@ -76,12 +76,12 @@ async def load_model(model_path: pathlib.Path, **kwargs):
pass
def load_loras(lora_dir, **kwargs):
async def load_loras(lora_dir, **kwargs):
"""Wrapper to load loras."""
if len(container.active_loras) > 0:
unload_loras()
return container.load_loras(lora_dir, **kwargs)
return await container.load_loras(lora_dir, **kwargs)
def unload_loras():

View file

@ -6,6 +6,8 @@ from loguru import logger
from pydantic import BaseModel
from typing import Optional
from common.concurrency import release_semaphore
def load_progress(module, modules):
"""Wrapper callback for load progress."""
@ -51,6 +53,13 @@ def handle_request_error(message: str, exc_info: bool = True):
return request_error
def handle_request_disconnect(message: str):
"""Wrapper for handling for request disconnection."""
release_semaphore()
logger.error(message)
def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
if wrapped is None:

View file

@ -1,7 +1,6 @@
import pathlib
import uvicorn
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from functools import partial
from loguru import logger
@ -10,7 +9,7 @@ from sys import maxsize
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key
from common.generators import (
from common.concurrency import (
call_with_semaphore,
generate_with_semaphore,
)
@ -181,9 +180,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
load_callback = partial(
stream_model_load, request, data, model_path, draft_model_path
)
load_callback = partial(stream_model_load, data, model_path, draft_model_path)
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
@ -333,9 +330,7 @@ async def load_lora(data: LoraLoadRequest):
"A parent lora directory does not exist. Check your config.yml?",
)
load_callback = partial(
run_in_threadpool, model.load_loras, lora_dir, **data.model_dump()
)
load_callback = partial(model.load_loras, lora_dir, **data.model_dump())
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
@ -409,9 +404,7 @@ async def completion_request(request: Request, data: CompletionRequest):
)
if data.stream and not disable_request_streaming:
generator_callback = partial(
stream_generate_completion, request, data, model_path
)
generator_callback = partial(stream_generate_completion, data, model_path)
return EventSourceResponse(
generate_with_semaphore(generator_callback),
@ -452,7 +445,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
if data.stream and not disable_request_streaming:
generator_callback = partial(
stream_generate_chat_completion, prompt, request, data, model_path
stream_generate_chat_completion, prompt, data, model_path
)
return EventSourceResponse(
@ -461,13 +454,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
)
else:
response = await call_with_semaphore(
partial(generate_chat_completion, prompt, request, data, model_path)
partial(generate_chat_completion, prompt, data, model_path)
)
return response
def start_api(host: str, port: int):
async def start_api(host: str, port: int):
"""Isolated function to start the API server"""
# TODO: Move OAI API to a separate folder
@ -475,9 +468,12 @@ def start_api(host: str, port: int):
logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
uvicorn.run(
config = uvicorn.Config(
app,
host=host,
port=port,
log_config=UVICORN_LOG_CONFIG,
)
server = uvicorn.Server(config)
await server.serve()

View file

@ -1,18 +1,21 @@
"""Chat completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib
from typing import Optional
from uuid import uuid4
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi import HTTPException
from jinja2 import TemplateError
from loguru import logger
from common import model
from common.generators import release_semaphore
from common.templating import get_prompt_from_template
from common.utils import get_generator_error, handle_request_error, unwrap
from common.utils import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
unwrap,
)
from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs,
ChatCompletionLogprob,
@ -150,20 +153,14 @@ def format_prompt_with_template(data: ChatCompletionRequest):
async def stream_generate_chat_completion(
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
):
"""Generator for the generation process."""
try:
const_id = f"chatcmpl-{uuid4().hex}"
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Chat completion generation cancelled by user.")
return
async for generation in new_generation:
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json()
@ -172,6 +169,10 @@ async def stream_generate_chat_completion(
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
yield finish_response.model_dump_json()
except CancelledError:
# Get out if the request gets disconnected
handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception:
yield get_generator_error(
"Chat completion aborted. Please check the server console."
@ -179,11 +180,10 @@ async def stream_generate_chat_completion(
async def generate_chat_completion(
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
):
try:
generation = await run_in_threadpool(
model.container.generate,
generation = await model.container.generate(
prompt,
**data.to_gen_params(),
)

View file

@ -1,14 +1,17 @@
"""Completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from loguru import logger
from fastapi import HTTPException
from typing import Optional
from common import model
from common.generators import release_semaphore
from common.utils import get_generator_error, handle_request_error, unwrap
from common.utils import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
unwrap,
)
from endpoints.OAI.types.completion import (
CompletionRequest,
CompletionResponse,
@ -57,28 +60,24 @@ def _create_response(generation: dict, model_name: Optional[str]):
return response
async def stream_generate_completion(
request: Request, data: CompletionRequest, model_path: pathlib.Path
):
async def stream_generate_completion(data: CompletionRequest, model_path: pathlib.Path):
"""Streaming generation for completions."""
try:
new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params()
)
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Completion generation cancelled by user.")
return
async for generation in new_generation:
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
# Yield a finish response on successful generation
yield "[DONE]"
except CancelledError:
# Get out if the request gets disconnected
handle_request_disconnect("Completion generation cancelled by user.")
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
@ -89,9 +88,7 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
"""Non-streaming generate for completions"""
try:
generation = await run_in_threadpool(
model.container.generate, data.prompt, **data.to_gen_params()
)
generation = await model.container.generate(data.prompt, **data.to_gen_params())
response = _create_response(generation, model_path.name)
return response

View file

@ -9,6 +9,6 @@ def get_lora_list(lora_path: pathlib.Path):
for path in lora_path.iterdir():
if path.is_dir():
lora_card = LoraCard(id=path.name)
lora_list.data.append(lora_card) # pylint: disable=no-member
lora_list.data.append(lora_card)
return lora_list

View file

@ -1,12 +1,9 @@
import pathlib
from asyncio import CancelledError
from fastapi import Request
from loguru import logger
from typing import Optional
from common import model
from common.generators import release_semaphore
from common.utils import get_generator_error
from common.utils import get_generator_error, handle_request_disconnect
from endpoints.OAI.types.model import (
ModelCard,
@ -35,7 +32,6 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
async def stream_model_load(
request: Request,
data: ModelLoadRequest,
model_path: pathlib.Path,
draft_model_path: str,
@ -50,14 +46,6 @@ async def stream_model_load(
load_status = model.load_model_gen(model_path, **load_data)
try:
async for module, modules, model_type in load_status:
if await request.is_disconnected():
release_semaphore()
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
return
if module != 0:
response = ModelLoadResponse(
model_type=model_type,
@ -78,7 +66,9 @@ async def stream_model_load(
yield response.model_dump_json()
except CancelledError:
logger.error(
# Get out if the request gets disconnected
handle_request_disconnect(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)

11
main.py
View file

@ -5,9 +5,6 @@ import os
import pathlib
import signal
import sys
import threading
import time
from functools import partial
from loguru import logger
from typing import Optional
@ -121,13 +118,7 @@ async def entrypoint(args: Optional[dict] = None):
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
model.container.load_loras(lora_dir.resolve(), **lora_config)
# TODO: Replace this with abortables, async via producer consumer, or something else
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
api_thread.start()
# Keep the program alive
while api_thread.is_alive():
time.sleep(0.5)
await start_api(host, port)
if __name__ == "__main__":