Merge branch 'main' of https://github.com/theroyallab/tabbyapi into inline

This commit is contained in:
kingbri 2024-09-03 18:03:17 -04:00
commit dd30d6592a
25 changed files with 804 additions and 308 deletions

49
.github/workflows/docker-image.yml vendored Normal file
View file

@ -0,0 +1,49 @@
name: Build and publish a Docker image
# Configures this workflow to run every time a change is pushed to the following branches.
on:
push:
branches: ['main']
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}
# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
jobs:
build-and-push-image:
runs-on: ubuntu-latest
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
permissions:
contents: read
packages: write
#
steps:
- name: Checkout repository
uses: actions/checkout@v4
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
- name: Log in to the Container registry
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
- name: Extract metadata (tags, labels) for Docker
id: meta
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
# This is needed to add the `latest` tag to the newly built image
tags: type=raw,value=latest
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
- name: Build and push Docker image
uses: docker/build-push-action@5cd11c3a4ced054e52742c5fd54dca954e0edd85
with:
file: ./docker/Dockerfile
# target: production # Might need it in the future when we have multistage builds
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}

1
.gitignore vendored
View file

@ -192,6 +192,7 @@ templates/*
!templates/place_your_templates_here.txt
!templates/alpaca.jinja
!templates/chatml.jinja
!templates/chatml_with_headers_tool_calling.jinja
# Sampler overrides folder
sampler_overrides/*

View file

@ -64,6 +64,7 @@ Read the [Wiki](https://github.com/theroyallab/tabbyAPI/wiki/1.-Getting-Started)
- Utilizes modern python paradigms
- Continuous batching engine using paged attention
- Fast classifer-free guidance
- OAI style tool/function calling
And much more. If something is missing here, PR it in!

2
api_tokens_sample.yml Normal file
View file

@ -0,0 +1,2 @@
api_key: # Insert api key here
admin_key: # Insert admin key here

View file

@ -10,6 +10,7 @@ import uuid
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
ExLlamaV2CacheBase,
ExLlamaV2Cache,
ExLlamaV2Cache_Q4,
ExLlamaV2Cache_Q6,
@ -26,6 +27,8 @@ from itertools import zip_longest
from loguru import logger
from typing import List, Optional, Union
import yaml
from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
clear_grammar_func_cache,
@ -50,10 +53,22 @@ from common.templating import (
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
from common.utils import coalesce, unwrap
# Dynamic imports
try:
from exllamav2 import ExLlamaV2Cache_TP
has_tp = True
except ImportError:
has_tp = False
class ExllamaV2Container:
"""The model container class for ExLlamaV2 models."""
# Model directories
model_dir: pathlib.Path = pathlib.Path("models")
draft_model_dir: pathlib.Path = pathlib.Path("models")
# Exl2 vars
config: Optional[ExLlamaV2Config] = None
draft_config: Optional[ExLlamaV2Config] = None
@ -78,6 +93,7 @@ class ExllamaV2Container:
gpu_split: Optional[list] = None
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False
# Load state
model_is_loading: bool = False
@ -91,105 +107,129 @@ class ExllamaV2Container:
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
"""
Create model container
Primary initializer for model container.
Args:
model_dir (int): Model directory containing config.json,
tokenizer.model etc.
quiet (bool): Suppress console output
load_progress_callback (function, optional): A function to call for
each module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int,
loading_draft: bool)
**kwargs:
`cache_mode` (str): Sets cache mode: "FP16"/"Q8"/"Q6"/"Q4"
(default: "FP16")
'max_seq_len' (int): Override model's default max sequence
length (default: 4096)
'cache_size' (int): Num of tokens to allocate space for in the k/v cache
(default: max_seq_len)
'rope_scale' (float): Set RoPE scaling factor for model
(default: 1.0)
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model
(default: 1.0)
'prompt_template' (str): Manually sets the prompt template for
this model (default: None)
'chunk_size' (int): Sets the maximum chunk size for the model
(default: 2048)
Inferencing in chunks reduces overall VRAM overhead by
processing very long sequences in smaller batches. This
limits the size of temporary buffers needed for the hidden
state and attention weights.
'draft_model_dir' (str): Draft model directory
'draft_rope_scale' (float): Set RoPE scaling factor for draft
model (default: 1.0)
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft
model. By default, the draft model's alpha value is
calculated automatically to scale to the size of the
full model.
'draft_cache_mode' (str): Sets draft cache mode: "FP16"/"Q8"/"Q6"/"Q4"
(default: "FP16")
'lora_dir' (str): LoRA directory
'loras' (list[dict]): List of loras to be loaded, consisting of
'name' and 'scaling'
'gpu_split_auto' (bool): Automatically split model across
available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some)
tensors, per device
Kwargs are located in config_sample.yml
"""
self.quiet = quiet
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
gpu_device_list = list(range(0, gpu_count))
if gpu_count > 1 and gpu_split_auto:
# Auto GPU split parameters
self.gpu_split_auto = gpu_split_auto
autosplit_reserve_megabytes = unwrap(kwargs.get("autosplit_reserve"), [96])
self.autosplit_reserve = [
int(math.ceil(value * 1024**2)) for value in autosplit_reserve_megabytes
]
elif gpu_count > 1:
# Manual GPU split
self.gpu_split = kwargs.get("gpu_split")
self.gpu_split_auto = False
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.gpu_split)
if memory > 0
]
else:
# One GPU setup
self.gpu_split_auto = False
logger.info("Disabling GPU split because one GPU is in use.")
# Initialize config
self.config = ExLlamaV2Config()
self.model_dir = model_directory
self.config.model_dir = str(model_directory.resolve())
# Make the max seq len 4096 before preparing the config
# This is a better default than 2048
self.config.max_seq_len = 4096
# Hardcode max output length to 16
self.config.max_output_len = 16
self.config.prepare()
# Check if the model arch is compatible with various exl2 features
try:
self.config.arch_compat_overrides()
except AttributeError:
pass
self.config.arch_compat_overrides()
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False
if enable_draft:
self.draft_config = ExLlamaV2Config()
self.draft_config.no_flash_attn = self.config.no_flash_attn
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name
self.draft_model_dir = draft_model_path
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Apply a model's config overrides while respecting user settings
kwargs = self.set_model_overrides(**kwargs)
# MARK: User configuration
# Get cache mode
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
use_tp = unwrap(kwargs.get("tensor_parallel"), False)
gpu_split = kwargs.get("gpu_split")
gpu_device_list = list(range(0, gpu_count))
# Set GPU split options
if gpu_count == 1:
self.gpu_split_auto = False
logger.info("Disabling GPU split because one GPU is in use.")
else:
# Set tensor parallel
if use_tp:
if has_tp:
self.use_tp = True
# TP has its own autosplit loader
self.gpu_split_auto = False
else:
# TODO: Remove conditional with exl2 v0.1.9 release
logger.warning(
"Tensor parallelism is not supported in the "
"current ExllamaV2 version."
)
# Enable manual GPU split if provided
if gpu_split:
self.gpu_split_auto = False
self.gpu_split = gpu_split
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.gpu_split)
if memory > 0
]
elif gpu_split_auto and not self.use_tp:
# Otherwise fallback to autosplit settings
self.gpu_split_auto = gpu_split_auto
autosplit_reserve_megabytes = unwrap(
kwargs.get("autosplit_reserve"), [96]
)
# Reserve VRAM for each GPU
self.autosplit_reserve = [
int(math.ceil(value * 1024**2))
for value in autosplit_reserve_megabytes
]
# Hardcode max output length to 16
self.config.max_output_len = 16
# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
@ -209,10 +249,13 @@ class ExllamaV2Container:
kwargs.get("rope_scale"), self.config.scale_pos_emb
)
# Automatically calculate rope alpha
self.config.scale_alpha_value = unwrap(
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
)
# Sets rope alpha value.
# Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
if rope_alpha == "auto":
self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len)
else:
self.config.scale_alpha_value = rope_alpha
# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
@ -275,19 +318,6 @@ class ExllamaV2Container:
else:
self.cache_size = self.config.max_seq_len
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
@ -315,47 +345,55 @@ class ExllamaV2Container:
self.config.max_input_len = chunk_size
self.config.max_attention_size = chunk_size**2
draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False
# Set user-configured draft model values
if enable_draft:
self.draft_config = ExLlamaV2Config()
self.draft_config.no_flash_attn = self.config.no_flash_attn
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name
# Fetch from the updated kwargs
draft_args = unwrap(kwargs.get("draft"), {})
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
self.draft_config.max_seq_len = self.config.max_seq_len
self.draft_config.scale_pos_emb = unwrap(
draft_args.get("draft_rope_scale"), 1.0
)
# Automatically calculate draft rope alpha
self.draft_config.scale_alpha_value = unwrap(
draft_args.get("draft_rope_alpha"),
self.calculate_rope_alpha(self.draft_config.max_seq_len),
)
self.draft_config.max_seq_len = self.config.max_seq_len
# Set draft rope alpha. Follows same behavior as model rope alpha.
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
if draft_rope_alpha == "auto":
self.draft_config.scale_alpha_value = self.calculate_rope_alpha(
self.draft_config.max_seq_len
)
else:
self.draft_config.scale_alpha_value = draft_rope_alpha
# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
if chunk_size:
self.draft_config.max_input_len = chunk_size
self.draft_config.max_attention_size = chunk_size**2
def set_model_overrides(self, **kwargs):
"""Sets overrides from a model folder's config yaml."""
override_config_path = self.model_dir / "tabby_config.yml"
if not override_config_path.exists():
return kwargs
with open(override_config_path, "r", encoding="utf8") as override_config_file:
override_args = unwrap(yaml.safe_load(override_config_file), {})
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
if self.draft_config and draft_override_args:
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}
return merged_kwargs
def find_prompt_template(self, prompt_template_name, model_directory):
"""Tries to find a prompt template using various methods"""
"""Tries to find a prompt template using various methods."""
logger.info("Attempting to load a prompt template if present.")
@ -397,6 +435,7 @@ class ExllamaV2Container:
def calculate_rope_alpha(self, base_seq_len):
"""Calculate the rope alpha value for a given sequence length."""
ratio = self.config.max_seq_len / base_seq_len
# Default to a 1 alpha if the sequence length is ever less
@ -407,20 +446,9 @@ class ExllamaV2Container:
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
return alpha
def get_model_path(self, is_draft: bool = False):
"""Get the path for this model."""
if is_draft and not self.draft_config:
return None
model_path = pathlib.Path(
self.draft_config.model_dir if is_draft else self.config.model_dir
)
return model_path
def get_model_parameters(self):
model_params = {
"name": self.get_model_path().name,
"name": self.model_dir.name,
"rope_scale": self.config.scale_pos_emb,
"rope_alpha": self.config.scale_alpha_value,
"max_seq_len": self.config.max_seq_len,
@ -435,7 +463,7 @@ class ExllamaV2Container:
if self.draft_config:
draft_model_params = {
"name": self.get_model_path(is_draft=True).name,
"name": self.draft_model_dir.name,
"rope_scale": self.draft_config.scale_pos_emb,
"rope_alpha": self.draft_config.scale_alpha_value,
"max_seq_len": self.draft_config.max_seq_len,
@ -473,7 +501,9 @@ class ExllamaV2Container:
Args:
progress_callback (function, optional): A function to call for each
module loaded. Prototype:
module loaded.
Prototype:
def progress(loaded_modules: int, total_modules: int)
"""
@ -518,11 +548,13 @@ class ExllamaV2Container:
@torch.inference_mode()
def load_model_sync(self, progress_callback=None):
"""
Load model, generator function
Synchronous generator for loading.
Args:
progress_callback (function, optional): A function to call for each
module loaded. Prototype:
module loaded.
Prototype:
def progress(loaded_modules: int, total_modules: int)
Runs under a shared inference mode context.
@ -548,30 +580,15 @@ class ExllamaV2Container:
if not self.quiet:
logger.info("Loading draft model: " + self.draft_config.model_dir)
if self.draft_cache_mode == "Q4":
self.draft_cache = ExLlamaV2Cache_Q4(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
elif self.draft_cache_mode == "Q6":
self.draft_cache = ExLlamaV2Cache_Q6(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
elif self.draft_cache_mode == "Q8":
self.draft_cache = ExLlamaV2Cache_Q8(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
else:
self.draft_cache = ExLlamaV2Cache(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
# Draft uses the autosplit loader, so create a cache that reflects this
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
self.draft_cache = self.create_cache(
cache_class=draft_cache_class,
autosplit=True,
use_tp=False,
model=self.draft_model,
)
for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
@ -589,9 +606,23 @@ class ExllamaV2Container:
if not self.quiet:
logger.info("Loading model: " + self.config.model_dir)
# Get class of the model cache
cache_class = self.get_cache_class(self.cache_mode)
# Load model with manual split
# Entrypoint for single GPU users
if not self.gpu_split_auto:
if self.use_tp:
logger.info("Loading with tensor parallel")
for value in self.model.load_tp_gen(
self.gpu_split,
callback_gen=progress_callback,
expect_cache_base=cache_class,
expect_cache_tokens=self.cache_size,
):
if value:
yield value
elif not self.gpu_split_auto:
logger.info("Loading with a manual GPU split (or a one GPU setup)")
for value in self.model.load_gen(
@ -601,37 +632,16 @@ class ExllamaV2Container:
if value:
yield value
if self.cache_mode == "Q4":
self.cache = ExLlamaV2Cache_Q4(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
elif self.cache_mode == "Q6":
self.cache = ExLlamaV2Cache_Q6(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
elif self.cache_mode == "Q8":
self.cache = ExLlamaV2Cache_Q8(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
else:
self.cache = ExLlamaV2Cache(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
# Create the model cache
self.cache = self.create_cache(
cache_class=cache_class,
autosplit=self.gpu_split_auto,
use_tp=self.use_tp,
model=self.model,
)
# Load model with autosplit
if self.gpu_split_auto:
# Load model with autosplit (without TP)
if self.gpu_split_auto and not self.use_tp:
logger.info("Loading with autosplit")
for value in self.model.load_autosplit_gen(
@ -647,7 +657,47 @@ class ExllamaV2Container:
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
# TODO: Maybe make a wrapper class with an ID instead of a utility function
def get_cache_class(self, cache_mode: str):
"""Utility function to get a cache class based on user preference."""
match cache_mode:
case "Q4":
return ExLlamaV2Cache_Q4
case "Q6":
return ExLlamaV2Cache_Q6
case "Q8":
return ExLlamaV2Cache_Q8
case _:
return ExLlamaV2Cache
def create_cache(
self,
cache_class: ExLlamaV2CacheBase,
autosplit: bool,
use_tp: bool,
model: ExLlamaV2,
):
"""Utility function to create a model cache."""
if has_tp and use_tp:
return ExLlamaV2Cache_TP(
model,
base=cache_class,
max_seq_len=self.cache_size,
batch_size=1,
)
else:
return cache_class(
model,
max_seq_len=self.cache_size,
lazy=autosplit,
batch_size=1,
)
async def create_generator(self):
"""Create and save a Exllama generator class."""
try:
# Don't acquire locks unless a model is loaded
if self.model_loaded:
@ -681,9 +731,7 @@ class ExllamaV2Container:
return unwrap(self.generator.generator.current_loras, [])
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
"""
Load loras
"""
"""Load loras."""
loras = unwrap(kwargs.get("loras"), [])
@ -730,9 +778,7 @@ class ExllamaV2Container:
self.load_condition.notify_all()
async def unload(self, loras_only: bool = False, **kwargs):
"""
Free all VRAM resources used by this model
"""
"""Free all VRAM resources used by the model (and loras)."""
# Shutdown immediately unloads and bypasses all locks
do_shutdown = kwargs.get("shutdown")
@ -789,7 +835,7 @@ class ExllamaV2Container:
self.load_condition.notify_all()
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string"""
"""Wrapper to encode tokens from a text string."""
return (
self.tokenizer.encode(
@ -824,7 +870,7 @@ class ExllamaV2Container:
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
top_tokens = [
self.tokenizer.extended_id_to_piece.get(
index, self.tokenizer.id_to_piece[index]
index, self.tokenizer.get_id_to_piece_list(True)[index]
)
for index in token_ids.flatten().tolist()
]
@ -841,7 +887,7 @@ class ExllamaV2Container:
async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
):
"""Generate a response to a prompt"""
"""Generate a response to a prompt."""
generations = []
async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs
@ -852,6 +898,7 @@ class ExllamaV2Container:
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"tool_calls": None,
"offset": [],
"token_probs": {},
"logprobs": [],
@ -864,6 +911,7 @@ class ExllamaV2Container:
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else:
joined_generation["finish_reason"] = "stop"
@ -890,7 +938,11 @@ class ExllamaV2Container:
return joined_generation
def check_unsupported_settings(self, **kwargs):
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
"""
Check and warn the user if a sampler is unsupported.
Meant for dev wheels!
"""
return kwargs
@ -1068,6 +1120,15 @@ class ExllamaV2Container:
gen_settings.top_p = 0
gen_settings.typical = 0
logger.warning(
"".join(
[
"Temperature is set to 0. Overriding temp, ",
"top_k, top_p, and typical to 1.0, 1, 0, and 0.",
]
)
)
# Store the gen settings for logging purposes
gen_settings_log_dict = vars(gen_settings)
@ -1076,6 +1137,11 @@ class ExllamaV2Container:
if banned_tokens:
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
# Set allowed tokens
allowed_tokens = unwrap(kwargs.get("allowed_tokens"), [])
if allowed_tokens:
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
# Set logit bias
if logit_bias:
# Create a vocab tensor if it doesn't exist for token biasing
@ -1088,7 +1154,7 @@ class ExllamaV2Container:
# Map logits to the tensor with their biases
for token_id, bias in logit_bias.items():
if 0 <= token_id < len(self.tokenizer.id_to_piece):
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
gen_settings.token_bias[token_id] = bias
else:
logger.warning(
@ -1140,8 +1206,12 @@ class ExllamaV2Container:
# This is an inverse of skip_special_tokens
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
# Log prompt to console
log_prompt(prompt, request_id, negative_prompt)
# Log prompt to console. Add the BOS token if specified
log_prompt(
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
request_id,
negative_prompt,
)
# Create and add a new job
# Don't use the request ID here as there can be multiple jobs per request
@ -1227,9 +1297,17 @@ class ExllamaV2Container:
log_response(request_id, full_response)
eos_reason = result.get("eos_reason")
finish_reason = (
"length" if eos_reason == "max_new_tokens" else "stop"
)
stop_str = None
if eos_reason == "max_new_tokens":
finish_reason = "length"
else:
finish_reason = "stop"
# Grab stop string if stop was the reason
if eos_reason == "stop_token":
stop_str = result.get("eos_triggering_token_str")
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
# Save the final result for metrics logging
metrics_result = result
@ -1239,6 +1317,7 @@ class ExllamaV2Container:
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
"stop_str": stop_str,
}
yield generation
@ -1277,6 +1356,7 @@ class ExllamaV2Container:
logprobs=request_logprobs,
stop_conditions=stop_conditions,
banned_tokens=banned_tokens,
allowed_tokens=allowed_tokens,
banned_strings=banned_strings,
logit_bias=logit_bias,
filters=grammar_handler.filters,

View file

@ -8,7 +8,7 @@ from loguru import logger
def check_exllama_version():
"""Verifies the exllama version"""
required_version = version.parse("0.1.7")
required_version = version.parse("0.1.9")
current_version = version.parse(package_version("exllamav2").split("+")[0])
unsupported_message = (

View file

@ -13,6 +13,24 @@ def str_to_bool(value):
raise ValueError(f"{value} is not a valid boolean value")
def argument_with_auto(value):
"""
Argparse type wrapper for any argument that has an automatic option.
Ex. rope_alpha
"""
if value == "auto":
return "auto"
try:
return float(value)
except ValueError as ex:
raise argparse.ArgumentTypeError(
'This argument only takes a type of float or "auto"'
) from ex
def init_argparser():
"""Creates an argument parser that any function can use"""
@ -107,6 +125,11 @@ def add_model_args(parser: argparse.ArgumentParser):
type=str_to_bool,
help="Overrides base model context length",
)
model_group.add_argument(
"--tensor-parallel",
type=str_to_bool,
help="Use tensor parallelism to load models",
)
model_group.add_argument(
"--gpu-split-auto",
type=str_to_bool,
@ -128,7 +151,11 @@ def add_model_args(parser: argparse.ArgumentParser):
model_group.add_argument(
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
)
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
model_group.add_argument(
"--rope-alpha",
type=argument_with_auto,
help="Sets rope_alpha for NTK",
)
model_group.add_argument(
"--cache-mode",
type=str,

View file

@ -57,7 +57,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Check if the model is already loaded
if container and container.model:
loaded_model_name = container.get_model_path().name
loaded_model_name = container.model_dir.name
if loaded_model_name == model_path.name and container.model_loaded:
raise ValueError(
@ -149,7 +149,8 @@ async def unload_embedding_model():
embeddings_container = None
def get_config_default(key: str, fallback=None, model_type: str = "model"):
# FIXME: Maybe make this a one-time function instead of a dynamic default
def get_config_default(key: str, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user."""
model_config = config.model_config()
@ -162,14 +163,12 @@ def get_config_default(key: str, fallback=None, model_type: str = "model"):
# Is this a draft model load parameter?
if model_type == "draft":
draft_config = config.draft_model_config()
return unwrap(draft_config.get(key), fallback)
return draft_config.get(key)
elif model_type == "embedding":
embeddings_config = config.embeddings_config()
return unwrap(embeddings_config.get(key), fallback)
return embeddings_config.get(key)
else:
return unwrap(model_config.get(key), fallback)
else:
return fallback
return model_config.get(key)
async def check_model_container():

View file

@ -33,7 +33,7 @@ class BaseSamplerRequest(BaseModel):
examples=[512],
)
stop: Optional[Union[str, List[str]]] = Field(
stop: Optional[Union[str, List[Union[str, int]]]] = Field(
default_factory=lambda: get_default_sampler_value("stop", []),
validation_alias=AliasChoices("stop", "stop_sequence"),
description="Aliases: stop_sequence",
@ -50,6 +50,13 @@ class BaseSamplerRequest(BaseModel):
examples=[[128, 330]],
)
allowed_tokens: Optional[Union[List[int], str]] = Field(
default_factory=lambda: get_default_sampler_value("allowed_tokens", []),
validation_alias=AliasChoices("allowed_tokens", "allowed_token_ids"),
description="Aliases: allowed_token_ids",
examples=[[128, 330]],
)
token_healing: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
@ -287,12 +294,17 @@ class BaseSamplerRequest(BaseModel):
if self.banned_strings and isinstance(self.banned_strings, str):
self.banned_strings = [self.banned_strings]
# Convert string banned tokens to an integer list
# Convert string banned and allowed tokens to an integer list
if self.banned_tokens and isinstance(self.banned_tokens, str):
self.banned_tokens = [
int(x) for x in self.banned_tokens.split(",") if x.isdigit()
]
if self.allowed_tokens and isinstance(self.allowed_tokens, str):
self.allowed_tokens = [
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
]
gen_params = {
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
@ -305,6 +317,7 @@ class BaseSamplerRequest(BaseModel):
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"banned_tokens": self.banned_tokens,
"allowed_tokens": self.allowed_tokens,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"min_temp": self.min_temp,

View file

@ -3,7 +3,7 @@
import json
import pathlib
from importlib.metadata import version as package_version
from typing import Optional
from typing import List, Optional
from jinja2 import Template, TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
@ -18,6 +18,13 @@ class TemplateLoadError(Exception):
pass
class TemplateMetadata:
"""Represents the parsed metadata from a template."""
stop_strings: List[str] = []
tool_starts: List[str] = []
class PromptTemplate:
"""A template for chat completion prompts."""
@ -25,27 +32,48 @@ class PromptTemplate:
raw_template: str
template: Template
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
trim_blocks=True, lstrip_blocks=True, enable_async=True
)
metadata: Optional[TemplateMetadata] = None
def stop_strings(self, template_vars: dict):
"""Appends extra stop strings if present in a chat template."""
async def extract_metadata(self, template_vars: dict):
"""
Returns deserialized template metadata from a chat template.
extra_stop_strings = []
template_module = self.template.make_module(template_vars)
NOTE: Requires all template vars to be passed in since the template
is run once to make a module and errors can result.
"""
# No need to extract new metadata if it already exists
# This might be removed if stored metadata becomes arbitrary
if self.metadata:
return self.metadata
template_metadata = TemplateMetadata()
template_module = await self.template.make_module_async(template_vars)
if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
extra_stop_strings += template_module.stop_strings
template_metadata.stop_strings += template_module.stop_strings
else:
logger.warning(
"Skipping append of stopping strings from chat template "
"because stop_strings isn't a list."
)
return extra_stop_strings
if hasattr(template_module, "tool_start"):
if isinstance(template_module.tool_start, str):
template_metadata.tool_starts.append(template_module.tool_start)
def render(self, template_vars: dict):
if hasattr(template_module, "tool_start_token"):
if isinstance(template_module.tool_start_token, int):
template_metadata.tool_starts.append(template_module.tool_start_token)
self.metadata = template_metadata
return template_metadata
async def render(self, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
@ -55,10 +83,9 @@ class PromptTemplate:
"pip install --upgrade jinja2"
)
rendered_template = self.template.render(**template_vars)
template_stop_strings = self.stop_strings(template_vars)
rendered_template = await self.template.render_async(**template_vars)
return rendered_template, template_stop_strings
return rendered_template
def compile(self, template_str: str):
"""Compiles and stores a jinja2 template"""

View file

@ -109,6 +109,11 @@ model:
# Only use this if the model's base sequence length in config.json is incorrect (ex. Mistral 7B)
#override_base_seq_len:
# Load model with tensor parallelism
# If a GPU split isn't provided, the TP loader will fallback to autosplit
# Enabling ignores the gpu_split_auto and autosplit_reserve values
#tensor_parallel: False
# Automatically allocate resources to GPUs (default: True)
# NOTE: Not parsed for single GPU users
#gpu_split_auto: True
@ -118,6 +123,7 @@ model:
#autosplit_reserve: [96]
# An integer array of GBs of vram to split between GPUs (default: [])
# Used with tensor parallelism
# NOTE: Not parsed for single GPU users
#gpu_split: [20.6, 24]
@ -129,7 +135,8 @@ model:
# Rope alpha (default: 1.0)
# Same thing as alpha_value
# Leave blank to automatically calculate alpha
# Set to "auto" to automatically calculate
# Leave blank to pull the value from the model
#rope_alpha: 1.0
# Enable different cache modes for VRAM savings (slight performance hit).

View file

@ -1,4 +1,6 @@
models/
loras/
.ruff_cache/
**/__pycache__/
**/__pycache__/
config.yml
api_tokens.yml

View file

@ -1,12 +1,5 @@
# Use an official CUDA runtime with Ubuntu as a parent image
FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04
ARG GIT_REPO=https://github.com/theroyallab/tabbyAPI
ARG DO_PULL=true
ENV DO_PULL $DO_PULL
# Set the working directory in the container
WORKDIR /app
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
@ -15,20 +8,14 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
python3.11 \
python3-pip \
git \
&& rm -rf /var/lib/apt/lists/*
# Update repo
RUN if [ ${DO_PULL} ]; then \
git init && \
git remote add origin $GIT_REPO && \
git fetch origin && \
git pull origin main && \
echo "Pull finished"; fi
# Upgrade pip
RUN pip3 install --no-cache-dir --upgrade pip
# Set the working directory in the container
WORKDIR /app
# Get requirements
COPY pyproject.toml .
@ -38,9 +25,6 @@ RUN pip3 install --no-cache-dir .[cu121]
# Copy the current directory contents into the container
COPY . .
# Create a config.yml
COPY config_sample.yml config.yml
# Make port 5000 available to the world outside this container
EXPOSE 5000

View file

@ -8,11 +8,18 @@ services:
- DO_PULL=true
ports:
- "5000:5000"
healthcheck:
test: ["CMD", "curl", "-f", "http://127.0.0.1:5000/health"]
interval: 30s
timeout: 10s
retries: 3
environment:
- NAME=TabbyAPI
- NVIDIA_VISIBLE_DEVICES=all
volumes:
- ./models:/app/models
- ./models:/app/models # Change me
# - /path/to/config.yml:/app/config.yml # Change me
# - /path/to/api_tokens.yml:/app/api_tokens.yml # Change me
deploy:
resources:
reservations:

View file

@ -54,7 +54,7 @@ async def completion_request(
If stream = true, this returns an SSE stream.
"""
model_path = model.container.get_model_path()
model_path = model.container.model_dir
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
@ -153,12 +153,12 @@ async def chat_completion_request(
raise HTTPException(422, error_message)
model_path = model.container.get_model_path()
model_path = model.container.model_dir
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = format_prompt_with_template(data)
prompt = await format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":

View file

@ -1,9 +1,11 @@
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema
from time import time
from typing import Union, List, Optional, Dict
from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
class ChatCompletionLogprob(BaseModel):
@ -19,12 +21,16 @@ class ChatCompletionLogprobs(BaseModel):
class ChatCompletionMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: Optional[str] = None
# let's us understand why it stopped and if we need to generate a tool_call
stop_str: Optional[str] = None
message: ChatCompletionMessage
logprobs: Optional[ChatCompletionLogprobs] = None
@ -42,13 +48,29 @@ class ChatCompletionRequest(CommonCompletionRequest):
# Messages
# Take in a string as well even though it's not part of the OAI spec
# support messages.content as a list of dict
messages: Union[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]]
# WIP this can probably be tightened, or maybe match the OAI lib type
# in openai\types\chat\chat_completion_message_param.py
messages: Union[str, List[Dict]]
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}
response_prefix: Optional[str] = None
model: Optional[str] = None
# tools is follows the format OAI schema, functions is more flexible
# both are available in the chat template.
tools: Optional[List[ToolSpec]] = None
functions: Optional[List[Dict]] = None
# Typically collected from Chat Template.
# Don't include this in the OpenAPI docs
# TODO: Use these custom parameters
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None
tool_call_end: SkipJsonSchema[Optional[str]] = None
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")

View file

@ -0,0 +1,58 @@
from pydantic import BaseModel
from typing import Dict, Literal
tool_call_schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
# Converted to OAI's string in post process
"type": "object"
},
},
"required": ["name", "arguments"],
},
"type": {"type": "string", "enum": ["function"]},
},
"required": ["id", "function", "type"],
},
}
class Function(BaseModel):
"""Represents a description of a tool function."""
name: str
description: str
parameters: Dict[str, object]
class ToolSpec(BaseModel):
"""Wrapper for an inner tool function."""
function: Function
type: Literal["function"]
class Tool(BaseModel):
"""Represents an OAI tool description."""
name: str
# Makes more sense to be a dict, but OAI knows best
arguments: str
class ToolCall(BaseModel):
"""Represents an OAI tool description."""
id: str
function: Tool
type: Literal["function"]

View file

@ -5,6 +5,7 @@ import pathlib
from asyncio import CancelledError
from copy import deepcopy
from typing import List, Optional
import json
from fastapi import HTTPException, Request
from jinja2 import TemplateError
@ -30,6 +31,7 @@ from endpoints.OAI.types.chat_completion import (
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _stream_collector
from endpoints.OAI.types.tools import ToolCall
def _create_response(
@ -46,6 +48,10 @@ def _create_response(
role="assistant", content=unwrap(generation.get("text"), "")
)
tool_calls = generation["tool_calls"]
if tool_calls:
message.tool_calls = postprocess_tool_call(tool_calls)
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
@ -72,6 +78,7 @@ def _create_response(
choice = ChatCompletionRespChoice(
index=index,
finish_reason=generation.get("finish_reason"),
stop_str=generation.get("stop_str"),
message=message,
logprobs=logprob_response,
)
@ -119,7 +126,16 @@ def _create_stream_chunk(
finish_reason=generation.get("finish_reason"),
)
# lets check if we have tool calls since we are at the end of the generation
if "tool_calls" in generation:
tool_calls = generation["tool_calls"]
message = ChatCompletionMessage(
tool_calls=postprocess_tool_call(tool_calls)
)
choice.delta = message
choices.append(choice)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
@ -162,7 +178,31 @@ def _create_stream_chunk(
return chunk
def format_prompt_with_template(data: ChatCompletionRequest):
async def _append_template_metadata(data: ChatCompletionRequest):
"""Adding metadata is a one-time process."""
template_metadata = await model.container.prompt_template.extract_metadata(
data.template_vars
)
# Stop strings
if isinstance(data.stop, str):
data.stop = [data.stop] + template_metadata.stop_strings
else:
data.stop += template_metadata.stop_strings
# Tool call start strings
if template_metadata.tool_starts:
if data.tool_call_start is None:
data.tool_call_start = template_metadata.tool_starts
# Append to stop strings to halt for a tool call generation
data.stop.extend(template_metadata.tool_starts)
async def format_prompt_with_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
"""
Compile the prompt and get any additional stop strings from the template.
Template stop strings can be overriden by sampler overrides if force is true.
@ -187,18 +227,22 @@ def format_prompt_with_template(data: ChatCompletionRequest):
"",
)
if "tool_calls" in message:
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
# Overwrite any protected vars with their values
data.template_vars.update(
{
"messages": data.messages,
"add_generation_prompt": data.add_generation_prompt,
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
"functions_json": json.dumps(data.functions, indent=2),
"tool_precursor": tool_precursor,
**special_tokens_dict,
}
)
prompt, template_stop_strings = model.container.prompt_template.render(
data.template_vars
)
prompt = await model.container.prompt_template.render(data.template_vars)
# Append response prefix if present
if data.response_prefix:
@ -216,11 +260,8 @@ def format_prompt_with_template(data: ChatCompletionRequest):
if bos_token and prompt.startswith(bos_token):
prompt = prompt.removeprefix(bos_token)
# Append template stop strings
if isinstance(data.stop, str):
data.stop = [data.stop] + template_stop_strings
else:
data.stop += template_stop_strings
# Add template metadata
await _append_template_metadata(data)
return prompt
@ -271,6 +312,9 @@ async def stream_generate_chat_completion(
gen_tasks.append(gen_task)
# We need to keep track of the text generated so we can resume the tool calls
current_generation_text = ""
# Consumer loop
while True:
if disconnect_task.done():
@ -280,6 +324,19 @@ async def stream_generate_chat_completion(
)
generation = await gen_queue.get()
# lets only append the text if we need it for tool calls later
if data.tool_call_start and "text" in generation:
current_generation_text += generation["text"]
# check if we are running a tool model, and that we are at stop
if data.tool_call_start and "stop_str" in generation:
generations = await generate_tool_calls(
data,
[generation],
request,
current_generations=current_generation_text,
)
generation = generations[0] # We only have one generation in this case
# Stream collector will push an exception to the queue if it fails
if isinstance(generation, Exception):
@ -344,6 +401,11 @@ async def generate_chat_completion(
)
generations = await asyncio.gather(*gen_tasks)
# Let's not waste our time if we arn't running a tool model
if data.tool_call_start:
generations = await generate_tool_calls(data, generations, request)
response = _create_response(request.state.id, generations, model_path.name)
logger.info(f"Finished chat completion request {request.state.id}")
@ -358,3 +420,56 @@ async def generate_chat_completion(
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc
async def generate_tool_calls(
data: ChatCompletionRequest,
generations: List[str],
request: Request,
current_generations: str = None,
):
gen_tasks: List[asyncio.Task] = []
tool_idx: List[int] = []
# Copy to make sure the parent JSON schema doesn't get modified
# FIXME: May not be necessary depending on how the codebase evolves
tool_data = deepcopy(data)
tool_data.json_schema = tool_data.tool_call_schema
gen_params = tool_data.to_gen_params()
for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start:
if "text" in gen:
# non streaming, all generations will have the text they generated
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt = await format_prompt_with_template(
data, current_generations
)
gen_tasks.append(
asyncio.create_task(
model.container.generate(
pre_tool_prompt, request.state.id, **gen_params
)
)
)
tool_idx.append(idx)
tool_calls = await asyncio.gather(*gen_tasks)
for outer_idx in range(0, len(tool_idx)):
gen_idx = tool_idx[outer_idx]
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
return generations
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_calls = json.loads(call_str)
for tool_call in tool_calls:
tool_call["function"]["arguments"] = json.dumps(
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]

View file

@ -44,6 +44,13 @@ from endpoints.core.utils.model import (
router = APIRouter()
# Healthcheck endpoint
@router.get("/health")
async def healthcheck():
"""Get the current service health status"""
return {"status": "healthy"}
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])

View file

@ -2,7 +2,7 @@
from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Optional
from typing import List, Literal, Optional, Union
from common.gen_logging import GenLogPreferences
from common.model import get_config_default
@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel):
# Config arguments
draft_rope_scale: Optional[float] = Field(
default_factory=lambda: get_config_default(
"draft_rope_scale", 1.0, model_type="draft"
"draft_rope_scale", model_type="draft"
)
)
draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
description='Automatically calculated if set to "auto"',
default_factory=lambda: get_config_default(
"draft_rope_alpha", None, model_type="draft"
"draft_rope_alpha", model_type="draft"
),
examples=[1.0],
)
draft_cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default(
"draft_cache_mode", "FP16", model_type="draft"
"draft_cache_mode", model_type="draft"
)
)
@ -96,14 +96,17 @@ class ModelLoadRequest(BaseModel):
default_factory=lambda: get_config_default("cache_size"),
examples=[4096],
)
tensor_parallel: Optional[bool] = Field(
default_factory=lambda: get_config_default("tensor_parallel")
)
gpu_split_auto: Optional[bool] = Field(
default_factory=lambda: get_config_default("gpu_split_auto", True)
default_factory=lambda: get_config_default("gpu_split_auto")
)
autosplit_reserve: Optional[List[float]] = Field(
default_factory=lambda: get_config_default("autosplit_reserve", [96])
default_factory=lambda: get_config_default("autosplit_reserve")
)
gpu_split: Optional[List[float]] = Field(
default_factory=lambda: get_config_default("gpu_split", []),
default_factory=lambda: get_config_default("gpu_split"),
examples=[[24.0, 20.0]],
)
rope_scale: Optional[float] = Field(
@ -111,16 +114,16 @@ class ModelLoadRequest(BaseModel):
default_factory=lambda: get_config_default("rope_scale"),
examples=[1.0],
)
rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
description='Automatically calculated if set to "auto"',
default_factory=lambda: get_config_default("rope_alpha"),
examples=[1.0],
)
cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default("cache_mode", "FP16")
default_factory=lambda: get_config_default("cache_mode")
)
chunk_size: Optional[int] = Field(
default_factory=lambda: get_config_default("chunk_size", 2048)
default_factory=lambda: get_config_default("chunk_size")
)
prompt_template: Optional[str] = Field(
default_factory=lambda: get_config_default("prompt_template")
@ -129,7 +132,7 @@ class ModelLoadRequest(BaseModel):
default_factory=lambda: get_config_default("num_experts_per_token")
)
fasttensors: Optional[bool] = Field(
default_factory=lambda: get_config_default("fasttensors", False)
default_factory=lambda: get_config_default("fasttensors")
)
# Non-config arguments

View file

@ -43,12 +43,16 @@ async def get_current_model_list(model_type: str = "model"):
model_path = None
# Make sure the model container exists
if model_type == "model" or model_type == "draft":
if model.container:
model_path = model.container.get_model_path(model_type == "draft")
elif model_type == "embedding":
if model.embeddings_container:
model_path = model.embeddings_container.model_dir
match model_type:
case "model":
if model.container:
model_path = model.container.model_dir
case "draft":
if model.container:
model_path = model.container.draft_model_dir
case "embedding":
if model.embeddings_container:
model_path = model.embeddings_container.model_dir
if model_path:
current_models.append(ModelCard(id=model_path.name))
@ -94,8 +98,10 @@ async def stream_model_load(
):
"""Request generation wrapper for the loading process."""
# Get trimmed load data
load_data = data.model_dump(exclude_none=True)
# Set the draft model path if it exists
load_data = data.model_dump()
if draft_model_path:
load_data["draft"]["draft_model_dir"] = draft_model_path

View file

@ -77,8 +77,6 @@ async def start_api(host: str, port: int):
# TODO: Move OAI API to a separate folder
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
# logger.info(f"Completions: http://{host}:{port}/v1/completions")
# logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
# Setup app
app = setup_app(host, port)

View file

@ -68,12 +68,12 @@ cu121 = [
"torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
@ -95,12 +95,12 @@ cu118 = [
"torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
@ -119,9 +119,9 @@ amd = [
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
]
# MARK: Ruff options

View file

@ -126,6 +126,10 @@ banned_tokens:
override: []
force: false
additive: false
allowed_tokens:
override: []
force: false
additive: false
# MARK: CFG scale
cfg_scale:

View file

@ -0,0 +1,84 @@
{# Metadata #}
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
{% set message_roles = ['system', 'user', 'assistant', 'tool'] %}
{% set tool_start = "<|tool_start|>" %}
{% set tool_end = "<|tool_end|>" %}
{%- set start_header = "<|start_header_id|>" -%}
{%- set end_header = "<|end_header_id|>\n" -%}
{%- set example_tool_call -%}[
{
"id": "tool_id_1342",
"function": {
"arguments": "arg_name": 3,
"name": "tool_name"
},
"type": "function"
},
{
"id": "example_id_13f42",
"function": {
"arguments": "example_arg": 1.0, "another_example_arg": true,
"name": "another_tool_name"
},
"type": "function"
}
]
{%- endset -%}
{%- set inital_system_prompt -%}You are an assistant that has access to the following set of tools, to call a tool:
1. Prefix calls with '{{ tool_start }}' and end calls with '{{ tool_end }}'
2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be.
3. Generate all calls using the following json tool call format. Here is a multi tool call example:
{{ tool_start }}{{ example_tool_call }}{{ tool_end }}
Here are the tools available for you to call:
{{ tools_json }}
{%- endset -%}
{%- set tool_reminder -%}Available Tools:
{{ tools_json }}
Tool Call Format Example:
{{ tool_start }}{{ example_tool_call }}
Prefix & Suffix: Begin tool calls with {{ tool_start }} and end with {{ tool_end }}.
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).
{%- endset -%}
{# Template #}
{%- for message in messages -%}
{%- set role = message['role'] | lower -%}
{%- if role not in message_roles -%}
{{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }}
{%- endif -%}
{%- set content = message['content'] | default('', true) | trim -%}
{%- if loop.first -%}
{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }}
{{ inital_system_prompt }}
{{ content }}{{ eos_token }}
{%- endif -%}
{%- if not loop.first -%}
{{ start_header }}{{ role }}{{ end_header }}
{{ content }}
{%- if 'tool_calls_json' in message and message['tool_calls_json'] -%}
{{ tool_start }}{{ message['tool_calls_json']}}{{ tool_end }}
{%- endif -%}
{{ eos_token }}
{%- endif -%}
{%- endfor -%}
{%- if tool_precursor -%}
{{ start_header }}system{{ end_header }}
{{ tool_reminder }}{{ eos_token }}
{{ start_header }}assistant{{ end_header }}
{{ tool_precursor }}{{ tool_start }}
{%- else -%}
{{ start_header }}assistant{{ end_header }}
{%- endif -%}