From f070587e9f263691491501a272d6e50465ce4797 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Thu, 24 Apr 2025 21:30:55 -0400 Subject: [PATCH] Model: Add proper jobs cleanup and fix var calls Jobs should be started and immediately cleaned up when calling the generation stream. Expose a stream_generate function and append this to the base class since it's more idiomatic than generate_gen. The exl2 container's generate_gen function is now internal. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/base_model_container.py | 2 +- backends/exllamav2/model.py | 55 +++++++++++++++++++--------- backends/infinity/model.py | 4 +- common/model.py | 4 +- endpoints/Kobold/utils/generation.py | 2 +- endpoints/OAI/utils/completion.py | 4 +- 6 files changed, 45 insertions(+), 26 deletions(-) diff --git a/backends/base_model_container.py b/backends/base_model_container.py index aca3a77..6336d4d 100644 --- a/backends/base_model_container.py +++ b/backends/base_model_container.py @@ -163,7 +163,7 @@ class BaseModelContainer(abc.ABC): pass @abc.abstractmethod - async def generate_gen( + async def stream_generate( self, request_id: str, prompt: str, diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 55aa497..d1d364f 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -879,7 +879,7 @@ class ExllamaV2Container(BaseModelContainer): ): """Generate a response to a prompt.""" generations = [] - async for generation in self.generate_gen( + async for generation in self.stream_generate( request_id, prompt, params, @@ -931,6 +931,42 @@ class ExllamaV2Container(BaseModelContainer): return joined_generation + async def stream_generate( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ): + try: + # Wait for load lock to be freed before processing + # Mainly used for loras and other operations where the class is available + async with self.load_condition: + await self.load_condition.wait_for(lambda: not self.load_lock.locked()) + + # If the model is being unloaded, don't accept new requests + if not self.loaded: + raise RuntimeError( + "Model is being unloaded. Cannot process new generation requests." + ) + + # Mark that the job is running + self.active_job_ids[request_id] = None + + # Yield from the internal generator + async for generation_chunk in self.generate_gen( + request_id=request_id, + prompt=prompt, + params=params, + abort_event=abort_event, + mm_embeddings=mm_embeddings, + ): + yield generation_chunk + finally: + # Clean up and remove the job from active IDs + del self.active_job_ids[request_id] + def check_unsupported_settings(self, params: BaseSamplerRequest): """ Check and warn the user if a sampler is unsupported. @@ -1165,20 +1201,6 @@ class ExllamaV2Container(BaseModelContainer): for kwargs, check common/sampling.py """ - # Wait for load lock to be freed before processing - # Mainly used for loras and other operations where the class is available - async with self.load_condition: - await self.load_condition.wait_for(lambda: not self.load_lock.locked()) - - # If the model is being unloaded, don't accept new requests - if not self.loaded: - raise RuntimeError( - "Model is being unloaded. Cannot process new generation requests." - ) - - # Mark that the job is running - self.active_job_ids[request_id] = None - prompts = [prompt] gen_settings = ExLlamaV2Sampler.Settings() grammar_handler = ExLlamaV2Grammar() @@ -1421,6 +1443,3 @@ class ExllamaV2Container(BaseModelContainer): context_len, max_seq_len, ) - - # Remove the job from active IDs - del self.active_job_ids[request_id] diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 04698e5..c131e3c 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -14,7 +14,7 @@ if dependencies.extras: class InfinityContainer: model_dir: pathlib.Path - model_loaded: bool = False + loaded: bool = False # Use a runtime type hint here engine: Optional["AsyncEmbeddingEngine"] = None @@ -37,7 +37,7 @@ class InfinityContainer: self.engine = AsyncEmbeddingEngine.from_args(engine_args) await self.engine.astart() - self.model_loaded = True + self.loaded = True logger.info("Embedding model successfully loaded.") async def unload(self): diff --git a/common/model.py b/common/model.py index 19e8d3c..18dd960 100644 --- a/common/model.py +++ b/common/model.py @@ -92,7 +92,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): if container and container.model: loaded_model_name = container.model_dir.name - if loaded_model_name == model_path.name and container.model_loaded: + if loaded_model_name == model_path.name and container.loaded: raise ValueError( f'Model "{loaded_model_name}" is already loaded! Aborting.' ) @@ -191,7 +191,7 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs): if embeddings_container and embeddings_container.engine: loaded_model_name = embeddings_container.model_dir.name - if loaded_model_name == model_path.name and embeddings_container.model_loaded: + if loaded_model_name == model_path.name and embeddings_container.loaded: raise ValueError( f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.' ) diff --git a/endpoints/Kobold/utils/generation.py b/endpoints/Kobold/utils/generation.py index 2086788..39ddb9f 100644 --- a/endpoints/Kobold/utils/generation.py +++ b/endpoints/Kobold/utils/generation.py @@ -52,7 +52,7 @@ async def _stream_collector(data: GenerateRequest, request: Request): try: logger.info(f"Received Kobold generation request {data.genkey}") - generator = model.container.generate_gen( + generator = model.container.stream_generate( request_id=data.genkey, abort_event=abort_event, **data.model_dump() ) async for generation in generator: diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 1d706d4..8be249c 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -95,7 +95,7 @@ async def _stream_collector( """Collects a stream and places results in a common queue""" try: - new_generation = model.container.generate_gen( + new_generation = model.container.stream_generate( request_id, prompt, params, @@ -120,7 +120,7 @@ async def load_inline_model(model_name: str, request: Request): if ( model.container and model.container.model_dir.name == model_name - and model.container.model_loaded + and model.container.loaded ): return