Model: Add proper jobs cleanup and fix var calls
Jobs should be started and immediately cleaned up when calling the generation stream. Expose a stream_generate function and append this to the base class since it's more idiomatic than generate_gen. The exl2 container's generate_gen function is now internal. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
7e007f0761
commit
f070587e9f
6 changed files with 45 additions and 26 deletions
|
|
@ -163,7 +163,7 @@ class BaseModelContainer(abc.ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def generate_gen(
|
async def stream_generate(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
|
|
||||||
|
|
@ -879,7 +879,7 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
):
|
):
|
||||||
"""Generate a response to a prompt."""
|
"""Generate a response to a prompt."""
|
||||||
generations = []
|
generations = []
|
||||||
async for generation in self.generate_gen(
|
async for generation in self.stream_generate(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
prompt,
|
||||||
params,
|
params,
|
||||||
|
|
@ -931,6 +931,42 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
|
|
||||||
return joined_generation
|
return joined_generation
|
||||||
|
|
||||||
|
async def stream_generate(
|
||||||
|
self,
|
||||||
|
request_id: str,
|
||||||
|
prompt: str,
|
||||||
|
params: BaseSamplerRequest,
|
||||||
|
abort_event: Optional[asyncio.Event] = None,
|
||||||
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Wait for load lock to be freed before processing
|
||||||
|
# Mainly used for loras and other operations where the class is available
|
||||||
|
async with self.load_condition:
|
||||||
|
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
||||||
|
|
||||||
|
# If the model is being unloaded, don't accept new requests
|
||||||
|
if not self.loaded:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Model is being unloaded. Cannot process new generation requests."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mark that the job is running
|
||||||
|
self.active_job_ids[request_id] = None
|
||||||
|
|
||||||
|
# Yield from the internal generator
|
||||||
|
async for generation_chunk in self.generate_gen(
|
||||||
|
request_id=request_id,
|
||||||
|
prompt=prompt,
|
||||||
|
params=params,
|
||||||
|
abort_event=abort_event,
|
||||||
|
mm_embeddings=mm_embeddings,
|
||||||
|
):
|
||||||
|
yield generation_chunk
|
||||||
|
finally:
|
||||||
|
# Clean up and remove the job from active IDs
|
||||||
|
del self.active_job_ids[request_id]
|
||||||
|
|
||||||
def check_unsupported_settings(self, params: BaseSamplerRequest):
|
def check_unsupported_settings(self, params: BaseSamplerRequest):
|
||||||
"""
|
"""
|
||||||
Check and warn the user if a sampler is unsupported.
|
Check and warn the user if a sampler is unsupported.
|
||||||
|
|
@ -1165,20 +1201,6 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
for kwargs, check common/sampling.py
|
for kwargs, check common/sampling.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Wait for load lock to be freed before processing
|
|
||||||
# Mainly used for loras and other operations where the class is available
|
|
||||||
async with self.load_condition:
|
|
||||||
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
|
||||||
|
|
||||||
# If the model is being unloaded, don't accept new requests
|
|
||||||
if not self.loaded:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Model is being unloaded. Cannot process new generation requests."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mark that the job is running
|
|
||||||
self.active_job_ids[request_id] = None
|
|
||||||
|
|
||||||
prompts = [prompt]
|
prompts = [prompt]
|
||||||
gen_settings = ExLlamaV2Sampler.Settings()
|
gen_settings = ExLlamaV2Sampler.Settings()
|
||||||
grammar_handler = ExLlamaV2Grammar()
|
grammar_handler = ExLlamaV2Grammar()
|
||||||
|
|
@ -1421,6 +1443,3 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
context_len,
|
context_len,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove the job from active IDs
|
|
||||||
del self.active_job_ids[request_id]
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ if dependencies.extras:
|
||||||
|
|
||||||
class InfinityContainer:
|
class InfinityContainer:
|
||||||
model_dir: pathlib.Path
|
model_dir: pathlib.Path
|
||||||
model_loaded: bool = False
|
loaded: bool = False
|
||||||
|
|
||||||
# Use a runtime type hint here
|
# Use a runtime type hint here
|
||||||
engine: Optional["AsyncEmbeddingEngine"] = None
|
engine: Optional["AsyncEmbeddingEngine"] = None
|
||||||
|
|
@ -37,7 +37,7 @@ class InfinityContainer:
|
||||||
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
||||||
await self.engine.astart()
|
await self.engine.astart()
|
||||||
|
|
||||||
self.model_loaded = True
|
self.loaded = True
|
||||||
logger.info("Embedding model successfully loaded.")
|
logger.info("Embedding model successfully loaded.")
|
||||||
|
|
||||||
async def unload(self):
|
async def unload(self):
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||||
if container and container.model:
|
if container and container.model:
|
||||||
loaded_model_name = container.model_dir.name
|
loaded_model_name = container.model_dir.name
|
||||||
|
|
||||||
if loaded_model_name == model_path.name and container.model_loaded:
|
if loaded_model_name == model_path.name and container.loaded:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
||||||
)
|
)
|
||||||
|
|
@ -191,7 +191,7 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs):
|
||||||
if embeddings_container and embeddings_container.engine:
|
if embeddings_container and embeddings_container.engine:
|
||||||
loaded_model_name = embeddings_container.model_dir.name
|
loaded_model_name = embeddings_container.model_dir.name
|
||||||
|
|
||||||
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
|
if loaded_model_name == model_path.name and embeddings_container.loaded:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
|
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,7 @@ async def _stream_collector(data: GenerateRequest, request: Request):
|
||||||
try:
|
try:
|
||||||
logger.info(f"Received Kobold generation request {data.genkey}")
|
logger.info(f"Received Kobold generation request {data.genkey}")
|
||||||
|
|
||||||
generator = model.container.generate_gen(
|
generator = model.container.stream_generate(
|
||||||
request_id=data.genkey, abort_event=abort_event, **data.model_dump()
|
request_id=data.genkey, abort_event=abort_event, **data.model_dump()
|
||||||
)
|
)
|
||||||
async for generation in generator:
|
async for generation in generator:
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,7 @@ async def _stream_collector(
|
||||||
"""Collects a stream and places results in a common queue"""
|
"""Collects a stream and places results in a common queue"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_generation = model.container.generate_gen(
|
new_generation = model.container.stream_generate(
|
||||||
request_id,
|
request_id,
|
||||||
prompt,
|
prompt,
|
||||||
params,
|
params,
|
||||||
|
|
@ -120,7 +120,7 @@ async def load_inline_model(model_name: str, request: Request):
|
||||||
if (
|
if (
|
||||||
model.container
|
model.container
|
||||||
and model.container.model_dir.name == model_name
|
and model.container.model_dir.name == model_name
|
||||||
and model.container.model_loaded
|
and model.container.loaded
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue