Model: Add exl3 and associated load functions
Initial exl3 compat and loading functionality. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
7c6a053747
commit
0c1d794390
5 changed files with 357 additions and 67 deletions
|
|
@ -25,6 +25,10 @@ class BaseModelContainer(abc.ABC):
|
|||
prompt_template: Optional[PromptTemplate] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# Optional features
|
||||
use_draft_model: bool = False
|
||||
use_vision: bool = False
|
||||
|
||||
# Load synchronization
|
||||
# The bool is a master switch for accepting requests
|
||||
# The lock keeps load tasks sequential
|
||||
|
|
@ -65,7 +69,7 @@ class BaseModelContainer(abc.ABC):
|
|||
|
||||
# NOTE: Might be an optional method
|
||||
@abc.abstractmethod
|
||||
async def load_gen(self, progress_callback=None, **kwargs) -> AsyncIterator[Any]:
|
||||
async def load_gen(self, progress_callback=None, **kwargs):
|
||||
"""
|
||||
Loads the model into memory, yielding progress updates.
|
||||
|
||||
|
|
@ -134,57 +138,6 @@ class BaseModelContainer(abc.ABC):
|
|||
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a complete response for a given prompt and parameters.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generation info
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def stream_generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Generates a response iteratively (streaming) for a given prompt.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Yields:
|
||||
Generation chunks
|
||||
"""
|
||||
|
||||
if False:
|
||||
yield
|
||||
|
||||
@abc.abstractmethod
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
|
|
@ -239,3 +192,54 @@ class BaseModelContainer(abc.ABC):
|
|||
"""
|
||||
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a complete response for a given prompt and parameters.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generation info
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def stream_generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Generates a response iteratively (streaming) for a given prompt.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Yields:
|
||||
Generation chunks
|
||||
"""
|
||||
|
||||
if False:
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -64,16 +64,19 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
|
||||
# Exl2 vars
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
model: Optional[ExLlamaV2] = None
|
||||
draft_model: Optional[ExLlamaV2] = None
|
||||
cache: Optional[ExLlamaV2Cache] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
paged: bool = True
|
||||
|
||||
# Draft model vars
|
||||
use_draft_model: bool = False
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
draft_model: Optional[ExLlamaV2] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
|
||||
# Internal config vars
|
||||
cache_size: int = None
|
||||
cache_mode: str = "FP16"
|
||||
|
|
@ -100,7 +103,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Primary asynchronous initializer for model container.
|
||||
|
||||
|
|
@ -110,8 +113,6 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Create a new instance as a "fake self"
|
||||
self = cls()
|
||||
|
||||
self.quiet = quiet
|
||||
|
||||
# Initialize config
|
||||
self.config = ExLlamaV2Config()
|
||||
self.model_dir = model_directory
|
||||
|
|
@ -122,6 +123,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
self.config.max_seq_len = 4096
|
||||
|
||||
self.config.prepare()
|
||||
print(self.config.max_seq_len)
|
||||
|
||||
# Check if the model arch is compatible with various exl2 features
|
||||
self.config.arch_compat_overrides()
|
||||
|
|
@ -162,7 +164,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
self.use_draft_model = draft_args and draft_model_name
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
|
|
@ -170,9 +172,9 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
enable_draft = False
|
||||
self.use_draft_model = False
|
||||
|
||||
if enable_draft:
|
||||
if self.use_draft_model:
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
|
|
@ -365,7 +367,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
self.config.max_attention_size = chunk_size**2
|
||||
|
||||
# Set user-configured draft model values
|
||||
if enable_draft:
|
||||
if self.use_draft_model:
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
self.draft_config.scale_pos_emb = unwrap(
|
||||
|
|
|
|||
275
backends/exllamav3/model.py
Normal file
275
backends/exllamav3/model.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
import asyncio
|
||||
import gc
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
from backends.base_model_container import BaseModelContainer
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from endpoints.core.types.model import ModelCard
|
||||
|
||||
from exllamav3 import Config, Model, Cache, Tokenizer
|
||||
|
||||
|
||||
class ExllamaV3Container(BaseModelContainer):
|
||||
"""Abstract base class for model containers."""
|
||||
|
||||
# Exposed model information
|
||||
model_dir: pathlib.Path = pathlib.Path("models")
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# Load synchronization
|
||||
# The bool is a master switch for accepting requests
|
||||
# The lock keeps load tasks sequential
|
||||
# The condition notifies any waiting tasks
|
||||
active_job_ids: Dict[str, Any] = {}
|
||||
loaded: bool = False
|
||||
load_lock: asyncio.Lock = asyncio.Lock()
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
# Exl3 vars
|
||||
model: Model
|
||||
cache: Cache
|
||||
tokenizer: Tokenizer
|
||||
config: Config
|
||||
|
||||
# Required methods
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Asynchronously creates and initializes a model container instance.
|
||||
|
||||
Args:
|
||||
model_directory: Path to the model files.
|
||||
**kwargs: Backend-specific configuration options.
|
||||
|
||||
Returns:
|
||||
An instance of the implementing class.
|
||||
"""
|
||||
|
||||
self = cls()
|
||||
|
||||
logger.warning(
|
||||
"ExllamaV3 is currently in an alpha state. "
|
||||
"Please note that all config options may not work."
|
||||
)
|
||||
|
||||
self.config = Config.from_directory(model_directory.resolve())
|
||||
self.model = Model.from_config(self.config)
|
||||
self.tokenizer = Tokenizer.from_config(self.config)
|
||||
|
||||
max_seq_len = kwargs.get("max_seq_len")
|
||||
self.cache = Cache(self.model, max_num_tokens=max_seq_len)
|
||||
|
||||
return self
|
||||
|
||||
async def load(self, progress_callback=None, **kwargs):
|
||||
"""
|
||||
Loads the model into memory.
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback for progress updates.
|
||||
**kwargs: Additional loading options.
|
||||
"""
|
||||
|
||||
async for _ in self.load_gen(progress_callback):
|
||||
pass
|
||||
|
||||
async def load_gen(self, progress_callback=None, **kwargs):
|
||||
"""
|
||||
Loads the model into memory, yielding progress updates.
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback for progress updates.
|
||||
**kwargs: Additional loading options.
|
||||
|
||||
Yields:
|
||||
Progress updates
|
||||
"""
|
||||
|
||||
try:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Wait for existing generation jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
generator = self.load_model_sync(progress_callback)
|
||||
async for module, modules in iterate_in_threadpool(generator):
|
||||
yield module, modules
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Cleanup and update model load state
|
||||
self.loaded = True
|
||||
logger.info("Model successfully loaded.")
|
||||
finally:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
# TODO: Add draft loading
|
||||
@torch.inference_mode()
|
||||
def load_model_sync(self, progress_callback=None):
|
||||
for value in self.model.load_gen(callback=progress_callback):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
async def unload(self, loras_only: bool = False, **kwargs):
|
||||
"""
|
||||
Unloads the model and associated resources from memory.
|
||||
|
||||
Args:
|
||||
loras_only: If True, only unload LoRAs.
|
||||
**kwargs: Additional unloading options (e.g., shutdown).
|
||||
"""
|
||||
|
||||
try:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Wait for other jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
self.model.unload()
|
||||
self.model = None
|
||||
|
||||
self.config = None
|
||||
self.cache = None
|
||||
self.tokenizer = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Model unloaded.")
|
||||
finally:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
def encode_tokens(self, text: str, **kwargs) -> List[int]:
|
||||
"""
|
||||
Encodes a string of text into a list of token IDs.
|
||||
|
||||
Args:
|
||||
text: The input text string.
|
||||
**kwargs: Backend-specific encoding options (e.g., add_bos_token).
|
||||
|
||||
Returns:
|
||||
A list of integer token IDs.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs back into a string.
|
||||
|
||||
Args:
|
||||
ids: A list of integer token IDs.
|
||||
**kwargs: Backend-specific decoding options (e.g., decode_special_tokens).
|
||||
|
||||
Returns:
|
||||
The decoded text string.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def get_special_tokens(self, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
Gets special tokens used by the model/tokenizer.
|
||||
|
||||
Args:
|
||||
**kwargs: Options like add_bos_token, ban_eos_token.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
|
||||
to their string or ID representation.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
Returns a dictionary of the current model's configuration parameters.
|
||||
|
||||
Returns:
|
||||
Model parameters provided by the backend
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||
"""
|
||||
Waits for any active generation jobs to complete.
|
||||
|
||||
Args:
|
||||
skip_wait: If True, cancel jobs immediately instead of waiting.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a complete response for a given prompt and parameters.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generation info
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
async def stream_generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Generates a response iteratively (streaming) for a given prompt.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Yields:
|
||||
Generation chunks
|
||||
"""
|
||||
|
||||
if False:
|
||||
yield
|
||||
|
|
@ -10,7 +10,7 @@ from enum import Enum
|
|||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from ruamel.yaml import YAML
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from backends.base_model_container import BaseModelContainer
|
||||
from common.logger import get_loading_progress_bar
|
||||
|
|
@ -24,7 +24,7 @@ container: Optional[BaseModelContainer] = None
|
|||
embeddings_container = None
|
||||
|
||||
|
||||
_BACKEND_REGISTRY = {}
|
||||
_BACKEND_REGISTRY: Dict[str, BaseModelContainer] = {}
|
||||
|
||||
if dependencies.exllamav2:
|
||||
from backends.exllamav2.model import ExllamaV2Container
|
||||
|
|
@ -32,6 +32,12 @@ if dependencies.exllamav2:
|
|||
_BACKEND_REGISTRY["exllamav2"] = ExllamaV2Container
|
||||
|
||||
|
||||
if dependencies.exllamav3:
|
||||
from backends.exllamav3.model import ExllamaV3Container
|
||||
|
||||
_BACKEND_REGISTRY["exllamav3"] = ExllamaV3Container
|
||||
|
||||
|
||||
if dependencies.extras:
|
||||
from backends.infinity.model import InfinityContainer
|
||||
|
||||
|
|
@ -134,7 +140,9 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
"Available backends: {available_backends}"
|
||||
)
|
||||
|
||||
new_container = await container_class.create(model_path.resolve(), False, **kwargs)
|
||||
new_container: BaseModelContainer = await container_class.create(
|
||||
model_path.resolve(), **kwargs
|
||||
)
|
||||
|
||||
# Add possible types of models that can be loaded
|
||||
model_type = [ModelType.MODEL]
|
||||
|
|
@ -142,7 +150,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
if new_container.use_vision:
|
||||
model_type.insert(0, ModelType.VISION)
|
||||
|
||||
if new_container.draft_config:
|
||||
if new_container.use_draft_model:
|
||||
model_type.insert(0, ModelType.DRAFT)
|
||||
|
||||
load_status = new_container.load_gen(load_progress, **kwargs)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class DependenciesModel(BaseModel):
|
|||
|
||||
torch: bool
|
||||
exllamav2: bool
|
||||
exllamav3: bool
|
||||
flash_attn: bool
|
||||
infinity_emb: bool
|
||||
sentence_transformers: bool
|
||||
|
|
@ -25,7 +26,7 @@ class DependenciesModel(BaseModel):
|
|||
@computed_field
|
||||
@property
|
||||
def inference(self) -> bool:
|
||||
return self.torch and self.exllamav2 and self.flash_attn
|
||||
return self.torch and (self.exllamav2 or self.exllamav3) and self.flash_attn
|
||||
|
||||
|
||||
def is_installed(package_name: str) -> bool:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue