tabbyAPI-ollama/backends/exllamav3/model.py
kingbri 390daeb92f Model: Create universal HFModel class
The HFModel class serves to coalesce all config files that contain
random keys which are required for model usage.

Adding this base class allows us to expand as HuggingFace randomly
changes their JSON schemas over time, reducing the brunt that backend
devs need to feel when their next model isn't supported.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
2025-05-13 18:12:38 -04:00

927 lines
31 KiB
Python

import asyncio
import gc
import pathlib
import re
from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
)
import torch
from exllamav3 import (
AsyncGenerator,
AsyncJob,
Cache,
Config,
Model,
Tokenizer,
)
from exllamav3.cache import CacheLayer_quant
from loguru import logger
from backends.base_model_container import BaseModelContainer
from backends.exllamav3.sampler import ExllamaV3SamplerBuilder
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
log_metrics,
)
from common.hardware import hardware_supports_flash_attn
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import HFModel
from common.utils import coalesce, unwrap
from endpoints.core.types.model import ModelCard, ModelCardParameters
class ExllamaV3Container(BaseModelContainer):
"""Abstract base class for model containers."""
# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
# HF Model instance
hf_model: HFModel
# Load synchronization
# The bool is a master switch for accepting requests
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
active_job_ids: Dict[str, Any] = {}
loaded: bool = False
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
# Exl3 vars
model: Optional[Model] = None
cache: Optional[Cache] = None
draft_model: Optional[Model] = None
draft_cache: Optional[Cache] = None
tokenizer: Optional[Tokenizer] = None
config: Optional[Config] = None
draft_config: Optional[Config] = None
generator: Optional[AsyncGenerator] = None
# Class-specific vars
gpu_split: List[float] | None = None
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 / 1024]
use_tp: bool = False
max_seq_len: int = 4096
cache_size: int = 4096
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
chunk_size: int = 2048
max_batch_size: Optional[int] = None
# Required methods
@classmethod
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
"""
Asynchronously creates and initializes a model container instance.
Args:
model_directory: Path to the model files.
**kwargs: Backend-specific configuration options.
Returns:
An instance of the implementing class.
"""
self = cls()
logger.warning(
"ExllamaV3 is currently in an alpha state. "
"Please note that all config options may not work."
)
self.model_dir = model_directory
self.hf_model = hf_model
self.config = Config.from_directory(str(model_directory.resolve()))
self.model = Model.from_config(self.config)
self.tokenizer = Tokenizer.from_config(self.config)
# Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
self.use_draft_model = draft_args and draft_model_name
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
self.use_draft_model = False
if self.use_draft_model:
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name
self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
self.draft_model_dir = draft_model_path
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
self.draft_model = Model.from_config(self.draft_config)
logger.info(f"Using draft model: {str(draft_model_path.resolve())}")
else:
self.draft_model = None
self.draft_cache = None
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
gpu_split = unwrap(kwargs.get("gpu_split"), None)
gpu_device_list = list(range(0, gpu_count))
# Set GPU split options
if gpu_count == 1:
self.gpu_split_auto = False
logger.info("Disabling GPU split because one GPU is in use.")
else:
# TODO: Set tensor parallel
# Set GPU split options
# Enable manual GPU split if provided
if gpu_split:
self.gpu_split = gpu_split
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.gpu_split)
if memory > 0
]
elif gpu_split_auto and not self.use_tp:
# Otherwise fallback to autosplit settings
self.gpu_split_auto = gpu_split_auto
autosplit_reserve_megabytes = unwrap(
kwargs.get("autosplit_reserve"), [96]
)
# Reserve VRAM for each GPU
self.autosplit_reserve = [
value / 1024 for value in autosplit_reserve_megabytes
]
if not hardware_supports_flash_attn(gpu_device_list):
gpu_unsupported_message = (
"Unable to run ExllamaV3 because an unsupported GPU is "
"found in this configuration. \n"
"All GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
logger.warning(gpu_unsupported_message)
raise RuntimeError(gpu_unsupported_message)
# Cache
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
self.cache_size = self.adjust_cache_size(user_cache_size)
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
self.cache = self.create_cache(self.cache_mode, self.model)
# Draft cache
if self.use_draft_model:
# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
self.draft_cache = self.create_cache(
self.draft_cache_mode, self.draft_model
)
# Max batch size
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
# Make sure chunk size is >= 256, keep near or below max seq len
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
self.chunk_size = self.adjust_chunk_size(user_chunk_size)
# Template setup
self.prompt_template = await find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Catch all for template lookup errors
if self.prompt_template:
logger.info(
f'Using template "{self.prompt_template.name}" for chat completions.'
)
else:
logger.warning(
"Chat completions are disabled because a prompt "
"template wasn't provided or auto-detected."
)
return self
def adjust_cache_size(self, cache_size):
if cache_size < self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is smaller than the "
"desired context length.\n"
"Overriding cache_size to max_seq_len. "
)
cache_size = self.max_seq_len
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
cache_remainder = cache_size % 256
if cache_remainder != 0:
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
logger.warning(
f"The given cache size ({cache_size}) is "
"not a multiple of 256.\n"
"Overriding cache_size with an overestimated value of "
f"{rounded_cache_size} tokens."
)
cache_size = rounded_cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
"and may be too small for requests using CFG. \n"
"Ignore this warning if you do not plan on using CFG."
)
return cache_size
def adjust_chunk_size(self, user_chunk_size: int):
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
chunk_remainder = chunk_size % 256
if chunk_remainder != 0:
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
logger.warning(
f"The given chunk size ({chunk_size}) is "
"not a multiple of 256.\n"
"Overriding chunk_size with an overestimated value of "
f"{rounded_chunk_size} tokens."
)
chunk_size = rounded_chunk_size
return chunk_size
def create_cache(self, raw_cache_mode: str, model: Model):
# Cast exl2 types to exl3
match raw_cache_mode:
case "Q4":
raw_cache_mode = "4,4"
case "Q6":
raw_cache_mode = "6,6"
case "Q8":
raw_cache_mode = "8,8"
split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", raw_cache_mode)
if split_cache_mode:
draft_k_bits = int(split_cache_mode.group(1))
draft_v_bits = int(split_cache_mode.group(2))
cache = Cache(
model,
max_num_tokens=self.cache_size,
layer_type=CacheLayer_quant,
k_bits=draft_k_bits,
v_bits=draft_v_bits,
)
else:
cache = Cache(model, max_num_tokens=self.cache_size)
return cache
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.
Returns:
Model parameters provided by the backend
"""
model_params = ModelCardParameters(
max_seq_len=self.max_seq_len,
cache_size=self.cache_size,
max_batch_size=self.max_batch_size,
cache_mode=self.cache_mode,
chunk_size=self.chunk_size,
use_vision=self.use_vision,
)
if self.prompt_template:
model_params.prompt_template = self.prompt_template.name
model_params.prompt_template_content = self.prompt_template.raw_template
model_card = ModelCard(
id=self.model_dir.name,
parameters=model_params,
)
return model_card
async def wait_for_jobs(self, skip_wait: bool = False):
"""
Polling to wait for any active generation jobs to complete.
Args:
skip_wait: If True, cancel jobs immediately instead of waiting.
"""
if not self.generator:
return
# Immediately abort all jobs if asked
if skip_wait:
logger.warning(
"Immediately terminating all jobs. "
"Clients will have their requests cancelled.\n"
)
for job in self.active_job_ids.values():
if job:
await job.cancel()
while len(self.active_job_ids) > 0:
await asyncio.sleep(0.01)
async def load(self, progress_callback=None, **kwargs):
"""
Loads the model into memory.
Args:
progress_callback: Optional callback for progress updates.
**kwargs: Additional loading options.
"""
async for _ in self.load_gen(progress_callback):
pass
async def load_gen(self, progress_callback=None, **kwargs):
"""
Loads the model into memory, yielding progress updates.
Args:
progress_callback: Optional callback for progress updates.
**kwargs: Additional loading options.
Yields:
Progress updates
"""
try:
await self.load_lock.acquire()
# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
generator = self.load_model_sync(progress_callback)
async for value in iterate_in_threadpool(generator):
yield value
# Create async generator
await self.create_generator()
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
gc.collect()
torch.cuda.empty_cache()
# Cleanup and update model load state
self.loaded = True
logger.info("Model successfully loaded.")
finally:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
@torch.inference_mode()
def load_model_sync(self, progress_callback=None):
if self.use_draft_model:
for value in self.draft_model.load_gen(
reserve_per_device=self.autosplit_reserve,
callback=progress_callback,
):
if value:
yield value
for value in self.model.load_gen(
reserve_per_device=self.autosplit_reserve,
use_per_device=self.gpu_split,
callback=progress_callback,
):
if value:
yield value
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.loaded:
await self.load_lock.acquire()
# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)
# Create new generator
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
max_chunk_size=self.chunk_size,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.loaded:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async def unload(self, loras_only: bool = False, **kwargs):
"""
Unloads the model and associated resources from memory.
Args:
loras_only: If True, only unload LoRAs.
**kwargs: Additional unloading options (e.g., shutdown).
"""
# Used when shutting down the server
do_shutdown = kwargs.get("shutdown")
try:
if not do_shutdown:
await self.load_lock.acquire()
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
self.model.unload()
self.model = None
self.config = None
self.cache = None
self.tokenizer = None
if self.use_draft_model:
self.draft_model.unload()
self.draft_model = None
self.draft_config = None
self.draft_cache = None
# Cleanup the generator from any pending jobs
if self.generator is not None:
await self.generator.close()
self.generator = None
gc.collect()
torch.cuda.empty_cache()
logger.info("Model unloaded.")
finally:
if not do_shutdown:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
def encode_tokens(self, text: str, **kwargs) -> List[int]:
"""
Encodes a string of text into a list of token IDs.
Args:
text: The input text string.
**kwargs: Backend-specific encoding options (e.g., add_bos_token).
Returns:
A list of integer token IDs.
"""
return (
self.tokenizer.encode(
text,
add_bos=unwrap(
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
)
.flatten()
.tolist()
)
def decode_tokens(self, ids: List[int], **kwargs) -> str:
"""
Decodes a list of token IDs back into a string.
Args:
ids: A list of integer token IDs.
**kwargs: Backend-specific decoding options (e.g., decode_special_tokens).
Returns:
The decoded text string.
"""
ids = torch.tensor([ids])
return self.tokenizer.decode(
ids,
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
def get_special_tokens(
self, add_bos_token: bool = True, ban_eos_token: bool = False
):
"""
Gets special tokens used by the model/tokenizer.
Args:
**kwargs: Options like add_bos_token, ban_eos_token.
Returns:
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
to their string or ID representation.
"""
return {
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
"pad_token": self.tokenizer.pad_token,
"unk_token": self.tokenizer.unk_token,
}
async def generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> Dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Returns:
A dictionary containing the generation info
"""
generations = []
async for generation in self.stream_generate(
request_id,
prompt,
params,
abort_event,
mm_embeddings,
):
generations.append(generation)
joined_generation = {
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"tool_calls": None,
"offset": [],
"token_probs": {},
"logprobs": [],
}
if generations:
# Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else:
joined_generation["finish_reason"] = "stop"
if len(generations) > 0:
for generation in generations:
joined_generation["text"] += unwrap(generation.get("text"), "")
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
)
# Include empty logprob dicts for index preservation
joined_generation["logprobs"].append(
unwrap(generation.get("logprobs"), {})
)
joined_generation["prompt_tokens"] = unwrap(
generations[-1].get("prompt_tokens"), 0
)
joined_generation["generated_tokens"] = unwrap(
generations[-1].get("generated_tokens"), 0
)
return joined_generation
async def stream_generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Yields:
Generation chunks
"""
try:
# Wait for load lock to be freed before processing
# Mainly used for loras and other operations where the class is available
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
# If the model is being unloaded, don't accept new requests
if not self.loaded:
raise RuntimeError(
"Model is being unloaded. Cannot process new generation requests."
)
# Mark that the job is running
self.active_job_ids[request_id] = None
# Yield from the internal generator
async for generation_chunk in self.generate_gen(
request_id=request_id,
prompt=prompt,
params=params,
abort_event=abort_event,
mm_embeddings=mm_embeddings,
):
yield generation_chunk
finally:
# Clean up and remove the job from active IDs
del self.active_job_ids[request_id]
def handle_finish_chunk(self, result: dict, generation: dict):
eos_reason = result.get("eos_reason")
stop_str = None
if eos_reason == "max_new_tokens":
finish_reason = "length"
else:
finish_reason = "stop"
# Grab stop string if stop was the reason
if eos_reason == "stop_token":
stop_str = result.get("eos_triggering_token_str")
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
"stop_str": stop_str,
}
return finish_chunk
async def generate_gen(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""
Create generator function for prompt completion.
for kwargs, check common/sampling.py
"""
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
sampler_builder = ExllamaV3SamplerBuilder()
# Penalties
# Set penalty range
penalty_range = unwrap(params.penalty_range, self.max_seq_len)
# Exl3's version of including the entire context
if penalty_range < 0:
penalty_range = int(10e7)
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
# fallback
# Always default to 0 if something goes wrong
if params.penalty_range < 0:
fallback_decay = 0
else:
fallback_decay = params.penalty_range
repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0)
# Apply penalties to builder
sampler_builder.penalties(
params.repetition_penalty,
params.frequency_penalty,
params.presence_penalty,
penalty_range,
repetition_decay,
)
# Apply temperature first to builder
if not params.temperature_last:
sampler_builder.temperature(params.temperature)
# Apply alphabet samplers to builder
sampler_builder.top_k(params.top_k)
sampler_builder.top_p(params.top_p)
sampler_builder.min_p(params.min_p)
# Apply temperature last to builder
if params.temperature_last:
sampler_builder.temperature(params.temperature)
# Build the sampler
# Set greedy if temperature is 0
sampler = sampler_builder.build(params.temperature == 0)
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
# TODO: This currently does not work in exl3
# auto_scale_penalty_range = (
# gen_settings.token_frequency_penalty != 0
# or gen_settings.token_presence_penalty != 0
# ) and gen_settings.token_repetition_range == -1
prompts = [prompt]
stop_conditions = params.stop
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
# Fetch EOS tokens from generation_config if they exist
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
stop_conditions += eos_tokens
input_ids = [
self.tokenizer.encode(
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
)
for prompt in prompts
]
# The first index will always be the positive prompt
context_len = input_ids[0].size(dim=-1)
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
params.max_tokens,
self.max_seq_len - context_len,
)
if max_tokens < 1:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1
# Determine if the negative context or the context length is bigger
context_to_check = context_len
# Check total length of prompt against max context length
if context_to_check > self.max_seq_len:
preamble = "Prompt"
raise ValueError(
f"{preamble} length {context_to_check} is greater than "
f"max_seq_len {self.max_seq_len}"
)
generation = {}
job = AsyncJob(
self.generator,
sampler=sampler,
input_ids=self.tokenizer.encode(prompt, add_bos=False),
max_new_tokens=max_tokens,
stop_conditions=stop_conditions,
banned_strings=params.banned_strings,
)
generated_tokens = 0
full_response = ""
metrics_result = {}
# Get the generation status once it's ready
try:
async for result in job:
# Abort if the event is set while streaming
if abort_event and abort_event.is_set():
await job.cancel()
break
chunk = unwrap(result.get("text"), "")
if chunk:
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
full_response += chunk
if isinstance(chunk_tokens, torch.Tensor):
generated_tokens += chunk_tokens.size(dim=0)
# Increase penalty range to generated token amount
# TODO:
# if auto_scale_penalty_range:
# gen_settings.token_repetition_range = generated_tokens
generation = {
"text": chunk,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
yield generation
if result.get("eos"):
generation = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = result
yield generation
break
# Assign the active job to the request ID
self.active_job_ids[request_id] = job
except asyncio.CancelledError:
await job.cancel()
except Exception as ex:
# Create a new generator since the current state is broken
# No need to wait for this to finish
logger.error(
"FATAL ERROR with generation. "
"Attempting to recreate the generator. "
"If this fails, please restart the server.\n"
)
asyncio.ensure_future(self.create_generator())
await HealthManager.add_unhealthy_event(ex)
raise ex
finally:
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
request_id=request_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
prompt=prompt,
**params.model_dump(exclude={"prompt"}),
# auto_scale_penalty_range=auto_scale_penalty_range, # TODO
)
# Log the metrics if present
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len,
self.max_seq_len,
)