Tree: Switch to async generators
Async generation helps remove many roadblocks to managing tasks using threads. It should allow for abortables and modern-day paradigms. NOTE: Exllamav2 itself is not an asynchronous library. It's just been added into tabby's async nature to allow for a fast and concurrent API server. It's still being debated to run stream_ex in a separate thread or manually manage it using asyncio.sleep(0) Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
33e2df50b7
commit
7fded4f183
10 changed files with 84 additions and 88 deletions
|
|
@ -1,5 +1,6 @@
|
|||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
from itertools import zip_longest
|
||||
import pathlib
|
||||
|
|
@ -325,7 +326,7 @@ class ExllamaV2Container:
|
|||
|
||||
return model_params
|
||||
|
||||
def load(self, progress_callback=None):
|
||||
async def load(self, progress_callback=None):
|
||||
"""
|
||||
Load model
|
||||
|
||||
|
|
@ -338,7 +339,7 @@ class ExllamaV2Container:
|
|||
for _ in self.load_gen(progress_callback):
|
||||
pass
|
||||
|
||||
def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Load loras
|
||||
"""
|
||||
|
|
@ -361,7 +362,7 @@ class ExllamaV2Container:
|
|||
|
||||
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
|
||||
lora_path = lora_directory / lora_name
|
||||
# FIXME(alpin): Does self.model need to be passed here?
|
||||
|
||||
self.active_loras.append(
|
||||
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
||||
)
|
||||
|
|
@ -371,7 +372,7 @@ class ExllamaV2Container:
|
|||
# Return success and failure names
|
||||
return {"success": success, "failure": failure}
|
||||
|
||||
def load_gen(self, progress_callback=None):
|
||||
async def load_gen(self, progress_callback=None):
|
||||
"""
|
||||
Load model, generator function
|
||||
|
||||
|
|
@ -400,12 +401,16 @@ class ExllamaV2Container:
|
|||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||
|
||||
self.draft_cache = ExLlamaV2Cache(self.draft_model, lazy=True)
|
||||
yield from self.draft_model.load_autosplit_gen(
|
||||
for value in self.draft_model.load_autosplit_gen(
|
||||
self.draft_cache,
|
||||
reserve_vram=autosplit_reserve,
|
||||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
)
|
||||
):
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
if value:
|
||||
yield value
|
||||
|
||||
# Test VRAM allocation with a full-length forward pass
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
|
|
@ -424,6 +429,8 @@ class ExllamaV2Container:
|
|||
self.gpu_split,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
if value:
|
||||
yield value
|
||||
|
||||
|
|
@ -452,6 +459,8 @@ class ExllamaV2Container:
|
|||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
if value:
|
||||
yield value
|
||||
|
||||
|
|
@ -565,9 +574,11 @@ class ExllamaV2Container:
|
|||
|
||||
return dict(zip_longest(top_tokens, cleaned_values))
|
||||
|
||||
def generate(self, prompt: str, **kwargs):
|
||||
async def generate(self, prompt: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
generations = list(self.generate_gen(prompt, **kwargs))
|
||||
generations = []
|
||||
async for generation in self.generate_gen(prompt, **kwargs):
|
||||
generations.append(generation)
|
||||
|
||||
joined_generation = {
|
||||
"text": "",
|
||||
|
|
@ -615,8 +626,7 @@ class ExllamaV2Container:
|
|||
|
||||
return kwargs
|
||||
|
||||
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
|
||||
def generate_gen(self, prompt: str, **kwargs):
|
||||
async def generate_gen(self, prompt: str, **kwargs):
|
||||
"""
|
||||
Create generator function for prompt completion
|
||||
|
||||
|
|
@ -889,6 +899,9 @@ class ExllamaV2Container:
|
|||
chunk_tokens = 0
|
||||
|
||||
while True:
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Ingest prompt
|
||||
if chunk_tokens == 0:
|
||||
ids = torch.cat((ids, save_tokens), dim=-1)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Generator handling"""
|
||||
"""Concurrency handling"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
|
|
@ -52,7 +52,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
progress.start()
|
||||
|
||||
try:
|
||||
for module, modules in load_status:
|
||||
async for module, modules in load_status:
|
||||
if module == 0:
|
||||
loading_task = progress.add_task(
|
||||
f"[cyan]Loading {model_type} modules", total=modules
|
||||
|
|
@ -76,12 +76,12 @@ async def load_model(model_path: pathlib.Path, **kwargs):
|
|||
pass
|
||||
|
||||
|
||||
def load_loras(lora_dir, **kwargs):
|
||||
async def load_loras(lora_dir, **kwargs):
|
||||
"""Wrapper to load loras."""
|
||||
if len(container.active_loras) > 0:
|
||||
unload_loras()
|
||||
|
||||
return container.load_loras(lora_dir, **kwargs)
|
||||
return await container.load_loras(lora_dir, **kwargs)
|
||||
|
||||
|
||||
def unload_loras():
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ from loguru import logger
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from common.concurrency import release_semaphore
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
"""Wrapper callback for load progress."""
|
||||
|
|
@ -51,6 +53,13 @@ def handle_request_error(message: str, exc_info: bool = True):
|
|||
return request_error
|
||||
|
||||
|
||||
def handle_request_disconnect(message: str):
|
||||
"""Wrapper for handling for request disconnection."""
|
||||
|
||||
release_semaphore()
|
||||
logger.error(message)
|
||||
|
||||
|
||||
def unwrap(wrapped, default=None):
|
||||
"""Unwrap function for Optionals."""
|
||||
if wrapped is None:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import pathlib
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from functools import partial
|
||||
from loguru import logger
|
||||
|
|
@ -10,7 +9,7 @@ from sys import maxsize
|
|||
|
||||
from common import config, model, gen_logging, sampling
|
||||
from common.auth import check_admin_key, check_api_key
|
||||
from common.generators import (
|
||||
from common.concurrency import (
|
||||
call_with_semaphore,
|
||||
generate_with_semaphore,
|
||||
)
|
||||
|
|
@ -181,9 +180,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not model_path.exists():
|
||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||
|
||||
load_callback = partial(
|
||||
stream_model_load, request, data, model_path, draft_model_path
|
||||
)
|
||||
load_callback = partial(stream_model_load, data, model_path, draft_model_path)
|
||||
|
||||
# Wrap in a semaphore if the queue isn't being skipped
|
||||
if data.skip_queue:
|
||||
|
|
@ -333,9 +330,7 @@ async def load_lora(data: LoraLoadRequest):
|
|||
"A parent lora directory does not exist. Check your config.yml?",
|
||||
)
|
||||
|
||||
load_callback = partial(
|
||||
run_in_threadpool, model.load_loras, lora_dir, **data.model_dump()
|
||||
)
|
||||
load_callback = partial(model.load_loras, lora_dir, **data.model_dump())
|
||||
|
||||
# Wrap in a semaphore if the queue isn't being skipped
|
||||
if data.skip_queue:
|
||||
|
|
@ -409,9 +404,7 @@ async def completion_request(request: Request, data: CompletionRequest):
|
|||
)
|
||||
|
||||
if data.stream and not disable_request_streaming:
|
||||
generator_callback = partial(
|
||||
stream_generate_completion, request, data, model_path
|
||||
)
|
||||
generator_callback = partial(stream_generate_completion, data, model_path)
|
||||
|
||||
return EventSourceResponse(
|
||||
generate_with_semaphore(generator_callback),
|
||||
|
|
@ -452,7 +445,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
|
|||
|
||||
if data.stream and not disable_request_streaming:
|
||||
generator_callback = partial(
|
||||
stream_generate_chat_completion, prompt, request, data, model_path
|
||||
stream_generate_chat_completion, prompt, data, model_path
|
||||
)
|
||||
|
||||
return EventSourceResponse(
|
||||
|
|
@ -461,13 +454,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
|
|||
)
|
||||
else:
|
||||
response = await call_with_semaphore(
|
||||
partial(generate_chat_completion, prompt, request, data, model_path)
|
||||
partial(generate_chat_completion, prompt, data, model_path)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def start_api(host: str, port: int):
|
||||
async def start_api(host: str, port: int):
|
||||
"""Isolated function to start the API server"""
|
||||
|
||||
# TODO: Move OAI API to a separate folder
|
||||
|
|
@ -475,9 +468,12 @@ def start_api(host: str, port: int):
|
|||
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||
|
||||
uvicorn.run(
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_config=UVICORN_LOG_CONFIG,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
await server.serve()
|
||||
|
|
|
|||
|
|
@ -1,18 +1,21 @@
|
|||
"""Chat completion utilities for OAI server."""
|
||||
|
||||
from asyncio import CancelledError
|
||||
import pathlib
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi import HTTPException
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.generators import release_semaphore
|
||||
from common.templating import get_prompt_from_template
|
||||
from common.utils import get_generator_error, handle_request_error, unwrap
|
||||
from common.utils import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
unwrap,
|
||||
)
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionLogprobs,
|
||||
ChatCompletionLogprob,
|
||||
|
|
@ -150,20 +153,14 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
try:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
|
||||
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
|
||||
for generation in new_generation:
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error("Chat completion generation cancelled by user.")
|
||||
return
|
||||
|
||||
async for generation in new_generation:
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
|
@ -172,6 +169,10 @@ async def stream_generate_chat_completion(
|
|||
finish_response = _create_stream_chunk(const_id, finish_reason="stop")
|
||||
|
||||
yield finish_response.model_dump_json()
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
handle_request_disconnect("Chat completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Chat completion aborted. Please check the server console."
|
||||
|
|
@ -179,11 +180,10 @@ async def stream_generate_chat_completion(
|
|||
|
||||
|
||||
async def generate_chat_completion(
|
||||
prompt: str, request: Request, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
try:
|
||||
generation = await run_in_threadpool(
|
||||
model.container.generate,
|
||||
generation = await model.container.generate(
|
||||
prompt,
|
||||
**data.to_gen_params(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,14 +1,17 @@
|
|||
"""Completion utilities for OAI server."""
|
||||
|
||||
from asyncio import CancelledError
|
||||
import pathlib
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional
|
||||
|
||||
from common import model
|
||||
from common.generators import release_semaphore
|
||||
from common.utils import get_generator_error, handle_request_error, unwrap
|
||||
from common.utils import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
unwrap,
|
||||
)
|
||||
from endpoints.OAI.types.completion import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
|
|
@ -57,28 +60,24 @@ def _create_response(generation: dict, model_name: Optional[str]):
|
|||
return response
|
||||
|
||||
|
||||
async def stream_generate_completion(
|
||||
request: Request, data: CompletionRequest, model_path: pathlib.Path
|
||||
):
|
||||
async def stream_generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
"""Streaming generation for completions."""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
)
|
||||
for generation in new_generation:
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error("Completion generation cancelled by user.")
|
||||
return
|
||||
|
||||
async for generation in new_generation:
|
||||
response = _create_response(generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Yield a finish response on successful generation
|
||||
yield "[DONE]"
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Completion aborted. Please check the server console."
|
||||
|
|
@ -89,9 +88,7 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
|
|||
"""Non-streaming generate for completions"""
|
||||
|
||||
try:
|
||||
generation = await run_in_threadpool(
|
||||
model.container.generate, data.prompt, **data.to_gen_params()
|
||||
)
|
||||
generation = await model.container.generate(data.prompt, **data.to_gen_params())
|
||||
|
||||
response = _create_response(generation, model_path.name)
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -9,6 +9,6 @@ def get_lora_list(lora_path: pathlib.Path):
|
|||
for path in lora_path.iterdir():
|
||||
if path.is_dir():
|
||||
lora_card = LoraCard(id=path.name)
|
||||
lora_list.data.append(lora_card) # pylint: disable=no-member
|
||||
lora_list.data.append(lora_card)
|
||||
|
||||
return lora_list
|
||||
|
|
|
|||
|
|
@ -1,12 +1,9 @@
|
|||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import Request
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from common import model
|
||||
from common.generators import release_semaphore
|
||||
from common.utils import get_generator_error
|
||||
from common.utils import get_generator_error, handle_request_disconnect
|
||||
|
||||
from endpoints.OAI.types.model import (
|
||||
ModelCard,
|
||||
|
|
@ -35,7 +32,6 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
|
|||
|
||||
|
||||
async def stream_model_load(
|
||||
request: Request,
|
||||
data: ModelLoadRequest,
|
||||
model_path: pathlib.Path,
|
||||
draft_model_path: str,
|
||||
|
|
@ -50,14 +46,6 @@ async def stream_model_load(
|
|||
load_status = model.load_model_gen(model_path, **load_data)
|
||||
try:
|
||||
async for module, modules, model_type in load_status:
|
||||
if await request.is_disconnected():
|
||||
release_semaphore()
|
||||
logger.error(
|
||||
"Model load cancelled by user. "
|
||||
"Please make sure to run unload to free up resources."
|
||||
)
|
||||
return
|
||||
|
||||
if module != 0:
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
|
|
@ -78,7 +66,9 @@ async def stream_model_load(
|
|||
|
||||
yield response.model_dump_json()
|
||||
except CancelledError:
|
||||
logger.error(
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
handle_request_disconnect(
|
||||
"Model load cancelled by user. "
|
||||
"Please make sure to run unload to free up resources."
|
||||
)
|
||||
|
|
|
|||
11
main.py
11
main.py
|
|
@ -5,9 +5,6 @@ import os
|
|||
import pathlib
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from functools import partial
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -121,13 +118,7 @@ async def entrypoint(args: Optional[dict] = None):
|
|||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
model.container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
||||
# TODO: Replace this with abortables, async via producer consumer, or something else
|
||||
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
|
||||
|
||||
api_thread.start()
|
||||
# Keep the program alive
|
||||
while api_thread.is_alive():
|
||||
time.sleep(0.5)
|
||||
await start_api(host, port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue