API + Model: Add blocks and checks for various load requests

Add a sequential lock and wait until jobs are completed before executing
any loading requests that directly alter the model. However, we also
need to block any new requests that come in until the load is finished,
so add a condition that triggers once the lock is free.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-25 18:24:11 -04:00 committed by Brian Dashore
parent 408c66a1f2
commit 43cd7f57e8
5 changed files with 268 additions and 249 deletions

View file

@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models."""
import asyncio
import gc
import math
import pathlib
@ -54,7 +55,6 @@ class ExllamaV2Container:
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
prompt_template: Optional[PromptTemplate] = None
active_loras: List[ExLlamaV2Lora] = []
paged: bool = True
# Internal config vars
@ -71,6 +71,12 @@ class ExllamaV2Container:
model_is_loading: bool = False
model_loaded: bool = False
# Load synchronization
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
"""
Create model container
@ -348,6 +354,22 @@ class ExllamaV2Container:
return model_params
async def wait_for_jobs(self, skip_wait: bool = False):
"""Polling mechanism to wait for pending generation jobs."""
if not self.generator:
return
# Immediately abort all jobs if asked
if skip_wait:
# Requires a copy to avoid errors during iteration
jobs_copy = self.generator.jobs.copy()
for job in jobs_copy.values():
await job.cancel()
while self.generator.jobs:
await asyncio.sleep(0.01)
async def load(self, progress_callback=None):
"""
Load model
@ -361,45 +383,18 @@ class ExllamaV2Container:
async for _ in self.load_gen(progress_callback):
pass
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""
loras = unwrap(kwargs.get("loras"), [])
success: List[str] = []
failure: List[str] = []
for lora in loras:
lora_name = lora.get("name")
lora_scaling = unwrap(lora.get("scaling"), 1.0)
if lora_name is None:
logger.warning(
"One of your loras does not have a name. Please check your "
"config.yml! Skipping lora load."
)
failure.append(lora_name)
continue
logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
self.active_loras.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
logger.info(f"Lora successfully loaded: {lora_name}")
success.append(lora_name)
# Return success and failure names
return {"success": success, "failure": failure}
async def load_gen(self, progress_callback=None):
async def load_gen(self, progress_callback=None, **kwargs):
"""Loads a model and streams progress via a generator."""
# Indicate that model load has started
# Do this operation under the load lock's context
try:
await self.load_lock.acquire()
self.model_is_loading = True
# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
# Streaming gen for model load progress
model_load_generator = self.load_model_sync(progress_callback)
async for value in iterate_in_threadpool(model_load_generator):
@ -441,9 +436,14 @@ class ExllamaV2Container:
torch.cuda.empty_cache()
# Cleanup and update model load state
self.model_is_loading = False
self.model_loaded = True
logger.info("Model successfully loaded.")
finally:
self.load_lock.release()
self.model_is_loading = False
async with self.load_condition:
self.load_condition.notify_all()
@torch.inference_mode()
def load_model_sync(self, progress_callback=None):
@ -538,15 +538,76 @@ class ExllamaV2Container:
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
def unload(self, loras_only: bool = False):
def get_loras(self):
"""Convenience function to get all loras."""
return unwrap(self.generator.generator.current_loras, [])
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""
loras = unwrap(kwargs.get("loras"), [])
try:
await self.load_lock.acquire()
# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
loras_to_load: List[ExLlamaV2Lora] = []
success: List[str] = []
failure: List[str] = []
for lora in loras:
lora_name = lora.get("name")
lora_scaling = unwrap(lora.get("scaling"), 1.0)
if lora_name is None:
logger.warning(
"One of your loras does not have a name. Please check your "
"config.yml! Skipping lora load."
)
failure.append(lora_name)
continue
logger.info(f"Adding lora: {lora_name} at scaling {lora_scaling}")
lora_path = lora_directory / lora_name
loras_to_load.append(
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
)
logger.info(f"Lora successfully added: {lora_name}")
success.append(lora_name)
self.generator.generator.set_loras(loras_to_load)
logger.info("All loras successfully loaded")
# Return success and failure names
return {"success": success, "failure": failure}
finally:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async def unload(self, loras_only: bool = False, **kwargs):
"""
Free all VRAM resources used by this model
"""
for lora in self.active_loras:
try:
await self.load_lock.acquire()
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras:
lora.unload()
self.active_loras = []
self.generator.generator.set_loras([])
# Unload the entire model if not just unloading loras
if not loras_only:
@ -561,6 +622,9 @@ class ExllamaV2Container:
self.config = None
self.cache = None
self.tokenizer = None
# Cleanup the generator from any pending jobs
await self.generator.close()
self.generator = None
# Set all model state variables to False
@ -571,6 +635,11 @@ class ExllamaV2Container:
torch.cuda.empty_cache()
logger.info("Loras unloaded." if loras_only else "Model unloaded.")
finally:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string"""
@ -683,6 +752,10 @@ class ExllamaV2Container:
for kwargs, check common/sampling.py
"""
# Wait for load lock to be freed before processing
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
prompts = [prompt]
token_healing = unwrap(kwargs.get("token_healing"), False)
@ -951,10 +1024,13 @@ class ExllamaV2Container:
)
# Save generated tokens and full response
# Copy over max seq len incase model is unloaded and stored jobs can complete
# Full response is required for offset calculation
max_seq_len = self.config.max_seq_len
generated_tokens = 0
full_response = ""
try:
# Get the generation status once it's ready
async for result in job:
stage = result.get("stage")
@ -1015,7 +1091,7 @@ class ExllamaV2Container:
result.get("new_tokens"),
result.get("time_generate"),
context_len,
self.config.max_seq_len,
max_seq_len,
)
# Remove the token text
@ -1027,3 +1103,5 @@ class ExllamaV2Container:
yield generation
break
except asyncio.CancelledError:
await job.cancel()

View file

@ -1,12 +1,8 @@
"""Concurrency handling"""
import asyncio
import inspect
from fastapi.concurrency import run_in_threadpool # noqa
from functools import partialmethod
from typing import AsyncGenerator, Generator, Union
generate_semaphore = asyncio.Semaphore(1)
from typing import AsyncGenerator, Generator
# Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py
@ -34,24 +30,3 @@ async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator:
yield await asyncio.to_thread(gen_next, generator)
except _StopIteration:
break
async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]):
"""Generate with a semaphore."""
async with generate_semaphore:
if not inspect.isasyncgenfunction:
generator = iterate_in_threadpool(generator())
async for result in generator():
yield result
async def call_with_semaphore(callback: partialmethod):
"""Call with a semaphore."""
async with generate_semaphore:
if not inspect.iscoroutinefunction:
callback = run_in_threadpool(callback)
return await callback()

View file

@ -20,11 +20,11 @@ def load_progress(module, modules):
yield module, modules
async def unload_model():
async def unload_model(skip_wait: bool = False):
"""Unloads a model"""
global container
container.unload()
await container.unload(skip_wait=skip_wait)
container = None
@ -49,7 +49,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
load_status = container.load_gen(load_progress)
load_status = container.load_gen(load_progress, **kwargs)
progress = get_loading_progress_bar()
progress.start()
@ -81,12 +81,12 @@ async def load_model(model_path: pathlib.Path, **kwargs):
async def load_loras(lora_dir, **kwargs):
"""Wrapper to load loras."""
if len(container.active_loras) > 0:
unload_loras()
if len(container.get_loras()) > 0:
await unload_loras()
return await container.load_loras(lora_dir, **kwargs)
def unload_loras():
async def unload_loras():
"""Wrapper to unload loras"""
container.unload(loras_only=True)
await container.unload(loras_only=True)

View file

@ -1,18 +1,12 @@
import asyncio
import pathlib
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from functools import partial
from loguru import logger
from sse_starlette import EventSourceResponse
from sys import maxsize
from typing import Optional
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key, validate_key_permission
from common.concurrency import (
call_with_semaphore,
generate_with_semaphore,
)
from common.downloader import hf_repo_download
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
@ -141,7 +135,7 @@ async def list_draft_models():
# Load model endpoint
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(request: Request, data: ModelLoadRequest):
async def load_model(data: ModelLoadRequest):
"""Loads a model into the model container."""
# Verify request parameters
@ -178,18 +172,9 @@ async def load_model(request: Request, data: ModelLoadRequest):
raise HTTPException(400, error_message)
load_callback = partial(stream_model_load, data, model_path, draft_model_path)
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
logger.warning(
"Model load request is skipping the completions queue. "
"Unexpected results may occur."
return EventSourceResponse(
stream_model_load(data, model_path, draft_model_path), ping=maxsize
)
else:
load_callback = partial(generate_with_semaphore, load_callback)
return EventSourceResponse(load_callback(), ping=maxsize)
# Unload model endpoint
@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
async def unload_model():
"""Unloads the currently loaded model."""
await model.unload_model()
await model.unload_model(skip_wait=True)
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@ -335,15 +320,13 @@ async def get_all_loras():
async def get_active_loras():
"""Returns the currently loaded loras."""
active_loras = LoraList(
data=list(
map(
lambda lora: LoraCard(
data=[
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
),
model.container.active_loras,
)
)
for lora in model.container.get_loras()
]
)
return active_loras
@ -374,18 +357,9 @@ async def load_lora(data: LoraLoadRequest):
raise HTTPException(400, error_message)
load_callback = partial(model.load_loras, lora_dir, **data.model_dump())
# Wrap in a semaphore if the queue isn't being skipped
if data.skip_queue:
logger.warning(
"Lora load request is skipping the completions queue. "
"Unexpected results may occur."
load_result = await model.load_loras(
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
)
else:
load_callback = partial(call_with_semaphore, load_callback)
load_result = await load_callback()
return LoraLoadResponse(
success=unwrap(load_result.get("success"), []),
@ -401,7 +375,7 @@ async def load_lora(data: LoraLoadRequest):
async def unload_loras():
"""Unloads the currently loaded loras."""
model.unload_loras()
await model.unload_loras()
# Encode tokens endpoint
@ -494,16 +468,12 @@ async def completion_request(request: Request, data: CompletionRequest):
data.json_schema = {"type": "object"}
if data.stream and not disable_request_streaming:
generator_callback = partial(stream_generate_completion, data, model_path)
return EventSourceResponse(
generate_with_semaphore(generator_callback),
stream_generate_completion(data, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
call_with_semaphore(partial(generate_completion, data, model_path))
)
generate_task = asyncio.create_task(generate_completion(data, model_path))
response = await run_with_request_disconnect(
request,
@ -545,19 +515,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
)
if data.stream and not disable_request_streaming:
generator_callback = partial(
stream_generate_chat_completion, prompt, data, model_path
)
return EventSourceResponse(
generate_with_semaphore(generator_callback),
stream_generate_chat_completion(prompt, data, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
call_with_semaphore(
partial(generate_chat_completion, prompt, data, model_path)
)
generate_chat_completion(prompt, data, model_path)
)
response = await run_with_request_disconnect(

View file

@ -43,7 +43,9 @@ async def stream_model_load(
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path
load_status = model.load_model_gen(model_path, **load_data)
load_status = model.load_model_gen(
model_path, skip_wait=data.skip_queue, **load_data
)
try:
async for module, modules, model_type in load_status:
if module != 0: