958 lines
32 KiB
Python
958 lines
32 KiB
Python
import asyncio
|
|
import gc
|
|
import pathlib
|
|
import re
|
|
import traceback
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
)
|
|
|
|
import torch
|
|
from exllamav3 import (
|
|
AsyncGenerator,
|
|
AsyncJob,
|
|
Cache,
|
|
Config,
|
|
Model,
|
|
Tokenizer,
|
|
)
|
|
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
|
from loguru import logger
|
|
|
|
from backends.base_model_container import BaseModelContainer
|
|
from backends.exllamav3.sampler import ExllamaV3SamplerBuilder
|
|
from common.concurrency import iterate_in_threadpool
|
|
from common.gen_logging import (
|
|
log_generation_params,
|
|
log_metrics,
|
|
)
|
|
from common.hardware import hardware_supports_flash_attn
|
|
from common.health import HealthManager
|
|
from common.multimodal import MultimodalEmbeddingWrapper
|
|
from common.sampling import BaseSamplerRequest
|
|
from common.templating import PromptTemplate, find_prompt_template
|
|
from common.transformers_utils import GenerationConfig, TokenizerConfig
|
|
from common.utils import coalesce, unwrap
|
|
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
|
|
|
|
|
class ExllamaV3Container(BaseModelContainer):
|
|
"""Abstract base class for model containers."""
|
|
|
|
# Exposed model information
|
|
model_dir: pathlib.Path = pathlib.Path("models")
|
|
prompt_template: Optional[PromptTemplate] = None
|
|
generation_config: Optional[GenerationConfig] = None
|
|
|
|
# Load synchronization
|
|
# The bool is a master switch for accepting requests
|
|
# The lock keeps load tasks sequential
|
|
# The condition notifies any waiting tasks
|
|
active_job_ids: Dict[str, Any] = {}
|
|
loaded: bool = False
|
|
load_lock: asyncio.Lock = asyncio.Lock()
|
|
load_condition: asyncio.Condition = asyncio.Condition()
|
|
|
|
# Exl3 vars
|
|
model: Optional[Model]
|
|
cache: Optional[Cache]
|
|
draft_model: Optional[Model]
|
|
draft_cache: Optional[Cache]
|
|
tokenizer: Optional[Tokenizer]
|
|
config: Optional[Config]
|
|
draft_config: Optional[Config]
|
|
generator: Optional[AsyncGenerator]
|
|
tokenizer_config: Optional[TokenizerConfig]
|
|
|
|
# Class-specific vars
|
|
gpu_split: List[float] | None = None
|
|
gpu_split_auto: bool = True
|
|
autosplit_reserve: List[float] = [96 / 1024]
|
|
use_tp: bool = False
|
|
max_seq_len: int = 4096
|
|
cache_size: int = 4096
|
|
cache_mode: str = "FP16"
|
|
chunk_size: int = 2048
|
|
max_batch_size: Optional[int] = None
|
|
|
|
# Required methods
|
|
@classmethod
|
|
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
|
"""
|
|
Asynchronously creates and initializes a model container instance.
|
|
|
|
Args:
|
|
model_directory: Path to the model files.
|
|
**kwargs: Backend-specific configuration options.
|
|
|
|
Returns:
|
|
An instance of the implementing class.
|
|
"""
|
|
|
|
self = cls()
|
|
|
|
self.model = None
|
|
self.cache = None
|
|
self.draft_model = None
|
|
self.draft_cache = None
|
|
self.tokenizer = None
|
|
self.config = None
|
|
self.draft_config = None
|
|
self.generator = None
|
|
self.tokenizer_config = None
|
|
|
|
logger.warning(
|
|
"ExllamaV3 is currently in an alpha state. "
|
|
"Please note that all config options may not work."
|
|
)
|
|
|
|
self.model_dir = model_directory
|
|
self.config = Config.from_directory(str(model_directory.resolve()))
|
|
self.model = Model.from_config(self.config)
|
|
self.tokenizer = Tokenizer.from_config(self.config)
|
|
|
|
# Load generation config overrides
|
|
generation_config_path = model_directory / "generation_config.json"
|
|
if generation_config_path.exists():
|
|
try:
|
|
self.generation_config = await GenerationConfig.from_file(
|
|
model_directory
|
|
)
|
|
except Exception:
|
|
logger.error(traceback.format_exc())
|
|
logger.warning(
|
|
"Skipping generation config load because of an unexpected error."
|
|
)
|
|
|
|
# Load tokenizer config overrides
|
|
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
|
if tokenizer_config_path.exists():
|
|
try:
|
|
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
|
|
except Exception:
|
|
logger.error(traceback.format_exc())
|
|
logger.warning(
|
|
"Skipping tokenizer config load because of an unexpected error."
|
|
)
|
|
|
|
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
|
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
|
|
|
# Prepare the draft model config if necessary
|
|
draft_args = unwrap(kwargs.get("draft_model"), {})
|
|
draft_model_name = draft_args.get("draft_model_name")
|
|
self.use_draft_model = draft_args and draft_model_name
|
|
|
|
# Always disable draft if params are incorrectly configured
|
|
if draft_args and draft_model_name is None:
|
|
logger.warning(
|
|
"Draft model is disabled because a model name "
|
|
"wasn't provided. Please check your config.yml!"
|
|
)
|
|
self.use_draft_model = False
|
|
|
|
if self.use_draft_model:
|
|
draft_model_path = pathlib.Path(
|
|
unwrap(draft_args.get("draft_model_dir"), "models")
|
|
)
|
|
draft_model_path = draft_model_path / draft_model_name
|
|
self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
|
|
self.draft_model_dir = draft_model_path
|
|
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
|
|
self.draft_model = Model.from_config(self.draft_config)
|
|
logger.info(
|
|
f'Using draft model: {str(draft_model_path.resolve())}'
|
|
)
|
|
else:
|
|
self.draft_model = None
|
|
self.craft_cache = None
|
|
|
|
# Turn off GPU split if the user is using 1 GPU
|
|
gpu_count = torch.cuda.device_count()
|
|
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
|
gpu_split = unwrap(kwargs.get("gpu_split"), None)
|
|
gpu_device_list = list(range(0, gpu_count))
|
|
|
|
# Set GPU split options
|
|
if gpu_count == 1:
|
|
self.gpu_split_auto = False
|
|
logger.info("Disabling GPU split because one GPU is in use.")
|
|
else:
|
|
# TODO: Set tensor parallel
|
|
|
|
# Set GPU split options
|
|
# Enable manual GPU split if provided
|
|
if gpu_split:
|
|
self.gpu_split = gpu_split
|
|
|
|
gpu_device_list = [
|
|
device_idx
|
|
for device_idx, memory in enumerate(self.gpu_split)
|
|
if memory > 0
|
|
]
|
|
elif gpu_split_auto and not self.use_tp:
|
|
# Otherwise fallback to autosplit settings
|
|
self.gpu_split_auto = gpu_split_auto
|
|
|
|
autosplit_reserve_megabytes = unwrap(
|
|
kwargs.get("autosplit_reserve"), [96]
|
|
)
|
|
|
|
# Reserve VRAM for each GPU
|
|
self.autosplit_reserve = [
|
|
value / 1024 for value in autosplit_reserve_megabytes
|
|
]
|
|
|
|
if not hardware_supports_flash_attn(gpu_device_list):
|
|
gpu_unsupported_message = (
|
|
"Unable to run ExllamaV3 because an unsupported GPU is "
|
|
"found in this configuration. \n"
|
|
"All GPUs must be ampere "
|
|
"(30 series) or newer. AMD GPUs are not supported."
|
|
)
|
|
|
|
logger.warning(gpu_unsupported_message)
|
|
|
|
raise RuntimeError(gpu_unsupported_message)
|
|
|
|
# Cache
|
|
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
|
|
self.cache_size = self.adjust_cache_size(user_cache_size)
|
|
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
|
|
|
# Alias Exl2 q-cache settings
|
|
match self.cache_mode:
|
|
case "Q4":
|
|
self.cache_mode = "4,4"
|
|
case "Q6":
|
|
self.cache_mode = "6,6"
|
|
case "Q8":
|
|
self.cache_mode = "8,8"
|
|
|
|
split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", self.cache_mode)
|
|
if split_cache_mode:
|
|
k_bits = int(split_cache_mode.group(1))
|
|
v_bits = int(split_cache_mode.group(2))
|
|
self.cache = Cache(
|
|
self.model,
|
|
max_num_tokens=self.cache_size,
|
|
layer_type=CacheLayer_quant,
|
|
k_bits=k_bits,
|
|
v_bits=v_bits,
|
|
)
|
|
else:
|
|
self.cache = Cache(
|
|
self.model, max_num_tokens=self.cache_size, layer_type=CacheLayer_fp16
|
|
)
|
|
|
|
# Draft cache
|
|
if self.use_draft_model:
|
|
self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size)
|
|
|
|
# Max batch size
|
|
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
|
|
|
|
# Make sure chunk size is >= 256, keep near or below max seq len
|
|
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
|
|
self.chunk_size = self.adjust_chunk_size(user_chunk_size)
|
|
|
|
# Template setup
|
|
self.prompt_template = await find_prompt_template(
|
|
kwargs.get("prompt_template"), model_directory
|
|
)
|
|
|
|
# Catch all for template lookup errors
|
|
if self.prompt_template:
|
|
logger.info(
|
|
f'Using template "{self.prompt_template.name}" for chat completions.'
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Chat completions are disabled because a prompt "
|
|
"template wasn't provided or auto-detected."
|
|
)
|
|
|
|
return self
|
|
|
|
def adjust_cache_size(self, cache_size):
|
|
if cache_size < self.max_seq_len:
|
|
logger.warning(
|
|
f"The given cache_size ({cache_size}) is smaller than the "
|
|
"desired context length.\n"
|
|
"Overriding cache_size to max_seq_len. "
|
|
)
|
|
|
|
cache_size = self.max_seq_len
|
|
|
|
# Enforce a multiple of 256 for cache size
|
|
# Overestimate to ensure that the cache isn't below max_seq_len
|
|
cache_remainder = cache_size % 256
|
|
if cache_remainder != 0:
|
|
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
|
|
|
|
logger.warning(
|
|
f"The given cache size ({cache_size}) is "
|
|
"not a multiple of 256.\n"
|
|
"Overriding cache_size with an overestimated value of "
|
|
f"{rounded_cache_size} tokens."
|
|
)
|
|
|
|
cache_size = rounded_cache_size
|
|
|
|
# Warn user if cache size may be inadequate for CFG
|
|
if cache_size < 2 * self.max_seq_len:
|
|
logger.warning(
|
|
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
|
"and may be too small for requests using CFG. \n"
|
|
"Ignore this warning if you do not plan on using CFG."
|
|
)
|
|
|
|
return cache_size
|
|
|
|
def adjust_chunk_size(self, user_chunk_size: int):
|
|
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
|
|
chunk_remainder = chunk_size % 256
|
|
if chunk_remainder != 0:
|
|
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
|
|
|
|
logger.warning(
|
|
f"The given chunk size ({chunk_size}) is "
|
|
"not a multiple of 256.\n"
|
|
"Overriding chunk_size with an overestimated value of "
|
|
f"{rounded_chunk_size} tokens."
|
|
)
|
|
|
|
chunk_size = rounded_chunk_size
|
|
|
|
return chunk_size
|
|
|
|
def model_info(self) -> ModelCard:
|
|
"""
|
|
Returns a dictionary of the current model's configuration parameters.
|
|
|
|
Returns:
|
|
Model parameters provided by the backend
|
|
"""
|
|
|
|
model_params = ModelCardParameters(
|
|
max_seq_len=self.max_seq_len,
|
|
cache_size=self.cache_size,
|
|
max_batch_size=self.max_batch_size,
|
|
cache_mode=self.cache_mode,
|
|
chunk_size=self.chunk_size,
|
|
use_vision=self.use_vision,
|
|
)
|
|
|
|
if self.prompt_template:
|
|
model_params.prompt_template = self.prompt_template.name
|
|
model_params.prompt_template_content = self.prompt_template.raw_template
|
|
|
|
model_card = ModelCard(
|
|
id=self.model_dir.name,
|
|
parameters=model_params,
|
|
)
|
|
|
|
return model_card
|
|
|
|
async def wait_for_jobs(self, skip_wait: bool = False):
|
|
"""
|
|
Polling to wait for any active generation jobs to complete.
|
|
|
|
Args:
|
|
skip_wait: If True, cancel jobs immediately instead of waiting.
|
|
"""
|
|
|
|
if not self.generator:
|
|
return
|
|
|
|
# Immediately abort all jobs if asked
|
|
if skip_wait:
|
|
logger.warning(
|
|
"Immediately terminating all jobs. "
|
|
"Clients will have their requests cancelled.\n"
|
|
)
|
|
|
|
for job in self.active_job_ids.values():
|
|
if job:
|
|
await job.cancel()
|
|
|
|
while len(self.active_job_ids) > 0:
|
|
await asyncio.sleep(0.01)
|
|
|
|
async def load(self, progress_callback=None, **kwargs):
|
|
"""
|
|
Loads the model into memory.
|
|
|
|
Args:
|
|
progress_callback: Optional callback for progress updates.
|
|
**kwargs: Additional loading options.
|
|
"""
|
|
|
|
async for _ in self.load_gen(progress_callback):
|
|
pass
|
|
|
|
async def load_gen(self, progress_callback=None, **kwargs):
|
|
"""
|
|
Loads the model into memory, yielding progress updates.
|
|
|
|
Args:
|
|
progress_callback: Optional callback for progress updates.
|
|
**kwargs: Additional loading options.
|
|
|
|
Yields:
|
|
Progress updates
|
|
"""
|
|
|
|
try:
|
|
await self.load_lock.acquire()
|
|
|
|
# Wait for existing generation jobs to finish
|
|
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
|
|
|
generator = self.load_model_sync(progress_callback)
|
|
async for value in iterate_in_threadpool(generator):
|
|
yield value
|
|
|
|
# Create async generator
|
|
await self.create_generator()
|
|
|
|
# Clean up any extra vram usage from torch and cuda
|
|
# (Helps reduce VRAM bottlenecking on Windows)
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
# Cleanup and update model load state
|
|
self.loaded = True
|
|
logger.info("Model successfully loaded.")
|
|
finally:
|
|
self.load_lock.release()
|
|
|
|
async with self.load_condition:
|
|
self.load_condition.notify_all()
|
|
|
|
@torch.inference_mode()
|
|
def load_model_sync(self, progress_callback=None):
|
|
if self.use_draft_model:
|
|
for value in self.draft_model.load_gen(
|
|
reserve_per_device=self.autosplit_reserve,
|
|
callback=progress_callback,
|
|
):
|
|
if value:
|
|
yield value
|
|
|
|
for value in self.model.load_gen(
|
|
reserve_per_device=self.autosplit_reserve,
|
|
use_per_device=self.gpu_split,
|
|
callback=progress_callback,
|
|
):
|
|
if value:
|
|
yield value
|
|
|
|
async def create_generator(self):
|
|
"""Create and save a Exllama generator class."""
|
|
|
|
try:
|
|
# Don't acquire locks unless a model is loaded
|
|
if self.loaded:
|
|
await self.load_lock.acquire()
|
|
|
|
# Immediately cancel all jobs
|
|
await self.wait_for_jobs(skip_wait=True)
|
|
|
|
# Create new generator
|
|
self.generator = AsyncGenerator(
|
|
model=self.model,
|
|
cache=self.cache,
|
|
draft_model=self.draft_model,
|
|
draft_cache=self.draft_cache,
|
|
tokenizer=self.tokenizer,
|
|
max_batch_size=self.max_batch_size,
|
|
max_chunk_size=self.chunk_size,
|
|
)
|
|
|
|
# Update the state of the container var
|
|
if self.max_batch_size is None:
|
|
self.max_batch_size = self.generator.generator.max_batch_size
|
|
finally:
|
|
# This means the generator is being recreated
|
|
# The load lock is already released in the load function
|
|
if self.loaded:
|
|
self.load_lock.release()
|
|
|
|
async with self.load_condition:
|
|
self.load_condition.notify_all()
|
|
|
|
async def unload(self, loras_only: bool = False, **kwargs):
|
|
"""
|
|
Unloads the model and associated resources from memory.
|
|
|
|
Args:
|
|
loras_only: If True, only unload LoRAs.
|
|
**kwargs: Additional unloading options (e.g., shutdown).
|
|
"""
|
|
|
|
# Used when shutting down the server
|
|
do_shutdown = kwargs.get("shutdown")
|
|
|
|
try:
|
|
if not do_shutdown:
|
|
await self.load_lock.acquire()
|
|
|
|
# Wait for other jobs to finish
|
|
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
|
|
|
self.model.unload()
|
|
self.model = None
|
|
self.config = None
|
|
self.cache = None
|
|
self.tokenizer = None
|
|
|
|
if self.use_draft_model:
|
|
self.draft_model.unload()
|
|
self.draft_model = None
|
|
self.draft_config = None
|
|
self.draft_cache = None
|
|
|
|
# Cleanup the generator from any pending jobs
|
|
if self.generator is not None:
|
|
await self.generator.close()
|
|
self.generator = None
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
|
|
logger.info("Model unloaded.")
|
|
finally:
|
|
if not do_shutdown:
|
|
self.load_lock.release()
|
|
|
|
async with self.load_condition:
|
|
self.load_condition.notify_all()
|
|
|
|
def encode_tokens(self, text: str, **kwargs) -> List[int]:
|
|
"""
|
|
Encodes a string of text into a list of token IDs.
|
|
|
|
Args:
|
|
text: The input text string.
|
|
**kwargs: Backend-specific encoding options (e.g., add_bos_token).
|
|
|
|
Returns:
|
|
A list of integer token IDs.
|
|
"""
|
|
|
|
return (
|
|
self.tokenizer.encode(
|
|
text,
|
|
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
|
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
|
)
|
|
.flatten()
|
|
.tolist()
|
|
)
|
|
|
|
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
|
"""
|
|
Decodes a list of token IDs back into a string.
|
|
|
|
Args:
|
|
ids: A list of integer token IDs.
|
|
**kwargs: Backend-specific decoding options (e.g., decode_special_tokens).
|
|
|
|
Returns:
|
|
The decoded text string.
|
|
"""
|
|
|
|
ids = torch.tensor([ids])
|
|
return self.tokenizer.decode(
|
|
ids,
|
|
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
|
)[0]
|
|
|
|
def get_special_tokens(
|
|
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
|
):
|
|
"""
|
|
Gets special tokens used by the model/tokenizer.
|
|
|
|
Args:
|
|
**kwargs: Options like add_bos_token, ban_eos_token.
|
|
|
|
Returns:
|
|
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
|
|
to their string or ID representation.
|
|
"""
|
|
|
|
return {
|
|
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
|
|
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
|
|
"pad_token": self.tokenizer.pad_token,
|
|
"unk_token": self.tokenizer.unk_token,
|
|
}
|
|
|
|
async def generate(
|
|
self,
|
|
request_id: str,
|
|
prompt: str,
|
|
params: BaseSamplerRequest,
|
|
abort_event: Optional[asyncio.Event] = None,
|
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Generates a complete response for a given prompt and parameters.
|
|
|
|
Args:
|
|
request_id: Unique identifier for the generation request.
|
|
prompt: The input prompt string.
|
|
params: Sampling and generation parameters.
|
|
abort_event: An asyncio Event to signal cancellation.
|
|
mm_embeddings: Optional multimodal embeddings.
|
|
|
|
Returns:
|
|
A dictionary containing the generation info
|
|
"""
|
|
|
|
generations = []
|
|
async for generation in self.stream_generate(
|
|
request_id,
|
|
prompt,
|
|
params,
|
|
abort_event,
|
|
mm_embeddings,
|
|
):
|
|
generations.append(generation)
|
|
|
|
joined_generation = {
|
|
"text": "",
|
|
"prompt_tokens": 0,
|
|
"generation_tokens": 0,
|
|
"tool_calls": None,
|
|
"offset": [],
|
|
"token_probs": {},
|
|
"logprobs": [],
|
|
}
|
|
|
|
if generations:
|
|
# Get finish_reason first and then shift where -1 points to
|
|
if "finish_reason" in generations[-1]:
|
|
finish_reason_gen = generations.pop()
|
|
joined_generation["finish_reason"] = finish_reason_gen.get(
|
|
"finish_reason"
|
|
)
|
|
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
|
|
else:
|
|
joined_generation["finish_reason"] = "stop"
|
|
|
|
if len(generations) > 0:
|
|
for generation in generations:
|
|
joined_generation["text"] += unwrap(generation.get("text"), "")
|
|
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
|
|
joined_generation["token_probs"].update(
|
|
unwrap(generation.get("token_probs"), {})
|
|
)
|
|
|
|
# Include empty logprob dicts for index preservation
|
|
joined_generation["logprobs"].append(
|
|
unwrap(generation.get("logprobs"), {})
|
|
)
|
|
|
|
joined_generation["prompt_tokens"] = unwrap(
|
|
generations[-1].get("prompt_tokens"), 0
|
|
)
|
|
joined_generation["generated_tokens"] = unwrap(
|
|
generations[-1].get("generated_tokens"), 0
|
|
)
|
|
|
|
return joined_generation
|
|
|
|
async def stream_generate(
|
|
self,
|
|
request_id: str,
|
|
prompt: str,
|
|
params: BaseSamplerRequest,
|
|
abort_event: Optional[asyncio.Event] = None,
|
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
|
) -> AsyncIterator[Dict[str, Any]]:
|
|
"""
|
|
Generates a response iteratively (streaming) for a given prompt.
|
|
|
|
Args:
|
|
request_id: Unique identifier for the generation request.
|
|
prompt: The input prompt string.
|
|
params: Sampling and generation parameters.
|
|
abort_event: An asyncio Event to signal cancellation.
|
|
mm_embeddings: Optional multimodal embeddings.
|
|
|
|
Yields:
|
|
Generation chunks
|
|
"""
|
|
|
|
try:
|
|
# Wait for load lock to be freed before processing
|
|
# Mainly used for loras and other operations where the class is available
|
|
async with self.load_condition:
|
|
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
|
|
|
# If the model is being unloaded, don't accept new requests
|
|
if not self.loaded:
|
|
raise RuntimeError(
|
|
"Model is being unloaded. Cannot process new generation requests."
|
|
)
|
|
|
|
# Mark that the job is running
|
|
self.active_job_ids[request_id] = None
|
|
|
|
# Yield from the internal generator
|
|
async for generation_chunk in self.generate_gen(
|
|
request_id=request_id,
|
|
prompt=prompt,
|
|
params=params,
|
|
abort_event=abort_event,
|
|
mm_embeddings=mm_embeddings,
|
|
):
|
|
yield generation_chunk
|
|
finally:
|
|
# Clean up and remove the job from active IDs
|
|
del self.active_job_ids[request_id]
|
|
|
|
def handle_finish_chunk(self, result: dict, generation: dict):
|
|
eos_reason = result.get("eos_reason")
|
|
|
|
stop_str = None
|
|
if eos_reason == "max_new_tokens":
|
|
finish_reason = "length"
|
|
else:
|
|
finish_reason = "stop"
|
|
# Grab stop string if stop was the reason
|
|
if eos_reason == "stop_token":
|
|
stop_str = result.get("eos_triggering_token_str")
|
|
elif eos_reason == "stop_string":
|
|
stop_str = result.get("eos_triggering_string")
|
|
|
|
finish_chunk = {
|
|
"prompt_tokens": generation.get("prompt_tokens"),
|
|
"generated_tokens": generation.get("generated_tokens"),
|
|
"finish_reason": finish_reason,
|
|
"stop_str": stop_str,
|
|
}
|
|
|
|
return finish_chunk
|
|
|
|
async def generate_gen(
|
|
self,
|
|
request_id: str,
|
|
prompt: str,
|
|
params: BaseSamplerRequest,
|
|
abort_event: Optional[asyncio.Event] = None,
|
|
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
|
):
|
|
"""
|
|
Create generator function for prompt completion.
|
|
|
|
for kwargs, check common/sampling.py
|
|
"""
|
|
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
|
|
|
sampler_builder = ExllamaV3SamplerBuilder()
|
|
|
|
# Penalties
|
|
|
|
# Set penalty range
|
|
penalty_range = unwrap(params.penalty_range, self.max_seq_len)
|
|
|
|
# Exl3's version of including the entire context
|
|
if penalty_range < 0:
|
|
penalty_range = int(10e7)
|
|
|
|
# Always make sure the fallback is 0 if range < 0
|
|
# It's technically fine to use -1, but this just validates the passed
|
|
# fallback
|
|
# Always default to 0 if something goes wrong
|
|
if params.penalty_range < 0:
|
|
fallback_decay = 0
|
|
else:
|
|
fallback_decay = params.penalty_range
|
|
|
|
repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0)
|
|
|
|
# Apply penalties to builder
|
|
sampler_builder.penalties(
|
|
params.repetition_penalty,
|
|
params.frequency_penalty,
|
|
params.presence_penalty,
|
|
penalty_range,
|
|
repetition_decay,
|
|
)
|
|
|
|
# Apply temperature first to builder
|
|
if not params.temperature_last:
|
|
sampler_builder.temperature(params.temperature)
|
|
|
|
# Apply alphabet samplers to builder
|
|
sampler_builder.top_k(params.top_k)
|
|
sampler_builder.top_p(params.top_p)
|
|
sampler_builder.min_p(params.min_p)
|
|
|
|
# Apply temperature last to builder
|
|
if params.temperature_last:
|
|
sampler_builder.temperature(params.temperature)
|
|
|
|
# Build the sampler
|
|
# Set greedy if temperature is 0
|
|
sampler = sampler_builder.build(params.temperature == 0)
|
|
|
|
# Dynamically scale penalty range to output tokens
|
|
# Only do this if freq/pres pen is enabled
|
|
# and the repetition range is -1
|
|
# TODO: This currently does not work in exl3
|
|
# auto_scale_penalty_range = (
|
|
# gen_settings.token_frequency_penalty != 0
|
|
# or gen_settings.token_presence_penalty != 0
|
|
# ) and gen_settings.token_repetition_range == -1
|
|
|
|
prompts = [prompt]
|
|
stop_conditions = params.stop
|
|
add_bos_token = unwrap(
|
|
params.add_bos_token, self.tokenizer_config.add_bos_token
|
|
)
|
|
|
|
# Fetch EOS tokens from generation_config if they exist
|
|
eos_tokens = (
|
|
self.generation_config.eos_tokens()
|
|
if self.generation_config
|
|
else [self.tokenizer.eos_token_id]
|
|
)
|
|
|
|
stop_conditions += eos_tokens
|
|
|
|
input_ids = [
|
|
self.tokenizer.encode(
|
|
prompt,
|
|
add_bos=add_bos_token,
|
|
encode_special_tokens=True,
|
|
)
|
|
for prompt in prompts
|
|
]
|
|
|
|
# The first index will always be the positive prompt
|
|
context_len = input_ids[0].size(dim=-1)
|
|
|
|
# Automatically set max_tokens to fill up the context
|
|
# This should be an OK default, but may be changed in the future
|
|
max_tokens = unwrap(
|
|
params.max_tokens,
|
|
self.max_seq_len - context_len,
|
|
)
|
|
if max_tokens < 1:
|
|
logger.warning("max_tokens must be a positive integer, setting to 1.")
|
|
max_tokens = 1
|
|
|
|
# Determine if the negative context or the context length is bigger
|
|
context_to_check = context_len
|
|
|
|
# Check total length of prompt against max context length
|
|
if context_to_check > self.max_seq_len:
|
|
preamble = "Prompt"
|
|
|
|
raise ValueError(
|
|
f"{preamble} length {context_to_check} is greater than "
|
|
f"max_seq_len {self.max_seq_len}"
|
|
)
|
|
|
|
generation = {}
|
|
job = AsyncJob(
|
|
self.generator,
|
|
sampler=sampler,
|
|
input_ids=self.tokenizer.encode(prompt, add_bos=False),
|
|
max_new_tokens=max_tokens,
|
|
stop_conditions=stop_conditions,
|
|
banned_strings=params.banned_strings,
|
|
)
|
|
|
|
generated_tokens = 0
|
|
full_response = ""
|
|
metrics_result = {}
|
|
|
|
# Get the generation status once it's ready
|
|
try:
|
|
async for result in job:
|
|
# Abort if the event is set while streaming
|
|
if abort_event and abort_event.is_set():
|
|
await job.cancel()
|
|
break
|
|
|
|
chunk = unwrap(result.get("text"), "")
|
|
if chunk:
|
|
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
|
|
full_response += chunk
|
|
if isinstance(chunk_tokens, torch.Tensor):
|
|
generated_tokens += chunk_tokens.size(dim=0)
|
|
|
|
# Increase penalty range to generated token amount
|
|
# TODO:
|
|
# if auto_scale_penalty_range:
|
|
# gen_settings.token_repetition_range = generated_tokens
|
|
|
|
generation = {
|
|
"text": chunk,
|
|
"prompt_tokens": context_len,
|
|
"generated_tokens": generated_tokens,
|
|
"offset": len(full_response),
|
|
}
|
|
yield generation
|
|
|
|
if result.get("eos"):
|
|
generation = self.handle_finish_chunk(result, generation)
|
|
|
|
# Save the final result for metrics logging
|
|
metrics_result = result
|
|
|
|
yield generation
|
|
break
|
|
# Assign the active job to the request ID
|
|
self.active_job_ids[request_id] = job
|
|
|
|
except asyncio.CancelledError:
|
|
await job.cancel()
|
|
except Exception as ex:
|
|
# Create a new generator since the current state is broken
|
|
# No need to wait for this to finish
|
|
logger.error(
|
|
"FATAL ERROR with generation. "
|
|
"Attempting to recreate the generator. "
|
|
"If this fails, please restart the server.\n"
|
|
)
|
|
asyncio.ensure_future(self.create_generator())
|
|
|
|
await HealthManager.add_unhealthy_event(ex)
|
|
|
|
raise ex
|
|
finally:
|
|
# Log generation options to console
|
|
# Some options are too large, so log the args instead
|
|
log_generation_params(
|
|
request_id=request_id,
|
|
bos_token_id=self.tokenizer.bos_token_id,
|
|
eos_token_id=eos_tokens,
|
|
prompt=prompt,
|
|
**params.model_dump(exclude={"prompt"}),
|
|
# auto_scale_penalty_range=auto_scale_penalty_range, # TODO
|
|
)
|
|
|
|
# Log the metrics if present
|
|
if metrics_result:
|
|
log_metrics(
|
|
request_id,
|
|
metrics_result.get("time_enqueued"),
|
|
metrics_result.get("prompt_tokens"),
|
|
metrics_result.get("cached_tokens"),
|
|
metrics_result.get("time_prefill"),
|
|
metrics_result.get("new_tokens"),
|
|
metrics_result.get("time_generate"),
|
|
context_len,
|
|
self.max_seq_len,
|
|
)
|