Model: Add proper jobs cleanup and fix var calls
Jobs should be started and immediately cleaned up when calling the generation stream. Expose a stream_generate function and append this to the base class since it's more idiomatic than generate_gen. The exl2 container's generate_gen function is now internal. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
7e007f0761
commit
f070587e9f
6 changed files with 45 additions and 26 deletions
|
|
@ -163,7 +163,7 @@ class BaseModelContainer(abc.ABC):
|
|||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def generate_gen(
|
||||
async def stream_generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
|
|
|
|||
|
|
@ -879,7 +879,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
):
|
||||
"""Generate a response to a prompt."""
|
||||
generations = []
|
||||
async for generation in self.generate_gen(
|
||||
async for generation in self.stream_generate(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
|
|
@ -931,6 +931,42 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
|
||||
return joined_generation
|
||||
|
||||
async def stream_generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
):
|
||||
try:
|
||||
# Wait for load lock to be freed before processing
|
||||
# Mainly used for loras and other operations where the class is available
|
||||
async with self.load_condition:
|
||||
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
||||
|
||||
# If the model is being unloaded, don't accept new requests
|
||||
if not self.loaded:
|
||||
raise RuntimeError(
|
||||
"Model is being unloaded. Cannot process new generation requests."
|
||||
)
|
||||
|
||||
# Mark that the job is running
|
||||
self.active_job_ids[request_id] = None
|
||||
|
||||
# Yield from the internal generator
|
||||
async for generation_chunk in self.generate_gen(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
abort_event=abort_event,
|
||||
mm_embeddings=mm_embeddings,
|
||||
):
|
||||
yield generation_chunk
|
||||
finally:
|
||||
# Clean up and remove the job from active IDs
|
||||
del self.active_job_ids[request_id]
|
||||
|
||||
def check_unsupported_settings(self, params: BaseSamplerRequest):
|
||||
"""
|
||||
Check and warn the user if a sampler is unsupported.
|
||||
|
|
@ -1165,20 +1201,6 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
for kwargs, check common/sampling.py
|
||||
"""
|
||||
|
||||
# Wait for load lock to be freed before processing
|
||||
# Mainly used for loras and other operations where the class is available
|
||||
async with self.load_condition:
|
||||
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
||||
|
||||
# If the model is being unloaded, don't accept new requests
|
||||
if not self.loaded:
|
||||
raise RuntimeError(
|
||||
"Model is being unloaded. Cannot process new generation requests."
|
||||
)
|
||||
|
||||
# Mark that the job is running
|
||||
self.active_job_ids[request_id] = None
|
||||
|
||||
prompts = [prompt]
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
|
@ -1421,6 +1443,3 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
context_len,
|
||||
max_seq_len,
|
||||
)
|
||||
|
||||
# Remove the job from active IDs
|
||||
del self.active_job_ids[request_id]
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ if dependencies.extras:
|
|||
|
||||
class InfinityContainer:
|
||||
model_dir: pathlib.Path
|
||||
model_loaded: bool = False
|
||||
loaded: bool = False
|
||||
|
||||
# Use a runtime type hint here
|
||||
engine: Optional["AsyncEmbeddingEngine"] = None
|
||||
|
|
@ -37,7 +37,7 @@ class InfinityContainer:
|
|||
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
||||
await self.engine.astart()
|
||||
|
||||
self.model_loaded = True
|
||||
self.loaded = True
|
||||
logger.info("Embedding model successfully loaded.")
|
||||
|
||||
async def unload(self):
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
if container and container.model:
|
||||
loaded_model_name = container.model_dir.name
|
||||
|
||||
if loaded_model_name == model_path.name and container.model_loaded:
|
||||
if loaded_model_name == model_path.name and container.loaded:
|
||||
raise ValueError(
|
||||
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
|
@ -191,7 +191,7 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs):
|
|||
if embeddings_container and embeddings_container.engine:
|
||||
loaded_model_name = embeddings_container.model_dir.name
|
||||
|
||||
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
|
||||
if loaded_model_name == model_path.name and embeddings_container.loaded:
|
||||
raise ValueError(
|
||||
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ async def _stream_collector(data: GenerateRequest, request: Request):
|
|||
try:
|
||||
logger.info(f"Received Kobold generation request {data.genkey}")
|
||||
|
||||
generator = model.container.generate_gen(
|
||||
generator = model.container.stream_generate(
|
||||
request_id=data.genkey, abort_event=abort_event, **data.model_dump()
|
||||
)
|
||||
async for generation in generator:
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ async def _stream_collector(
|
|||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
new_generation = model.container.stream_generate(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
|
|
@ -120,7 +120,7 @@ async def load_inline_model(model_name: str, request: Request):
|
|||
if (
|
||||
model.container
|
||||
and model.container.model_dir.name == model_name
|
||||
and model.container.model_loaded
|
||||
and model.container.loaded
|
||||
):
|
||||
return
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue