Model: Add proper jobs cleanup and fix var calls

Jobs should be started and immediately cleaned up when calling the
generation stream. Expose a stream_generate function and append
this to the base class since it's more idiomatic than generate_gen.

The exl2 container's generate_gen function is now internal.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-24 21:30:55 -04:00
parent 7e007f0761
commit f070587e9f
6 changed files with 45 additions and 26 deletions

View file

@ -163,7 +163,7 @@ class BaseModelContainer(abc.ABC):
pass
@abc.abstractmethod
async def generate_gen(
async def stream_generate(
self,
request_id: str,
prompt: str,

View file

@ -879,7 +879,7 @@ class ExllamaV2Container(BaseModelContainer):
):
"""Generate a response to a prompt."""
generations = []
async for generation in self.generate_gen(
async for generation in self.stream_generate(
request_id,
prompt,
params,
@ -931,6 +931,42 @@ class ExllamaV2Container(BaseModelContainer):
return joined_generation
async def stream_generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
try:
# Wait for load lock to be freed before processing
# Mainly used for loras and other operations where the class is available
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
# If the model is being unloaded, don't accept new requests
if not self.loaded:
raise RuntimeError(
"Model is being unloaded. Cannot process new generation requests."
)
# Mark that the job is running
self.active_job_ids[request_id] = None
# Yield from the internal generator
async for generation_chunk in self.generate_gen(
request_id=request_id,
prompt=prompt,
params=params,
abort_event=abort_event,
mm_embeddings=mm_embeddings,
):
yield generation_chunk
finally:
# Clean up and remove the job from active IDs
del self.active_job_ids[request_id]
def check_unsupported_settings(self, params: BaseSamplerRequest):
"""
Check and warn the user if a sampler is unsupported.
@ -1165,20 +1201,6 @@ class ExllamaV2Container(BaseModelContainer):
for kwargs, check common/sampling.py
"""
# Wait for load lock to be freed before processing
# Mainly used for loras and other operations where the class is available
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
# If the model is being unloaded, don't accept new requests
if not self.loaded:
raise RuntimeError(
"Model is being unloaded. Cannot process new generation requests."
)
# Mark that the job is running
self.active_job_ids[request_id] = None
prompts = [prompt]
gen_settings = ExLlamaV2Sampler.Settings()
grammar_handler = ExLlamaV2Grammar()
@ -1421,6 +1443,3 @@ class ExllamaV2Container(BaseModelContainer):
context_len,
max_seq_len,
)
# Remove the job from active IDs
del self.active_job_ids[request_id]

View file

@ -14,7 +14,7 @@ if dependencies.extras:
class InfinityContainer:
model_dir: pathlib.Path
model_loaded: bool = False
loaded: bool = False
# Use a runtime type hint here
engine: Optional["AsyncEmbeddingEngine"] = None
@ -37,7 +37,7 @@ class InfinityContainer:
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
await self.engine.astart()
self.model_loaded = True
self.loaded = True
logger.info("Embedding model successfully loaded.")
async def unload(self):

View file

@ -92,7 +92,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
if container and container.model:
loaded_model_name = container.model_dir.name
if loaded_model_name == model_path.name and container.model_loaded:
if loaded_model_name == model_path.name and container.loaded:
raise ValueError(
f'Model "{loaded_model_name}" is already loaded! Aborting.'
)
@ -191,7 +191,7 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs):
if embeddings_container and embeddings_container.engine:
loaded_model_name = embeddings_container.model_dir.name
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
if loaded_model_name == model_path.name and embeddings_container.loaded:
raise ValueError(
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
)

View file

@ -52,7 +52,7 @@ async def _stream_collector(data: GenerateRequest, request: Request):
try:
logger.info(f"Received Kobold generation request {data.genkey}")
generator = model.container.generate_gen(
generator = model.container.stream_generate(
request_id=data.genkey, abort_event=abort_event, **data.model_dump()
)
async for generation in generator:

View file

@ -95,7 +95,7 @@ async def _stream_collector(
"""Collects a stream and places results in a common queue"""
try:
new_generation = model.container.generate_gen(
new_generation = model.container.stream_generate(
request_id,
prompt,
params,
@ -120,7 +120,7 @@ async def load_inline_model(model_name: str, request: Request):
if (
model.container
and model.container.model_dir.name == model_name
and model.container.model_loaded
and model.container.loaded
):
return