Merge pull request #341 from theroyallab/exl3

Exl3
This commit is contained in:
Brian 2025-05-10 23:43:02 -04:00 committed by GitHub
commit 3674d7b9b5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1273 additions and 213 deletions

View file

@ -25,6 +25,10 @@ class BaseModelContainer(abc.ABC):
prompt_template: Optional[PromptTemplate] = None
generation_config: Optional[GenerationConfig] = None
# Optional features
use_draft_model: bool = False
use_vision: bool = False
# Load synchronization
# The bool is a master switch for accepting requests
# The lock keeps load tasks sequential
@ -65,7 +69,7 @@ class BaseModelContainer(abc.ABC):
# NOTE: Might be an optional method
@abc.abstractmethod
async def load_gen(self, progress_callback=None, **kwargs) -> AsyncIterator[Any]:
async def load_gen(self, progress_callback=None, **kwargs):
"""
Loads the model into memory, yielding progress updates.
@ -134,57 +138,6 @@ class BaseModelContainer(abc.ABC):
pass
@abc.abstractmethod
async def generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> Dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Returns:
A dictionary containing the generation info
"""
pass
@abc.abstractmethod
async def stream_generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Yields:
Generation chunks
"""
if False:
yield
@abc.abstractmethod
def model_info(self) -> ModelCard:
"""
@ -239,3 +192,54 @@ class BaseModelContainer(abc.ABC):
"""
return []
@abc.abstractmethod
async def generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> Dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Returns:
A dictionary containing the generation info
"""
pass
@abc.abstractmethod
async def stream_generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Yields:
Generation chunks
"""
if False:
yield

View file

@ -33,11 +33,7 @@ from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
clear_grammar_func_cache,
)
from backends.exllamav2.utils import (
exllama_disabled_flash_attn,
hardware_supports_flash_attn,
supports_paged_attn,
)
from backends.exllamav2.utils import exllama_disabled_flash_attn
from backends.exllamav2.vision import clear_image_embedding_cache
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
@ -46,6 +42,7 @@ from common.gen_logging import (
log_prompt,
log_response,
)
from common.hardware import hardware_supports_flash_attn
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
@ -64,16 +61,19 @@ class ExllamaV2Container(BaseModelContainer):
# Exl2 vars
config: Optional[ExLlamaV2Config] = None
draft_config: Optional[ExLlamaV2Config] = None
model: Optional[ExLlamaV2] = None
draft_model: Optional[ExLlamaV2] = None
cache: Optional[ExLlamaV2Cache] = None
draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
prompt_template: Optional[PromptTemplate] = None
paged: bool = True
# Draft model vars
use_draft_model: bool = False
draft_config: Optional[ExLlamaV2Config] = None
draft_model: Optional[ExLlamaV2] = None
draft_cache: Optional[ExLlamaV2Cache] = None
# Internal config vars
cache_size: int = None
cache_mode: str = "FP16"
@ -100,7 +100,7 @@ class ExllamaV2Container(BaseModelContainer):
load_condition: asyncio.Condition = asyncio.Condition()
@classmethod
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
async def create(cls, model_directory: pathlib.Path, **kwargs):
"""
Primary asynchronous initializer for model container.
@ -110,8 +110,6 @@ class ExllamaV2Container(BaseModelContainer):
# Create a new instance as a "fake self"
self = cls()
self.quiet = quiet
# Initialize config
self.config = ExLlamaV2Config()
self.model_dir = model_directory
@ -162,7 +160,7 @@ class ExllamaV2Container(BaseModelContainer):
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
self.use_draft_model = draft_args and draft_model_name
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
@ -170,9 +168,9 @@ class ExllamaV2Container(BaseModelContainer):
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False
self.use_draft_model = False
if enable_draft:
if self.use_draft_model:
self.draft_config = ExLlamaV2Config()
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
@ -189,6 +187,15 @@ class ExllamaV2Container(BaseModelContainer):
# Get cache mode
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
# Catch exllamav3 cache_mode
if not self.cache_mode.startswith("Q"):
logger.warning(
f"Provided cache mode '{self.cache_mode}' is not a "
"valid choice for exllamav2, please check your settings. "
"Defaulting to FP16."
)
self.cache_mode = "FP16"
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
@ -276,11 +283,20 @@ class ExllamaV2Container(BaseModelContainer):
# Check whether the user's configuration supports flash/paged attention
# Also check if exl2 has disabled flash attention
if (
exllama_disabled_flash_attn(self.config.no_flash_attn)
or not hardware_supports_flash_attn(gpu_device_list)
or not supports_paged_attn()
):
if exllama_disabled_flash_attn(
self.config.no_flash_attn
) or not hardware_supports_flash_attn(gpu_device_list):
gpu_unsupported_message = (
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
logger.warning(gpu_unsupported_message)
self.config.no_flash_attn = True
if self.draft_config:
self.draft_config.no_flash_attn = True
@ -365,7 +381,7 @@ class ExllamaV2Container(BaseModelContainer):
self.config.max_attention_size = chunk_size**2
# Set user-configured draft model values
if enable_draft:
if self.use_draft_model:
self.draft_config.max_seq_len = self.config.max_seq_len
self.draft_config.scale_pos_emb = unwrap(
@ -385,6 +401,15 @@ class ExllamaV2Container(BaseModelContainer):
# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
# Catch exllamav3 draft_cache_mode
if not self.draft_cache_mode.startswith("Q"):
logger.warning(
f"Provided draft cache mode '{self.draft_cache_mode}' is not a "
"valid choice for exllamav2, please check your settings. "
"Defaulting to FP16."
)
self.draft_cache_mode = "FP16"
# Edit the draft config size
if chunk_size:
self.draft_config.max_input_len = chunk_size
@ -531,8 +556,7 @@ class ExllamaV2Container(BaseModelContainer):
# Load draft model if a config is present
if self.draft_config:
self.draft_model = ExLlamaV2(self.draft_config)
if not self.quiet:
logger.info("Loading draft model: " + self.draft_config.model_dir)
logger.info("Loading draft model: " + self.draft_config.model_dir)
# Draft uses the autosplit loader, so create a cache that reflects this
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
@ -585,8 +609,7 @@ class ExllamaV2Container(BaseModelContainer):
yield value
self.model = ExLlamaV2(self.config)
if not self.quiet:
logger.info("Loading model: " + self.config.model_dir)
logger.info("Loading model: " + self.config.model_dir)
# Get class of the model cache
cache_class = self.get_cache_class(self.cache_mode)
@ -1350,7 +1373,7 @@ class ExllamaV2Container(BaseModelContainer):
min_new_tokens=params.min_tokens,
gen_settings=gen_settings,
stop_conditions=stop_conditions,
decode_special_tokens=not params.skip_special_tokens,
decode_special_tokens=True,
filters=grammar_handler.filters,
filter_prefer_eos=bool(grammar_handler.filters),
return_probs=params.logprobs > 0,

View file

@ -1,74 +1,6 @@
import platform
import torch
from packaging import version
from importlib.metadata import PackageNotFoundError, version as package_version
from loguru import logger
def hardware_supports_flash_attn(gpu_device_list: list[int]):
"""
Check whether all GPUs in list support FA2
Compute capability < 8 is not supported by FA2
AMD is also unsupported until ROCm updates its FA2 fork
"""
# Logged message if unsupported
unsupported_message = (
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
min_compute_capability = min(
torch.cuda.get_device_capability(device=device_idx)[0]
for device_idx in gpu_device_list
)
if torch.version.hip or min_compute_capability < 8:
logger.warning(unsupported_message)
return False
else:
return True
def supports_paged_attn():
"""Check whether the user's flash-attn version supports paged mode"""
# Logged message if unsupported
unsupported_message = (
"Flash attention version >=2.5.7 "
"is required to use paged attention. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"Please upgrade your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
"pip install --upgrade .[cu121]\n\n"
"NOTE: Windows users must use CUDA 12.x to use flash-attn."
)
required_version = version.parse("2.5.7")
try:
current_version = version.parse(package_version("flash-attn").split("+")[0])
except PackageNotFoundError:
logger.warning(unsupported_message)
return False
if current_version < required_version:
logger.warning(unsupported_message)
return False
else:
return True
def exllama_disabled_flash_attn(no_flash_attn: bool):
unsupported_message = (
"ExllamaV2 has disabled Flash Attention. \n"

View file

@ -1,37 +0,0 @@
import platform
from packaging import version
from importlib.metadata import version as package_version
from loguru import logger
from common.optional_dependencies import dependencies
def check_exllama_version():
"""Verifies the exllama version"""
install_message = (
"Please update your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
"pip install --upgrade .[cu121]\n\n"
"For ROCm:\n"
"pip install --upgrade .[amd]\n\n"
)
if not dependencies.exllamav2:
raise SystemExit(("Exllamav2 is not installed.\n" + install_message))
required_version = version.parse("0.2.8")
current_version = version.parse(package_version("exllamav2").split("+")[0])
unsupported_message = (
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
f"or greater. Your current version is {current_version}.\n" + install_message
)
if current_version < required_version:
raise SystemExit(unsupported_message)
else:
logger.info(f"ExllamaV2 version: {current_version}")

964
backends/exllamav3/model.py Normal file
View file

@ -0,0 +1,964 @@
import asyncio
import gc
import pathlib
import re
import traceback
from typing import (
Any,
AsyncIterator,
Dict,
List,
Optional,
)
import torch
from exllamav3 import (
AsyncGenerator,
AsyncJob,
Cache,
Config,
Model,
Tokenizer,
)
from exllamav3.cache import CacheLayer_quant
from loguru import logger
from backends.base_model_container import BaseModelContainer
from backends.exllamav3.sampler import ExllamaV3SamplerBuilder
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
log_metrics,
)
from common.hardware import hardware_supports_flash_attn
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig, TokenizerConfig
from common.utils import coalesce, unwrap
from endpoints.core.types.model import ModelCard, ModelCardParameters
class ExllamaV3Container(BaseModelContainer):
"""Abstract base class for model containers."""
# Exposed model information
model_dir: pathlib.Path = pathlib.Path("models")
prompt_template: Optional[PromptTemplate] = None
generation_config: Optional[GenerationConfig] = None
# Load synchronization
# The bool is a master switch for accepting requests
# The lock keeps load tasks sequential
# The condition notifies any waiting tasks
active_job_ids: Dict[str, Any] = {}
loaded: bool = False
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
# Exl3 vars
model: Optional[Model]
cache: Optional[Cache]
draft_model: Optional[Model]
draft_cache: Optional[Cache]
tokenizer: Optional[Tokenizer]
config: Optional[Config]
draft_config: Optional[Config]
generator: Optional[AsyncGenerator]
tokenizer_config: Optional[TokenizerConfig]
# Class-specific vars
gpu_split: List[float] | None = None
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 / 1024]
use_tp: bool = False
max_seq_len: int = 4096
cache_size: int = 4096
cache_mode: str = "FP16"
draft_cache_mode: str = "FP16"
chunk_size: int = 2048
max_batch_size: Optional[int] = None
# Required methods
@classmethod
async def create(cls, model_directory: pathlib.Path, **kwargs):
"""
Asynchronously creates and initializes a model container instance.
Args:
model_directory: Path to the model files.
**kwargs: Backend-specific configuration options.
Returns:
An instance of the implementing class.
"""
self = cls()
self.model = None
self.cache = None
self.draft_model = None
self.draft_cache = None
self.tokenizer = None
self.config = None
self.draft_config = None
self.generator = None
self.tokenizer_config = None
logger.warning(
"ExllamaV3 is currently in an alpha state. "
"Please note that all config options may not work."
)
self.model_dir = model_directory
self.config = Config.from_directory(str(model_directory.resolve()))
self.model = Model.from_config(self.config)
self.tokenizer = Tokenizer.from_config(self.config)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
model_directory
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Load tokenizer config overrides
tokenizer_config_path = model_directory / "tokenizer_config.json"
if tokenizer_config_path.exists():
try:
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping tokenizer config load because of an unexpected error."
)
# Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
self.use_draft_model = draft_args and draft_model_name
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
self.use_draft_model = False
if self.use_draft_model:
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name
self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
self.draft_model_dir = draft_model_path
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
self.draft_model = Model.from_config(self.draft_config)
logger.info(f"Using draft model: {str(draft_model_path.resolve())}")
else:
self.draft_model = None
self.draft_cache = None
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
gpu_split = unwrap(kwargs.get("gpu_split"), None)
gpu_device_list = list(range(0, gpu_count))
# Set GPU split options
if gpu_count == 1:
self.gpu_split_auto = False
logger.info("Disabling GPU split because one GPU is in use.")
else:
# TODO: Set tensor parallel
# Set GPU split options
# Enable manual GPU split if provided
if gpu_split:
self.gpu_split = gpu_split
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.gpu_split)
if memory > 0
]
elif gpu_split_auto and not self.use_tp:
# Otherwise fallback to autosplit settings
self.gpu_split_auto = gpu_split_auto
autosplit_reserve_megabytes = unwrap(
kwargs.get("autosplit_reserve"), [96]
)
# Reserve VRAM for each GPU
self.autosplit_reserve = [
value / 1024 for value in autosplit_reserve_megabytes
]
if not hardware_supports_flash_attn(gpu_device_list):
gpu_unsupported_message = (
"Unable to run ExllamaV3 because an unsupported GPU is "
"found in this configuration. \n"
"All GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
logger.warning(gpu_unsupported_message)
raise RuntimeError(gpu_unsupported_message)
# Cache
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
self.cache_size = self.adjust_cache_size(user_cache_size)
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
self.cache = self.create_cache(self.cache_mode, self.model)
# Draft cache
if self.use_draft_model:
# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
self.draft_cache = self.create_cache(
self.draft_cache_mode, self.draft_model
)
# Max batch size
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
# Make sure chunk size is >= 256, keep near or below max seq len
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
self.chunk_size = self.adjust_chunk_size(user_chunk_size)
# Template setup
self.prompt_template = await find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
# Catch all for template lookup errors
if self.prompt_template:
logger.info(
f'Using template "{self.prompt_template.name}" for chat completions.'
)
else:
logger.warning(
"Chat completions are disabled because a prompt "
"template wasn't provided or auto-detected."
)
return self
def adjust_cache_size(self, cache_size):
if cache_size < self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is smaller than the "
"desired context length.\n"
"Overriding cache_size to max_seq_len. "
)
cache_size = self.max_seq_len
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
cache_remainder = cache_size % 256
if cache_remainder != 0:
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
logger.warning(
f"The given cache size ({cache_size}) is "
"not a multiple of 256.\n"
"Overriding cache_size with an overestimated value of "
f"{rounded_cache_size} tokens."
)
cache_size = rounded_cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
"and may be too small for requests using CFG. \n"
"Ignore this warning if you do not plan on using CFG."
)
return cache_size
def adjust_chunk_size(self, user_chunk_size: int):
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
chunk_remainder = chunk_size % 256
if chunk_remainder != 0:
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
logger.warning(
f"The given chunk size ({chunk_size}) is "
"not a multiple of 256.\n"
"Overriding chunk_size with an overestimated value of "
f"{rounded_chunk_size} tokens."
)
chunk_size = rounded_chunk_size
return chunk_size
def create_cache(self, raw_cache_mode: str, model: Model):
# Cast exl2 types to exl3
match raw_cache_mode:
case "Q4":
raw_cache_mode = "4,4"
case "Q6":
raw_cache_mode = "6,6"
case "Q8":
raw_cache_mode = "8,8"
split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", raw_cache_mode)
if split_cache_mode:
draft_k_bits = int(split_cache_mode.group(1))
draft_v_bits = int(split_cache_mode.group(2))
cache = Cache(
model,
max_num_tokens=self.cache_size,
layer_type=CacheLayer_quant,
k_bits=draft_k_bits,
v_bits=draft_v_bits,
)
else:
cache = Cache(model, max_num_tokens=self.cache_size)
return cache
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.
Returns:
Model parameters provided by the backend
"""
model_params = ModelCardParameters(
max_seq_len=self.max_seq_len,
cache_size=self.cache_size,
max_batch_size=self.max_batch_size,
cache_mode=self.cache_mode,
chunk_size=self.chunk_size,
use_vision=self.use_vision,
)
if self.prompt_template:
model_params.prompt_template = self.prompt_template.name
model_params.prompt_template_content = self.prompt_template.raw_template
model_card = ModelCard(
id=self.model_dir.name,
parameters=model_params,
)
return model_card
async def wait_for_jobs(self, skip_wait: bool = False):
"""
Polling to wait for any active generation jobs to complete.
Args:
skip_wait: If True, cancel jobs immediately instead of waiting.
"""
if not self.generator:
return
# Immediately abort all jobs if asked
if skip_wait:
logger.warning(
"Immediately terminating all jobs. "
"Clients will have their requests cancelled.\n"
)
for job in self.active_job_ids.values():
if job:
await job.cancel()
while len(self.active_job_ids) > 0:
await asyncio.sleep(0.01)
async def load(self, progress_callback=None, **kwargs):
"""
Loads the model into memory.
Args:
progress_callback: Optional callback for progress updates.
**kwargs: Additional loading options.
"""
async for _ in self.load_gen(progress_callback):
pass
async def load_gen(self, progress_callback=None, **kwargs):
"""
Loads the model into memory, yielding progress updates.
Args:
progress_callback: Optional callback for progress updates.
**kwargs: Additional loading options.
Yields:
Progress updates
"""
try:
await self.load_lock.acquire()
# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
generator = self.load_model_sync(progress_callback)
async for value in iterate_in_threadpool(generator):
yield value
# Create async generator
await self.create_generator()
# Clean up any extra vram usage from torch and cuda
# (Helps reduce VRAM bottlenecking on Windows)
gc.collect()
torch.cuda.empty_cache()
# Cleanup and update model load state
self.loaded = True
logger.info("Model successfully loaded.")
finally:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
@torch.inference_mode()
def load_model_sync(self, progress_callback=None):
if self.use_draft_model:
for value in self.draft_model.load_gen(
reserve_per_device=self.autosplit_reserve,
callback=progress_callback,
):
if value:
yield value
for value in self.model.load_gen(
reserve_per_device=self.autosplit_reserve,
use_per_device=self.gpu_split,
callback=progress_callback,
):
if value:
yield value
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.loaded:
await self.load_lock.acquire()
# Immediately cancel all jobs
await self.wait_for_jobs(skip_wait=True)
# Create new generator
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
max_chunk_size=self.chunk_size,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function
if self.loaded:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
async def unload(self, loras_only: bool = False, **kwargs):
"""
Unloads the model and associated resources from memory.
Args:
loras_only: If True, only unload LoRAs.
**kwargs: Additional unloading options (e.g., shutdown).
"""
# Used when shutting down the server
do_shutdown = kwargs.get("shutdown")
try:
if not do_shutdown:
await self.load_lock.acquire()
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
self.model.unload()
self.model = None
self.config = None
self.cache = None
self.tokenizer = None
if self.use_draft_model:
self.draft_model.unload()
self.draft_model = None
self.draft_config = None
self.draft_cache = None
# Cleanup the generator from any pending jobs
if self.generator is not None:
await self.generator.close()
self.generator = None
gc.collect()
torch.cuda.empty_cache()
logger.info("Model unloaded.")
finally:
if not do_shutdown:
self.load_lock.release()
async with self.load_condition:
self.load_condition.notify_all()
def encode_tokens(self, text: str, **kwargs) -> List[int]:
"""
Encodes a string of text into a list of token IDs.
Args:
text: The input text string.
**kwargs: Backend-specific encoding options (e.g., add_bos_token).
Returns:
A list of integer token IDs.
"""
return (
self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
)
.flatten()
.tolist()
)
def decode_tokens(self, ids: List[int], **kwargs) -> str:
"""
Decodes a list of token IDs back into a string.
Args:
ids: A list of integer token IDs.
**kwargs: Backend-specific decoding options (e.g., decode_special_tokens).
Returns:
The decoded text string.
"""
ids = torch.tensor([ids])
return self.tokenizer.decode(
ids,
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
def get_special_tokens(
self, add_bos_token: bool = True, ban_eos_token: bool = False
):
"""
Gets special tokens used by the model/tokenizer.
Args:
**kwargs: Options like add_bos_token, ban_eos_token.
Returns:
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
to their string or ID representation.
"""
return {
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
"pad_token": self.tokenizer.pad_token,
"unk_token": self.tokenizer.unk_token,
}
async def generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> Dict[str, Any]:
"""
Generates a complete response for a given prompt and parameters.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Returns:
A dictionary containing the generation info
"""
generations = []
async for generation in self.stream_generate(
request_id,
prompt,
params,
abort_event,
mm_embeddings,
):
generations.append(generation)
joined_generation = {
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"tool_calls": None,
"offset": [],
"token_probs": {},
"logprobs": [],
}
if generations:
# Get finish_reason first and then shift where -1 points to
if "finish_reason" in generations[-1]:
finish_reason_gen = generations.pop()
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else:
joined_generation["finish_reason"] = "stop"
if len(generations) > 0:
for generation in generations:
joined_generation["text"] += unwrap(generation.get("text"), "")
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
joined_generation["token_probs"].update(
unwrap(generation.get("token_probs"), {})
)
# Include empty logprob dicts for index preservation
joined_generation["logprobs"].append(
unwrap(generation.get("logprobs"), {})
)
joined_generation["prompt_tokens"] = unwrap(
generations[-1].get("prompt_tokens"), 0
)
joined_generation["generated_tokens"] = unwrap(
generations[-1].get("generated_tokens"), 0
)
return joined_generation
async def stream_generate(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Generates a response iteratively (streaming) for a given prompt.
Args:
request_id: Unique identifier for the generation request.
prompt: The input prompt string.
params: Sampling and generation parameters.
abort_event: An asyncio Event to signal cancellation.
mm_embeddings: Optional multimodal embeddings.
Yields:
Generation chunks
"""
try:
# Wait for load lock to be freed before processing
# Mainly used for loras and other operations where the class is available
async with self.load_condition:
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
# If the model is being unloaded, don't accept new requests
if not self.loaded:
raise RuntimeError(
"Model is being unloaded. Cannot process new generation requests."
)
# Mark that the job is running
self.active_job_ids[request_id] = None
# Yield from the internal generator
async for generation_chunk in self.generate_gen(
request_id=request_id,
prompt=prompt,
params=params,
abort_event=abort_event,
mm_embeddings=mm_embeddings,
):
yield generation_chunk
finally:
# Clean up and remove the job from active IDs
del self.active_job_ids[request_id]
def handle_finish_chunk(self, result: dict, generation: dict):
eos_reason = result.get("eos_reason")
stop_str = None
if eos_reason == "max_new_tokens":
finish_reason = "length"
else:
finish_reason = "stop"
# Grab stop string if stop was the reason
if eos_reason == "stop_token":
stop_str = result.get("eos_triggering_token_str")
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
finish_chunk = {
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
"stop_str": stop_str,
}
return finish_chunk
async def generate_gen(
self,
request_id: str,
prompt: str,
params: BaseSamplerRequest,
abort_event: Optional[asyncio.Event] = None,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""
Create generator function for prompt completion.
for kwargs, check common/sampling.py
"""
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
sampler_builder = ExllamaV3SamplerBuilder()
# Penalties
# Set penalty range
penalty_range = unwrap(params.penalty_range, self.max_seq_len)
# Exl3's version of including the entire context
if penalty_range < 0:
penalty_range = int(10e7)
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
# fallback
# Always default to 0 if something goes wrong
if params.penalty_range < 0:
fallback_decay = 0
else:
fallback_decay = params.penalty_range
repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0)
# Apply penalties to builder
sampler_builder.penalties(
params.repetition_penalty,
params.frequency_penalty,
params.presence_penalty,
penalty_range,
repetition_decay,
)
# Apply temperature first to builder
if not params.temperature_last:
sampler_builder.temperature(params.temperature)
# Apply alphabet samplers to builder
sampler_builder.top_k(params.top_k)
sampler_builder.top_p(params.top_p)
sampler_builder.min_p(params.min_p)
# Apply temperature last to builder
if params.temperature_last:
sampler_builder.temperature(params.temperature)
# Build the sampler
# Set greedy if temperature is 0
sampler = sampler_builder.build(params.temperature == 0)
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled
# and the repetition range is -1
# TODO: This currently does not work in exl3
# auto_scale_penalty_range = (
# gen_settings.token_frequency_penalty != 0
# or gen_settings.token_presence_penalty != 0
# ) and gen_settings.token_repetition_range == -1
prompts = [prompt]
stop_conditions = params.stop
add_bos_token = unwrap(
params.add_bos_token, self.tokenizer_config.add_bos_token
)
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()
if self.generation_config
else [self.tokenizer.eos_token_id]
)
stop_conditions += eos_tokens
input_ids = [
self.tokenizer.encode(
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
)
for prompt in prompts
]
# The first index will always be the positive prompt
context_len = input_ids[0].size(dim=-1)
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
params.max_tokens,
self.max_seq_len - context_len,
)
if max_tokens < 1:
logger.warning("max_tokens must be a positive integer, setting to 1.")
max_tokens = 1
# Determine if the negative context or the context length is bigger
context_to_check = context_len
# Check total length of prompt against max context length
if context_to_check > self.max_seq_len:
preamble = "Prompt"
raise ValueError(
f"{preamble} length {context_to_check} is greater than "
f"max_seq_len {self.max_seq_len}"
)
generation = {}
job = AsyncJob(
self.generator,
sampler=sampler,
input_ids=self.tokenizer.encode(prompt, add_bos=False),
max_new_tokens=max_tokens,
stop_conditions=stop_conditions,
banned_strings=params.banned_strings,
)
generated_tokens = 0
full_response = ""
metrics_result = {}
# Get the generation status once it's ready
try:
async for result in job:
# Abort if the event is set while streaming
if abort_event and abort_event.is_set():
await job.cancel()
break
chunk = unwrap(result.get("text"), "")
if chunk:
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
full_response += chunk
if isinstance(chunk_tokens, torch.Tensor):
generated_tokens += chunk_tokens.size(dim=0)
# Increase penalty range to generated token amount
# TODO:
# if auto_scale_penalty_range:
# gen_settings.token_repetition_range = generated_tokens
generation = {
"text": chunk,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}
yield generation
if result.get("eos"):
generation = self.handle_finish_chunk(result, generation)
# Save the final result for metrics logging
metrics_result = result
yield generation
break
# Assign the active job to the request ID
self.active_job_ids[request_id] = job
except asyncio.CancelledError:
await job.cancel()
except Exception as ex:
# Create a new generator since the current state is broken
# No need to wait for this to finish
logger.error(
"FATAL ERROR with generation. "
"Attempting to recreate the generator. "
"If this fails, please restart the server.\n"
)
asyncio.ensure_future(self.create_generator())
await HealthManager.add_unhealthy_event(ex)
raise ex
finally:
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
request_id=request_id,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
prompt=prompt,
**params.model_dump(exclude={"prompt"}),
# auto_scale_penalty_range=auto_scale_penalty_range, # TODO
)
# Log the metrics if present
if metrics_result:
log_metrics(
request_id,
metrics_result.get("time_enqueued"),
metrics_result.get("prompt_tokens"),
metrics_result.get("cached_tokens"),
metrics_result.get("time_prefill"),
metrics_result.get("new_tokens"),
metrics_result.get("time_generate"),
context_len,
self.max_seq_len,
)

View file

@ -0,0 +1,54 @@
from dataclasses import dataclass, field
from typing import List
from exllamav3.generator.sampler import (
CustomSampler,
SS_Temperature,
SS_RepP,
SS_PresFreqP,
SS_Argmax,
SS_MinP,
SS_TopK,
SS_TopP,
SS_Sample,
SS_Base,
)
@dataclass
class ExllamaV3SamplerBuilder:
"""
Custom sampler chain/stack for TabbyAPI
"""
stack: List[SS_Base] = field(default_factory=list)
def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay):
self.stack += [
SS_RepP(rep_p, penalty_range, rep_decay),
SS_PresFreqP(pres_p, freq_p, penalty_range, rep_decay),
]
def temperature(self, temp):
self.stack.append(SS_Temperature(temp))
def top_k(self, top_k):
self.stack.append(SS_TopK(top_k))
def top_p(self, top_p):
self.stack.append(SS_TopP(top_p))
def min_p(self, min_p):
self.stack.append(SS_MinP(min_p))
def greedy(self):
self.stack.append(SS_Argmax())
def build(self, greedy):
"""Builds the final sampler from stack."""
# Use greedy if temp is 0
if greedy:
return CustomSampler([SS_Argmax()])
else:
self.stack.append(SS_Sample())
return CustomSampler(self.stack)

View file

@ -1,6 +1,7 @@
from pydantic import (
BaseModel,
ConfigDict,
constr,
Field,
PrivateAttr,
field_validator,
@ -9,6 +10,7 @@ from typing import List, Literal, Optional, Union
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
CACHE_TYPE = Union[CACHE_SIZES, constr(pattern=r"^[2-8]\s*,\s*[2-8]$")]
class Metadata(BaseModel):
@ -163,6 +165,13 @@ class ModelConfig(BaseConfigModel):
"Example: ['max_seq_len', 'cache_mode']."
),
)
backend: Optional[str] = Field(
None,
description=(
"Backend to use for this model (auto-detect if not specified)\n"
"Options: exllamav2, exllamav3"
),
)
max_seq_len: Optional[int] = Field(
None,
description=(
@ -186,7 +195,7 @@ class ModelConfig(BaseConfigModel):
"Not parsed for single GPU users."
),
)
autosplit_reserve: List[int] = Field(
autosplit_reserve: List[float] = Field(
[96],
description=(
"Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n"
@ -218,11 +227,13 @@ class ModelConfig(BaseConfigModel):
"or auto-calculate."
),
)
cache_mode: Optional[CACHE_SIZES] = Field(
cache_mode: Optional[CACHE_TYPE] = Field(
"FP16",
description=(
"Enable different cache modes for VRAM savings (default: FP16).\n"
f"Possible values: {str(CACHE_SIZES)[15:-1]}."
f"Possible values for exllamav2: {str(CACHE_SIZES)[15:-1]}.\n"
"For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits "
"are integers from 2-8 (i.e. 8,8)."
),
)
cache_size: Optional[int] = Field(

20
common/hardware.py Normal file
View file

@ -0,0 +1,20 @@
import torch
def hardware_supports_flash_attn(gpu_device_list: list[int]):
"""
Check whether all GPUs in list support FA2
Compute capability < 8 is not supported by FA2
AMD is also unsupported until ROCm updates its FA2 fork
"""
min_compute_capability = min(
torch.cuda.get_device_capability(device=device_idx)[0]
for device_idx in gpu_device_list
)
if torch.version.hip or min_compute_capability < 8:
return False
else:
return True

View file

@ -10,23 +10,34 @@ from enum import Enum
from fastapi import HTTPException
from loguru import logger
from ruamel.yaml import YAML
from typing import Optional
from typing import Dict, Optional
from backends.base_model_container import BaseModelContainer
from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.optional_dependencies import dependencies
from common.transformers_utils import HuggingFaceConfig
from common.utils import unwrap
# Global variables for model container
container: Optional[BaseModelContainer] = None
embeddings_container = None
# FIXME: Possibly use this solely when creating the model
_BACKEND_REGISTRY: Dict[str, BaseModelContainer] = {}
if dependencies.exllamav2:
from backends.exllamav2.model import ExllamaV2Container
_BACKEND_REGISTRY["exllamav2"] = ExllamaV2Container
if dependencies.exllamav3:
from backends.exllamav3.model import ExllamaV3Container
_BACKEND_REGISTRY["exllamav3"] = ExllamaV3Container
if dependencies.extras:
from backends.infinity.model import InfinityContainer
@ -46,6 +57,24 @@ def load_progress(module, modules):
yield module, modules
async def detect_backend(model_path: pathlib.Path) -> str:
"""Determine the appropriate backend based on model files and configuration."""
try:
hf_config = await HuggingFaceConfig.from_directory(model_path)
quant_method = hf_config.quant_method()
if quant_method == "exl3":
return "exllamav3"
else:
return "exllamav2"
except Exception as exc:
raise ValueError(
"Failed to read the model's config.json. "
f"Please check your model directory at {model_path}."
) from exc
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
"""Sets overrides from a model folder's config yaml."""
@ -113,9 +142,28 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
kwargs = {**config.model_defaults, **kwargs}
kwargs = await apply_inline_overrides(model_path, **kwargs)
# Create a new container
new_container = await ExllamaV2Container.create(
model_path.resolve(), False, **kwargs
# Create a new container and check if the right dependencies are installed
backend_name = unwrap(
kwargs.get("backend"), await detect_backend(model_path)
).lower()
container_class = _BACKEND_REGISTRY.get(backend_name)
if not container_class:
available_backends = list(_BACKEND_REGISTRY.keys())
if backend_name in available_backends:
raise ValueError(
f"Backend '{backend_name}' selected, but required dependencies "
"are not installed."
)
else:
raise ValueError(
f"Invalid backend '{backend_name}'. "
f"Available backends: {available_backends}"
)
logger.info(f"Using backend {backend_name}")
new_container: BaseModelContainer = await container_class.create(
model_path.resolve(), **kwargs
)
# Add possible types of models that can be loaded
@ -124,7 +172,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
if new_container.use_vision:
model_type.insert(0, ModelType.VISION)
if new_container.draft_config:
if new_container.use_draft_model:
model_type.insert(0, ModelType.DRAFT)
load_status = new_container.load_gen(load_progress, **kwargs)

View file

@ -13,6 +13,7 @@ class DependenciesModel(BaseModel):
torch: bool
exllamav2: bool
exllamav3: bool
flash_attn: bool
infinity_emb: bool
sentence_transformers: bool
@ -25,7 +26,7 @@ class DependenciesModel(BaseModel):
@computed_field
@property
def inference(self) -> bool:
return self.torch and self.exllamav2 and self.flash_attn
return self.torch and (self.exllamav2 or (self.exllamav3 and self.flash_attn))
def is_installed(package_name: str) -> bool:

View file

@ -205,7 +205,7 @@ class BaseSamplerRequest(BaseModel):
)
add_bos_token: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("add_bos_token", True)
default_factory=lambda: get_default_sampler_value("add_bos_token")
)
ban_eos_token: Optional[bool] = Field(
@ -215,11 +215,6 @@ class BaseSamplerRequest(BaseModel):
examples=[False],
)
skip_special_tokens: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("skip_special_tokens", True),
examples=[True],
)
logit_bias: Optional[Dict[int, float]] = Field(
default_factory=lambda: get_default_sampler_value("logit_bias"),
examples=[{"1": 10, "2": 50}],

View file

@ -1,7 +1,7 @@
import aiofiles
import json
import pathlib
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
from pydantic import BaseModel
@ -42,8 +42,10 @@ class HuggingFaceConfig(BaseModel):
Will be expanded as needed.
"""
quantization_config: Optional[Dict] = None
@classmethod
async def from_file(cls, model_directory: pathlib.Path):
async def from_directory(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel):
hf_config_dict = json.loads(contents)
return cls.model_validate(hf_config_dict)
def quant_method(self):
"""Wrapper method to fetch quant type"""
if isinstance(self.quantization_config, Dict):
return self.quantization_config.get("quant_method")
else:
return None
class TokenizerConfig(BaseModel):
"""

View file

@ -74,6 +74,10 @@ model:
# Example: ['max_seq_len', 'cache_mode'].
use_as_default: []
# Backend to use for this model (auto-detect if not specified)
# Options: exllamav2, exllamav3
backend:
# Max sequence length (default: Empty).
# Fetched from the model's base sequence length in config.json by default.
max_seq_len:
@ -110,7 +114,8 @@ model:
rope_alpha:
# Enable different cache modes for VRAM savings (default: FP16).
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
# Possible values for exllamav2: 'FP16', 'Q8', 'Q6', 'Q4'.
# For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits are integers from 2-8 (i.e. 8,8).
cache_mode: FP16
# Size of the prompt cache to allocate (default: max_seq_len).
@ -160,7 +165,8 @@ draft_model:
draft_rope_alpha:
# Cache mode for draft models to save VRAM (default: FP16).
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
# Possible values for exllamav2: 'FP16', 'Q8', 'Q6', 'Q4'.
# For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits are integers from 2-8 (i.e. 8,8).
draft_cache_mode: FP16
# An integer array of GBs of VRAM to split between GPUs (default: []).

View file

@ -53,8 +53,12 @@ async def _stream_collector(data: GenerateRequest, request: Request):
logger.info(f"Received Kobold generation request {data.genkey}")
generator = model.container.stream_generate(
request_id=data.genkey, abort_event=abort_event, **data.model_dump()
request_id=data.genkey,
prompt=data.prompt,
params=data,
abort_event=abort_event,
)
async for generation in generator:
if disconnect_task.done():
abort_event.set()

View file

@ -82,10 +82,13 @@ class ChatCompletionRequest(CommonCompletionRequest):
tool_call_end: SkipJsonSchema[Optional[str]] = None
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
# Chat completions requests do not have a BOS token preference. Backend
# respects the tokenization config for the individual model.
add_bos_token: Optional[bool] = None
@field_validator("add_bos_token", mode="after")
def force_bos_token(cls, v):
"""Always disable add_bos_token with chat completions."""
return None

View file

@ -81,7 +81,10 @@ class ModelLoadRequest(BaseModel):
)
# Config arguments
backend: Optional[str] = Field(
description="Backend to use",
default="exllamav2",
)
max_seq_len: Optional[int] = Field(
description="Leave this blank to use the model's base sequence length",
default=None,

20
main.py
View file

@ -15,12 +15,11 @@ from common.auth import load_auth_keys
from common.actions import run_subcommand
from common.logger import setup_logger
from common.networking import is_port_in_use
from common.optional_dependencies import dependencies
from common.signals import signal_handler
from common.tabby_config import config
from endpoints.server import start_api
from backends.exllamav2.version import check_exllama_version
async def entrypoint_async():
"""Async entry function for program startup"""
@ -139,8 +138,21 @@ def entrypoint(
"UNSAFE: Skipping ExllamaV2 version check.\n"
"If you aren't a developer, please keep this off!"
)
else:
check_exllama_version()
elif not dependencies.inference:
install_message = (
f"ERROR: Inference dependencies for TabbyAPI are not installed.\n"
"Please update your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
"pip install --upgrade .[cu121]\n\n"
"For ROCm:\n"
"pip install --upgrade .[amd]\n\n"
)
raise SystemExit(install_message)
# Enable CUDA malloc backend
if config.developer.cuda_malloc_backend:

View file

@ -77,6 +77,16 @@ cu121 = [
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl3
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Windows FA2 from https://github.com/kingbri1/flash-attention/releases
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",

View file

@ -131,14 +131,11 @@ mirostat_eta:
# MARK: Token options
add_bos_token:
override: true
override:
force: false
ban_eos_token:
override: false
force: false
skip_special_tokens:
override: true
force: false
logit_bias:
override:
force: false