Jobs should be started and immediately cleaned up when calling the generation stream. Expose a stream_generate function and append this to the base class since it's more idiomatic than generate_gen. The exl2 container's generate_gen function is now internal. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
55 lines
1.5 KiB
Python
55 lines
1.5 KiB
Python
import gc
|
|
import pathlib
|
|
import torch
|
|
from loguru import logger
|
|
from typing import List, Optional
|
|
|
|
from common.utils import unwrap
|
|
from common.optional_dependencies import dependencies
|
|
|
|
# Conditionally import infinity to sidestep its logger
|
|
if dependencies.extras:
|
|
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
|
|
|
|
|
class InfinityContainer:
|
|
model_dir: pathlib.Path
|
|
loaded: bool = False
|
|
|
|
# Use a runtime type hint here
|
|
engine: Optional["AsyncEmbeddingEngine"] = None
|
|
|
|
def __init__(self, model_directory: pathlib.Path):
|
|
self.model_dir = model_directory
|
|
|
|
async def load(self, **kwargs):
|
|
# Use cpu by default
|
|
device = unwrap(kwargs.get("embeddings_device"), "cpu")
|
|
|
|
engine_args = EngineArgs(
|
|
model_name_or_path=str(self.model_dir),
|
|
engine="torch",
|
|
device=device,
|
|
bettertransformer=False,
|
|
model_warmup=False,
|
|
)
|
|
|
|
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
|
await self.engine.astart()
|
|
|
|
self.loaded = True
|
|
logger.info("Embedding model successfully loaded.")
|
|
|
|
async def unload(self):
|
|
await self.engine.astop()
|
|
self.engine = None
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
logger.info("Embedding model unloaded.")
|
|
|
|
async def generate(self, sentence_input: List[str]):
|
|
result_embeddings, usage = await self.engine.embed(sentence_input)
|
|
|
|
return {"embeddings": result_embeddings, "usage": usage}
|