diff --git a/backends/base_model_container.py b/backends/base_model_container.py index 5c79867..10eae0d 100644 --- a/backends/base_model_container.py +++ b/backends/base_model_container.py @@ -25,6 +25,10 @@ class BaseModelContainer(abc.ABC): prompt_template: Optional[PromptTemplate] = None generation_config: Optional[GenerationConfig] = None + # Optional features + use_draft_model: bool = False + use_vision: bool = False + # Load synchronization # The bool is a master switch for accepting requests # The lock keeps load tasks sequential @@ -65,7 +69,7 @@ class BaseModelContainer(abc.ABC): # NOTE: Might be an optional method @abc.abstractmethod - async def load_gen(self, progress_callback=None, **kwargs) -> AsyncIterator[Any]: + async def load_gen(self, progress_callback=None, **kwargs): """ Loads the model into memory, yielding progress updates. @@ -134,57 +138,6 @@ class BaseModelContainer(abc.ABC): pass - @abc.abstractmethod - async def generate( - self, - request_id: str, - prompt: str, - params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, - ) -> Dict[str, Any]: - """ - Generates a complete response for a given prompt and parameters. - - Args: - request_id: Unique identifier for the generation request. - prompt: The input prompt string. - params: Sampling and generation parameters. - abort_event: An asyncio Event to signal cancellation. - mm_embeddings: Optional multimodal embeddings. - - Returns: - A dictionary containing the generation info - """ - - pass - - @abc.abstractmethod - async def stream_generate( - self, - request_id: str, - prompt: str, - params: BaseSamplerRequest, - abort_event: Optional[asyncio.Event] = None, - mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, - ) -> AsyncIterator[Dict[str, Any]]: - """ - Generates a response iteratively (streaming) for a given prompt. - - Args: - request_id: Unique identifier for the generation request. - prompt: The input prompt string. - params: Sampling and generation parameters. - abort_event: An asyncio Event to signal cancellation. - mm_embeddings: Optional multimodal embeddings. - - Yields: - Generation chunks - """ - - if False: - yield - @abc.abstractmethod def model_info(self) -> ModelCard: """ @@ -239,3 +192,54 @@ class BaseModelContainer(abc.ABC): """ return [] + + @abc.abstractmethod + async def generate( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ) -> Dict[str, Any]: + """ + Generates a complete response for a given prompt and parameters. + + Args: + request_id: Unique identifier for the generation request. + prompt: The input prompt string. + params: Sampling and generation parameters. + abort_event: An asyncio Event to signal cancellation. + mm_embeddings: Optional multimodal embeddings. + + Returns: + A dictionary containing the generation info + """ + + pass + + @abc.abstractmethod + async def stream_generate( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ) -> AsyncIterator[Dict[str, Any]]: + """ + Generates a response iteratively (streaming) for a given prompt. + + Args: + request_id: Unique identifier for the generation request. + prompt: The input prompt string. + params: Sampling and generation parameters. + abort_event: An asyncio Event to signal cancellation. + mm_embeddings: Optional multimodal embeddings. + + Yields: + Generation chunks + """ + + if False: + yield diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index b821d1a..65689f4 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -33,11 +33,7 @@ from backends.exllamav2.grammar import ( ExLlamaV2Grammar, clear_grammar_func_cache, ) -from backends.exllamav2.utils import ( - exllama_disabled_flash_attn, - hardware_supports_flash_attn, - supports_paged_attn, -) +from backends.exllamav2.utils import exllama_disabled_flash_attn from backends.exllamav2.vision import clear_image_embedding_cache from common.concurrency import iterate_in_threadpool from common.gen_logging import ( @@ -46,6 +42,7 @@ from common.gen_logging import ( log_prompt, log_response, ) +from common.hardware import hardware_supports_flash_attn from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest @@ -64,16 +61,19 @@ class ExllamaV2Container(BaseModelContainer): # Exl2 vars config: Optional[ExLlamaV2Config] = None - draft_config: Optional[ExLlamaV2Config] = None model: Optional[ExLlamaV2] = None - draft_model: Optional[ExLlamaV2] = None cache: Optional[ExLlamaV2Cache] = None - draft_cache: Optional[ExLlamaV2Cache] = None tokenizer: Optional[ExLlamaV2Tokenizer] = None generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None prompt_template: Optional[PromptTemplate] = None paged: bool = True + # Draft model vars + use_draft_model: bool = False + draft_config: Optional[ExLlamaV2Config] = None + draft_model: Optional[ExLlamaV2] = None + draft_cache: Optional[ExLlamaV2Cache] = None + # Internal config vars cache_size: int = None cache_mode: str = "FP16" @@ -100,7 +100,7 @@ class ExllamaV2Container(BaseModelContainer): load_condition: asyncio.Condition = asyncio.Condition() @classmethod - async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): + async def create(cls, model_directory: pathlib.Path, **kwargs): """ Primary asynchronous initializer for model container. @@ -110,8 +110,6 @@ class ExllamaV2Container(BaseModelContainer): # Create a new instance as a "fake self" self = cls() - self.quiet = quiet - # Initialize config self.config = ExLlamaV2Config() self.model_dir = model_directory @@ -162,7 +160,7 @@ class ExllamaV2Container(BaseModelContainer): # Prepare the draft model config if necessary draft_args = unwrap(kwargs.get("draft_model"), {}) draft_model_name = draft_args.get("draft_model_name") - enable_draft = draft_args and draft_model_name + self.use_draft_model = draft_args and draft_model_name # Always disable draft if params are incorrectly configured if draft_args and draft_model_name is None: @@ -170,9 +168,9 @@ class ExllamaV2Container(BaseModelContainer): "Draft model is disabled because a model name " "wasn't provided. Please check your config.yml!" ) - enable_draft = False + self.use_draft_model = False - if enable_draft: + if self.use_draft_model: self.draft_config = ExLlamaV2Config() draft_model_path = pathlib.Path( unwrap(draft_args.get("draft_model_dir"), "models") @@ -189,6 +187,15 @@ class ExllamaV2Container(BaseModelContainer): # Get cache mode self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") + # Catch exllamav3 cache_mode + if not self.cache_mode.startswith("Q"): + logger.warning( + f"Provided cache mode '{self.cache_mode}' is not a " + "valid choice for exllamav2, please check your settings. " + "Defaulting to FP16." + ) + self.cache_mode = "FP16" + # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) @@ -276,11 +283,20 @@ class ExllamaV2Container(BaseModelContainer): # Check whether the user's configuration supports flash/paged attention # Also check if exl2 has disabled flash attention - if ( - exllama_disabled_flash_attn(self.config.no_flash_attn) - or not hardware_supports_flash_attn(gpu_device_list) - or not supports_paged_attn() - ): + if exllama_disabled_flash_attn( + self.config.no_flash_attn + ) or not hardware_supports_flash_attn(gpu_device_list): + gpu_unsupported_message = ( + "An unsupported GPU is found in this configuration. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "To disable compatability mode, all GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + + logger.warning(gpu_unsupported_message) + self.config.no_flash_attn = True if self.draft_config: self.draft_config.no_flash_attn = True @@ -365,7 +381,7 @@ class ExllamaV2Container(BaseModelContainer): self.config.max_attention_size = chunk_size**2 # Set user-configured draft model values - if enable_draft: + if self.use_draft_model: self.draft_config.max_seq_len = self.config.max_seq_len self.draft_config.scale_pos_emb = unwrap( @@ -385,6 +401,15 @@ class ExllamaV2Container(BaseModelContainer): # Set draft cache mode self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16") + # Catch exllamav3 draft_cache_mode + if not self.draft_cache_mode.startswith("Q"): + logger.warning( + f"Provided draft cache mode '{self.draft_cache_mode}' is not a " + "valid choice for exllamav2, please check your settings. " + "Defaulting to FP16." + ) + self.draft_cache_mode = "FP16" + # Edit the draft config size if chunk_size: self.draft_config.max_input_len = chunk_size @@ -531,8 +556,7 @@ class ExllamaV2Container(BaseModelContainer): # Load draft model if a config is present if self.draft_config: self.draft_model = ExLlamaV2(self.draft_config) - if not self.quiet: - logger.info("Loading draft model: " + self.draft_config.model_dir) + logger.info("Loading draft model: " + self.draft_config.model_dir) # Draft uses the autosplit loader, so create a cache that reflects this draft_cache_class = self.get_cache_class(self.draft_cache_mode) @@ -585,8 +609,7 @@ class ExllamaV2Container(BaseModelContainer): yield value self.model = ExLlamaV2(self.config) - if not self.quiet: - logger.info("Loading model: " + self.config.model_dir) + logger.info("Loading model: " + self.config.model_dir) # Get class of the model cache cache_class = self.get_cache_class(self.cache_mode) @@ -1350,7 +1373,7 @@ class ExllamaV2Container(BaseModelContainer): min_new_tokens=params.min_tokens, gen_settings=gen_settings, stop_conditions=stop_conditions, - decode_special_tokens=not params.skip_special_tokens, + decode_special_tokens=True, filters=grammar_handler.filters, filter_prefer_eos=bool(grammar_handler.filters), return_probs=params.logprobs > 0, diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 0fd1fcc..1648c62 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -1,74 +1,6 @@ -import platform -import torch -from packaging import version -from importlib.metadata import PackageNotFoundError, version as package_version from loguru import logger -def hardware_supports_flash_attn(gpu_device_list: list[int]): - """ - Check whether all GPUs in list support FA2 - - Compute capability < 8 is not supported by FA2 - AMD is also unsupported until ROCm updates its FA2 fork - """ - - # Logged message if unsupported - unsupported_message = ( - "An unsupported GPU is found in this configuration. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "To disable compatability mode, all GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." - ) - - min_compute_capability = min( - torch.cuda.get_device_capability(device=device_idx)[0] - for device_idx in gpu_device_list - ) - - if torch.version.hip or min_compute_capability < 8: - logger.warning(unsupported_message) - return False - else: - return True - - -def supports_paged_attn(): - """Check whether the user's flash-attn version supports paged mode""" - - # Logged message if unsupported - unsupported_message = ( - "Flash attention version >=2.5.7 " - "is required to use paged attention. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "Please upgrade your environment by running an update script " - "(update_scripts/" - f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" - "Or you can manually run a requirements update " - "using the following command:\n\n" - "For CUDA 12.1:\n" - "pip install --upgrade .[cu121]\n\n" - "NOTE: Windows users must use CUDA 12.x to use flash-attn." - ) - - required_version = version.parse("2.5.7") - try: - current_version = version.parse(package_version("flash-attn").split("+")[0]) - except PackageNotFoundError: - logger.warning(unsupported_message) - return False - - if current_version < required_version: - logger.warning(unsupported_message) - return False - else: - return True - - def exllama_disabled_flash_attn(no_flash_attn: bool): unsupported_message = ( "ExllamaV2 has disabled Flash Attention. \n" diff --git a/backends/exllamav2/version.py b/backends/exllamav2/version.py deleted file mode 100644 index 08d0bda..0000000 --- a/backends/exllamav2/version.py +++ /dev/null @@ -1,37 +0,0 @@ -import platform -from packaging import version -from importlib.metadata import version as package_version -from loguru import logger -from common.optional_dependencies import dependencies - - -def check_exllama_version(): - """Verifies the exllama version""" - - install_message = ( - "Please update your environment by running an update script " - "(update_scripts/" - f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" - "Or you can manually run a requirements update " - "using the following command:\n\n" - "For CUDA 12.1:\n" - "pip install --upgrade .[cu121]\n\n" - "For ROCm:\n" - "pip install --upgrade .[amd]\n\n" - ) - - if not dependencies.exllamav2: - raise SystemExit(("Exllamav2 is not installed.\n" + install_message)) - - required_version = version.parse("0.2.8") - current_version = version.parse(package_version("exllamav2").split("+")[0]) - - unsupported_message = ( - f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} " - f"or greater. Your current version is {current_version}.\n" + install_message - ) - - if current_version < required_version: - raise SystemExit(unsupported_message) - else: - logger.info(f"ExllamaV2 version: {current_version}") diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py new file mode 100644 index 0000000..c386a6e --- /dev/null +++ b/backends/exllamav3/model.py @@ -0,0 +1,964 @@ +import asyncio +import gc +import pathlib +import re +import traceback +from typing import ( + Any, + AsyncIterator, + Dict, + List, + Optional, +) + +import torch +from exllamav3 import ( + AsyncGenerator, + AsyncJob, + Cache, + Config, + Model, + Tokenizer, +) +from exllamav3.cache import CacheLayer_quant +from loguru import logger + +from backends.base_model_container import BaseModelContainer +from backends.exllamav3.sampler import ExllamaV3SamplerBuilder +from common.concurrency import iterate_in_threadpool +from common.gen_logging import ( + log_generation_params, + log_metrics, +) +from common.hardware import hardware_supports_flash_attn +from common.health import HealthManager +from common.multimodal import MultimodalEmbeddingWrapper +from common.sampling import BaseSamplerRequest +from common.templating import PromptTemplate, find_prompt_template +from common.transformers_utils import GenerationConfig, TokenizerConfig +from common.utils import coalesce, unwrap +from endpoints.core.types.model import ModelCard, ModelCardParameters + + +class ExllamaV3Container(BaseModelContainer): + """Abstract base class for model containers.""" + + # Exposed model information + model_dir: pathlib.Path = pathlib.Path("models") + prompt_template: Optional[PromptTemplate] = None + generation_config: Optional[GenerationConfig] = None + + # Load synchronization + # The bool is a master switch for accepting requests + # The lock keeps load tasks sequential + # The condition notifies any waiting tasks + active_job_ids: Dict[str, Any] = {} + loaded: bool = False + load_lock: asyncio.Lock = asyncio.Lock() + load_condition: asyncio.Condition = asyncio.Condition() + + # Exl3 vars + model: Optional[Model] + cache: Optional[Cache] + draft_model: Optional[Model] + draft_cache: Optional[Cache] + tokenizer: Optional[Tokenizer] + config: Optional[Config] + draft_config: Optional[Config] + generator: Optional[AsyncGenerator] + tokenizer_config: Optional[TokenizerConfig] + + # Class-specific vars + gpu_split: List[float] | None = None + gpu_split_auto: bool = True + autosplit_reserve: List[float] = [96 / 1024] + use_tp: bool = False + max_seq_len: int = 4096 + cache_size: int = 4096 + cache_mode: str = "FP16" + draft_cache_mode: str = "FP16" + chunk_size: int = 2048 + max_batch_size: Optional[int] = None + + # Required methods + @classmethod + async def create(cls, model_directory: pathlib.Path, **kwargs): + """ + Asynchronously creates and initializes a model container instance. + + Args: + model_directory: Path to the model files. + **kwargs: Backend-specific configuration options. + + Returns: + An instance of the implementing class. + """ + + self = cls() + + self.model = None + self.cache = None + self.draft_model = None + self.draft_cache = None + self.tokenizer = None + self.config = None + self.draft_config = None + self.generator = None + self.tokenizer_config = None + + logger.warning( + "ExllamaV3 is currently in an alpha state. " + "Please note that all config options may not work." + ) + + self.model_dir = model_directory + self.config = Config.from_directory(str(model_directory.resolve())) + self.model = Model.from_config(self.config) + self.tokenizer = Tokenizer.from_config(self.config) + + # Load generation config overrides + generation_config_path = model_directory / "generation_config.json" + if generation_config_path.exists(): + try: + self.generation_config = await GenerationConfig.from_file( + model_directory + ) + except Exception: + logger.error(traceback.format_exc()) + logger.warning( + "Skipping generation config load because of an unexpected error." + ) + + # Load tokenizer config overrides + tokenizer_config_path = model_directory / "tokenizer_config.json" + if tokenizer_config_path.exists(): + try: + self.tokenizer_config = await TokenizerConfig.from_file(model_directory) + except Exception: + logger.error(traceback.format_exc()) + logger.warning( + "Skipping tokenizer config load because of an unexpected error." + ) + + # Fallback to 4096 since exl3 can't fetch from HF's config.json + self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) + + # Prepare the draft model config if necessary + draft_args = unwrap(kwargs.get("draft_model"), {}) + draft_model_name = draft_args.get("draft_model_name") + self.use_draft_model = draft_args and draft_model_name + + # Always disable draft if params are incorrectly configured + if draft_args and draft_model_name is None: + logger.warning( + "Draft model is disabled because a model name " + "wasn't provided. Please check your config.yml!" + ) + self.use_draft_model = False + + if self.use_draft_model: + draft_model_path = pathlib.Path( + unwrap(draft_args.get("draft_model_dir"), "models") + ) + draft_model_path = draft_model_path / draft_model_name + self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), []) + self.draft_model_dir = draft_model_path + self.draft_config = Config.from_directory(str(draft_model_path.resolve())) + self.draft_model = Model.from_config(self.draft_config) + logger.info(f"Using draft model: {str(draft_model_path.resolve())}") + else: + self.draft_model = None + self.draft_cache = None + + # Turn off GPU split if the user is using 1 GPU + gpu_count = torch.cuda.device_count() + gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) + gpu_split = unwrap(kwargs.get("gpu_split"), None) + gpu_device_list = list(range(0, gpu_count)) + + # Set GPU split options + if gpu_count == 1: + self.gpu_split_auto = False + logger.info("Disabling GPU split because one GPU is in use.") + else: + # TODO: Set tensor parallel + + # Set GPU split options + # Enable manual GPU split if provided + if gpu_split: + self.gpu_split = gpu_split + + gpu_device_list = [ + device_idx + for device_idx, memory in enumerate(self.gpu_split) + if memory > 0 + ] + elif gpu_split_auto and not self.use_tp: + # Otherwise fallback to autosplit settings + self.gpu_split_auto = gpu_split_auto + + autosplit_reserve_megabytes = unwrap( + kwargs.get("autosplit_reserve"), [96] + ) + + # Reserve VRAM for each GPU + self.autosplit_reserve = [ + value / 1024 for value in autosplit_reserve_megabytes + ] + + if not hardware_supports_flash_attn(gpu_device_list): + gpu_unsupported_message = ( + "Unable to run ExllamaV3 because an unsupported GPU is " + "found in this configuration. \n" + "All GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + + logger.warning(gpu_unsupported_message) + + raise RuntimeError(gpu_unsupported_message) + + # Cache + user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len) + self.cache_size = self.adjust_cache_size(user_cache_size) + self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16") + self.cache = self.create_cache(self.cache_mode, self.model) + + # Draft cache + if self.use_draft_model: + # Set draft cache mode + self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16") + self.draft_cache = self.create_cache( + self.draft_cache_mode, self.draft_model + ) + + # Max batch size + self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256) + + # Make sure chunk size is >= 256, keep near or below max seq len + user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048) + self.chunk_size = self.adjust_chunk_size(user_chunk_size) + + # Template setup + self.prompt_template = await find_prompt_template( + kwargs.get("prompt_template"), model_directory + ) + + # Catch all for template lookup errors + if self.prompt_template: + logger.info( + f'Using template "{self.prompt_template.name}" for chat completions.' + ) + else: + logger.warning( + "Chat completions are disabled because a prompt " + "template wasn't provided or auto-detected." + ) + + return self + + def adjust_cache_size(self, cache_size): + if cache_size < self.max_seq_len: + logger.warning( + f"The given cache_size ({cache_size}) is smaller than the " + "desired context length.\n" + "Overriding cache_size to max_seq_len. " + ) + + cache_size = self.max_seq_len + + # Enforce a multiple of 256 for cache size + # Overestimate to ensure that the cache isn't below max_seq_len + cache_remainder = cache_size % 256 + if cache_remainder != 0: + rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1)) + + logger.warning( + f"The given cache size ({cache_size}) is " + "not a multiple of 256.\n" + "Overriding cache_size with an overestimated value of " + f"{rounded_cache_size} tokens." + ) + + cache_size = rounded_cache_size + + # Warn user if cache size may be inadequate for CFG + if cache_size < 2 * self.max_seq_len: + logger.warning( + f"The given cache_size ({cache_size}) is less than 2 * max_seq_len " + "and may be too small for requests using CFG. \n" + "Ignore this warning if you do not plan on using CFG." + ) + + return cache_size + + def adjust_chunk_size(self, user_chunk_size: int): + chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1] + chunk_remainder = chunk_size % 256 + if chunk_remainder != 0: + rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1)) + + logger.warning( + f"The given chunk size ({chunk_size}) is " + "not a multiple of 256.\n" + "Overriding chunk_size with an overestimated value of " + f"{rounded_chunk_size} tokens." + ) + + chunk_size = rounded_chunk_size + + return chunk_size + + def create_cache(self, raw_cache_mode: str, model: Model): + # Cast exl2 types to exl3 + match raw_cache_mode: + case "Q4": + raw_cache_mode = "4,4" + case "Q6": + raw_cache_mode = "6,6" + case "Q8": + raw_cache_mode = "8,8" + + split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", raw_cache_mode) + + if split_cache_mode: + draft_k_bits = int(split_cache_mode.group(1)) + draft_v_bits = int(split_cache_mode.group(2)) + cache = Cache( + model, + max_num_tokens=self.cache_size, + layer_type=CacheLayer_quant, + k_bits=draft_k_bits, + v_bits=draft_v_bits, + ) + else: + cache = Cache(model, max_num_tokens=self.cache_size) + + return cache + + def model_info(self) -> ModelCard: + """ + Returns a dictionary of the current model's configuration parameters. + + Returns: + Model parameters provided by the backend + """ + + model_params = ModelCardParameters( + max_seq_len=self.max_seq_len, + cache_size=self.cache_size, + max_batch_size=self.max_batch_size, + cache_mode=self.cache_mode, + chunk_size=self.chunk_size, + use_vision=self.use_vision, + ) + + if self.prompt_template: + model_params.prompt_template = self.prompt_template.name + model_params.prompt_template_content = self.prompt_template.raw_template + + model_card = ModelCard( + id=self.model_dir.name, + parameters=model_params, + ) + + return model_card + + async def wait_for_jobs(self, skip_wait: bool = False): + """ + Polling to wait for any active generation jobs to complete. + + Args: + skip_wait: If True, cancel jobs immediately instead of waiting. + """ + + if not self.generator: + return + + # Immediately abort all jobs if asked + if skip_wait: + logger.warning( + "Immediately terminating all jobs. " + "Clients will have their requests cancelled.\n" + ) + + for job in self.active_job_ids.values(): + if job: + await job.cancel() + + while len(self.active_job_ids) > 0: + await asyncio.sleep(0.01) + + async def load(self, progress_callback=None, **kwargs): + """ + Loads the model into memory. + + Args: + progress_callback: Optional callback for progress updates. + **kwargs: Additional loading options. + """ + + async for _ in self.load_gen(progress_callback): + pass + + async def load_gen(self, progress_callback=None, **kwargs): + """ + Loads the model into memory, yielding progress updates. + + Args: + progress_callback: Optional callback for progress updates. + **kwargs: Additional loading options. + + Yields: + Progress updates + """ + + try: + await self.load_lock.acquire() + + # Wait for existing generation jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) + + generator = self.load_model_sync(progress_callback) + async for value in iterate_in_threadpool(generator): + yield value + + # Create async generator + await self.create_generator() + + # Clean up any extra vram usage from torch and cuda + # (Helps reduce VRAM bottlenecking on Windows) + gc.collect() + torch.cuda.empty_cache() + + # Cleanup and update model load state + self.loaded = True + logger.info("Model successfully loaded.") + finally: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + + @torch.inference_mode() + def load_model_sync(self, progress_callback=None): + if self.use_draft_model: + for value in self.draft_model.load_gen( + reserve_per_device=self.autosplit_reserve, + callback=progress_callback, + ): + if value: + yield value + + for value in self.model.load_gen( + reserve_per_device=self.autosplit_reserve, + use_per_device=self.gpu_split, + callback=progress_callback, + ): + if value: + yield value + + async def create_generator(self): + """Create and save a Exllama generator class.""" + + try: + # Don't acquire locks unless a model is loaded + if self.loaded: + await self.load_lock.acquire() + + # Immediately cancel all jobs + await self.wait_for_jobs(skip_wait=True) + + # Create new generator + self.generator = AsyncGenerator( + model=self.model, + cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, + tokenizer=self.tokenizer, + max_batch_size=self.max_batch_size, + max_chunk_size=self.chunk_size, + ) + + # Update the state of the container var + if self.max_batch_size is None: + self.max_batch_size = self.generator.generator.max_batch_size + finally: + # This means the generator is being recreated + # The load lock is already released in the load function + if self.loaded: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + + async def unload(self, loras_only: bool = False, **kwargs): + """ + Unloads the model and associated resources from memory. + + Args: + loras_only: If True, only unload LoRAs. + **kwargs: Additional unloading options (e.g., shutdown). + """ + + # Used when shutting down the server + do_shutdown = kwargs.get("shutdown") + + try: + if not do_shutdown: + await self.load_lock.acquire() + + # Wait for other jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) + + self.model.unload() + self.model = None + self.config = None + self.cache = None + self.tokenizer = None + + if self.use_draft_model: + self.draft_model.unload() + self.draft_model = None + self.draft_config = None + self.draft_cache = None + + # Cleanup the generator from any pending jobs + if self.generator is not None: + await self.generator.close() + self.generator = None + + gc.collect() + torch.cuda.empty_cache() + + logger.info("Model unloaded.") + finally: + if not do_shutdown: + self.load_lock.release() + + async with self.load_condition: + self.load_condition.notify_all() + + def encode_tokens(self, text: str, **kwargs) -> List[int]: + """ + Encodes a string of text into a list of token IDs. + + Args: + text: The input text string. + **kwargs: Backend-specific encoding options (e.g., add_bos_token). + + Returns: + A list of integer token IDs. + """ + + return ( + self.tokenizer.encode( + text, + add_bos=unwrap(kwargs.get("add_bos_token"), True), + encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), + ) + .flatten() + .tolist() + ) + + def decode_tokens(self, ids: List[int], **kwargs) -> str: + """ + Decodes a list of token IDs back into a string. + + Args: + ids: A list of integer token IDs. + **kwargs: Backend-specific decoding options (e.g., decode_special_tokens). + + Returns: + The decoded text string. + """ + + ids = torch.tensor([ids]) + return self.tokenizer.decode( + ids, + decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), + )[0] + + def get_special_tokens( + self, add_bos_token: bool = True, ban_eos_token: bool = False + ): + """ + Gets special tokens used by the model/tokenizer. + + Args: + **kwargs: Options like add_bos_token, ban_eos_token. + + Returns: + A dictionary mapping special token names (e.g., 'bos_token', 'eos_token') + to their string or ID representation. + """ + + return { + "bos_token": self.tokenizer.bos_token if add_bos_token else "", + "eos_token": self.tokenizer.eos_token if not ban_eos_token else "", + "pad_token": self.tokenizer.pad_token, + "unk_token": self.tokenizer.unk_token, + } + + async def generate( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ) -> Dict[str, Any]: + """ + Generates a complete response for a given prompt and parameters. + + Args: + request_id: Unique identifier for the generation request. + prompt: The input prompt string. + params: Sampling and generation parameters. + abort_event: An asyncio Event to signal cancellation. + mm_embeddings: Optional multimodal embeddings. + + Returns: + A dictionary containing the generation info + """ + + generations = [] + async for generation in self.stream_generate( + request_id, + prompt, + params, + abort_event, + mm_embeddings, + ): + generations.append(generation) + + joined_generation = { + "text": "", + "prompt_tokens": 0, + "generation_tokens": 0, + "tool_calls": None, + "offset": [], + "token_probs": {}, + "logprobs": [], + } + + if generations: + # Get finish_reason first and then shift where -1 points to + if "finish_reason" in generations[-1]: + finish_reason_gen = generations.pop() + joined_generation["finish_reason"] = finish_reason_gen.get( + "finish_reason" + ) + joined_generation["stop_str"] = finish_reason_gen.get("stop_str") + else: + joined_generation["finish_reason"] = "stop" + + if len(generations) > 0: + for generation in generations: + joined_generation["text"] += unwrap(generation.get("text"), "") + joined_generation["offset"].append(unwrap(generation.get("offset"), -1)) + joined_generation["token_probs"].update( + unwrap(generation.get("token_probs"), {}) + ) + + # Include empty logprob dicts for index preservation + joined_generation["logprobs"].append( + unwrap(generation.get("logprobs"), {}) + ) + + joined_generation["prompt_tokens"] = unwrap( + generations[-1].get("prompt_tokens"), 0 + ) + joined_generation["generated_tokens"] = unwrap( + generations[-1].get("generated_tokens"), 0 + ) + + return joined_generation + + async def stream_generate( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ) -> AsyncIterator[Dict[str, Any]]: + """ + Generates a response iteratively (streaming) for a given prompt. + + Args: + request_id: Unique identifier for the generation request. + prompt: The input prompt string. + params: Sampling and generation parameters. + abort_event: An asyncio Event to signal cancellation. + mm_embeddings: Optional multimodal embeddings. + + Yields: + Generation chunks + """ + + try: + # Wait for load lock to be freed before processing + # Mainly used for loras and other operations where the class is available + async with self.load_condition: + await self.load_condition.wait_for(lambda: not self.load_lock.locked()) + + # If the model is being unloaded, don't accept new requests + if not self.loaded: + raise RuntimeError( + "Model is being unloaded. Cannot process new generation requests." + ) + + # Mark that the job is running + self.active_job_ids[request_id] = None + + # Yield from the internal generator + async for generation_chunk in self.generate_gen( + request_id=request_id, + prompt=prompt, + params=params, + abort_event=abort_event, + mm_embeddings=mm_embeddings, + ): + yield generation_chunk + finally: + # Clean up and remove the job from active IDs + del self.active_job_ids[request_id] + + def handle_finish_chunk(self, result: dict, generation: dict): + eos_reason = result.get("eos_reason") + + stop_str = None + if eos_reason == "max_new_tokens": + finish_reason = "length" + else: + finish_reason = "stop" + # Grab stop string if stop was the reason + if eos_reason == "stop_token": + stop_str = result.get("eos_triggering_token_str") + elif eos_reason == "stop_string": + stop_str = result.get("eos_triggering_string") + + finish_chunk = { + "prompt_tokens": generation.get("prompt_tokens"), + "generated_tokens": generation.get("generated_tokens"), + "finish_reason": finish_reason, + "stop_str": stop_str, + } + + return finish_chunk + + async def generate_gen( + self, + request_id: str, + prompt: str, + params: BaseSamplerRequest, + abort_event: Optional[asyncio.Event] = None, + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None, + ): + """ + Create generator function for prompt completion. + + for kwargs, check common/sampling.py + """ + chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor] + + sampler_builder = ExllamaV3SamplerBuilder() + + # Penalties + + # Set penalty range + penalty_range = unwrap(params.penalty_range, self.max_seq_len) + + # Exl3's version of including the entire context + if penalty_range < 0: + penalty_range = int(10e7) + + # Always make sure the fallback is 0 if range < 0 + # It's technically fine to use -1, but this just validates the passed + # fallback + # Always default to 0 if something goes wrong + if params.penalty_range < 0: + fallback_decay = 0 + else: + fallback_decay = params.penalty_range + + repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0) + + # Apply penalties to builder + sampler_builder.penalties( + params.repetition_penalty, + params.frequency_penalty, + params.presence_penalty, + penalty_range, + repetition_decay, + ) + + # Apply temperature first to builder + if not params.temperature_last: + sampler_builder.temperature(params.temperature) + + # Apply alphabet samplers to builder + sampler_builder.top_k(params.top_k) + sampler_builder.top_p(params.top_p) + sampler_builder.min_p(params.min_p) + + # Apply temperature last to builder + if params.temperature_last: + sampler_builder.temperature(params.temperature) + + # Build the sampler + # Set greedy if temperature is 0 + sampler = sampler_builder.build(params.temperature == 0) + + # Dynamically scale penalty range to output tokens + # Only do this if freq/pres pen is enabled + # and the repetition range is -1 + # TODO: This currently does not work in exl3 + # auto_scale_penalty_range = ( + # gen_settings.token_frequency_penalty != 0 + # or gen_settings.token_presence_penalty != 0 + # ) and gen_settings.token_repetition_range == -1 + + prompts = [prompt] + stop_conditions = params.stop + add_bos_token = unwrap( + params.add_bos_token, self.tokenizer_config.add_bos_token + ) + + # Fetch EOS tokens from generation_config if they exist + eos_tokens = ( + self.generation_config.eos_tokens() + if self.generation_config + else [self.tokenizer.eos_token_id] + ) + + stop_conditions += eos_tokens + + input_ids = [ + self.tokenizer.encode( + prompt, + add_bos=add_bos_token, + encode_special_tokens=True, + ) + for prompt in prompts + ] + + # The first index will always be the positive prompt + context_len = input_ids[0].size(dim=-1) + + # Automatically set max_tokens to fill up the context + # This should be an OK default, but may be changed in the future + max_tokens = unwrap( + params.max_tokens, + self.max_seq_len - context_len, + ) + if max_tokens < 1: + logger.warning("max_tokens must be a positive integer, setting to 1.") + max_tokens = 1 + + # Determine if the negative context or the context length is bigger + context_to_check = context_len + + # Check total length of prompt against max context length + if context_to_check > self.max_seq_len: + preamble = "Prompt" + + raise ValueError( + f"{preamble} length {context_to_check} is greater than " + f"max_seq_len {self.max_seq_len}" + ) + + generation = {} + job = AsyncJob( + self.generator, + sampler=sampler, + input_ids=self.tokenizer.encode(prompt, add_bos=False), + max_new_tokens=max_tokens, + stop_conditions=stop_conditions, + banned_strings=params.banned_strings, + ) + + generated_tokens = 0 + full_response = "" + metrics_result = {} + + # Get the generation status once it's ready + try: + async for result in job: + # Abort if the event is set while streaming + if abort_event and abort_event.is_set(): + await job.cancel() + break + + chunk = unwrap(result.get("text"), "") + if chunk: + chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) + full_response += chunk + if isinstance(chunk_tokens, torch.Tensor): + generated_tokens += chunk_tokens.size(dim=0) + + # Increase penalty range to generated token amount + # TODO: + # if auto_scale_penalty_range: + # gen_settings.token_repetition_range = generated_tokens + + generation = { + "text": chunk, + "prompt_tokens": context_len, + "generated_tokens": generated_tokens, + "offset": len(full_response), + } + yield generation + + if result.get("eos"): + generation = self.handle_finish_chunk(result, generation) + + # Save the final result for metrics logging + metrics_result = result + + yield generation + break + # Assign the active job to the request ID + self.active_job_ids[request_id] = job + + except asyncio.CancelledError: + await job.cancel() + except Exception as ex: + # Create a new generator since the current state is broken + # No need to wait for this to finish + logger.error( + "FATAL ERROR with generation. " + "Attempting to recreate the generator. " + "If this fails, please restart the server.\n" + ) + asyncio.ensure_future(self.create_generator()) + + await HealthManager.add_unhealthy_event(ex) + + raise ex + finally: + # Log generation options to console + # Some options are too large, so log the args instead + log_generation_params( + request_id=request_id, + bos_token_id=self.tokenizer.bos_token_id, + eos_token_id=eos_tokens, + prompt=prompt, + **params.model_dump(exclude={"prompt"}), + # auto_scale_penalty_range=auto_scale_penalty_range, # TODO + ) + + # Log the metrics if present + if metrics_result: + log_metrics( + request_id, + metrics_result.get("time_enqueued"), + metrics_result.get("prompt_tokens"), + metrics_result.get("cached_tokens"), + metrics_result.get("time_prefill"), + metrics_result.get("new_tokens"), + metrics_result.get("time_generate"), + context_len, + self.max_seq_len, + ) diff --git a/backends/exllamav3/sampler.py b/backends/exllamav3/sampler.py new file mode 100644 index 0000000..7b08a9b --- /dev/null +++ b/backends/exllamav3/sampler.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass, field +from typing import List +from exllamav3.generator.sampler import ( + CustomSampler, + SS_Temperature, + SS_RepP, + SS_PresFreqP, + SS_Argmax, + SS_MinP, + SS_TopK, + SS_TopP, + SS_Sample, + SS_Base, +) + + +@dataclass +class ExllamaV3SamplerBuilder: + """ + Custom sampler chain/stack for TabbyAPI + """ + + stack: List[SS_Base] = field(default_factory=list) + + def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay): + self.stack += [ + SS_RepP(rep_p, penalty_range, rep_decay), + SS_PresFreqP(pres_p, freq_p, penalty_range, rep_decay), + ] + + def temperature(self, temp): + self.stack.append(SS_Temperature(temp)) + + def top_k(self, top_k): + self.stack.append(SS_TopK(top_k)) + + def top_p(self, top_p): + self.stack.append(SS_TopP(top_p)) + + def min_p(self, min_p): + self.stack.append(SS_MinP(min_p)) + + def greedy(self): + self.stack.append(SS_Argmax()) + + def build(self, greedy): + """Builds the final sampler from stack.""" + + # Use greedy if temp is 0 + if greedy: + return CustomSampler([SS_Argmax()]) + else: + self.stack.append(SS_Sample()) + return CustomSampler(self.stack) diff --git a/common/config_models.py b/common/config_models.py index a31bd3e..0958a8e 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -1,6 +1,7 @@ from pydantic import ( BaseModel, ConfigDict, + constr, Field, PrivateAttr, field_validator, @@ -9,6 +10,7 @@ from typing import List, Literal, Optional, Union CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] +CACHE_TYPE = Union[CACHE_SIZES, constr(pattern=r"^[2-8]\s*,\s*[2-8]$")] class Metadata(BaseModel): @@ -163,6 +165,13 @@ class ModelConfig(BaseConfigModel): "Example: ['max_seq_len', 'cache_mode']." ), ) + backend: Optional[str] = Field( + None, + description=( + "Backend to use for this model (auto-detect if not specified)\n" + "Options: exllamav2, exllamav3" + ), + ) max_seq_len: Optional[int] = Field( None, description=( @@ -186,7 +195,7 @@ class ModelConfig(BaseConfigModel): "Not parsed for single GPU users." ), ) - autosplit_reserve: List[int] = Field( + autosplit_reserve: List[float] = Field( [96], description=( "Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n" @@ -218,11 +227,13 @@ class ModelConfig(BaseConfigModel): "or auto-calculate." ), ) - cache_mode: Optional[CACHE_SIZES] = Field( + cache_mode: Optional[CACHE_TYPE] = Field( "FP16", description=( "Enable different cache modes for VRAM savings (default: FP16).\n" - f"Possible values: {str(CACHE_SIZES)[15:-1]}." + f"Possible values for exllamav2: {str(CACHE_SIZES)[15:-1]}.\n" + "For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits " + "are integers from 2-8 (i.e. 8,8)." ), ) cache_size: Optional[int] = Field( diff --git a/common/hardware.py b/common/hardware.py new file mode 100644 index 0000000..10723c5 --- /dev/null +++ b/common/hardware.py @@ -0,0 +1,20 @@ +import torch + + +def hardware_supports_flash_attn(gpu_device_list: list[int]): + """ + Check whether all GPUs in list support FA2 + + Compute capability < 8 is not supported by FA2 + AMD is also unsupported until ROCm updates its FA2 fork + """ + + min_compute_capability = min( + torch.cuda.get_device_capability(device=device_idx)[0] + for device_idx in gpu_device_list + ) + + if torch.version.hip or min_compute_capability < 8: + return False + else: + return True diff --git a/common/model.py b/common/model.py index cc26b43..9cdfdeb 100644 --- a/common/model.py +++ b/common/model.py @@ -10,23 +10,34 @@ from enum import Enum from fastapi import HTTPException from loguru import logger from ruamel.yaml import YAML -from typing import Optional +from typing import Dict, Optional from backends.base_model_container import BaseModelContainer from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.tabby_config import config from common.optional_dependencies import dependencies +from common.transformers_utils import HuggingFaceConfig from common.utils import unwrap # Global variables for model container container: Optional[BaseModelContainer] = None embeddings_container = None -# FIXME: Possibly use this solely when creating the model + +_BACKEND_REGISTRY: Dict[str, BaseModelContainer] = {} + if dependencies.exllamav2: from backends.exllamav2.model import ExllamaV2Container + _BACKEND_REGISTRY["exllamav2"] = ExllamaV2Container + + +if dependencies.exllamav3: + from backends.exllamav3.model import ExllamaV3Container + + _BACKEND_REGISTRY["exllamav3"] = ExllamaV3Container + if dependencies.extras: from backends.infinity.model import InfinityContainer @@ -46,6 +57,24 @@ def load_progress(module, modules): yield module, modules +async def detect_backend(model_path: pathlib.Path) -> str: + """Determine the appropriate backend based on model files and configuration.""" + + try: + hf_config = await HuggingFaceConfig.from_directory(model_path) + quant_method = hf_config.quant_method() + + if quant_method == "exl3": + return "exllamav3" + else: + return "exllamav2" + except Exception as exc: + raise ValueError( + "Failed to read the model's config.json. " + f"Please check your model directory at {model_path}." + ) from exc + + async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): """Sets overrides from a model folder's config yaml.""" @@ -113,9 +142,28 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): kwargs = {**config.model_defaults, **kwargs} kwargs = await apply_inline_overrides(model_path, **kwargs) - # Create a new container - new_container = await ExllamaV2Container.create( - model_path.resolve(), False, **kwargs + # Create a new container and check if the right dependencies are installed + backend_name = unwrap( + kwargs.get("backend"), await detect_backend(model_path) + ).lower() + container_class = _BACKEND_REGISTRY.get(backend_name) + + if not container_class: + available_backends = list(_BACKEND_REGISTRY.keys()) + if backend_name in available_backends: + raise ValueError( + f"Backend '{backend_name}' selected, but required dependencies " + "are not installed." + ) + else: + raise ValueError( + f"Invalid backend '{backend_name}'. " + f"Available backends: {available_backends}" + ) + + logger.info(f"Using backend {backend_name}") + new_container: BaseModelContainer = await container_class.create( + model_path.resolve(), **kwargs ) # Add possible types of models that can be loaded @@ -124,7 +172,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): if new_container.use_vision: model_type.insert(0, ModelType.VISION) - if new_container.draft_config: + if new_container.use_draft_model: model_type.insert(0, ModelType.DRAFT) load_status = new_container.load_gen(load_progress, **kwargs) diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py index 06b1286..b449c2e 100644 --- a/common/optional_dependencies.py +++ b/common/optional_dependencies.py @@ -13,6 +13,7 @@ class DependenciesModel(BaseModel): torch: bool exllamav2: bool + exllamav3: bool flash_attn: bool infinity_emb: bool sentence_transformers: bool @@ -25,7 +26,7 @@ class DependenciesModel(BaseModel): @computed_field @property def inference(self) -> bool: - return self.torch and self.exllamav2 and self.flash_attn + return self.torch and (self.exllamav2 or (self.exllamav3 and self.flash_attn)) def is_installed(package_name: str) -> bool: diff --git a/common/sampling.py b/common/sampling.py index fc9f9bc..49be5b9 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -205,7 +205,7 @@ class BaseSamplerRequest(BaseModel): ) add_bos_token: Optional[bool] = Field( - default_factory=lambda: get_default_sampler_value("add_bos_token", True) + default_factory=lambda: get_default_sampler_value("add_bos_token") ) ban_eos_token: Optional[bool] = Field( @@ -215,11 +215,6 @@ class BaseSamplerRequest(BaseModel): examples=[False], ) - skip_special_tokens: Optional[bool] = Field( - default_factory=lambda: get_default_sampler_value("skip_special_tokens", True), - examples=[True], - ) - logit_bias: Optional[Dict[int, float]] = Field( default_factory=lambda: get_default_sampler_value("logit_bias"), examples=[{"1": 10, "2": 50}], diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 045312c..d1e5ac1 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,7 +1,7 @@ import aiofiles import json import pathlib -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel @@ -42,8 +42,10 @@ class HuggingFaceConfig(BaseModel): Will be expanded as needed. """ + quantization_config: Optional[Dict] = None + @classmethod - async def from_file(cls, model_directory: pathlib.Path): + async def from_directory(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json" @@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel): hf_config_dict = json.loads(contents) return cls.model_validate(hf_config_dict) + def quant_method(self): + """Wrapper method to fetch quant type""" + + if isinstance(self.quantization_config, Dict): + return self.quantization_config.get("quant_method") + else: + return None + class TokenizerConfig(BaseModel): """ diff --git a/config_sample.yml b/config_sample.yml index b6f362d..ffe2605 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -74,6 +74,10 @@ model: # Example: ['max_seq_len', 'cache_mode']. use_as_default: [] + # Backend to use for this model (auto-detect if not specified) + # Options: exllamav2, exllamav3 + backend: + # Max sequence length (default: Empty). # Fetched from the model's base sequence length in config.json by default. max_seq_len: @@ -110,7 +114,8 @@ model: rope_alpha: # Enable different cache modes for VRAM savings (default: FP16). - # Possible values: 'FP16', 'Q8', 'Q6', 'Q4'. + # Possible values for exllamav2: 'FP16', 'Q8', 'Q6', 'Q4'. + # For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits are integers from 2-8 (i.e. 8,8). cache_mode: FP16 # Size of the prompt cache to allocate (default: max_seq_len). @@ -160,7 +165,8 @@ draft_model: draft_rope_alpha: # Cache mode for draft models to save VRAM (default: FP16). - # Possible values: 'FP16', 'Q8', 'Q6', 'Q4'. + # Possible values for exllamav2: 'FP16', 'Q8', 'Q6', 'Q4'. + # For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits are integers from 2-8 (i.e. 8,8). draft_cache_mode: FP16 # An integer array of GBs of VRAM to split between GPUs (default: []). diff --git a/endpoints/Kobold/utils/generation.py b/endpoints/Kobold/utils/generation.py index 39ddb9f..0d345d8 100644 --- a/endpoints/Kobold/utils/generation.py +++ b/endpoints/Kobold/utils/generation.py @@ -53,8 +53,12 @@ async def _stream_collector(data: GenerateRequest, request: Request): logger.info(f"Received Kobold generation request {data.genkey}") generator = model.container.stream_generate( - request_id=data.genkey, abort_event=abort_event, **data.model_dump() + request_id=data.genkey, + prompt=data.prompt, + params=data, + abort_event=abort_event, ) + async for generation in generator: if disconnect_task.done(): abort_event.set() diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index d1209a6..36934a8 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -82,10 +82,13 @@ class ChatCompletionRequest(CommonCompletionRequest): tool_call_end: SkipJsonSchema[Optional[str]] = None tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema + # Chat completions requests do not have a BOS token preference. Backend + # respects the tokenization config for the individual model. + add_bos_token: Optional[bool] = None + @field_validator("add_bos_token", mode="after") def force_bos_token(cls, v): """Always disable add_bos_token with chat completions.""" - return None diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 02213f9..6855108 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -81,7 +81,10 @@ class ModelLoadRequest(BaseModel): ) # Config arguments - + backend: Optional[str] = Field( + description="Backend to use", + default="exllamav2", + ) max_seq_len: Optional[int] = Field( description="Leave this blank to use the model's base sequence length", default=None, diff --git a/main.py b/main.py index df4e472..2159508 100644 --- a/main.py +++ b/main.py @@ -15,12 +15,11 @@ from common.auth import load_auth_keys from common.actions import run_subcommand from common.logger import setup_logger from common.networking import is_port_in_use +from common.optional_dependencies import dependencies from common.signals import signal_handler from common.tabby_config import config from endpoints.server import start_api -from backends.exllamav2.version import check_exllama_version - async def entrypoint_async(): """Async entry function for program startup""" @@ -139,8 +138,21 @@ def entrypoint( "UNSAFE: Skipping ExllamaV2 version check.\n" "If you aren't a developer, please keep this off!" ) - else: - check_exllama_version() + elif not dependencies.inference: + install_message = ( + f"ERROR: Inference dependencies for TabbyAPI are not installed.\n" + "Please update your environment by running an update script " + "(update_scripts/" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" + "Or you can manually run a requirements update " + "using the following command:\n\n" + "For CUDA 12.1:\n" + "pip install --upgrade .[cu121]\n\n" + "For ROCm:\n" + "pip install --upgrade .[amd]\n\n" + ) + + raise SystemExit(install_message) # Enable CUDA malloc backend if config.developer.cuda_malloc_backend: diff --git a/pyproject.toml b/pyproject.toml index b16a22c..cee9673 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,16 @@ cu121 = [ "exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", "exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + # Exl3 + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + # Windows FA2 from https://github.com/kingbri1/flash-attention/releases "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index 9225976..0a2d05c 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -131,14 +131,11 @@ mirostat_eta: # MARK: Token options add_bos_token: - override: true + override: force: false ban_eos_token: override: false force: false -skip_special_tokens: - override: true - force: false logit_bias: override: force: false