tabbyAPI-ollama/backends/infinity/model.py
kingbri f070587e9f Model: Add proper jobs cleanup and fix var calls
Jobs should be started and immediately cleaned up when calling the
generation stream. Expose a stream_generate function and append
this to the base class since it's more idiomatic than generate_gen.

The exl2 container's generate_gen function is now internal.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
2025-04-24 21:30:55 -04:00

55 lines
1.5 KiB
Python

import gc
import pathlib
import torch
from loguru import logger
from typing import List, Optional
from common.utils import unwrap
from common.optional_dependencies import dependencies
# Conditionally import infinity to sidestep its logger
if dependencies.extras:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
class InfinityContainer:
model_dir: pathlib.Path
loaded: bool = False
# Use a runtime type hint here
engine: Optional["AsyncEmbeddingEngine"] = None
def __init__(self, model_directory: pathlib.Path):
self.model_dir = model_directory
async def load(self, **kwargs):
# Use cpu by default
device = unwrap(kwargs.get("embeddings_device"), "cpu")
engine_args = EngineArgs(
model_name_or_path=str(self.model_dir),
engine="torch",
device=device,
bettertransformer=False,
model_warmup=False,
)
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
await self.engine.astart()
self.loaded = True
logger.info("Embedding model successfully loaded.")
async def unload(self):
await self.engine.astop()
self.engine = None
gc.collect()
torch.cuda.empty_cache()
logger.info("Embedding model unloaded.")
async def generate(self, sentence_input: List[str]):
result_embeddings, usage = await self.engine.embed(sentence_input)
return {"embeddings": result_embeddings, "usage": usage}