1478 lines
52 KiB
Python
1478 lines
52 KiB
Python
"""The model container class for ExLlamaV2 models."""
|
|
|
|
import asyncio
|
|
import gc
|
|
import math
|
|
import pathlib
|
|
import torch
|
|
from exllamav2 import (
|
|
ExLlamaV2,
|
|
ExLlamaV2Config,
|
|
ExLlamaV2CacheBase,
|
|
ExLlamaV2Cache,
|
|
ExLlamaV2Cache_Q4,
|
|
ExLlamaV2Cache_Q6,
|
|
ExLlamaV2Cache_Q8,
|
|
ExLlamaV2Cache_TP,
|
|
ExLlamaV2Tokenizer,
|
|
ExLlamaV2Lora,
|
|
ExLlamaV2VisionTower,
|
|
)
|
|
from exllamav2.generator import (
|
|
ExLlamaV2Sampler,
|
|
ExLlamaV2DynamicGeneratorAsync,
|
|
ExLlamaV2DynamicJobAsync,
|
|
)
|
|
from itertools import zip_longest
|
|
from loguru import logger
|
|
from typing import Dict, List, Optional
|
|
|
|
from backends.base_model_container import BaseModelContainer
|
|
from backends.exllamav2.grammar import (
|
|
ExLlamaV2Grammar,
|
|
clear_grammar_func_cache,
|
|
)
|
|
from backends.exllamav2.utils import exllama_disabled_flash_attn
|
|
from backends.exllamav2.vision import clear_image_embedding_cache
|
|
from common.concurrency import iterate_in_threadpool
|
|
from common.gen_logging import (
|
|
log_generation_params,
|
|
log_metrics,
|
|
log_prompt,
|
|
log_response,
|
|
)
|
|
from common.hardware import hardware_supports_flash_attn
|
|
from common.health import HealthManager
|
|
from common.multimodal import MultimodalEmbeddingWrapper
|
|
from common.optional_dependencies import check_package_version
|
|
from common.sampling import BaseSamplerRequest
|
|
from common.templating import PromptTemplate, find_prompt_template
|
|
from common.transformers_utils import HFModel
|
|
from common.utils import calculate_rope_alpha, coalesce, unwrap
|
|
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
|
|
|
|
|
class ExllamaV2Container(BaseModelContainer):
|
|
"""The model container class for ExLlamaV2 models."""
|
|
|
|
# Model directories
|
|
model_dir: pathlib.Path = pathlib.Path("models")
|
|
draft_model_dir: pathlib.Path = pathlib.Path("models")
|
|
prompt_template: Optional[PromptTemplate] = None
|
|
|
|
# HF model instance
|
|
hf_model: HFModel
|
|
|
|
# Exl2 vars
|
|
config: Optional[ExLlamaV2Config] = None
|
|
model: Optional[ExLlamaV2] = None
|
|
cache: Optional[ExLlamaV2Cache] = None
|
|
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
|
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
|
|
prompt_template: Optional[PromptTemplate] = None
|
|
paged: bool = True
|
|
|
|
# Draft model vars
|
|
use_draft_model: bool = False
|
|
draft_config: Optional[ExLlamaV2Config] = None
|
|
draft_model: Optional[ExLlamaV2] = None
|
|
draft_cache: Optional[ExLlamaV2Cache] = None
|
|
|
|
# Internal config vars
|
|
cache_size: int = None
|
|
cache_mode: str = "FP16"
|
|
draft_cache_mode: str = "FP16"
|
|
max_batch_size: Optional[int] = None
|
|
|
|
# GPU split vars
|
|
gpu_split: List[float] = []
|
|
draft_gpu_split: List[float] = []
|
|
gpu_split_auto: bool = True
|
|
autosplit_reserve: List[float] = [96 * 1024**2]
|
|
use_tp: bool = False
|
|
|
|
# Vision vars
|
|
use_vision: bool = False
|
|
vision_model: Optional[ExLlamaV2VisionTower] = None
|
|
|
|
# Load synchronization
|
|
active_job_ids: Dict[str, Optional[ExLlamaV2DynamicJobAsync]] = {}
|
|
loaded: bool = False
|
|
load_lock: asyncio.Lock = asyncio.Lock()
|
|
load_condition: asyncio.Condition = asyncio.Condition()
|
|
|
|
@classmethod
|
|
async def create(cls, model_directory: pathlib.Path, hf_model: HFModel, **kwargs):
|
|
"""
|
|
Primary asynchronous initializer for model container.
|
|
|
|
Kwargs are located in config_sample.yml
|
|
"""
|
|
|
|
# Create a new instance as a "fake self"
|
|
self = cls()
|
|
|
|
# Make sure ExllamaV2 is up to date
|
|
check_package_version("exllamav2", "0.3.1")
|
|
|
|
# Initialize config
|
|
self.config = ExLlamaV2Config()
|
|
self.model_dir = model_directory
|
|
self.config.model_dir = str(model_directory.resolve())
|
|
self.hf_model = hf_model
|
|
|
|
# Make the max seq len 4096 before preparing the config
|
|
# This is a better default than 2048
|
|
self.config.max_seq_len = 4096
|
|
|
|
self.config.prepare()
|
|
|
|
# Check if the model arch is compatible with various exl2 features
|
|
self.config.arch_compat_overrides()
|
|
|
|
# Set vision state and error if vision isn't supported on the current model
|
|
self.use_vision = unwrap(kwargs.get("vision"), False)
|
|
if self.use_vision and not self.config.vision_model_type:
|
|
raise ValueError(
|
|
"The provided model does not have vision capabilities that are "
|
|
"supported by ExllamaV2. "
|
|
"Please reload with vision disabled."
|
|
)
|
|
|
|
# Prepare the draft model config if necessary
|
|
draft_args = unwrap(kwargs.get("draft_model"), {})
|
|
draft_model_name = draft_args.get("draft_model_name")
|
|
self.use_draft_model = draft_args and draft_model_name
|
|
|
|
# Always disable draft if params are incorrectly configured
|
|
if draft_args and draft_model_name is None:
|
|
logger.warning(
|
|
"Draft model is disabled because a model name "
|
|
"wasn't provided. Please check your config.yml!"
|
|
)
|
|
self.use_draft_model = False
|
|
|
|
if self.use_draft_model:
|
|
self.draft_config = ExLlamaV2Config()
|
|
draft_model_path = pathlib.Path(
|
|
unwrap(draft_args.get("draft_model_dir"), "models")
|
|
)
|
|
draft_model_path = draft_model_path / draft_model_name
|
|
|
|
self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
|
|
self.draft_model_dir = draft_model_path
|
|
self.draft_config.model_dir = str(draft_model_path.resolve())
|
|
self.draft_config.prepare()
|
|
|
|
# MARK: User configuration
|
|
|
|
# Get cache mode
|
|
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
|
|
|
# Catch exllamav3 cache_mode
|
|
if self.cache_mode != "FP16" and not self.cache_mode.startswith("Q"):
|
|
logger.warning(
|
|
f"Provided cache mode '{self.cache_mode}' is not a "
|
|
"valid choice for exllamav2, please check your settings. "
|
|
"Defaulting to FP16."
|
|
)
|
|
self.cache_mode = "FP16"
|
|
|
|
# Turn off GPU split if the user is using 1 GPU
|
|
gpu_count = torch.cuda.device_count()
|
|
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
|
use_tp = unwrap(kwargs.get("tensor_parallel"), False)
|
|
gpu_split = unwrap(kwargs.get("gpu_split"), [])
|
|
gpu_device_list = list(range(0, gpu_count))
|
|
|
|
# Set GPU split options
|
|
if gpu_count == 1:
|
|
self.gpu_split_auto = False
|
|
logger.info("Disabling GPU split because one GPU is in use.")
|
|
else:
|
|
# Set tensor parallel
|
|
if use_tp:
|
|
self.use_tp = True
|
|
|
|
# TP has its own autosplit loader
|
|
self.gpu_split_auto = False
|
|
|
|
# Enable manual GPU split if provided
|
|
if gpu_split:
|
|
self.gpu_split_auto = False
|
|
self.gpu_split = gpu_split
|
|
|
|
gpu_device_list = [
|
|
device_idx
|
|
for device_idx, memory in enumerate(self.gpu_split)
|
|
if memory > 0
|
|
]
|
|
elif gpu_split_auto and not self.use_tp:
|
|
# Otherwise fallback to autosplit settings
|
|
self.gpu_split_auto = gpu_split_auto
|
|
|
|
autosplit_reserve_megabytes = unwrap(
|
|
kwargs.get("autosplit_reserve"), [96]
|
|
)
|
|
|
|
# Reserve VRAM for each GPU
|
|
self.autosplit_reserve = [
|
|
int(math.ceil(value * 1024**2))
|
|
for value in autosplit_reserve_megabytes
|
|
]
|
|
|
|
# Change the GPU device list only if gpu_split's list is too small
|
|
# This allows for an uneven list specification
|
|
if self.draft_gpu_split and len(self.draft_gpu_split) > len(self.gpu_split):
|
|
gpu_device_list = [
|
|
device_idx
|
|
for device_idx, memory in enumerate(self.draft_gpu_split)
|
|
if memory > 0
|
|
]
|
|
|
|
# Hardcode max output length to 16
|
|
self.config.max_output_len = 16
|
|
|
|
# Grab the base model's sequence length before overrides for
|
|
# rope calculations
|
|
base_seq_len = hf_model.hf_config.max_position_embeddings
|
|
|
|
# Set the target seq len if present
|
|
target_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
|
|
|
# Set the rope scale
|
|
self.config.scale_pos_emb = unwrap(
|
|
kwargs.get("rope_scale"), self.config.scale_pos_emb
|
|
)
|
|
|
|
# Sets rope alpha value.
|
|
# Utilize the model's max_position_embeddings as a base value
|
|
# Automatically calculate if unset or defined as an "auto" literal.
|
|
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
|
|
if rope_alpha == "auto":
|
|
self.config.scale_alpha_value = calculate_rope_alpha(
|
|
base_seq_len, target_seq_len
|
|
)
|
|
else:
|
|
self.config.scale_alpha_value = rope_alpha
|
|
|
|
# Set the max seq len if specified
|
|
if target_seq_len:
|
|
self.config.max_seq_len = target_seq_len
|
|
|
|
# Set max batch size to the config override
|
|
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))
|
|
|
|
# Check whether the user's configuration supports flash/paged attention
|
|
# Also check if exl2 has disabled flash attention
|
|
if exllama_disabled_flash_attn(
|
|
self.config.no_flash_attn
|
|
) or not hardware_supports_flash_attn(gpu_device_list):
|
|
gpu_unsupported_message = (
|
|
"An unsupported GPU is found in this configuration. "
|
|
"Switching to compatibility mode. \n"
|
|
"This disables parallel batching "
|
|
"and features that rely on it (ex. CFG). \n"
|
|
"To disable compatability mode, all GPUs must be ampere "
|
|
"(30 series) or newer. AMD GPUs are not supported."
|
|
)
|
|
|
|
logger.warning(gpu_unsupported_message)
|
|
|
|
self.config.no_flash_attn = True
|
|
if self.draft_config:
|
|
self.draft_config.no_flash_attn = True
|
|
self.paged = False
|
|
self.max_batch_size = 1
|
|
torch.backends.cuda.enable_flash_sdp(False)
|
|
|
|
# Set k/v cache size
|
|
# cache_size is only relevant when paged mode is enabled
|
|
if self.paged:
|
|
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
|
|
|
|
if cache_size < self.config.max_seq_len:
|
|
logger.warning(
|
|
f"The given cache_size ({cache_size}) is smaller than the "
|
|
"desired context length.\n"
|
|
"Overriding cache_size to max_seq_len. "
|
|
)
|
|
|
|
cache_size = self.config.max_seq_len
|
|
|
|
# Enforce a multiple of 256 for cache size
|
|
# Overestimate to ensure that the cache isn't below max_seq_len
|
|
cache_remainder = cache_size % 256
|
|
if cache_remainder != 0:
|
|
rounded_cache_size = int(
|
|
256 * ((cache_size - cache_remainder) / 256 + 1)
|
|
)
|
|
|
|
logger.warning(
|
|
f"The given cache size ({cache_size}) is "
|
|
"not a multiple of 256.\n"
|
|
"Overriding cache_size with an overestimated value of "
|
|
f"{rounded_cache_size} tokens."
|
|
)
|
|
|
|
cache_size = rounded_cache_size
|
|
|
|
# Warn user if cache size may be inadequate for CFG
|
|
if cache_size < 2 * self.config.max_seq_len:
|
|
logger.warning(
|
|
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
|
"and may be too small for requests using CFG. \n"
|
|
"Ignore this warning if you do not plan on using CFG."
|
|
)
|
|
|
|
self.cache_size = cache_size
|
|
else:
|
|
self.cache_size = self.config.max_seq_len
|
|
|
|
# Try to set prompt template
|
|
self.prompt_template = await find_prompt_template(
|
|
kwargs.get("prompt_template"), model_directory
|
|
)
|
|
|
|
# Catch all for template lookup errors
|
|
if self.prompt_template:
|
|
logger.info(
|
|
f'Using template "{self.prompt_template.name}" for chat completions.'
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Chat completions are disabled because a prompt "
|
|
"template wasn't provided or auto-detected."
|
|
)
|
|
|
|
# Make sure chunk size is >= 256, keep near or below max seq len
|
|
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
|
|
chunk_size = sorted((256, user_chunk_size, self.config.max_seq_len))[1]
|
|
chunk_remainder = chunk_size % 256
|
|
if chunk_remainder != 0:
|
|
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
|
|
|
|
logger.warning(
|
|
f"The given chunk size ({chunk_size}) is "
|
|
"not a multiple of 256.\n"
|
|
"Overriding chunk_size with an overestimated value of "
|
|
f"{rounded_chunk_size} tokens."
|
|
)
|
|
chunk_size = rounded_chunk_size
|
|
self.config.max_input_len = chunk_size
|
|
self.config.max_attention_size = chunk_size**2
|
|
|
|
# Set user-configured draft model values
|
|
if self.use_draft_model:
|
|
self.draft_config.max_seq_len = self.config.max_seq_len
|
|
|
|
self.draft_config.scale_pos_emb = unwrap(
|
|
draft_args.get("draft_rope_scale"), 1.0
|
|
)
|
|
|
|
# Set draft rope alpha. Follows same behavior as model rope alpha.
|
|
# Use the max_position_embeddings of the model
|
|
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
|
|
if draft_rope_alpha == "auto":
|
|
self.draft_config.scale_alpha_value = calculate_rope_alpha(
|
|
base_seq_len, self.draft_config.max_seq_len
|
|
)
|
|
else:
|
|
self.draft_config.scale_alpha_value = draft_rope_alpha
|
|
|
|
# Set draft cache mode
|
|
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
|
|
|
|
# Catch exllamav3 draft_cache_mode
|
|
if self.draft_cache_mode != "FP16" and not self.draft_cache_mode.startswith(
|
|
"Q"
|
|
):
|
|
logger.warning(
|
|
f"Provided draft cache mode '{self.draft_cache_mode}' is not a "
|
|
"valid choice for exllamav2, please check your settings. "
|
|
"Defaulting to FP16."
|
|
)
|
|
self.draft_cache_mode = "FP16"
|
|
|
|
# Edit the draft config size
|
|
if chunk_size:
|
|
self.draft_config.max_input_len = chunk_size
|
|
self.draft_config.max_attention_size = chunk_size**2
|
|
|
|
# Return the created instance
|
|
return self
|
|
|
|
def model_info(self):
|
|
draft_model_card: ModelCard = None
|
|
if self.draft_config:
|
|
draft_model_params = ModelCardParameters(
|
|
max_seq_len=self.draft_config.max_seq_len,
|
|
rope_scale=self.draft_config.scale_pos_emb,
|
|
rope_alpha=self.draft_config.scale_alpha_value,
|
|
cache_mode=self.draft_cache_mode,
|
|
)
|
|
|
|
draft_model_card = ModelCard(
|
|
id=self.draft_model_dir.name,
|
|
parameters=draft_model_params,
|
|
)
|
|
|
|
model_params = ModelCardParameters(
|
|
max_seq_len=self.config.max_seq_len,
|
|
cache_size=self.cache_size,
|
|
rope_scale=self.config.scale_pos_emb,
|
|
rope_alpha=self.config.scale_alpha_value,
|
|
max_batch_size=self.max_batch_size,
|
|
cache_mode=self.cache_mode,
|
|
chunk_size=self.config.max_input_len,
|
|
use_vision=self.use_vision,
|
|
draft=draft_model_card,
|
|
)
|
|
|
|
if self.prompt_template:
|
|
model_params.prompt_template = self.prompt_template.name
|
|
model_params.prompt_template_content = self.prompt_template.raw_template
|
|
|
|
model_card = ModelCard(
|
|
id=self.model_dir.name,
|
|
parameters=model_params,
|
|
)
|
|
|
|
return model_card
|
|
|
|
async def wait_for_jobs(self, skip_wait: bool = False):
|
|
"""Polling mechanism to wait for pending generation jobs."""
|
|
|
|
if not self.generator:
|
|
return
|
|
|
|
# Immediately abort all jobs if asked
|
|
if skip_wait:
|
|
logger.warning(
|
|
"Immediately terminating all jobs. "
|
|
"Clients will have their requests cancelled.\n"
|
|
)
|
|
|
|
for job in self.active_job_ids.values():
|
|
if job:
|
|
await job.cancel()
|
|
|
|
while len(self.active_job_ids) > 0:
|
|
await asyncio.sleep(0.01)
|
|
|
|
async def load(self, progress_callback=None):
|
|
"""
|
|
Load model
|
|
|
|
Args:
|
|
progress_callback (function, optional): A function to call for each
|
|
module loaded.
|
|
|
|
Prototype:
|
|
def progress(loaded_modules: int, total_modules: int)
|
|
"""
|
|
|
|
async for _ in self.load_gen(progress_callback):
|
|
pass
|
|
|
|
async def load_gen(self, progress_callback=None, **kwargs):
|
|
"""Loads a model and streams progress via a generator."""
|
|
|
|
# Indicate that model load has started
|
|
# Do this operation under the load lock's context
|
|
try:
|
|
await self.load_lock.acquire()
|
|
|
|
# Wait for existing generation jobs to finish
|
|
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
|
|
|
# Streaming gen for model load progress
|
|
model_load_generator = self.load_model_sync(progress_callback)
|
|
async for value in iterate_in_threadpool(model_load_generator):
|
|
yield value
|
|
|
|
# Create async generator
|
|
await self.create_generator()
|
|
|
|
# Clean up any extra vram usage from torch and cuda
|
|
# (Helps reduce VRAM bottlenecking on Windows)
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Cleanup and update model load state
|
|
self.loaded = True
|
|
logger.info("Model successfully loaded.")
|
|
finally:
|
|
self.load_lock.release()
|
|
|
|
async with self.load_condition:
|
|
self.load_condition.notify_all()
|
|
|
|
@torch.inference_mode()
|
|
def load_model_sync(self, progress_callback=None):
|
|
"""
|
|
Synchronous generator for loading.
|
|
|
|
Args:
|
|
progress_callback (function, optional): A function to call for each
|
|
module loaded.
|
|
|
|
Prototype:
|
|
def progress(loaded_modules: int, total_modules: int)
|
|
|
|
Runs under a shared inference mode context.
|
|
"""
|
|
|
|
# Reset tokenizer namespace vars and create a tokenizer
|
|
ExLlamaV2Tokenizer.unspecial_piece_to_id = {}
|
|
ExLlamaV2Tokenizer.unspecial_id_to_piece = {}
|
|
ExLlamaV2Tokenizer.extended_id_to_piece = {}
|
|
ExLlamaV2Tokenizer.extended_piece_to_id = {}
|
|
|
|
self.tokenizer = ExLlamaV2Tokenizer(self.config)
|
|
|
|
# Calculate autosplit reserve for all GPUs
|
|
gpu_count = torch.cuda.device_count()
|
|
autosplit_reserve = self.autosplit_reserve + [0] * (
|
|
gpu_count - len(self.autosplit_reserve)
|
|
)
|
|
|
|
# Load draft model if a config is present
|
|
if self.draft_config:
|
|
self.draft_model = ExLlamaV2(self.draft_config)
|
|
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
|
|
|
# Draft uses the autosplit loader, so create a cache that reflects this
|
|
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
|
|
|
|
if self.draft_gpu_split:
|
|
logger.info("Loading with a manual GPU split (or a one GPU setup)")
|
|
|
|
for value in self.draft_model.load_gen(
|
|
self.draft_gpu_split,
|
|
callback_gen=progress_callback,
|
|
):
|
|
if value:
|
|
yield value
|
|
|
|
self.draft_cache = self.create_cache(
|
|
cache_class=draft_cache_class,
|
|
autosplit=False,
|
|
use_tp=False,
|
|
model=self.draft_model,
|
|
)
|
|
else:
|
|
logger.info("Loading with autosplit")
|
|
|
|
self.draft_cache = self.create_cache(
|
|
cache_class=draft_cache_class,
|
|
autosplit=True,
|
|
use_tp=False,
|
|
model=self.draft_model,
|
|
)
|
|
|
|
for value in self.draft_model.load_autosplit_gen(
|
|
self.draft_cache,
|
|
reserve_vram=autosplit_reserve,
|
|
last_id_only=True,
|
|
callback_gen=progress_callback,
|
|
):
|
|
if value:
|
|
yield value
|
|
|
|
# Test VRAM allocation with a full-length forward pass
|
|
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
|
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
|
|
|
# Load vision tower if it exists
|
|
if self.use_vision:
|
|
self.vision_model = ExLlamaV2VisionTower(self.config)
|
|
|
|
for value in self.vision_model.load_gen(callback_gen=progress_callback):
|
|
if value:
|
|
yield value
|
|
|
|
self.model = ExLlamaV2(self.config)
|
|
logger.info("Loading model: " + self.config.model_dir)
|
|
|
|
# Get class of the model cache
|
|
cache_class = self.get_cache_class(self.cache_mode)
|
|
|
|
# Load model with manual split
|
|
# Entrypoint for single GPU users
|
|
if self.use_tp:
|
|
logger.info("Loading with tensor parallel")
|
|
|
|
# GPU split must be None if the array is empty
|
|
# Otherwise the TP loader fails
|
|
for value in self.model.load_tp_gen(
|
|
self.gpu_split or None,
|
|
callback_gen=progress_callback,
|
|
expect_cache_base=cache_class,
|
|
expect_cache_tokens=self.cache_size,
|
|
):
|
|
if value:
|
|
yield value
|
|
elif not self.gpu_split_auto:
|
|
logger.info("Loading with a manual GPU split (or a one GPU setup)")
|
|
|
|
for value in self.model.load_gen(
|
|
self.gpu_split,
|
|
callback_gen=progress_callback,
|
|
):
|
|
if value:
|
|
yield value
|
|
|
|
# Create the model cache
|
|
self.cache = self.create_cache(
|
|
cache_class=cache_class,
|
|
autosplit=self.gpu_split_auto,
|
|
use_tp=self.use_tp,
|
|
model=self.model,
|
|
)
|
|
|
|
# Load model with autosplit (without TP)
|
|
if self.gpu_split_auto and not self.use_tp:
|
|
logger.info("Loading with autosplit")
|
|
|
|
for value in self.model.load_autosplit_gen(
|
|
self.cache,
|
|
reserve_vram=autosplit_reserve,
|
|
last_id_only=True,
|
|
callback_gen=progress_callback,
|
|
):
|
|
if value:
|
|
yield value
|
|
|
|
# Test VRAM allocation with a full-length forward pass
|
|
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
|
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
|
|
|
# TODO: Maybe make a wrapper class with an ID instead of a utility function
|
|
def get_cache_class(self, cache_mode: str):
|
|
"""Utility function to get a cache class based on user preference."""
|
|
|
|
match cache_mode:
|
|
case "Q4":
|
|
return ExLlamaV2Cache_Q4
|
|
case "Q6":
|
|
return ExLlamaV2Cache_Q6
|
|
case "Q8":
|
|
return ExLlamaV2Cache_Q8
|
|
case _:
|
|
return ExLlamaV2Cache
|
|
|
|
def create_cache(
|
|
self,
|
|
cache_class: ExLlamaV2CacheBase,
|
|
autosplit: bool,
|
|
use_tp: bool,
|
|
model: ExLlamaV2,
|
|
):
|
|
"""Utility function to create a model cache."""
|
|
|
|
if use_tp:
|
|
return ExLlamaV2Cache_TP(
|
|
model,
|
|
base=cache_class,
|
|
max_seq_len=self.cache_size,
|
|
batch_size=1,
|
|
)
|
|
else:
|
|
return cache_class(
|
|
model,
|
|
max_seq_len=self.cache_size,
|
|
lazy=autosplit,
|
|
batch_size=1,
|
|
)
|
|
|
|
async def create_generator(self):
|
|
"""Create and save a Exllama generator class."""
|
|
|
|
try:
|
|
# Don't acquire locks unless a model is loaded
|
|
if self.loaded:
|
|
await self.load_lock.acquire()
|
|
|
|
# Immediately cancel all jobs
|
|
await self.wait_for_jobs(skip_wait=True)
|
|
|
|
# Create new generator
|
|
self.generator = ExLlamaV2DynamicGeneratorAsync(
|
|
model=self.model,
|
|
cache=self.cache,
|
|
draft_model=self.draft_model,
|
|
draft_cache=self.draft_cache,
|
|
tokenizer=self.tokenizer,
|
|
max_batch_size=self.max_batch_size,
|
|
paged=self.paged,
|
|
)
|
|
|
|
# Update the state of the container var
|
|
if self.max_batch_size is None:
|
|
self.max_batch_size = self.generator.generator.max_batch_size
|
|
finally:
|
|
# This means the generator is being recreated
|
|
# The load lock is already released in the load function
|
|
if self.loaded:
|
|
self.load_lock.release()
|
|
|
|
async with self.load_condition:
|
|
self.load_condition.notify_all()
|
|
|
|
def get_loras(self):
|
|
"""Convenience function to get all loras."""
|
|
|
|
return unwrap(self.generator.generator.current_loras, [])
|
|
|
|
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
|
"""Load loras."""
|
|
|
|
loras = unwrap(kwargs.get("loras"), [])
|
|
|
|
try:
|
|
await self.load_lock.acquire()
|
|
|
|
# Wait for existing generation jobs to finish
|
|
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
|
|
|
loras_to_load: List[ExLlamaV2Lora] = []
|
|
success: List[str] = []
|
|
failure: List[str] = []
|
|
|
|
for lora in loras:
|
|
lora_name = lora.get("name")
|
|
lora_scaling = unwrap(lora.get("scaling"), 1.0)
|
|
|
|
if lora_name is None:
|
|
logger.warning(
|
|
"One of your loras does not have a name. Please check your "
|
|
"config.yml! Skipping lora load."
|
|
)
|
|
failure.append(lora_name)
|
|
continue
|
|
|
|
logger.info(f"Adding lora: {lora_name} at scaling {lora_scaling}")
|
|
lora_path = lora_directory / lora_name
|
|
|
|
loras_to_load.append(
|
|
ExLlamaV2Lora.from_directory(self.model, lora_path, lora_scaling)
|
|
)
|
|
logger.info(f"Lora successfully added: {lora_name}")
|
|
success.append(lora_name)
|
|
|
|
self.generator.generator.set_loras(loras_to_load)
|
|
logger.info("All loras successfully loaded")
|
|
|
|
# Return success and failure names
|
|
return {"success": success, "failure": failure}
|
|
finally:
|
|
self.load_lock.release()
|
|
|
|
async with self.load_condition:
|
|
self.load_condition.notify_all()
|
|
|
|
async def unload(self, loras_only: bool = False, **kwargs):
|
|
"""Free all VRAM resources used by the model (and loras)."""
|
|
|
|
# Shutdown immediately unloads and bypasses all locks
|
|
do_shutdown = kwargs.get("shutdown")
|
|
|
|
try:
|
|
if not do_shutdown:
|
|
await self.load_lock.acquire()
|
|
|
|
# Wait for other jobs to finish
|
|
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
|
|
|
# Delete references held in the grammar module
|
|
clear_grammar_func_cache()
|
|
|
|
# Clear the image embedding cache
|
|
clear_image_embedding_cache()
|
|
|
|
# Unload LoRAs
|
|
if self.generator and self.generator.generator.current_loras:
|
|
for lora in self.generator.generator.current_loras:
|
|
lora.unload()
|
|
|
|
self.generator.generator.set_loras([])
|
|
|
|
# Unload the entire model if not just unloading loras
|
|
if not loras_only:
|
|
if self.model:
|
|
self.model.unload()
|
|
self.model = None
|
|
|
|
if self.vision_model:
|
|
self.vision_model.unload()
|
|
|
|
self.vision_model = None
|
|
|
|
if self.draft_model:
|
|
self.draft_model.unload()
|
|
self.draft_model = None
|
|
|
|
self.config = None
|
|
self.cache = None
|
|
self.tokenizer = None
|
|
|
|
# Cleanup the generator from any pending jobs
|
|
if self.generator is not None:
|
|
await self.generator.close()
|
|
self.generator = None
|
|
|
|
# Set all model state variables to False
|
|
self.loaded = False
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
logger.info("Loras unloaded." if loras_only else "Model unloaded.")
|
|
finally:
|
|
if not do_shutdown:
|
|
self.load_lock.release()
|
|
|
|
async with self.load_condition:
|
|
self.load_condition.notify_all()
|
|
|
|
def encode_tokens(self, text: str, **kwargs):
|
|
"""Wrapper to encode tokens from a text string."""
|
|
|
|
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
|
|
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
|
|
|
return (
|
|
self.tokenizer.encode(
|
|
text,
|
|
add_bos=unwrap(
|
|
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
|
|
),
|
|
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
|
embeddings=mm_embeddings_content,
|
|
)
|
|
.flatten()
|
|
.tolist()
|
|
)
|
|
|
|
def decode_tokens(self, ids: List[int], **kwargs):
|
|
"""Wrapper to decode tokens from a list of IDs"""
|
|
|
|
ids = torch.tensor([ids])
|
|
return self.tokenizer.decode(
|
|
ids,
|
|
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
|
)[0]
|
|
|
|
def get_special_tokens(self):
|
|
return {
|
|
"bos_token": self.tokenizer.bos_token,
|
|
"eos_token": self.tokenizer.eos_token,
|
|
"pad_token": self.tokenizer.pad_token,
|
|
"unk_token": self.tokenizer.unk_token,
|
|
}
|
|
|
|
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
|
|
top_tokens = [
|
|
self.tokenizer.extended_id_to_piece.get(
|
|
index, self.tokenizer.get_id_to_piece_list(True)[index]
|
|
)
|
|
for index in token_ids.flatten().tolist()
|
|
]
|
|
|
|
top_values = torch.log(token_probs).flatten().tolist()
|
|
|
|
# Cannot return -inf in JSON
|
|
cleaned_values = [
|
|
-1000 if value == float("-inf") else value for value in top_values
|
|
]
|
|
|
|
return dict(zip_longest(top_tokens, cleaned_values))
|
|
|
|
async def generate(
|
|
self,
|
|
request_id: str,
|
|
prompt: str,
|
|
params: BaseSamplerRequest,
|
|
abort_event: Optional[asyncio.Event] = None,
|
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
|
):
|
|
"""Generate a response to a prompt."""
|
|
generations = []
|
|
async for generation in self.stream_generate(
|
|
request_id,
|
|
prompt,
|
|
params,
|
|
abort_event,
|
|
mm_embeddings,
|
|
):
|
|
generations.append(generation)
|
|
|
|
joined_generation = {
|
|
"text": "",
|
|
"prompt_tokens": 0,
|
|
"gen_tokens": 0,
|
|
"tool_calls": None,
|
|
"offset": [],
|
|
"token_probs": {},
|
|
"logprobs": [],
|
|
}
|
|
|
|
if generations:
|
|
# Get finish_reason first and then shift where -1 points to
|
|
if "finish_reason" in generations[-1]:
|
|
finish_chunk = generations.pop()
|
|
joined_generation = {**joined_generation, **finish_chunk}
|
|
else:
|
|
joined_generation["finish_reason"] = "stop"
|
|
|
|
if len(generations) > 0:
|
|
for generation in generations:
|
|
joined_generation["text"] += unwrap(generation.get("text"), "")
|
|
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
|
|
joined_generation["token_probs"].update(
|
|
unwrap(generation.get("token_probs"), {})
|
|
)
|
|
|
|
# Include empty logprob dicts for index preservation
|
|
joined_generation["logprobs"].append(
|
|
unwrap(generation.get("logprobs"), {})
|
|
)
|
|
|
|
joined_generation["prompt_tokens"] = unwrap(
|
|
generations[-1].get("prompt_tokens"), 0
|
|
)
|
|
joined_generation["generated_tokens"] = unwrap(
|
|
generations[-1].get("generated_tokens"), 0
|
|
)
|
|
|
|
return joined_generation
|
|
|
|
async def stream_generate(
|
|
self,
|
|
request_id: str,
|
|
prompt: str,
|
|
params: BaseSamplerRequest,
|
|
abort_event: Optional[asyncio.Event] = None,
|
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
|
):
|
|
try:
|
|
# Wait for load lock to be freed before processing
|
|
# Mainly used for loras and other operations where the class is available
|
|
async with self.load_condition:
|
|
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
|
|
|
# If the model is being unloaded, don't accept new requests
|
|
if not self.loaded:
|
|
raise RuntimeError(
|
|
"Model is being unloaded. Cannot process new generation requests."
|
|
)
|
|
|
|
# Mark that the job is running
|
|
self.active_job_ids[request_id] = None
|
|
|
|
# Yield from the internal generator
|
|
async for generation_chunk in self.generate_gen(
|
|
request_id=request_id,
|
|
prompt=prompt,
|
|
params=params,
|
|
abort_event=abort_event,
|
|
mm_embeddings=mm_embeddings,
|
|
):
|
|
yield generation_chunk
|
|
finally:
|
|
# Clean up and remove the job from active IDs
|
|
del self.active_job_ids[request_id]
|
|
|
|
def check_unsupported_settings(self, params: BaseSamplerRequest):
|
|
"""
|
|
Check and warn the user if a sampler is unsupported.
|
|
|
|
Meant for dev wheels!
|
|
"""
|
|
|
|
return params
|
|
|
|
def assign_gen_params(
|
|
self,
|
|
params: BaseSamplerRequest,
|
|
gen_settings: ExLlamaV2Sampler.Settings,
|
|
grammar_handler: ExLlamaV2Grammar,
|
|
):
|
|
# Apply settings
|
|
gen_settings.temperature = params.temperature
|
|
gen_settings.temperature_last = params.temperature_last
|
|
gen_settings.smoothing_factor = params.smoothing_factor
|
|
gen_settings.top_k = params.top_k
|
|
gen_settings.top_p = params.top_p
|
|
gen_settings.top_a = params.top_a
|
|
gen_settings.min_p = params.min_p
|
|
gen_settings.tfs = params.tfs
|
|
gen_settings.typical = params.typical
|
|
gen_settings.mirostat = params.mirostat_mode == 2
|
|
gen_settings.skew = params.skew
|
|
|
|
# XTC
|
|
if params.xtc_probability > 0.0:
|
|
gen_settings.xtc_probability = params.xtc_probability
|
|
|
|
# 0.1 is the default for this value
|
|
gen_settings.xtc_threshold = params.xtc_threshold
|
|
|
|
# DynaTemp settings
|
|
max_temp = params.max_temp
|
|
min_temp = params.min_temp
|
|
|
|
if params.max_temp > params.min_temp:
|
|
gen_settings.max_temp = max_temp
|
|
gen_settings.min_temp = min_temp
|
|
gen_settings.temp_exponent = params.temp_exponent
|
|
else:
|
|
# Force to default values
|
|
gen_settings.max_temp = 1.0
|
|
gen_settings.min_temp = 1.0
|
|
gen_settings.temp_exponent = 1.0
|
|
|
|
# Warn if max/min temp values are > 0
|
|
# and if they're less than or equal to each other
|
|
if max_temp < min_temp or (
|
|
1 not in {min_temp, max_temp} and max_temp == min_temp
|
|
):
|
|
logger.warning(
|
|
"Max temp is less than or equal to min temp, skipping DynaTemp."
|
|
)
|
|
|
|
# Default tau and eta fallbacks don't matter if mirostat is off
|
|
gen_settings.mirostat_tau = params.mirostat_tau
|
|
gen_settings.mirostat_eta = params.mirostat_eta
|
|
|
|
# Penalties
|
|
gen_settings.token_repetition_penalty = params.repetition_penalty
|
|
gen_settings.token_frequency_penalty = params.frequency_penalty
|
|
gen_settings.token_presence_penalty = params.presence_penalty
|
|
|
|
# Applies for all penalties despite being called token_repetition_range
|
|
gen_settings.token_repetition_range = unwrap(
|
|
params.penalty_range, self.config.max_seq_len
|
|
)
|
|
|
|
# Always make sure the fallback is 0 if range < 0
|
|
# It's technically fine to use -1, but this just validates the passed
|
|
# fallback
|
|
# Always default to 0 if something goes wrong
|
|
if gen_settings.token_repetition_range < 0:
|
|
fallback_decay = 0
|
|
else:
|
|
fallback_decay = gen_settings.token_repetition_range
|
|
gen_settings.token_repetition_decay = coalesce(
|
|
params.repetition_decay, fallback_decay, 0
|
|
)
|
|
|
|
# DRY options
|
|
dry_multiplier = params.dry_multiplier
|
|
|
|
# < 0 = disabled
|
|
if dry_multiplier > 0:
|
|
gen_settings.dry_multiplier = dry_multiplier
|
|
gen_settings.dry_allowed_length = params.dry_allowed_length
|
|
gen_settings.dry_base = params.dry_base
|
|
|
|
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
|
|
# Use max_seq_len as the fallback to stay consistent
|
|
gen_settings.dry_range = unwrap(params.dry_range, self.config.max_seq_len)
|
|
|
|
# Tokenize sequence breakers
|
|
if params.dry_sequence_breakers:
|
|
gen_settings.dry_sequence_breakers = {
|
|
self.encode_tokens(s)[-1] for s in params.dry_sequence_breakers
|
|
}
|
|
|
|
# Add JSON schema filter if it exists
|
|
if params.json_schema:
|
|
grammar_handler.add_json_schema_filter(
|
|
params.json_schema, self.model, self.tokenizer
|
|
)
|
|
|
|
# Add regex filter if it exists
|
|
if params.regex_pattern:
|
|
grammar_handler.add_regex_filter(
|
|
params.regex_pattern, self.model, self.tokenizer
|
|
)
|
|
|
|
# Add EBNF filter if it exists
|
|
if params.grammar_string:
|
|
grammar_handler.add_kbnf_filter(
|
|
params.grammar_string, self.model, self.tokenizer
|
|
)
|
|
|
|
# Speculative Ngram
|
|
self.generator.speculative_ngram = params.speculative_ngram
|
|
|
|
# Override sampler settings for temp = 0
|
|
if gen_settings.temperature == 0:
|
|
gen_settings.temperature = 1.0
|
|
gen_settings.top_k = 1
|
|
gen_settings.top_p = 0
|
|
gen_settings.typical = 0
|
|
|
|
logger.warning(
|
|
"Temperature is set to 0. Overriding temp, "
|
|
"top_k, top_p, and typical to 1.0, 1, 0, and 0."
|
|
)
|
|
|
|
# Set banned tokens
|
|
if params.banned_tokens:
|
|
gen_settings.disallow_tokens(self.tokenizer, params.banned_tokens)
|
|
|
|
# Set allowed tokens
|
|
if params.allowed_tokens:
|
|
gen_settings.allow_tokens(self.tokenizer, params.allowed_tokens)
|
|
|
|
# Set logit bias
|
|
if params.logit_bias:
|
|
# Create a vocab tensor if it doesn't exist for token biasing
|
|
if gen_settings.token_bias is None:
|
|
padding = -self.tokenizer.config.vocab_size % 32
|
|
gen_settings.token_bias = torch.zeros(
|
|
(self.tokenizer.config.vocab_size + padding,),
|
|
dtype=torch.float,
|
|
)
|
|
|
|
# Map logits to the tensor with their biases
|
|
for token_id, bias in params.logit_bias.items():
|
|
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
|
|
gen_settings.token_bias[token_id] = bias
|
|
else:
|
|
logger.warning(
|
|
f"Logit bias: Token {token_id} not present "
|
|
"in the model's vocab. Skipping."
|
|
)
|
|
|
|
# Adds logprobs to a generation chunk
|
|
def handle_logprobs(self, result: dict, generation: dict):
|
|
top_tokens = unwrap(
|
|
result.get("top_k_tokens"),
|
|
torch.empty((1, 0, 1), dtype=torch.long),
|
|
)
|
|
|
|
top_probs = unwrap(
|
|
result.get("top_k_probs"),
|
|
torch.empty((1, 0, 1), dtype=torch.float),
|
|
)
|
|
|
|
if top_tokens.numel() > 0 and top_probs.numel() > 0:
|
|
logprobs = self.get_logprobs(top_tokens, top_probs)
|
|
generation["logprobs"] = logprobs
|
|
|
|
# The first logprob is the selected token prob
|
|
generation["token_probs"] = {
|
|
token: logprobs[token] for token in list(logprobs.keys())[:1]
|
|
}
|
|
|
|
# Creates and returns a finish chunk
|
|
def handle_finish_chunk(self, result: dict, generation: dict):
|
|
eos_reason = result.get("eos_reason")
|
|
|
|
stop_str = None
|
|
if eos_reason == "max_new_tokens":
|
|
finish_reason = "length"
|
|
else:
|
|
finish_reason = "stop"
|
|
# Grab stop string if stop was the reason
|
|
if eos_reason == "stop_token":
|
|
stop_str = result.get("eos_triggering_token_str")
|
|
elif eos_reason == "stop_string":
|
|
stop_str = result.get("eos_triggering_string")
|
|
|
|
# Prompt
|
|
prompt_tokens = result.get("prompt_tokens")
|
|
cached_tokens = round(result.get("cached_tokens"), 2)
|
|
prompt_time = round(result.get("time_prefill"), 2)
|
|
prompt_ts = (
|
|
"Indeterminate"
|
|
if prompt_time == 0
|
|
else round((prompt_tokens - cached_tokens) / prompt_time, 2)
|
|
)
|
|
|
|
# Generated
|
|
gen_tokens = result.get("new_tokens")
|
|
gen_time = result.get("time_generate")
|
|
gen_ts = "Indeterminate" if gen_time == 0 else round(gen_tokens / gen_time, 2)
|
|
|
|
# Queue + Total
|
|
queue_time = result.get("time_enqueued")
|
|
total_time = round(queue_time + prompt_time + gen_time, 2)
|
|
|
|
finish_chunk = {
|
|
"prompt_tokens": prompt_tokens,
|
|
"prompt_time": round(prompt_time, 2),
|
|
"prompt_tokens_per_sec": prompt_ts,
|
|
"gen_tokens": gen_tokens,
|
|
"gen_time": round(gen_time, 2),
|
|
"gen_tokens_per_sec": gen_ts,
|
|
"total_time": total_time,
|
|
"queue_time": round(queue_time, 2),
|
|
"cached_tokens": cached_tokens,
|
|
"finish_reason": finish_reason,
|
|
"stop_str": stop_str,
|
|
}
|
|
|
|
return finish_chunk
|
|
|
|
async def generate_gen(
|
|
self,
|
|
request_id: str,
|
|
prompt: str,
|
|
params: BaseSamplerRequest,
|
|
abort_event: Optional[asyncio.Event] = None,
|
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
|
):
|
|
"""
|
|
Create generator function for prompt completion.
|
|
|
|
for kwargs, check common/sampling.py
|
|
"""
|
|
|
|
prompts = [prompt]
|
|
gen_settings = ExLlamaV2Sampler.Settings()
|
|
grammar_handler = ExLlamaV2Grammar()
|
|
|
|
self.assign_gen_params(
|
|
params,
|
|
gen_settings,
|
|
grammar_handler,
|
|
)
|
|
|
|
# Set banned strings
|
|
banned_strings = params.banned_strings
|
|
if banned_strings and len(grammar_handler.filters) > 0:
|
|
logger.warning(
|
|
"Disabling banned_strings because "
|
|
"they cannot be used with grammar filters."
|
|
)
|
|
|
|
banned_strings = []
|
|
|
|
# Set CFG scale and negative prompt
|
|
cfg_scale = params.cfg_scale
|
|
negative_prompt = None
|
|
if cfg_scale not in [None, 1.0]:
|
|
if self.paged:
|
|
gen_settings.cfg_scale = cfg_scale
|
|
|
|
# If the negative prompt is empty, use the BOS token
|
|
negative_prompt = unwrap(
|
|
params.negative_prompt, self.tokenizer.bos_token
|
|
)
|
|
|
|
prompts.append(negative_prompt)
|
|
else:
|
|
logger.warning(
|
|
"CFG is currently disabled because paged mode is disabled. "
|
|
"Please use an ampere (30 series) or higher GPU for CFG support."
|
|
)
|
|
|
|
# Dynamically scale penalty range to output tokens
|
|
# Only do this if freq/pres pen is enabled
|
|
# and the repetition range is -1
|
|
auto_scale_penalty_range = (
|
|
gen_settings.token_frequency_penalty != 0
|
|
or gen_settings.token_presence_penalty != 0
|
|
) and gen_settings.token_repetition_range == -1
|
|
|
|
stop_conditions = params.stop
|
|
ban_eos_token = params.ban_eos_token
|
|
|
|
# Set add_bos_token for generation
|
|
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
|
|
|
|
# Fetch EOS tokens from the HF model if they exist
|
|
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
|
|
|
|
# Ban the EOS token if specified. If not, append to stop conditions
|
|
# as well.
|
|
# Set this below logging to avoid polluting the stop strings array
|
|
if ban_eos_token:
|
|
gen_settings.disallow_tokens(self.tokenizer, eos_tokens)
|
|
else:
|
|
stop_conditions += eos_tokens
|
|
|
|
# Get multimodal embeddings if present
|
|
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
|
|
|
# Encode both positive and negative prompts
|
|
input_ids = [
|
|
self.tokenizer.encode(
|
|
prompt,
|
|
add_bos=add_bos_token,
|
|
encode_special_tokens=True,
|
|
embeddings=mm_embeddings_content,
|
|
)
|
|
for prompt in prompts
|
|
]
|
|
|
|
# The first index will always be the positive prompt
|
|
context_len = input_ids[0].size(dim=-1)
|
|
|
|
# The second index will be the negative prompt if CFG is enabled
|
|
negative_context_len = input_ids[1].size(dim=-1) if negative_prompt else 0
|
|
|
|
# Automatically set max_tokens to fill up the context
|
|
# This should be an OK default, but may be changed in the future
|
|
max_tokens = unwrap(
|
|
params.max_tokens,
|
|
self.config.max_seq_len - max(context_len, negative_context_len),
|
|
)
|
|
if max_tokens < 1:
|
|
logger.warning("max_tokens must be a positive integer, setting to 1.")
|
|
max_tokens = 1
|
|
|
|
# Determine if the negative context or the context length is bigger
|
|
context_to_check = max(negative_context_len, context_len)
|
|
|
|
# Check total length of prompt against max context length
|
|
if context_to_check > self.config.max_seq_len:
|
|
preamble = (
|
|
"Negative prompt" if negative_context_len > context_len else "Prompt"
|
|
)
|
|
|
|
raise ValueError(
|
|
f"{preamble} length {context_to_check} is greater than "
|
|
f"max_seq_len {self.config.max_seq_len}"
|
|
)
|
|
|
|
# Check total required pages for CFG request to avoid overallocation
|
|
if negative_prompt and (
|
|
sum(
|
|
256 * math.ceil((context + max_tokens) / 256)
|
|
for context in (context_len, negative_context_len)
|
|
)
|
|
> self.cache_size
|
|
):
|
|
raise ValueError(
|
|
f"Total required page size for request "
|
|
f"{context_len} + {negative_context_len} + {max_tokens} * 2 "
|
|
f"is greater than cache_size {self.cache_size}"
|
|
)
|
|
|
|
# Log prompt to console. Add the BOS token if specified
|
|
log_prompt(
|
|
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
|
|
request_id,
|
|
negative_prompt,
|
|
)
|
|
|
|
# Create and add a new job
|
|
# Don't use the request ID here as there can be multiple jobs per request
|
|
job = ExLlamaV2DynamicJobAsync(
|
|
self.generator,
|
|
input_ids=input_ids,
|
|
max_new_tokens=max_tokens,
|
|
min_new_tokens=params.min_tokens,
|
|
gen_settings=gen_settings,
|
|
stop_conditions=stop_conditions,
|
|
decode_special_tokens=True,
|
|
filters=grammar_handler.filters,
|
|
filter_prefer_eos=bool(grammar_handler.filters),
|
|
return_probs=params.logprobs > 0,
|
|
return_top_tokens=params.logprobs,
|
|
return_logits=params.logprobs > 0,
|
|
banned_strings=banned_strings,
|
|
token_healing=params.token_healing,
|
|
identifier=request_id,
|
|
embeddings=mm_embeddings_content,
|
|
)
|
|
|
|
# Assign the active job to the request ID
|
|
self.active_job_ids[request_id] = job
|
|
|
|
# Save generated tokens and full response
|
|
# Copy over max seq len incase model is unloaded and stored jobs can complete
|
|
# Full response is required for offset calculation
|
|
max_seq_len = self.config.max_seq_len
|
|
generated_tokens = 0
|
|
full_response = ""
|
|
metrics_result = {}
|
|
|
|
# Get the generation status once it's ready
|
|
try:
|
|
async for result in job:
|
|
# Abort if the event is set while streaming
|
|
if abort_event and abort_event.is_set():
|
|
await job.cancel()
|
|
break
|
|
|
|
stage = result.get("stage")
|
|
result_id = result.get("identifier")
|
|
|
|
if stage == "streaming" and result_id == request_id:
|
|
chunk = unwrap(result.get("text"), "")
|
|
full_response += chunk
|
|
|
|
chunk_tokens = result.get("token_ids")
|
|
if chunk_tokens is not None:
|
|
generated_tokens += chunk_tokens.size(dim=0)
|
|
|
|
generation = {
|
|
"text": chunk,
|
|
"prompt_tokens": context_len,
|
|
"generated_tokens": generated_tokens,
|
|
"offset": len(full_response),
|
|
}
|
|
|
|
# Increase penalty range to generated token amount
|
|
if auto_scale_penalty_range:
|
|
gen_settings.token_repetition_range = generated_tokens
|
|
|
|
# Handle logprobs
|
|
if params.logprobs > 0:
|
|
self.handle_logprobs(result, generation)
|
|
|
|
yield generation
|
|
|
|
# Yield a finish chunk when generation is finished
|
|
if result.get("eos"):
|
|
log_response(request_id, full_response)
|
|
|
|
finish_chunk = self.handle_finish_chunk(result, generation)
|
|
|
|
# Save the final result for metrics logging
|
|
metrics_result = finish_chunk
|
|
|
|
yield finish_chunk
|
|
break
|
|
except asyncio.CancelledError:
|
|
await job.cancel()
|
|
except Exception as ex:
|
|
# Create a new generator since the current state is broken
|
|
# No need to wait for this to finish
|
|
logger.error(
|
|
"FATAL ERROR with generation. "
|
|
"Attempting to recreate the generator. "
|
|
"If this fails, please restart the server.\n"
|
|
)
|
|
asyncio.ensure_future(self.create_generator())
|
|
|
|
await HealthManager.add_unhealthy_event(ex)
|
|
|
|
raise ex
|
|
finally:
|
|
# Log generation options to console
|
|
# Some options are too large, so log the args instead
|
|
log_generation_params(
|
|
request_id=request_id,
|
|
bos_token_id=self.tokenizer.bos_token_id,
|
|
eos_token_id=eos_tokens,
|
|
prompt=prompt,
|
|
**params.model_dump(exclude={"prompt"}),
|
|
auto_scale_penalty_range=auto_scale_penalty_range,
|
|
)
|
|
|
|
# Log the metrics if present
|
|
if metrics_result:
|
|
log_metrics(
|
|
request_id,
|
|
metrics_result,
|
|
context_len,
|
|
max_seq_len,
|
|
)
|