From 43cd7f57e83e575973e02d5d7fa160b9832e56bd Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 25 May 2024 18:24:11 -0400 Subject: [PATCH] API + Model: Add blocks and checks for various load requests Add a sequential lock and wait until jobs are completed before executing any loading requests that directly alter the model. However, we also need to block any new requests that come in until the load is finished, so add a condition that triggers once the lock is free. Signed-off-by: kingbri --- backends/exllamav2/model.py | 398 +++++++++++++++++++++-------------- common/concurrency.py | 27 +-- common/model.py | 14 +- endpoints/OAI/router.py | 74 ++----- endpoints/OAI/utils/model.py | 4 +- 5 files changed, 268 insertions(+), 249 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index fd71ba3..30f8a2c 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" +import asyncio import gc import math import pathlib @@ -54,7 +55,6 @@ class ExllamaV2Container: tokenizer: Optional[ExLlamaV2Tokenizer] = None generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None prompt_template: Optional[PromptTemplate] = None - active_loras: List[ExLlamaV2Lora] = [] paged: bool = True # Internal config vars @@ -71,6 +71,12 @@ class ExllamaV2Container: model_is_loading: bool = False model_loaded: bool = False + # Load synchronization + # The lock keeps load tasks sequential + # The condition notifies any waiting tasks + load_lock: asyncio.Lock = asyncio.Lock() + load_condition: asyncio.Condition = asyncio.Condition() + def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): """ Create model container @@ -348,6 +354,22 @@ class ExllamaV2Container: return model_params + async def wait_for_jobs(self, skip_wait: bool = False): + """Polling mechanism to wait for pending generation jobs.""" + + if not self.generator: + return + + # Immediately abort all jobs if asked + if skip_wait: + # Requires a copy to avoid errors during iteration + jobs_copy = self.generator.jobs.copy() + for job in jobs_copy.values(): + await job.cancel() + + while self.generator.jobs: + await asyncio.sleep(0.01) + async def load(self, progress_callback=None): """ Load model @@ -361,89 +383,67 @@ class ExllamaV2Container: async for _ in self.load_gen(progress_callback): pass - async def load_loras(self, lora_directory: pathlib.Path, **kwargs): - """ - Load loras - """ - - loras = unwrap(kwargs.get("loras"), []) - success: List[str] = [] - failure: List[str] = [] - - for lora in loras: - lora_name = lora.get("name") - lora_scaling = unwrap(lora.get("scaling"), 1.0) - - if lora_name is None: - logger.warning( - "One of your loras does not have a name. Please check your " - "config.yml! Skipping lora load." - ) - failure.append(lora_name) - continue - - logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}") - lora_path = lora_directory / lora_name - - self.active_loras.append( - ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) - ) - logger.info(f"Lora successfully loaded: {lora_name}") - success.append(lora_name) - - # Return success and failure names - return {"success": success, "failure": failure} - - async def load_gen(self, progress_callback=None): + async def load_gen(self, progress_callback=None, **kwargs): """Loads a model and streams progress via a generator.""" # Indicate that model load has started - self.model_is_loading = True + # Do this operation under the load lock's context + try: + await self.load_lock.acquire() + self.model_is_loading = True - # Streaming gen for model load progress - model_load_generator = self.load_model_sync(progress_callback) - async for value in iterate_in_threadpool(model_load_generator): - yield value + # Wait for existing generation jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) - # Disable paged mode if the user's min GPU is supported (ampere and above) - min_compute_capability = min( - set( - [ - torch.cuda.get_device_capability(device=module.device_idx)[0] - for module in self.model.modules - if module.device_idx >= 0 - ] + # Streaming gen for model load progress + model_load_generator = self.load_model_sync(progress_callback) + async for value in iterate_in_threadpool(model_load_generator): + yield value + + # Disable paged mode if the user's min GPU is supported (ampere and above) + min_compute_capability = min( + set( + [ + torch.cuda.get_device_capability(device=module.device_idx)[0] + for module in self.model.modules + if module.device_idx >= 0 + ] + ) ) - ) - if torch.version.hip or min_compute_capability < 8: - logger.warning( - "An unsupported GPU is found in this configuration. " - "Switching to compatibility mode. This disables parallel batching." + if torch.version.hip or min_compute_capability < 8: + logger.warning( + "An unsupported GPU is found in this configuration. " + "Switching to compatibility mode. This disables parallel batching." + ) + self.paged = False + self.max_batch_size = 1 + + # Create async generator + self.generator = ExLlamaV2DynamicGeneratorAsync( + model=self.model, + cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, + tokenizer=self.tokenizer, + max_batch_size=self.max_batch_size, + paged=self.paged, ) - self.paged = False - self.max_batch_size = 1 - # Create async generator - self.generator = ExLlamaV2DynamicGeneratorAsync( - model=self.model, - cache=self.cache, - draft_model=self.draft_model, - draft_cache=self.draft_cache, - tokenizer=self.tokenizer, - max_batch_size=self.max_batch_size, - paged=self.paged, - ) + # Clean up any extra vram usage from torch and cuda + # (Helps reduce VRAM bottlenecking on Windows) + gc.collect() + torch.cuda.empty_cache() - # Clean up any extra vram usage from torch and cuda - # (Helps reduce VRAM bottlenecking on Windows) - gc.collect() - torch.cuda.empty_cache() + # Cleanup and update model load state + self.model_loaded = True + logger.info("Model successfully loaded.") + finally: + self.load_lock.release() + self.model_is_loading = False - # Cleanup and update model load state - self.model_is_loading = False - self.model_loaded = True - logger.info("Model successfully loaded.") + async with self.load_condition: + self.load_condition.notify_all() @torch.inference_mode() def load_model_sync(self, progress_callback=None): @@ -538,39 +538,108 @@ class ExllamaV2Container: input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) - def unload(self, loras_only: bool = False): + def get_loras(self): + """Convenience function to get all loras.""" + + return unwrap(self.generator.generator.current_loras, []) + + async def load_loras(self, lora_directory: pathlib.Path, **kwargs): + """ + Load loras + """ + + loras = unwrap(kwargs.get("loras"), []) + + try: + await self.load_lock.acquire() + + # Wait for existing generation jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) + + loras_to_load: List[ExLlamaV2Lora] = [] + success: List[str] = [] + failure: List[str] = [] + + for lora in loras: + lora_name = lora.get("name") + lora_scaling = unwrap(lora.get("scaling"), 1.0) + + if lora_name is None: + logger.warning( + "One of your loras does not have a name. Please check your " + "config.yml! Skipping lora load." + ) + failure.append(lora_name) + continue + + logger.info(f"Adding lora: {lora_name} at scaling {lora_scaling}") + lora_path = lora_directory / lora_name + + loras_to_load.append( + ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) + ) + logger.info(f"Lora successfully added: {lora_name}") + success.append(lora_name) + + self.generator.generator.set_loras(loras_to_load) + logger.info("All loras successfully loaded") + + # Return success and failure names + return {"success": success, "failure": failure} + finally: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + + async def unload(self, loras_only: bool = False, **kwargs): """ Free all VRAM resources used by this model """ - for lora in self.active_loras: - lora.unload() + try: + await self.load_lock.acquire() - self.active_loras = [] + # Wait for other jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) - # Unload the entire model if not just unloading loras - if not loras_only: - if self.model: - self.model.unload() - self.model = None + if self.generator and self.generator.generator.current_loras: + for lora in self.generator.generator.current_loras: + lora.unload() - if self.draft_model: - self.draft_model.unload() - self.draft_model = None + self.generator.generator.set_loras([]) - self.config = None - self.cache = None - self.tokenizer = None - self.generator = None + # Unload the entire model if not just unloading loras + if not loras_only: + if self.model: + self.model.unload() + self.model = None - # Set all model state variables to False - self.model_is_loading = False - self.model_loaded = False + if self.draft_model: + self.draft_model.unload() + self.draft_model = None - gc.collect() - torch.cuda.empty_cache() + self.config = None + self.cache = None + self.tokenizer = None - logger.info("Loras unloaded." if loras_only else "Model unloaded.") + # Cleanup the generator from any pending jobs + await self.generator.close() + self.generator = None + + # Set all model state variables to False + self.model_is_loading = False + self.model_loaded = False + + gc.collect() + torch.cuda.empty_cache() + + logger.info("Loras unloaded." if loras_only else "Model unloaded.") + finally: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string""" @@ -683,6 +752,10 @@ class ExllamaV2Container: for kwargs, check common/sampling.py """ + # Wait for load lock to be freed before processing + async with self.load_condition: + await self.load_condition.wait_for(lambda: not self.load_lock.locked()) + prompts = [prompt] token_healing = unwrap(kwargs.get("token_healing"), False) @@ -951,79 +1024,84 @@ class ExllamaV2Container: ) # Save generated tokens and full response + # Copy over max seq len incase model is unloaded and stored jobs can complete # Full response is required for offset calculation + max_seq_len = self.config.max_seq_len generated_tokens = 0 full_response = "" - # Get the generation status once it's ready - async for result in job: - stage = result.get("stage") - result_id = result.get("identifier") + try: + # Get the generation status once it's ready + async for result in job: + stage = result.get("stage") + result_id = result.get("identifier") - if stage == "streaming" and result_id == job_id: - chunk = unwrap(result.get("text"), "") - full_response += chunk + if stage == "streaming" and result_id == job_id: + chunk = unwrap(result.get("text"), "") + full_response += chunk - chunk_tokens = result.get("token_ids") - if chunk_tokens is not None: - generated_tokens += chunk_tokens.size(dim=0) + chunk_tokens = result.get("token_ids") + if chunk_tokens is not None: + generated_tokens += chunk_tokens.size(dim=0) - generation = { - "text": chunk, - "prompt_tokens": context_len, - "generated_tokens": generated_tokens, - "offset": len(full_response), - } - - if request_logprobs > 0: - # Get top tokens and probs - top_tokens = unwrap( - result.get("top_k_tokens"), - torch.empty((1, 0, 1), dtype=torch.long), - ) - - top_probs = unwrap( - result.get("top_k_probs"), - torch.empty((1, 0, 1), dtype=torch.float), - ) - - if top_tokens.numel() > 0 and top_probs.numel() > 0: - logprobs = self.get_logprobs(top_tokens, top_probs) - generation["logprobs"] = logprobs - - # The first logprob is the selected token prob - generation["token_probs"] = { - token: logprobs[token] - for token in list(logprobs.keys())[:1] - } - - yield generation - - # Second yield if eos is true - if result.get("eos"): - log_response(full_response) - - eos_reason = result.get("eos_reason") - finish_reason = ( - "length" if eos_reason == "max_new_tokens" else "stop" - ) - - log_metrics( - result.get("time_enqueued"), - result.get("prompt_tokens"), - result.get("time_prefill"), - result.get("new_tokens"), - result.get("time_generate"), - context_len, - self.config.max_seq_len, - ) - - # Remove the token text generation = { - "prompt_tokens": generation.get("prompt_tokens"), - "generated_tokens": generation.get("generated_tokens"), - "finish_reason": finish_reason, + "text": chunk, + "prompt_tokens": context_len, + "generated_tokens": generated_tokens, + "offset": len(full_response), } + if request_logprobs > 0: + # Get top tokens and probs + top_tokens = unwrap( + result.get("top_k_tokens"), + torch.empty((1, 0, 1), dtype=torch.long), + ) + + top_probs = unwrap( + result.get("top_k_probs"), + torch.empty((1, 0, 1), dtype=torch.float), + ) + + if top_tokens.numel() > 0 and top_probs.numel() > 0: + logprobs = self.get_logprobs(top_tokens, top_probs) + generation["logprobs"] = logprobs + + # The first logprob is the selected token prob + generation["token_probs"] = { + token: logprobs[token] + for token in list(logprobs.keys())[:1] + } + yield generation - break + + # Second yield if eos is true + if result.get("eos"): + log_response(full_response) + + eos_reason = result.get("eos_reason") + finish_reason = ( + "length" if eos_reason == "max_new_tokens" else "stop" + ) + + log_metrics( + result.get("time_enqueued"), + result.get("prompt_tokens"), + result.get("time_prefill"), + result.get("new_tokens"), + result.get("time_generate"), + context_len, + max_seq_len, + ) + + # Remove the token text + generation = { + "prompt_tokens": generation.get("prompt_tokens"), + "generated_tokens": generation.get("generated_tokens"), + "finish_reason": finish_reason, + } + + yield generation + break + except asyncio.CancelledError: + await job.cancel() diff --git a/common/concurrency.py b/common/concurrency.py index c1eefc6..8939432 100644 --- a/common/concurrency.py +++ b/common/concurrency.py @@ -1,12 +1,8 @@ """Concurrency handling""" import asyncio -import inspect from fastapi.concurrency import run_in_threadpool # noqa -from functools import partialmethod -from typing import AsyncGenerator, Generator, Union - -generate_semaphore = asyncio.Semaphore(1) +from typing import AsyncGenerator, Generator # Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py @@ -34,24 +30,3 @@ async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator: yield await asyncio.to_thread(gen_next, generator) except _StopIteration: break - - -async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]): - """Generate with a semaphore.""" - - async with generate_semaphore: - if not inspect.isasyncgenfunction: - generator = iterate_in_threadpool(generator()) - - async for result in generator(): - yield result - - -async def call_with_semaphore(callback: partialmethod): - """Call with a semaphore.""" - - async with generate_semaphore: - if not inspect.iscoroutinefunction: - callback = run_in_threadpool(callback) - - return await callback() diff --git a/common/model.py b/common/model.py index f916f5a..3ec2614 100644 --- a/common/model.py +++ b/common/model.py @@ -20,11 +20,11 @@ def load_progress(module, modules): yield module, modules -async def unload_model(): +async def unload_model(skip_wait: bool = False): """Unloads a model""" global container - container.unload() + await container.unload(skip_wait=skip_wait) container = None @@ -49,7 +49,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): container = ExllamaV2Container(model_path.resolve(), False, **kwargs) model_type = "draft" if container.draft_config else "model" - load_status = container.load_gen(load_progress) + load_status = container.load_gen(load_progress, **kwargs) progress = get_loading_progress_bar() progress.start() @@ -81,12 +81,12 @@ async def load_model(model_path: pathlib.Path, **kwargs): async def load_loras(lora_dir, **kwargs): """Wrapper to load loras.""" - if len(container.active_loras) > 0: - unload_loras() + if len(container.get_loras()) > 0: + await unload_loras() return await container.load_loras(lora_dir, **kwargs) -def unload_loras(): +async def unload_loras(): """Wrapper to unload loras""" - container.unload(loras_only=True) + await container.unload(loras_only=True) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index a09bf03..b7c594f 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,18 +1,12 @@ import asyncio import pathlib from fastapi import APIRouter, Depends, HTTPException, Header, Request -from functools import partial -from loguru import logger from sse_starlette import EventSourceResponse from sys import maxsize from typing import Optional from common import config, model, gen_logging, sampling from common.auth import check_admin_key, check_api_key, validate_key_permission -from common.concurrency import ( - call_with_semaphore, - generate_with_semaphore, -) from common.downloader import hf_repo_download from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates @@ -141,7 +135,7 @@ async def list_draft_models(): # Load model endpoint @router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) -async def load_model(request: Request, data: ModelLoadRequest): +async def load_model(data: ModelLoadRequest): """Loads a model into the model container.""" # Verify request parameters @@ -178,18 +172,9 @@ async def load_model(request: Request, data: ModelLoadRequest): raise HTTPException(400, error_message) - load_callback = partial(stream_model_load, data, model_path, draft_model_path) - - # Wrap in a semaphore if the queue isn't being skipped - if data.skip_queue: - logger.warning( - "Model load request is skipping the completions queue. " - "Unexpected results may occur." - ) - else: - load_callback = partial(generate_with_semaphore, load_callback) - - return EventSourceResponse(load_callback(), ping=maxsize) + return EventSourceResponse( + stream_model_load(data, model_path, draft_model_path), ping=maxsize + ) # Unload model endpoint @@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest): ) async def unload_model(): """Unloads the currently loaded model.""" - await model.unload_model() + await model.unload_model(skip_wait=True) @router.get("/v1/templates", dependencies=[Depends(check_api_key)]) @@ -335,15 +320,13 @@ async def get_all_loras(): async def get_active_loras(): """Returns the currently loaded loras.""" active_loras = LoraList( - data=list( - map( - lambda lora: LoraCard( - id=pathlib.Path(lora.lora_path).parent.name, - scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, - ), - model.container.active_loras, + data=[ + LoraCard( + id=pathlib.Path(lora.lora_path).parent.name, + scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, ) - ) + for lora in model.container.get_loras() + ] ) return active_loras @@ -374,18 +357,9 @@ async def load_lora(data: LoraLoadRequest): raise HTTPException(400, error_message) - load_callback = partial(model.load_loras, lora_dir, **data.model_dump()) - - # Wrap in a semaphore if the queue isn't being skipped - if data.skip_queue: - logger.warning( - "Lora load request is skipping the completions queue. " - "Unexpected results may occur." - ) - else: - load_callback = partial(call_with_semaphore, load_callback) - - load_result = await load_callback() + load_result = await model.load_loras( + lora_dir, **data.model_dump(), skip_wait=data.skip_queue + ) return LoraLoadResponse( success=unwrap(load_result.get("success"), []), @@ -401,7 +375,7 @@ async def load_lora(data: LoraLoadRequest): async def unload_loras(): """Unloads the currently loaded loras.""" - model.unload_loras() + await model.unload_loras() # Encode tokens endpoint @@ -494,16 +468,12 @@ async def completion_request(request: Request, data: CompletionRequest): data.json_schema = {"type": "object"} if data.stream and not disable_request_streaming: - generator_callback = partial(stream_generate_completion, data, model_path) - return EventSourceResponse( - generate_with_semaphore(generator_callback), + stream_generate_completion(data, model_path), ping=maxsize, ) else: - generate_task = asyncio.create_task( - call_with_semaphore(partial(generate_completion, data, model_path)) - ) + generate_task = asyncio.create_task(generate_completion(data, model_path)) response = await run_with_request_disconnect( request, @@ -545,19 +515,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) ) if data.stream and not disable_request_streaming: - generator_callback = partial( - stream_generate_chat_completion, prompt, data, model_path - ) - return EventSourceResponse( - generate_with_semaphore(generator_callback), + stream_generate_chat_completion(prompt, data, model_path), ping=maxsize, ) else: generate_task = asyncio.create_task( - call_with_semaphore( - partial(generate_chat_completion, prompt, data, model_path) - ) + generate_chat_completion(prompt, data, model_path) ) response = await run_with_request_disconnect( diff --git a/endpoints/OAI/utils/model.py b/endpoints/OAI/utils/model.py index 66c7625..0502193 100644 --- a/endpoints/OAI/utils/model.py +++ b/endpoints/OAI/utils/model.py @@ -43,7 +43,9 @@ async def stream_model_load( if draft_model_path: load_data["draft"]["draft_model_dir"] = draft_model_path - load_status = model.load_model_gen(model_path, **load_data) + load_status = model.load_model_gen( + model_path, skip_wait=data.skip_queue, **load_data + ) try: async for module, modules, model_type in load_status: if module != 0: