commit
3674d7b9b5
19 changed files with 1273 additions and 213 deletions
|
|
@ -25,6 +25,10 @@ class BaseModelContainer(abc.ABC):
|
|||
prompt_template: Optional[PromptTemplate] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# Optional features
|
||||
use_draft_model: bool = False
|
||||
use_vision: bool = False
|
||||
|
||||
# Load synchronization
|
||||
# The bool is a master switch for accepting requests
|
||||
# The lock keeps load tasks sequential
|
||||
|
|
@ -65,7 +69,7 @@ class BaseModelContainer(abc.ABC):
|
|||
|
||||
# NOTE: Might be an optional method
|
||||
@abc.abstractmethod
|
||||
async def load_gen(self, progress_callback=None, **kwargs) -> AsyncIterator[Any]:
|
||||
async def load_gen(self, progress_callback=None, **kwargs):
|
||||
"""
|
||||
Loads the model into memory, yielding progress updates.
|
||||
|
||||
|
|
@ -134,57 +138,6 @@ class BaseModelContainer(abc.ABC):
|
|||
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a complete response for a given prompt and parameters.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generation info
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def stream_generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Generates a response iteratively (streaming) for a given prompt.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Yields:
|
||||
Generation chunks
|
||||
"""
|
||||
|
||||
if False:
|
||||
yield
|
||||
|
||||
@abc.abstractmethod
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
|
|
@ -239,3 +192,54 @@ class BaseModelContainer(abc.ABC):
|
|||
"""
|
||||
|
||||
return []
|
||||
|
||||
@abc.abstractmethod
|
||||
async def generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a complete response for a given prompt and parameters.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generation info
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def stream_generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Generates a response iteratively (streaming) for a given prompt.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Yields:
|
||||
Generation chunks
|
||||
"""
|
||||
|
||||
if False:
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -33,11 +33,7 @@ from backends.exllamav2.grammar import (
|
|||
ExLlamaV2Grammar,
|
||||
clear_grammar_func_cache,
|
||||
)
|
||||
from backends.exllamav2.utils import (
|
||||
exllama_disabled_flash_attn,
|
||||
hardware_supports_flash_attn,
|
||||
supports_paged_attn,
|
||||
)
|
||||
from backends.exllamav2.utils import exllama_disabled_flash_attn
|
||||
from backends.exllamav2.vision import clear_image_embedding_cache
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.gen_logging import (
|
||||
|
|
@ -46,6 +42,7 @@ from common.gen_logging import (
|
|||
log_prompt,
|
||||
log_response,
|
||||
)
|
||||
from common.hardware import hardware_supports_flash_attn
|
||||
from common.health import HealthManager
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
|
|
@ -64,16 +61,19 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
|
||||
# Exl2 vars
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
model: Optional[ExLlamaV2] = None
|
||||
draft_model: Optional[ExLlamaV2] = None
|
||||
cache: Optional[ExLlamaV2Cache] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
paged: bool = True
|
||||
|
||||
# Draft model vars
|
||||
use_draft_model: bool = False
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
draft_model: Optional[ExLlamaV2] = None
|
||||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
|
||||
# Internal config vars
|
||||
cache_size: int = None
|
||||
cache_mode: str = "FP16"
|
||||
|
|
@ -100,7 +100,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Primary asynchronous initializer for model container.
|
||||
|
||||
|
|
@ -110,8 +110,6 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Create a new instance as a "fake self"
|
||||
self = cls()
|
||||
|
||||
self.quiet = quiet
|
||||
|
||||
# Initialize config
|
||||
self.config = ExLlamaV2Config()
|
||||
self.model_dir = model_directory
|
||||
|
|
@ -162,7 +160,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
self.use_draft_model = draft_args and draft_model_name
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
|
|
@ -170,9 +168,9 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
enable_draft = False
|
||||
self.use_draft_model = False
|
||||
|
||||
if enable_draft:
|
||||
if self.use_draft_model:
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
|
|
@ -189,6 +187,15 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Get cache mode
|
||||
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
||||
|
||||
# Catch exllamav3 cache_mode
|
||||
if not self.cache_mode.startswith("Q"):
|
||||
logger.warning(
|
||||
f"Provided cache mode '{self.cache_mode}' is not a "
|
||||
"valid choice for exllamav2, please check your settings. "
|
||||
"Defaulting to FP16."
|
||||
)
|
||||
self.cache_mode = "FP16"
|
||||
|
||||
# Turn off GPU split if the user is using 1 GPU
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
|
|
@ -276,11 +283,20 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
|
||||
# Check whether the user's configuration supports flash/paged attention
|
||||
# Also check if exl2 has disabled flash attention
|
||||
if (
|
||||
exllama_disabled_flash_attn(self.config.no_flash_attn)
|
||||
or not hardware_supports_flash_attn(gpu_device_list)
|
||||
or not supports_paged_attn()
|
||||
):
|
||||
if exllama_disabled_flash_attn(
|
||||
self.config.no_flash_attn
|
||||
) or not hardware_supports_flash_attn(gpu_device_list):
|
||||
gpu_unsupported_message = (
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
|
||||
logger.warning(gpu_unsupported_message)
|
||||
|
||||
self.config.no_flash_attn = True
|
||||
if self.draft_config:
|
||||
self.draft_config.no_flash_attn = True
|
||||
|
|
@ -365,7 +381,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
self.config.max_attention_size = chunk_size**2
|
||||
|
||||
# Set user-configured draft model values
|
||||
if enable_draft:
|
||||
if self.use_draft_model:
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
self.draft_config.scale_pos_emb = unwrap(
|
||||
|
|
@ -385,6 +401,15 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Set draft cache mode
|
||||
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
|
||||
|
||||
# Catch exllamav3 draft_cache_mode
|
||||
if not self.draft_cache_mode.startswith("Q"):
|
||||
logger.warning(
|
||||
f"Provided draft cache mode '{self.draft_cache_mode}' is not a "
|
||||
"valid choice for exllamav2, please check your settings. "
|
||||
"Defaulting to FP16."
|
||||
)
|
||||
self.draft_cache_mode = "FP16"
|
||||
|
||||
# Edit the draft config size
|
||||
if chunk_size:
|
||||
self.draft_config.max_input_len = chunk_size
|
||||
|
|
@ -531,8 +556,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Load draft model if a config is present
|
||||
if self.draft_config:
|
||||
self.draft_model = ExLlamaV2(self.draft_config)
|
||||
if not self.quiet:
|
||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||
|
||||
# Draft uses the autosplit loader, so create a cache that reflects this
|
||||
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
|
||||
|
|
@ -585,8 +609,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
yield value
|
||||
|
||||
self.model = ExLlamaV2(self.config)
|
||||
if not self.quiet:
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
|
||||
# Get class of the model cache
|
||||
cache_class = self.get_cache_class(self.cache_mode)
|
||||
|
|
@ -1350,7 +1373,7 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
min_new_tokens=params.min_tokens,
|
||||
gen_settings=gen_settings,
|
||||
stop_conditions=stop_conditions,
|
||||
decode_special_tokens=not params.skip_special_tokens,
|
||||
decode_special_tokens=True,
|
||||
filters=grammar_handler.filters,
|
||||
filter_prefer_eos=bool(grammar_handler.filters),
|
||||
return_probs=params.logprobs > 0,
|
||||
|
|
|
|||
|
|
@ -1,74 +1,6 @@
|
|||
import platform
|
||||
import torch
|
||||
from packaging import version
|
||||
from importlib.metadata import PackageNotFoundError, version as package_version
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def hardware_supports_flash_attn(gpu_device_list: list[int]):
|
||||
"""
|
||||
Check whether all GPUs in list support FA2
|
||||
|
||||
Compute capability < 8 is not supported by FA2
|
||||
AMD is also unsupported until ROCm updates its FA2 fork
|
||||
"""
|
||||
|
||||
# Logged message if unsupported
|
||||
unsupported_message = (
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device_idx)[0]
|
||||
for device_idx in gpu_device_list
|
||||
)
|
||||
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(unsupported_message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def supports_paged_attn():
|
||||
"""Check whether the user's flash-attn version supports paged mode"""
|
||||
|
||||
# Logged message if unsupported
|
||||
unsupported_message = (
|
||||
"Flash attention version >=2.5.7 "
|
||||
"is required to use paged attention. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"Please upgrade your environment by running an update script "
|
||||
"(update_scripts/"
|
||||
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"NOTE: Windows users must use CUDA 12.x to use flash-attn."
|
||||
)
|
||||
|
||||
required_version = version.parse("2.5.7")
|
||||
try:
|
||||
current_version = version.parse(package_version("flash-attn").split("+")[0])
|
||||
except PackageNotFoundError:
|
||||
logger.warning(unsupported_message)
|
||||
return False
|
||||
|
||||
if current_version < required_version:
|
||||
logger.warning(unsupported_message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def exllama_disabled_flash_attn(no_flash_attn: bool):
|
||||
unsupported_message = (
|
||||
"ExllamaV2 has disabled Flash Attention. \n"
|
||||
|
|
|
|||
|
|
@ -1,37 +0,0 @@
|
|||
import platform
|
||||
from packaging import version
|
||||
from importlib.metadata import version as package_version
|
||||
from loguru import logger
|
||||
from common.optional_dependencies import dependencies
|
||||
|
||||
|
||||
def check_exllama_version():
|
||||
"""Verifies the exllama version"""
|
||||
|
||||
install_message = (
|
||||
"Please update your environment by running an update script "
|
||||
"(update_scripts/"
|
||||
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"For ROCm:\n"
|
||||
"pip install --upgrade .[amd]\n\n"
|
||||
)
|
||||
|
||||
if not dependencies.exllamav2:
|
||||
raise SystemExit(("Exllamav2 is not installed.\n" + install_message))
|
||||
|
||||
required_version = version.parse("0.2.8")
|
||||
current_version = version.parse(package_version("exllamav2").split("+")[0])
|
||||
|
||||
unsupported_message = (
|
||||
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
|
||||
f"or greater. Your current version is {current_version}.\n" + install_message
|
||||
)
|
||||
|
||||
if current_version < required_version:
|
||||
raise SystemExit(unsupported_message)
|
||||
else:
|
||||
logger.info(f"ExllamaV2 version: {current_version}")
|
||||
964
backends/exllamav3/model.py
Normal file
964
backends/exllamav3/model.py
Normal file
|
|
@ -0,0 +1,964 @@
|
|||
import asyncio
|
||||
import gc
|
||||
import pathlib
|
||||
import re
|
||||
import traceback
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import torch
|
||||
from exllamav3 import (
|
||||
AsyncGenerator,
|
||||
AsyncJob,
|
||||
Cache,
|
||||
Config,
|
||||
Model,
|
||||
Tokenizer,
|
||||
)
|
||||
from exllamav3.cache import CacheLayer_quant
|
||||
from loguru import logger
|
||||
|
||||
from backends.base_model_container import BaseModelContainer
|
||||
from backends.exllamav3.sampler import ExllamaV3SamplerBuilder
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.gen_logging import (
|
||||
log_generation_params,
|
||||
log_metrics,
|
||||
)
|
||||
from common.hardware import hardware_supports_flash_attn
|
||||
from common.health import HealthManager
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig, TokenizerConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||
|
||||
|
||||
class ExllamaV3Container(BaseModelContainer):
|
||||
"""Abstract base class for model containers."""
|
||||
|
||||
# Exposed model information
|
||||
model_dir: pathlib.Path = pathlib.Path("models")
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# Load synchronization
|
||||
# The bool is a master switch for accepting requests
|
||||
# The lock keeps load tasks sequential
|
||||
# The condition notifies any waiting tasks
|
||||
active_job_ids: Dict[str, Any] = {}
|
||||
loaded: bool = False
|
||||
load_lock: asyncio.Lock = asyncio.Lock()
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
# Exl3 vars
|
||||
model: Optional[Model]
|
||||
cache: Optional[Cache]
|
||||
draft_model: Optional[Model]
|
||||
draft_cache: Optional[Cache]
|
||||
tokenizer: Optional[Tokenizer]
|
||||
config: Optional[Config]
|
||||
draft_config: Optional[Config]
|
||||
generator: Optional[AsyncGenerator]
|
||||
tokenizer_config: Optional[TokenizerConfig]
|
||||
|
||||
# Class-specific vars
|
||||
gpu_split: List[float] | None = None
|
||||
gpu_split_auto: bool = True
|
||||
autosplit_reserve: List[float] = [96 / 1024]
|
||||
use_tp: bool = False
|
||||
max_seq_len: int = 4096
|
||||
cache_size: int = 4096
|
||||
cache_mode: str = "FP16"
|
||||
draft_cache_mode: str = "FP16"
|
||||
chunk_size: int = 2048
|
||||
max_batch_size: Optional[int] = None
|
||||
|
||||
# Required methods
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Asynchronously creates and initializes a model container instance.
|
||||
|
||||
Args:
|
||||
model_directory: Path to the model files.
|
||||
**kwargs: Backend-specific configuration options.
|
||||
|
||||
Returns:
|
||||
An instance of the implementing class.
|
||||
"""
|
||||
|
||||
self = cls()
|
||||
|
||||
self.model = None
|
||||
self.cache = None
|
||||
self.draft_model = None
|
||||
self.draft_cache = None
|
||||
self.tokenizer = None
|
||||
self.config = None
|
||||
self.draft_config = None
|
||||
self.generator = None
|
||||
self.tokenizer_config = None
|
||||
|
||||
logger.warning(
|
||||
"ExllamaV3 is currently in an alpha state. "
|
||||
"Please note that all config options may not work."
|
||||
)
|
||||
|
||||
self.model_dir = model_directory
|
||||
self.config = Config.from_directory(str(model_directory.resolve()))
|
||||
self.model = Model.from_config(self.config)
|
||||
self.tokenizer = Tokenizer.from_config(self.config)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = await GenerationConfig.from_file(
|
||||
model_directory
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Load tokenizer config overrides
|
||||
tokenizer_config_path = model_directory / "tokenizer_config.json"
|
||||
if tokenizer_config_path.exists():
|
||||
try:
|
||||
self.tokenizer_config = await TokenizerConfig.from_file(model_directory)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping tokenizer config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
||||
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
self.use_draft_model = draft_args and draft_model_name
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
logger.warning(
|
||||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
self.use_draft_model = False
|
||||
|
||||
if self.use_draft_model:
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
)
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
|
||||
self.draft_model_dir = draft_model_path
|
||||
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
|
||||
self.draft_model = Model.from_config(self.draft_config)
|
||||
logger.info(f"Using draft model: {str(draft_model_path.resolve())}")
|
||||
else:
|
||||
self.draft_model = None
|
||||
self.draft_cache = None
|
||||
|
||||
# Turn off GPU split if the user is using 1 GPU
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
gpu_split = unwrap(kwargs.get("gpu_split"), None)
|
||||
gpu_device_list = list(range(0, gpu_count))
|
||||
|
||||
# Set GPU split options
|
||||
if gpu_count == 1:
|
||||
self.gpu_split_auto = False
|
||||
logger.info("Disabling GPU split because one GPU is in use.")
|
||||
else:
|
||||
# TODO: Set tensor parallel
|
||||
|
||||
# Set GPU split options
|
||||
# Enable manual GPU split if provided
|
||||
if gpu_split:
|
||||
self.gpu_split = gpu_split
|
||||
|
||||
gpu_device_list = [
|
||||
device_idx
|
||||
for device_idx, memory in enumerate(self.gpu_split)
|
||||
if memory > 0
|
||||
]
|
||||
elif gpu_split_auto and not self.use_tp:
|
||||
# Otherwise fallback to autosplit settings
|
||||
self.gpu_split_auto = gpu_split_auto
|
||||
|
||||
autosplit_reserve_megabytes = unwrap(
|
||||
kwargs.get("autosplit_reserve"), [96]
|
||||
)
|
||||
|
||||
# Reserve VRAM for each GPU
|
||||
self.autosplit_reserve = [
|
||||
value / 1024 for value in autosplit_reserve_megabytes
|
||||
]
|
||||
|
||||
if not hardware_supports_flash_attn(gpu_device_list):
|
||||
gpu_unsupported_message = (
|
||||
"Unable to run ExllamaV3 because an unsupported GPU is "
|
||||
"found in this configuration. \n"
|
||||
"All GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
|
||||
logger.warning(gpu_unsupported_message)
|
||||
|
||||
raise RuntimeError(gpu_unsupported_message)
|
||||
|
||||
# Cache
|
||||
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
|
||||
self.cache_size = self.adjust_cache_size(user_cache_size)
|
||||
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
||||
self.cache = self.create_cache(self.cache_mode, self.model)
|
||||
|
||||
# Draft cache
|
||||
if self.use_draft_model:
|
||||
# Set draft cache mode
|
||||
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
|
||||
self.draft_cache = self.create_cache(
|
||||
self.draft_cache_mode, self.draft_model
|
||||
)
|
||||
|
||||
# Max batch size
|
||||
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
|
||||
|
||||
# Make sure chunk size is >= 256, keep near or below max seq len
|
||||
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
|
||||
self.chunk_size = self.adjust_chunk_size(user_chunk_size)
|
||||
|
||||
# Template setup
|
||||
self.prompt_template = await find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
logger.info(
|
||||
f'Using template "{self.prompt_template.name}" for chat completions.'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Chat completions are disabled because a prompt "
|
||||
"template wasn't provided or auto-detected."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def adjust_cache_size(self, cache_size):
|
||||
if cache_size < self.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is smaller than the "
|
||||
"desired context length.\n"
|
||||
"Overriding cache_size to max_seq_len. "
|
||||
)
|
||||
|
||||
cache_size = self.max_seq_len
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given cache size ({cache_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding cache_size with an overestimated value of "
|
||||
f"{rounded_cache_size} tokens."
|
||||
)
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
# Warn user if cache size may be inadequate for CFG
|
||||
if cache_size < 2 * self.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
||||
"and may be too small for requests using CFG. \n"
|
||||
"Ignore this warning if you do not plan on using CFG."
|
||||
)
|
||||
|
||||
return cache_size
|
||||
|
||||
def adjust_chunk_size(self, user_chunk_size: int):
|
||||
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
|
||||
chunk_remainder = chunk_size % 256
|
||||
if chunk_remainder != 0:
|
||||
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given chunk size ({chunk_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding chunk_size with an overestimated value of "
|
||||
f"{rounded_chunk_size} tokens."
|
||||
)
|
||||
|
||||
chunk_size = rounded_chunk_size
|
||||
|
||||
return chunk_size
|
||||
|
||||
def create_cache(self, raw_cache_mode: str, model: Model):
|
||||
# Cast exl2 types to exl3
|
||||
match raw_cache_mode:
|
||||
case "Q4":
|
||||
raw_cache_mode = "4,4"
|
||||
case "Q6":
|
||||
raw_cache_mode = "6,6"
|
||||
case "Q8":
|
||||
raw_cache_mode = "8,8"
|
||||
|
||||
split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", raw_cache_mode)
|
||||
|
||||
if split_cache_mode:
|
||||
draft_k_bits = int(split_cache_mode.group(1))
|
||||
draft_v_bits = int(split_cache_mode.group(2))
|
||||
cache = Cache(
|
||||
model,
|
||||
max_num_tokens=self.cache_size,
|
||||
layer_type=CacheLayer_quant,
|
||||
k_bits=draft_k_bits,
|
||||
v_bits=draft_v_bits,
|
||||
)
|
||||
else:
|
||||
cache = Cache(model, max_num_tokens=self.cache_size)
|
||||
|
||||
return cache
|
||||
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
Returns a dictionary of the current model's configuration parameters.
|
||||
|
||||
Returns:
|
||||
Model parameters provided by the backend
|
||||
"""
|
||||
|
||||
model_params = ModelCardParameters(
|
||||
max_seq_len=self.max_seq_len,
|
||||
cache_size=self.cache_size,
|
||||
max_batch_size=self.max_batch_size,
|
||||
cache_mode=self.cache_mode,
|
||||
chunk_size=self.chunk_size,
|
||||
use_vision=self.use_vision,
|
||||
)
|
||||
|
||||
if self.prompt_template:
|
||||
model_params.prompt_template = self.prompt_template.name
|
||||
model_params.prompt_template_content = self.prompt_template.raw_template
|
||||
|
||||
model_card = ModelCard(
|
||||
id=self.model_dir.name,
|
||||
parameters=model_params,
|
||||
)
|
||||
|
||||
return model_card
|
||||
|
||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||
"""
|
||||
Polling to wait for any active generation jobs to complete.
|
||||
|
||||
Args:
|
||||
skip_wait: If True, cancel jobs immediately instead of waiting.
|
||||
"""
|
||||
|
||||
if not self.generator:
|
||||
return
|
||||
|
||||
# Immediately abort all jobs if asked
|
||||
if skip_wait:
|
||||
logger.warning(
|
||||
"Immediately terminating all jobs. "
|
||||
"Clients will have their requests cancelled.\n"
|
||||
)
|
||||
|
||||
for job in self.active_job_ids.values():
|
||||
if job:
|
||||
await job.cancel()
|
||||
|
||||
while len(self.active_job_ids) > 0:
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def load(self, progress_callback=None, **kwargs):
|
||||
"""
|
||||
Loads the model into memory.
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback for progress updates.
|
||||
**kwargs: Additional loading options.
|
||||
"""
|
||||
|
||||
async for _ in self.load_gen(progress_callback):
|
||||
pass
|
||||
|
||||
async def load_gen(self, progress_callback=None, **kwargs):
|
||||
"""
|
||||
Loads the model into memory, yielding progress updates.
|
||||
|
||||
Args:
|
||||
progress_callback: Optional callback for progress updates.
|
||||
**kwargs: Additional loading options.
|
||||
|
||||
Yields:
|
||||
Progress updates
|
||||
"""
|
||||
|
||||
try:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Wait for existing generation jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
generator = self.load_model_sync(progress_callback)
|
||||
async for value in iterate_in_threadpool(generator):
|
||||
yield value
|
||||
|
||||
# Create async generator
|
||||
await self.create_generator()
|
||||
|
||||
# Clean up any extra vram usage from torch and cuda
|
||||
# (Helps reduce VRAM bottlenecking on Windows)
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Cleanup and update model load state
|
||||
self.loaded = True
|
||||
logger.info("Model successfully loaded.")
|
||||
finally:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_model_sync(self, progress_callback=None):
|
||||
if self.use_draft_model:
|
||||
for value in self.draft_model.load_gen(
|
||||
reserve_per_device=self.autosplit_reserve,
|
||||
callback=progress_callback,
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
for value in self.model.load_gen(
|
||||
reserve_per_device=self.autosplit_reserve,
|
||||
use_per_device=self.gpu_split,
|
||||
callback=progress_callback,
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
async def create_generator(self):
|
||||
"""Create and save a Exllama generator class."""
|
||||
|
||||
try:
|
||||
# Don't acquire locks unless a model is loaded
|
||||
if self.loaded:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Immediately cancel all jobs
|
||||
await self.wait_for_jobs(skip_wait=True)
|
||||
|
||||
# Create new generator
|
||||
self.generator = AsyncGenerator(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_chunk_size=self.chunk_size,
|
||||
)
|
||||
|
||||
# Update the state of the container var
|
||||
if self.max_batch_size is None:
|
||||
self.max_batch_size = self.generator.generator.max_batch_size
|
||||
finally:
|
||||
# This means the generator is being recreated
|
||||
# The load lock is already released in the load function
|
||||
if self.loaded:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
async def unload(self, loras_only: bool = False, **kwargs):
|
||||
"""
|
||||
Unloads the model and associated resources from memory.
|
||||
|
||||
Args:
|
||||
loras_only: If True, only unload LoRAs.
|
||||
**kwargs: Additional unloading options (e.g., shutdown).
|
||||
"""
|
||||
|
||||
# Used when shutting down the server
|
||||
do_shutdown = kwargs.get("shutdown")
|
||||
|
||||
try:
|
||||
if not do_shutdown:
|
||||
await self.load_lock.acquire()
|
||||
|
||||
# Wait for other jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
self.model.unload()
|
||||
self.model = None
|
||||
self.config = None
|
||||
self.cache = None
|
||||
self.tokenizer = None
|
||||
|
||||
if self.use_draft_model:
|
||||
self.draft_model.unload()
|
||||
self.draft_model = None
|
||||
self.draft_config = None
|
||||
self.draft_cache = None
|
||||
|
||||
# Cleanup the generator from any pending jobs
|
||||
if self.generator is not None:
|
||||
await self.generator.close()
|
||||
self.generator = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
logger.info("Model unloaded.")
|
||||
finally:
|
||||
if not do_shutdown:
|
||||
self.load_lock.release()
|
||||
|
||||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
def encode_tokens(self, text: str, **kwargs) -> List[int]:
|
||||
"""
|
||||
Encodes a string of text into a list of token IDs.
|
||||
|
||||
Args:
|
||||
text: The input text string.
|
||||
**kwargs: Backend-specific encoding options (e.g., add_bos_token).
|
||||
|
||||
Returns:
|
||||
A list of integer token IDs.
|
||||
"""
|
||||
|
||||
return (
|
||||
self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
)
|
||||
.flatten()
|
||||
.tolist()
|
||||
)
|
||||
|
||||
def decode_tokens(self, ids: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Decodes a list of token IDs back into a string.
|
||||
|
||||
Args:
|
||||
ids: A list of integer token IDs.
|
||||
**kwargs: Backend-specific decoding options (e.g., decode_special_tokens).
|
||||
|
||||
Returns:
|
||||
The decoded text string.
|
||||
"""
|
||||
|
||||
ids = torch.tensor([ids])
|
||||
return self.tokenizer.decode(
|
||||
ids,
|
||||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||
)[0]
|
||||
|
||||
def get_special_tokens(
|
||||
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
||||
):
|
||||
"""
|
||||
Gets special tokens used by the model/tokenizer.
|
||||
|
||||
Args:
|
||||
**kwargs: Options like add_bos_token, ban_eos_token.
|
||||
|
||||
Returns:
|
||||
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
|
||||
to their string or ID representation.
|
||||
"""
|
||||
|
||||
return {
|
||||
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
|
||||
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
|
||||
"pad_token": self.tokenizer.pad_token,
|
||||
"unk_token": self.tokenizer.unk_token,
|
||||
}
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generates a complete response for a given prompt and parameters.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the generation info
|
||||
"""
|
||||
|
||||
generations = []
|
||||
async for generation in self.stream_generate(
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
abort_event,
|
||||
mm_embeddings,
|
||||
):
|
||||
generations.append(generation)
|
||||
|
||||
joined_generation = {
|
||||
"text": "",
|
||||
"prompt_tokens": 0,
|
||||
"generation_tokens": 0,
|
||||
"tool_calls": None,
|
||||
"offset": [],
|
||||
"token_probs": {},
|
||||
"logprobs": [],
|
||||
}
|
||||
|
||||
if generations:
|
||||
# Get finish_reason first and then shift where -1 points to
|
||||
if "finish_reason" in generations[-1]:
|
||||
finish_reason_gen = generations.pop()
|
||||
joined_generation["finish_reason"] = finish_reason_gen.get(
|
||||
"finish_reason"
|
||||
)
|
||||
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
|
||||
else:
|
||||
joined_generation["finish_reason"] = "stop"
|
||||
|
||||
if len(generations) > 0:
|
||||
for generation in generations:
|
||||
joined_generation["text"] += unwrap(generation.get("text"), "")
|
||||
joined_generation["offset"].append(unwrap(generation.get("offset"), -1))
|
||||
joined_generation["token_probs"].update(
|
||||
unwrap(generation.get("token_probs"), {})
|
||||
)
|
||||
|
||||
# Include empty logprob dicts for index preservation
|
||||
joined_generation["logprobs"].append(
|
||||
unwrap(generation.get("logprobs"), {})
|
||||
)
|
||||
|
||||
joined_generation["prompt_tokens"] = unwrap(
|
||||
generations[-1].get("prompt_tokens"), 0
|
||||
)
|
||||
joined_generation["generated_tokens"] = unwrap(
|
||||
generations[-1].get("generated_tokens"), 0
|
||||
)
|
||||
|
||||
return joined_generation
|
||||
|
||||
async def stream_generate(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Generates a response iteratively (streaming) for a given prompt.
|
||||
|
||||
Args:
|
||||
request_id: Unique identifier for the generation request.
|
||||
prompt: The input prompt string.
|
||||
params: Sampling and generation parameters.
|
||||
abort_event: An asyncio Event to signal cancellation.
|
||||
mm_embeddings: Optional multimodal embeddings.
|
||||
|
||||
Yields:
|
||||
Generation chunks
|
||||
"""
|
||||
|
||||
try:
|
||||
# Wait for load lock to be freed before processing
|
||||
# Mainly used for loras and other operations where the class is available
|
||||
async with self.load_condition:
|
||||
await self.load_condition.wait_for(lambda: not self.load_lock.locked())
|
||||
|
||||
# If the model is being unloaded, don't accept new requests
|
||||
if not self.loaded:
|
||||
raise RuntimeError(
|
||||
"Model is being unloaded. Cannot process new generation requests."
|
||||
)
|
||||
|
||||
# Mark that the job is running
|
||||
self.active_job_ids[request_id] = None
|
||||
|
||||
# Yield from the internal generator
|
||||
async for generation_chunk in self.generate_gen(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
params=params,
|
||||
abort_event=abort_event,
|
||||
mm_embeddings=mm_embeddings,
|
||||
):
|
||||
yield generation_chunk
|
||||
finally:
|
||||
# Clean up and remove the job from active IDs
|
||||
del self.active_job_ids[request_id]
|
||||
|
||||
def handle_finish_chunk(self, result: dict, generation: dict):
|
||||
eos_reason = result.get("eos_reason")
|
||||
|
||||
stop_str = None
|
||||
if eos_reason == "max_new_tokens":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
# Grab stop string if stop was the reason
|
||||
if eos_reason == "stop_token":
|
||||
stop_str = result.get("eos_triggering_token_str")
|
||||
elif eos_reason == "stop_string":
|
||||
stop_str = result.get("eos_triggering_string")
|
||||
|
||||
finish_chunk = {
|
||||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
}
|
||||
|
||||
return finish_chunk
|
||||
|
||||
async def generate_gen(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: BaseSamplerRequest,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
|
||||
for kwargs, check common/sampling.py
|
||||
"""
|
||||
chunk_tokens: torch.Tensor | tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
sampler_builder = ExllamaV3SamplerBuilder()
|
||||
|
||||
# Penalties
|
||||
|
||||
# Set penalty range
|
||||
penalty_range = unwrap(params.penalty_range, self.max_seq_len)
|
||||
|
||||
# Exl3's version of including the entire context
|
||||
if penalty_range < 0:
|
||||
penalty_range = int(10e7)
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed
|
||||
# fallback
|
||||
# Always default to 0 if something goes wrong
|
||||
if params.penalty_range < 0:
|
||||
fallback_decay = 0
|
||||
else:
|
||||
fallback_decay = params.penalty_range
|
||||
|
||||
repetition_decay = coalesce(params.repetition_decay, fallback_decay, 0)
|
||||
|
||||
# Apply penalties to builder
|
||||
sampler_builder.penalties(
|
||||
params.repetition_penalty,
|
||||
params.frequency_penalty,
|
||||
params.presence_penalty,
|
||||
penalty_range,
|
||||
repetition_decay,
|
||||
)
|
||||
|
||||
# Apply temperature first to builder
|
||||
if not params.temperature_last:
|
||||
sampler_builder.temperature(params.temperature)
|
||||
|
||||
# Apply alphabet samplers to builder
|
||||
sampler_builder.top_k(params.top_k)
|
||||
sampler_builder.top_p(params.top_p)
|
||||
sampler_builder.min_p(params.min_p)
|
||||
|
||||
# Apply temperature last to builder
|
||||
if params.temperature_last:
|
||||
sampler_builder.temperature(params.temperature)
|
||||
|
||||
# Build the sampler
|
||||
# Set greedy if temperature is 0
|
||||
sampler = sampler_builder.build(params.temperature == 0)
|
||||
|
||||
# Dynamically scale penalty range to output tokens
|
||||
# Only do this if freq/pres pen is enabled
|
||||
# and the repetition range is -1
|
||||
# TODO: This currently does not work in exl3
|
||||
# auto_scale_penalty_range = (
|
||||
# gen_settings.token_frequency_penalty != 0
|
||||
# or gen_settings.token_presence_penalty != 0
|
||||
# ) and gen_settings.token_repetition_range == -1
|
||||
|
||||
prompts = [prompt]
|
||||
stop_conditions = params.stop
|
||||
add_bos_token = unwrap(
|
||||
params.add_bos_token, self.tokenizer_config.add_bos_token
|
||||
)
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
if self.generation_config
|
||||
else [self.tokenizer.eos_token_id]
|
||||
)
|
||||
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
input_ids = [
|
||||
self.tokenizer.encode(
|
||||
prompt,
|
||||
add_bos=add_bos_token,
|
||||
encode_special_tokens=True,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
# The first index will always be the positive prompt
|
||||
context_len = input_ids[0].size(dim=-1)
|
||||
|
||||
# Automatically set max_tokens to fill up the context
|
||||
# This should be an OK default, but may be changed in the future
|
||||
max_tokens = unwrap(
|
||||
params.max_tokens,
|
||||
self.max_seq_len - context_len,
|
||||
)
|
||||
if max_tokens < 1:
|
||||
logger.warning("max_tokens must be a positive integer, setting to 1.")
|
||||
max_tokens = 1
|
||||
|
||||
# Determine if the negative context or the context length is bigger
|
||||
context_to_check = context_len
|
||||
|
||||
# Check total length of prompt against max context length
|
||||
if context_to_check > self.max_seq_len:
|
||||
preamble = "Prompt"
|
||||
|
||||
raise ValueError(
|
||||
f"{preamble} length {context_to_check} is greater than "
|
||||
f"max_seq_len {self.max_seq_len}"
|
||||
)
|
||||
|
||||
generation = {}
|
||||
job = AsyncJob(
|
||||
self.generator,
|
||||
sampler=sampler,
|
||||
input_ids=self.tokenizer.encode(prompt, add_bos=False),
|
||||
max_new_tokens=max_tokens,
|
||||
stop_conditions=stop_conditions,
|
||||
banned_strings=params.banned_strings,
|
||||
)
|
||||
|
||||
generated_tokens = 0
|
||||
full_response = ""
|
||||
metrics_result = {}
|
||||
|
||||
# Get the generation status once it's ready
|
||||
try:
|
||||
async for result in job:
|
||||
# Abort if the event is set while streaming
|
||||
if abort_event and abort_event.is_set():
|
||||
await job.cancel()
|
||||
break
|
||||
|
||||
chunk = unwrap(result.get("text"), "")
|
||||
if chunk:
|
||||
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
|
||||
full_response += chunk
|
||||
if isinstance(chunk_tokens, torch.Tensor):
|
||||
generated_tokens += chunk_tokens.size(dim=0)
|
||||
|
||||
# Increase penalty range to generated token amount
|
||||
# TODO:
|
||||
# if auto_scale_penalty_range:
|
||||
# gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
generation = {
|
||||
"text": chunk,
|
||||
"prompt_tokens": context_len,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
}
|
||||
yield generation
|
||||
|
||||
if result.get("eos"):
|
||||
generation = self.handle_finish_chunk(result, generation)
|
||||
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = result
|
||||
|
||||
yield generation
|
||||
break
|
||||
# Assign the active job to the request ID
|
||||
self.active_job_ids[request_id] = job
|
||||
|
||||
except asyncio.CancelledError:
|
||||
await job.cancel()
|
||||
except Exception as ex:
|
||||
# Create a new generator since the current state is broken
|
||||
# No need to wait for this to finish
|
||||
logger.error(
|
||||
"FATAL ERROR with generation. "
|
||||
"Attempting to recreate the generator. "
|
||||
"If this fails, please restart the server.\n"
|
||||
)
|
||||
asyncio.ensure_future(self.create_generator())
|
||||
|
||||
await HealthManager.add_unhealthy_event(ex)
|
||||
|
||||
raise ex
|
||||
finally:
|
||||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
request_id=request_id,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=eos_tokens,
|
||||
prompt=prompt,
|
||||
**params.model_dump(exclude={"prompt"}),
|
||||
# auto_scale_penalty_range=auto_scale_penalty_range, # TODO
|
||||
)
|
||||
|
||||
# Log the metrics if present
|
||||
if metrics_result:
|
||||
log_metrics(
|
||||
request_id,
|
||||
metrics_result.get("time_enqueued"),
|
||||
metrics_result.get("prompt_tokens"),
|
||||
metrics_result.get("cached_tokens"),
|
||||
metrics_result.get("time_prefill"),
|
||||
metrics_result.get("new_tokens"),
|
||||
metrics_result.get("time_generate"),
|
||||
context_len,
|
||||
self.max_seq_len,
|
||||
)
|
||||
54
backends/exllamav3/sampler.py
Normal file
54
backends/exllamav3/sampler.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
from exllamav3.generator.sampler import (
|
||||
CustomSampler,
|
||||
SS_Temperature,
|
||||
SS_RepP,
|
||||
SS_PresFreqP,
|
||||
SS_Argmax,
|
||||
SS_MinP,
|
||||
SS_TopK,
|
||||
SS_TopP,
|
||||
SS_Sample,
|
||||
SS_Base,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExllamaV3SamplerBuilder:
|
||||
"""
|
||||
Custom sampler chain/stack for TabbyAPI
|
||||
"""
|
||||
|
||||
stack: List[SS_Base] = field(default_factory=list)
|
||||
|
||||
def penalties(self, rep_p, freq_p, pres_p, penalty_range, rep_decay):
|
||||
self.stack += [
|
||||
SS_RepP(rep_p, penalty_range, rep_decay),
|
||||
SS_PresFreqP(pres_p, freq_p, penalty_range, rep_decay),
|
||||
]
|
||||
|
||||
def temperature(self, temp):
|
||||
self.stack.append(SS_Temperature(temp))
|
||||
|
||||
def top_k(self, top_k):
|
||||
self.stack.append(SS_TopK(top_k))
|
||||
|
||||
def top_p(self, top_p):
|
||||
self.stack.append(SS_TopP(top_p))
|
||||
|
||||
def min_p(self, min_p):
|
||||
self.stack.append(SS_MinP(min_p))
|
||||
|
||||
def greedy(self):
|
||||
self.stack.append(SS_Argmax())
|
||||
|
||||
def build(self, greedy):
|
||||
"""Builds the final sampler from stack."""
|
||||
|
||||
# Use greedy if temp is 0
|
||||
if greedy:
|
||||
return CustomSampler([SS_Argmax()])
|
||||
else:
|
||||
self.stack.append(SS_Sample())
|
||||
return CustomSampler(self.stack)
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
constr,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
|
|
@ -9,6 +10,7 @@ from typing import List, Literal, Optional, Union
|
|||
|
||||
|
||||
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
|
||||
CACHE_TYPE = Union[CACHE_SIZES, constr(pattern=r"^[2-8]\s*,\s*[2-8]$")]
|
||||
|
||||
|
||||
class Metadata(BaseModel):
|
||||
|
|
@ -163,6 +165,13 @@ class ModelConfig(BaseConfigModel):
|
|||
"Example: ['max_seq_len', 'cache_mode']."
|
||||
),
|
||||
)
|
||||
backend: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Backend to use for this model (auto-detect if not specified)\n"
|
||||
"Options: exllamav2, exllamav3"
|
||||
),
|
||||
)
|
||||
max_seq_len: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
|
|
@ -186,7 +195,7 @@ class ModelConfig(BaseConfigModel):
|
|||
"Not parsed for single GPU users."
|
||||
),
|
||||
)
|
||||
autosplit_reserve: List[int] = Field(
|
||||
autosplit_reserve: List[float] = Field(
|
||||
[96],
|
||||
description=(
|
||||
"Reserve VRAM used for autosplit loading (default: 96 MB on GPU 0).\n"
|
||||
|
|
@ -218,11 +227,13 @@ class ModelConfig(BaseConfigModel):
|
|||
"or auto-calculate."
|
||||
),
|
||||
)
|
||||
cache_mode: Optional[CACHE_SIZES] = Field(
|
||||
cache_mode: Optional[CACHE_TYPE] = Field(
|
||||
"FP16",
|
||||
description=(
|
||||
"Enable different cache modes for VRAM savings (default: FP16).\n"
|
||||
f"Possible values: {str(CACHE_SIZES)[15:-1]}."
|
||||
f"Possible values for exllamav2: {str(CACHE_SIZES)[15:-1]}.\n"
|
||||
"For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits "
|
||||
"are integers from 2-8 (i.e. 8,8)."
|
||||
),
|
||||
)
|
||||
cache_size: Optional[int] = Field(
|
||||
|
|
|
|||
20
common/hardware.py
Normal file
20
common/hardware.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import torch
|
||||
|
||||
|
||||
def hardware_supports_flash_attn(gpu_device_list: list[int]):
|
||||
"""
|
||||
Check whether all GPUs in list support FA2
|
||||
|
||||
Compute capability < 8 is not supported by FA2
|
||||
AMD is also unsupported until ROCm updates its FA2 fork
|
||||
"""
|
||||
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device_idx)[0]
|
||||
for device_idx in gpu_device_list
|
||||
)
|
||||
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
@ -10,23 +10,34 @@ from enum import Enum
|
|||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from ruamel.yaml import YAML
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
from backends.base_model_container import BaseModelContainer
|
||||
from common.logger import get_loading_progress_bar
|
||||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from common.optional_dependencies import dependencies
|
||||
from common.transformers_utils import HuggingFaceConfig
|
||||
from common.utils import unwrap
|
||||
|
||||
# Global variables for model container
|
||||
container: Optional[BaseModelContainer] = None
|
||||
embeddings_container = None
|
||||
|
||||
# FIXME: Possibly use this solely when creating the model
|
||||
|
||||
_BACKEND_REGISTRY: Dict[str, BaseModelContainer] = {}
|
||||
|
||||
if dependencies.exllamav2:
|
||||
from backends.exllamav2.model import ExllamaV2Container
|
||||
|
||||
_BACKEND_REGISTRY["exllamav2"] = ExllamaV2Container
|
||||
|
||||
|
||||
if dependencies.exllamav3:
|
||||
from backends.exllamav3.model import ExllamaV3Container
|
||||
|
||||
_BACKEND_REGISTRY["exllamav3"] = ExllamaV3Container
|
||||
|
||||
|
||||
if dependencies.extras:
|
||||
from backends.infinity.model import InfinityContainer
|
||||
|
|
@ -46,6 +57,24 @@ def load_progress(module, modules):
|
|||
yield module, modules
|
||||
|
||||
|
||||
async def detect_backend(model_path: pathlib.Path) -> str:
|
||||
"""Determine the appropriate backend based on model files and configuration."""
|
||||
|
||||
try:
|
||||
hf_config = await HuggingFaceConfig.from_directory(model_path)
|
||||
quant_method = hf_config.quant_method()
|
||||
|
||||
if quant_method == "exl3":
|
||||
return "exllamav3"
|
||||
else:
|
||||
return "exllamav2"
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
"Failed to read the model's config.json. "
|
||||
f"Please check your model directory at {model_path}."
|
||||
) from exc
|
||||
|
||||
|
||||
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||
"""Sets overrides from a model folder's config yaml."""
|
||||
|
||||
|
|
@ -113,9 +142,28 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
kwargs = {**config.model_defaults, **kwargs}
|
||||
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
||||
|
||||
# Create a new container
|
||||
new_container = await ExllamaV2Container.create(
|
||||
model_path.resolve(), False, **kwargs
|
||||
# Create a new container and check if the right dependencies are installed
|
||||
backend_name = unwrap(
|
||||
kwargs.get("backend"), await detect_backend(model_path)
|
||||
).lower()
|
||||
container_class = _BACKEND_REGISTRY.get(backend_name)
|
||||
|
||||
if not container_class:
|
||||
available_backends = list(_BACKEND_REGISTRY.keys())
|
||||
if backend_name in available_backends:
|
||||
raise ValueError(
|
||||
f"Backend '{backend_name}' selected, but required dependencies "
|
||||
"are not installed."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid backend '{backend_name}'. "
|
||||
f"Available backends: {available_backends}"
|
||||
)
|
||||
|
||||
logger.info(f"Using backend {backend_name}")
|
||||
new_container: BaseModelContainer = await container_class.create(
|
||||
model_path.resolve(), **kwargs
|
||||
)
|
||||
|
||||
# Add possible types of models that can be loaded
|
||||
|
|
@ -124,7 +172,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
if new_container.use_vision:
|
||||
model_type.insert(0, ModelType.VISION)
|
||||
|
||||
if new_container.draft_config:
|
||||
if new_container.use_draft_model:
|
||||
model_type.insert(0, ModelType.DRAFT)
|
||||
|
||||
load_status = new_container.load_gen(load_progress, **kwargs)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class DependenciesModel(BaseModel):
|
|||
|
||||
torch: bool
|
||||
exllamav2: bool
|
||||
exllamav3: bool
|
||||
flash_attn: bool
|
||||
infinity_emb: bool
|
||||
sentence_transformers: bool
|
||||
|
|
@ -25,7 +26,7 @@ class DependenciesModel(BaseModel):
|
|||
@computed_field
|
||||
@property
|
||||
def inference(self) -> bool:
|
||||
return self.torch and self.exllamav2 and self.flash_attn
|
||||
return self.torch and (self.exllamav2 or (self.exllamav3 and self.flash_attn))
|
||||
|
||||
|
||||
def is_installed(package_name: str) -> bool:
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
)
|
||||
|
||||
add_bos_token: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("add_bos_token", True)
|
||||
default_factory=lambda: get_default_sampler_value("add_bos_token")
|
||||
)
|
||||
|
||||
ban_eos_token: Optional[bool] = Field(
|
||||
|
|
@ -215,11 +215,6 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[False],
|
||||
)
|
||||
|
||||
skip_special_tokens: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("skip_special_tokens", True),
|
||||
examples=[True],
|
||||
)
|
||||
|
||||
logit_bias: Optional[Dict[int, float]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("logit_bias"),
|
||||
examples=[{"1": 10, "2": 50}],
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
@ -42,8 +42,10 @@ class HuggingFaceConfig(BaseModel):
|
|||
Will be expanded as needed.
|
||||
"""
|
||||
|
||||
quantization_config: Optional[Dict] = None
|
||||
|
||||
@classmethod
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
async def from_directory(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
hf_config_path = model_directory / "config.json"
|
||||
|
|
@ -54,6 +56,14 @@ class HuggingFaceConfig(BaseModel):
|
|||
hf_config_dict = json.loads(contents)
|
||||
return cls.model_validate(hf_config_dict)
|
||||
|
||||
def quant_method(self):
|
||||
"""Wrapper method to fetch quant type"""
|
||||
|
||||
if isinstance(self.quantization_config, Dict):
|
||||
return self.quantization_config.get("quant_method")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class TokenizerConfig(BaseModel):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -74,6 +74,10 @@ model:
|
|||
# Example: ['max_seq_len', 'cache_mode'].
|
||||
use_as_default: []
|
||||
|
||||
# Backend to use for this model (auto-detect if not specified)
|
||||
# Options: exllamav2, exllamav3
|
||||
backend:
|
||||
|
||||
# Max sequence length (default: Empty).
|
||||
# Fetched from the model's base sequence length in config.json by default.
|
||||
max_seq_len:
|
||||
|
|
@ -110,7 +114,8 @@ model:
|
|||
rope_alpha:
|
||||
|
||||
# Enable different cache modes for VRAM savings (default: FP16).
|
||||
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
|
||||
# Possible values for exllamav2: 'FP16', 'Q8', 'Q6', 'Q4'.
|
||||
# For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits are integers from 2-8 (i.e. 8,8).
|
||||
cache_mode: FP16
|
||||
|
||||
# Size of the prompt cache to allocate (default: max_seq_len).
|
||||
|
|
@ -160,7 +165,8 @@ draft_model:
|
|||
draft_rope_alpha:
|
||||
|
||||
# Cache mode for draft models to save VRAM (default: FP16).
|
||||
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
|
||||
# Possible values for exllamav2: 'FP16', 'Q8', 'Q6', 'Q4'.
|
||||
# For exllamav3, specify the pair k_bits,v_bits where k_bits and v_bits are integers from 2-8 (i.e. 8,8).
|
||||
draft_cache_mode: FP16
|
||||
|
||||
# An integer array of GBs of VRAM to split between GPUs (default: []).
|
||||
|
|
|
|||
|
|
@ -53,8 +53,12 @@ async def _stream_collector(data: GenerateRequest, request: Request):
|
|||
logger.info(f"Received Kobold generation request {data.genkey}")
|
||||
|
||||
generator = model.container.stream_generate(
|
||||
request_id=data.genkey, abort_event=abort_event, **data.model_dump()
|
||||
request_id=data.genkey,
|
||||
prompt=data.prompt,
|
||||
params=data,
|
||||
abort_event=abort_event,
|
||||
)
|
||||
|
||||
async for generation in generator:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
|
|
|
|||
|
|
@ -82,10 +82,13 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
tool_call_end: SkipJsonSchema[Optional[str]] = None
|
||||
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
|
||||
|
||||
# Chat completions requests do not have a BOS token preference. Backend
|
||||
# respects the tokenization config for the individual model.
|
||||
add_bos_token: Optional[bool] = None
|
||||
|
||||
@field_validator("add_bos_token", mode="after")
|
||||
def force_bos_token(cls, v):
|
||||
"""Always disable add_bos_token with chat completions."""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -81,7 +81,10 @@ class ModelLoadRequest(BaseModel):
|
|||
)
|
||||
|
||||
# Config arguments
|
||||
|
||||
backend: Optional[str] = Field(
|
||||
description="Backend to use",
|
||||
default="exllamav2",
|
||||
)
|
||||
max_seq_len: Optional[int] = Field(
|
||||
description="Leave this blank to use the model's base sequence length",
|
||||
default=None,
|
||||
|
|
|
|||
20
main.py
20
main.py
|
|
@ -15,12 +15,11 @@ from common.auth import load_auth_keys
|
|||
from common.actions import run_subcommand
|
||||
from common.logger import setup_logger
|
||||
from common.networking import is_port_in_use
|
||||
from common.optional_dependencies import dependencies
|
||||
from common.signals import signal_handler
|
||||
from common.tabby_config import config
|
||||
from endpoints.server import start_api
|
||||
|
||||
from backends.exllamav2.version import check_exllama_version
|
||||
|
||||
|
||||
async def entrypoint_async():
|
||||
"""Async entry function for program startup"""
|
||||
|
|
@ -139,8 +138,21 @@ def entrypoint(
|
|||
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
||||
"If you aren't a developer, please keep this off!"
|
||||
)
|
||||
else:
|
||||
check_exllama_version()
|
||||
elif not dependencies.inference:
|
||||
install_message = (
|
||||
f"ERROR: Inference dependencies for TabbyAPI are not installed.\n"
|
||||
"Please update your environment by running an update script "
|
||||
"(update_scripts/"
|
||||
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"For ROCm:\n"
|
||||
"pip install --upgrade .[amd]\n\n"
|
||||
)
|
||||
|
||||
raise SystemExit(install_message)
|
||||
|
||||
# Enable CUDA malloc backend
|
||||
if config.developer.cuda_malloc_backend:
|
||||
|
|
|
|||
|
|
@ -77,6 +77,16 @@ cu121 = [
|
|||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.2.9/exllamav2-0.2.9+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Exl3
|
||||
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
|
||||
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'",
|
||||
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.1/exllamav3-0.0.1+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Windows FA2 from https://github.com/kingbri1/flash-attention/releases
|
||||
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",
|
||||
"flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
|
|
|
|||
|
|
@ -131,14 +131,11 @@ mirostat_eta:
|
|||
|
||||
# MARK: Token options
|
||||
add_bos_token:
|
||||
override: true
|
||||
override:
|
||||
force: false
|
||||
ban_eos_token:
|
||||
override: false
|
||||
force: false
|
||||
skip_special_tokens:
|
||||
override: true
|
||||
force: false
|
||||
logit_bias:
|
||||
override:
|
||||
force: false
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue