diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index fd71ba3..30f8a2c 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" +import asyncio import gc import math import pathlib @@ -54,7 +55,6 @@ class ExllamaV2Container: tokenizer: Optional[ExLlamaV2Tokenizer] = None generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None prompt_template: Optional[PromptTemplate] = None - active_loras: List[ExLlamaV2Lora] = [] paged: bool = True # Internal config vars @@ -71,6 +71,12 @@ class ExllamaV2Container: model_is_loading: bool = False model_loaded: bool = False + # Load synchronization + # The lock keeps load tasks sequential + # The condition notifies any waiting tasks + load_lock: asyncio.Lock = asyncio.Lock() + load_condition: asyncio.Condition = asyncio.Condition() + def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): """ Create model container @@ -348,6 +354,22 @@ class ExllamaV2Container: return model_params + async def wait_for_jobs(self, skip_wait: bool = False): + """Polling mechanism to wait for pending generation jobs.""" + + if not self.generator: + return + + # Immediately abort all jobs if asked + if skip_wait: + # Requires a copy to avoid errors during iteration + jobs_copy = self.generator.jobs.copy() + for job in jobs_copy.values(): + await job.cancel() + + while self.generator.jobs: + await asyncio.sleep(0.01) + async def load(self, progress_callback=None): """ Load model @@ -361,89 +383,67 @@ class ExllamaV2Container: async for _ in self.load_gen(progress_callback): pass - async def load_loras(self, lora_directory: pathlib.Path, **kwargs): - """ - Load loras - """ - - loras = unwrap(kwargs.get("loras"), []) - success: List[str] = [] - failure: List[str] = [] - - for lora in loras: - lora_name = lora.get("name") - lora_scaling = unwrap(lora.get("scaling"), 1.0) - - if lora_name is None: - logger.warning( - "One of your loras does not have a name. Please check your " - "config.yml! Skipping lora load." - ) - failure.append(lora_name) - continue - - logger.info(f"Loading lora: {lora_name} at scaling {lora_scaling}") - lora_path = lora_directory / lora_name - - self.active_loras.append( - ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) - ) - logger.info(f"Lora successfully loaded: {lora_name}") - success.append(lora_name) - - # Return success and failure names - return {"success": success, "failure": failure} - - async def load_gen(self, progress_callback=None): + async def load_gen(self, progress_callback=None, **kwargs): """Loads a model and streams progress via a generator.""" # Indicate that model load has started - self.model_is_loading = True + # Do this operation under the load lock's context + try: + await self.load_lock.acquire() + self.model_is_loading = True - # Streaming gen for model load progress - model_load_generator = self.load_model_sync(progress_callback) - async for value in iterate_in_threadpool(model_load_generator): - yield value + # Wait for existing generation jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) - # Disable paged mode if the user's min GPU is supported (ampere and above) - min_compute_capability = min( - set( - [ - torch.cuda.get_device_capability(device=module.device_idx)[0] - for module in self.model.modules - if module.device_idx >= 0 - ] + # Streaming gen for model load progress + model_load_generator = self.load_model_sync(progress_callback) + async for value in iterate_in_threadpool(model_load_generator): + yield value + + # Disable paged mode if the user's min GPU is supported (ampere and above) + min_compute_capability = min( + set( + [ + torch.cuda.get_device_capability(device=module.device_idx)[0] + for module in self.model.modules + if module.device_idx >= 0 + ] + ) ) - ) - if torch.version.hip or min_compute_capability < 8: - logger.warning( - "An unsupported GPU is found in this configuration. " - "Switching to compatibility mode. This disables parallel batching." + if torch.version.hip or min_compute_capability < 8: + logger.warning( + "An unsupported GPU is found in this configuration. " + "Switching to compatibility mode. This disables parallel batching." + ) + self.paged = False + self.max_batch_size = 1 + + # Create async generator + self.generator = ExLlamaV2DynamicGeneratorAsync( + model=self.model, + cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, + tokenizer=self.tokenizer, + max_batch_size=self.max_batch_size, + paged=self.paged, ) - self.paged = False - self.max_batch_size = 1 - # Create async generator - self.generator = ExLlamaV2DynamicGeneratorAsync( - model=self.model, - cache=self.cache, - draft_model=self.draft_model, - draft_cache=self.draft_cache, - tokenizer=self.tokenizer, - max_batch_size=self.max_batch_size, - paged=self.paged, - ) + # Clean up any extra vram usage from torch and cuda + # (Helps reduce VRAM bottlenecking on Windows) + gc.collect() + torch.cuda.empty_cache() - # Clean up any extra vram usage from torch and cuda - # (Helps reduce VRAM bottlenecking on Windows) - gc.collect() - torch.cuda.empty_cache() + # Cleanup and update model load state + self.model_loaded = True + logger.info("Model successfully loaded.") + finally: + self.load_lock.release() + self.model_is_loading = False - # Cleanup and update model load state - self.model_is_loading = False - self.model_loaded = True - logger.info("Model successfully loaded.") + async with self.load_condition: + self.load_condition.notify_all() @torch.inference_mode() def load_model_sync(self, progress_callback=None): @@ -538,39 +538,108 @@ class ExllamaV2Container: input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) self.model.forward(input_ids, cache=self.cache, preprocess_only=True) - def unload(self, loras_only: bool = False): + def get_loras(self): + """Convenience function to get all loras.""" + + return unwrap(self.generator.generator.current_loras, []) + + async def load_loras(self, lora_directory: pathlib.Path, **kwargs): + """ + Load loras + """ + + loras = unwrap(kwargs.get("loras"), []) + + try: + await self.load_lock.acquire() + + # Wait for existing generation jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) + + loras_to_load: List[ExLlamaV2Lora] = [] + success: List[str] = [] + failure: List[str] = [] + + for lora in loras: + lora_name = lora.get("name") + lora_scaling = unwrap(lora.get("scaling"), 1.0) + + if lora_name is None: + logger.warning( + "One of your loras does not have a name. Please check your " + "config.yml! Skipping lora load." + ) + failure.append(lora_name) + continue + + logger.info(f"Adding lora: {lora_name} at scaling {lora_scaling}") + lora_path = lora_directory / lora_name + + loras_to_load.append( + ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling) + ) + logger.info(f"Lora successfully added: {lora_name}") + success.append(lora_name) + + self.generator.generator.set_loras(loras_to_load) + logger.info("All loras successfully loaded") + + # Return success and failure names + return {"success": success, "failure": failure} + finally: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + + async def unload(self, loras_only: bool = False, **kwargs): """ Free all VRAM resources used by this model """ - for lora in self.active_loras: - lora.unload() + try: + await self.load_lock.acquire() - self.active_loras = [] + # Wait for other jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) - # Unload the entire model if not just unloading loras - if not loras_only: - if self.model: - self.model.unload() - self.model = None + if self.generator and self.generator.generator.current_loras: + for lora in self.generator.generator.current_loras: + lora.unload() - if self.draft_model: - self.draft_model.unload() - self.draft_model = None + self.generator.generator.set_loras([]) - self.config = None - self.cache = None - self.tokenizer = None - self.generator = None + # Unload the entire model if not just unloading loras + if not loras_only: + if self.model: + self.model.unload() + self.model = None - # Set all model state variables to False - self.model_is_loading = False - self.model_loaded = False + if self.draft_model: + self.draft_model.unload() + self.draft_model = None - gc.collect() - torch.cuda.empty_cache() + self.config = None + self.cache = None + self.tokenizer = None - logger.info("Loras unloaded." if loras_only else "Model unloaded.") + # Cleanup the generator from any pending jobs + await self.generator.close() + self.generator = None + + # Set all model state variables to False + self.model_is_loading = False + self.model_loaded = False + + gc.collect() + torch.cuda.empty_cache() + + logger.info("Loras unloaded." if loras_only else "Model unloaded.") + finally: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string""" @@ -683,6 +752,10 @@ class ExllamaV2Container: for kwargs, check common/sampling.py """ + # Wait for load lock to be freed before processing + async with self.load_condition: + await self.load_condition.wait_for(lambda: not self.load_lock.locked()) + prompts = [prompt] token_healing = unwrap(kwargs.get("token_healing"), False) @@ -951,79 +1024,84 @@ class ExllamaV2Container: ) # Save generated tokens and full response + # Copy over max seq len incase model is unloaded and stored jobs can complete # Full response is required for offset calculation + max_seq_len = self.config.max_seq_len generated_tokens = 0 full_response = "" - # Get the generation status once it's ready - async for result in job: - stage = result.get("stage") - result_id = result.get("identifier") + try: + # Get the generation status once it's ready + async for result in job: + stage = result.get("stage") + result_id = result.get("identifier") - if stage == "streaming" and result_id == job_id: - chunk = unwrap(result.get("text"), "") - full_response += chunk + if stage == "streaming" and result_id == job_id: + chunk = unwrap(result.get("text"), "") + full_response += chunk - chunk_tokens = result.get("token_ids") - if chunk_tokens is not None: - generated_tokens += chunk_tokens.size(dim=0) + chunk_tokens = result.get("token_ids") + if chunk_tokens is not None: + generated_tokens += chunk_tokens.size(dim=0) - generation = { - "text": chunk, - "prompt_tokens": context_len, - "generated_tokens": generated_tokens, - "offset": len(full_response), - } - - if request_logprobs > 0: - # Get top tokens and probs - top_tokens = unwrap( - result.get("top_k_tokens"), - torch.empty((1, 0, 1), dtype=torch.long), - ) - - top_probs = unwrap( - result.get("top_k_probs"), - torch.empty((1, 0, 1), dtype=torch.float), - ) - - if top_tokens.numel() > 0 and top_probs.numel() > 0: - logprobs = self.get_logprobs(top_tokens, top_probs) - generation["logprobs"] = logprobs - - # The first logprob is the selected token prob - generation["token_probs"] = { - token: logprobs[token] - for token in list(logprobs.keys())[:1] - } - - yield generation - - # Second yield if eos is true - if result.get("eos"): - log_response(full_response) - - eos_reason = result.get("eos_reason") - finish_reason = ( - "length" if eos_reason == "max_new_tokens" else "stop" - ) - - log_metrics( - result.get("time_enqueued"), - result.get("prompt_tokens"), - result.get("time_prefill"), - result.get("new_tokens"), - result.get("time_generate"), - context_len, - self.config.max_seq_len, - ) - - # Remove the token text generation = { - "prompt_tokens": generation.get("prompt_tokens"), - "generated_tokens": generation.get("generated_tokens"), - "finish_reason": finish_reason, + "text": chunk, + "prompt_tokens": context_len, + "generated_tokens": generated_tokens, + "offset": len(full_response), } + if request_logprobs > 0: + # Get top tokens and probs + top_tokens = unwrap( + result.get("top_k_tokens"), + torch.empty((1, 0, 1), dtype=torch.long), + ) + + top_probs = unwrap( + result.get("top_k_probs"), + torch.empty((1, 0, 1), dtype=torch.float), + ) + + if top_tokens.numel() > 0 and top_probs.numel() > 0: + logprobs = self.get_logprobs(top_tokens, top_probs) + generation["logprobs"] = logprobs + + # The first logprob is the selected token prob + generation["token_probs"] = { + token: logprobs[token] + for token in list(logprobs.keys())[:1] + } + yield generation - break + + # Second yield if eos is true + if result.get("eos"): + log_response(full_response) + + eos_reason = result.get("eos_reason") + finish_reason = ( + "length" if eos_reason == "max_new_tokens" else "stop" + ) + + log_metrics( + result.get("time_enqueued"), + result.get("prompt_tokens"), + result.get("time_prefill"), + result.get("new_tokens"), + result.get("time_generate"), + context_len, + max_seq_len, + ) + + # Remove the token text + generation = { + "prompt_tokens": generation.get("prompt_tokens"), + "generated_tokens": generation.get("generated_tokens"), + "finish_reason": finish_reason, + } + + yield generation + break + except asyncio.CancelledError: + await job.cancel() diff --git a/common/concurrency.py b/common/concurrency.py index c1eefc6..8939432 100644 --- a/common/concurrency.py +++ b/common/concurrency.py @@ -1,12 +1,8 @@ """Concurrency handling""" import asyncio -import inspect from fastapi.concurrency import run_in_threadpool # noqa -from functools import partialmethod -from typing import AsyncGenerator, Generator, Union - -generate_semaphore = asyncio.Semaphore(1) +from typing import AsyncGenerator, Generator # Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py @@ -34,24 +30,3 @@ async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator: yield await asyncio.to_thread(gen_next, generator) except _StopIteration: break - - -async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]): - """Generate with a semaphore.""" - - async with generate_semaphore: - if not inspect.isasyncgenfunction: - generator = iterate_in_threadpool(generator()) - - async for result in generator(): - yield result - - -async def call_with_semaphore(callback: partialmethod): - """Call with a semaphore.""" - - async with generate_semaphore: - if not inspect.iscoroutinefunction: - callback = run_in_threadpool(callback) - - return await callback() diff --git a/common/model.py b/common/model.py index f916f5a..3ec2614 100644 --- a/common/model.py +++ b/common/model.py @@ -20,11 +20,11 @@ def load_progress(module, modules): yield module, modules -async def unload_model(): +async def unload_model(skip_wait: bool = False): """Unloads a model""" global container - container.unload() + await container.unload(skip_wait=skip_wait) container = None @@ -49,7 +49,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): container = ExllamaV2Container(model_path.resolve(), False, **kwargs) model_type = "draft" if container.draft_config else "model" - load_status = container.load_gen(load_progress) + load_status = container.load_gen(load_progress, **kwargs) progress = get_loading_progress_bar() progress.start() @@ -81,12 +81,12 @@ async def load_model(model_path: pathlib.Path, **kwargs): async def load_loras(lora_dir, **kwargs): """Wrapper to load loras.""" - if len(container.active_loras) > 0: - unload_loras() + if len(container.get_loras()) > 0: + await unload_loras() return await container.load_loras(lora_dir, **kwargs) -def unload_loras(): +async def unload_loras(): """Wrapper to unload loras""" - container.unload(loras_only=True) + await container.unload(loras_only=True) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index a09bf03..b7c594f 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,18 +1,12 @@ import asyncio import pathlib from fastapi import APIRouter, Depends, HTTPException, Header, Request -from functools import partial -from loguru import logger from sse_starlette import EventSourceResponse from sys import maxsize from typing import Optional from common import config, model, gen_logging, sampling from common.auth import check_admin_key, check_api_key, validate_key_permission -from common.concurrency import ( - call_with_semaphore, - generate_with_semaphore, -) from common.downloader import hf_repo_download from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates @@ -141,7 +135,7 @@ async def list_draft_models(): # Load model endpoint @router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) -async def load_model(request: Request, data: ModelLoadRequest): +async def load_model(data: ModelLoadRequest): """Loads a model into the model container.""" # Verify request parameters @@ -178,18 +172,9 @@ async def load_model(request: Request, data: ModelLoadRequest): raise HTTPException(400, error_message) - load_callback = partial(stream_model_load, data, model_path, draft_model_path) - - # Wrap in a semaphore if the queue isn't being skipped - if data.skip_queue: - logger.warning( - "Model load request is skipping the completions queue. " - "Unexpected results may occur." - ) - else: - load_callback = partial(generate_with_semaphore, load_callback) - - return EventSourceResponse(load_callback(), ping=maxsize) + return EventSourceResponse( + stream_model_load(data, model_path, draft_model_path), ping=maxsize + ) # Unload model endpoint @@ -199,7 +184,7 @@ async def load_model(request: Request, data: ModelLoadRequest): ) async def unload_model(): """Unloads the currently loaded model.""" - await model.unload_model() + await model.unload_model(skip_wait=True) @router.get("/v1/templates", dependencies=[Depends(check_api_key)]) @@ -335,15 +320,13 @@ async def get_all_loras(): async def get_active_loras(): """Returns the currently loaded loras.""" active_loras = LoraList( - data=list( - map( - lambda lora: LoraCard( - id=pathlib.Path(lora.lora_path).parent.name, - scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, - ), - model.container.active_loras, + data=[ + LoraCard( + id=pathlib.Path(lora.lora_path).parent.name, + scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha, ) - ) + for lora in model.container.get_loras() + ] ) return active_loras @@ -374,18 +357,9 @@ async def load_lora(data: LoraLoadRequest): raise HTTPException(400, error_message) - load_callback = partial(model.load_loras, lora_dir, **data.model_dump()) - - # Wrap in a semaphore if the queue isn't being skipped - if data.skip_queue: - logger.warning( - "Lora load request is skipping the completions queue. " - "Unexpected results may occur." - ) - else: - load_callback = partial(call_with_semaphore, load_callback) - - load_result = await load_callback() + load_result = await model.load_loras( + lora_dir, **data.model_dump(), skip_wait=data.skip_queue + ) return LoraLoadResponse( success=unwrap(load_result.get("success"), []), @@ -401,7 +375,7 @@ async def load_lora(data: LoraLoadRequest): async def unload_loras(): """Unloads the currently loaded loras.""" - model.unload_loras() + await model.unload_loras() # Encode tokens endpoint @@ -494,16 +468,12 @@ async def completion_request(request: Request, data: CompletionRequest): data.json_schema = {"type": "object"} if data.stream and not disable_request_streaming: - generator_callback = partial(stream_generate_completion, data, model_path) - return EventSourceResponse( - generate_with_semaphore(generator_callback), + stream_generate_completion(data, model_path), ping=maxsize, ) else: - generate_task = asyncio.create_task( - call_with_semaphore(partial(generate_completion, data, model_path)) - ) + generate_task = asyncio.create_task(generate_completion(data, model_path)) response = await run_with_request_disconnect( request, @@ -545,19 +515,13 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) ) if data.stream and not disable_request_streaming: - generator_callback = partial( - stream_generate_chat_completion, prompt, data, model_path - ) - return EventSourceResponse( - generate_with_semaphore(generator_callback), + stream_generate_chat_completion(prompt, data, model_path), ping=maxsize, ) else: generate_task = asyncio.create_task( - call_with_semaphore( - partial(generate_chat_completion, prompt, data, model_path) - ) + generate_chat_completion(prompt, data, model_path) ) response = await run_with_request_disconnect( diff --git a/endpoints/OAI/utils/model.py b/endpoints/OAI/utils/model.py index 66c7625..0502193 100644 --- a/endpoints/OAI/utils/model.py +++ b/endpoints/OAI/utils/model.py @@ -43,7 +43,9 @@ async def stream_model_load( if draft_model_path: load_data["draft"]["draft_model_dir"] = draft_model_path - load_status = model.load_model_gen(model_path, **load_data) + load_status = model.load_model_gen( + model_path, skip_wait=data.skip_queue, **load_data + ) try: async for module, modules, model_type in load_status: if module != 0: