Merge branch 'main' of https://github.com/theroyallab/tabbyapi into inline
This commit is contained in:
commit
dd30d6592a
25 changed files with 804 additions and 308 deletions
49
.github/workflows/docker-image.yml
vendored
Normal file
49
.github/workflows/docker-image.yml
vendored
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
name: Build and publish a Docker image
|
||||
# Configures this workflow to run every time a change is pushed to the following branches.
|
||||
on:
|
||||
push:
|
||||
branches: ['main']
|
||||
|
||||
# Defines two custom environment variables for the workflow. These are used for the Container registry domain, and a name for the Docker image that this workflow builds.
|
||||
env:
|
||||
REGISTRY: ghcr.io
|
||||
IMAGE_NAME: ${{ github.repository }}
|
||||
|
||||
# There is a single job in this workflow. It's configured to run on the latest available version of Ubuntu.
|
||||
jobs:
|
||||
build-and-push-image:
|
||||
runs-on: ubuntu-latest
|
||||
# Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job.
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
#
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
# Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here.
|
||||
- name: Log in to the Container registry
|
||||
uses: docker/login-action@65b78e6e13532edd9afa3aa52ac7964289d1a9c1
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.GITHUB_TOKEN }}
|
||||
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
|
||||
- name: Extract metadata (tags, labels) for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@9ec57ed1fcdbf14dcef7dfbe97b2010124a938b7
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
# This is needed to add the `latest` tag to the newly built image
|
||||
tags: type=raw,value=latest
|
||||
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages.
|
||||
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository.
|
||||
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step.
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@5cd11c3a4ced054e52742c5fd54dca954e0edd85
|
||||
with:
|
||||
file: ./docker/Dockerfile
|
||||
# target: production # Might need it in the future when we have multistage builds
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -192,6 +192,7 @@ templates/*
|
|||
!templates/place_your_templates_here.txt
|
||||
!templates/alpaca.jinja
|
||||
!templates/chatml.jinja
|
||||
!templates/chatml_with_headers_tool_calling.jinja
|
||||
|
||||
# Sampler overrides folder
|
||||
sampler_overrides/*
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ Read the [Wiki](https://github.com/theroyallab/tabbyAPI/wiki/1.-Getting-Started)
|
|||
- Utilizes modern python paradigms
|
||||
- Continuous batching engine using paged attention
|
||||
- Fast classifer-free guidance
|
||||
- OAI style tool/function calling
|
||||
|
||||
And much more. If something is missing here, PR it in!
|
||||
|
||||
|
|
|
|||
2
api_tokens_sample.yml
Normal file
2
api_tokens_sample.yml
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
api_key: # Insert api key here
|
||||
admin_key: # Insert admin key here
|
||||
|
|
@ -10,6 +10,7 @@ import uuid
|
|||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
ExLlamaV2CacheBase,
|
||||
ExLlamaV2Cache,
|
||||
ExLlamaV2Cache_Q4,
|
||||
ExLlamaV2Cache_Q6,
|
||||
|
|
@ -26,6 +27,8 @@ from itertools import zip_longest
|
|||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import yaml
|
||||
|
||||
from backends.exllamav2.grammar import (
|
||||
ExLlamaV2Grammar,
|
||||
clear_grammar_func_cache,
|
||||
|
|
@ -50,10 +53,22 @@ from common.templating import (
|
|||
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
|
||||
# Dynamic imports
|
||||
try:
|
||||
from exllamav2 import ExLlamaV2Cache_TP
|
||||
|
||||
has_tp = True
|
||||
except ImportError:
|
||||
has_tp = False
|
||||
|
||||
|
||||
class ExllamaV2Container:
|
||||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
# Model directories
|
||||
model_dir: pathlib.Path = pathlib.Path("models")
|
||||
draft_model_dir: pathlib.Path = pathlib.Path("models")
|
||||
|
||||
# Exl2 vars
|
||||
config: Optional[ExLlamaV2Config] = None
|
||||
draft_config: Optional[ExLlamaV2Config] = None
|
||||
|
|
@ -78,6 +93,7 @@ class ExllamaV2Container:
|
|||
gpu_split: Optional[list] = None
|
||||
gpu_split_auto: bool = True
|
||||
autosplit_reserve: List[float] = [96 * 1024**2]
|
||||
use_tp: bool = False
|
||||
|
||||
# Load state
|
||||
model_is_loading: bool = False
|
||||
|
|
@ -91,105 +107,129 @@ class ExllamaV2Container:
|
|||
|
||||
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
"""
|
||||
Create model container
|
||||
Primary initializer for model container.
|
||||
|
||||
Args:
|
||||
model_dir (int): Model directory containing config.json,
|
||||
tokenizer.model etc.
|
||||
quiet (bool): Suppress console output
|
||||
load_progress_callback (function, optional): A function to call for
|
||||
each module loaded. Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int,
|
||||
loading_draft: bool)
|
||||
**kwargs:
|
||||
`cache_mode` (str): Sets cache mode: "FP16"/"Q8"/"Q6"/"Q4"
|
||||
(default: "FP16")
|
||||
'max_seq_len' (int): Override model's default max sequence
|
||||
length (default: 4096)
|
||||
'cache_size' (int): Num of tokens to allocate space for in the k/v cache
|
||||
(default: max_seq_len)
|
||||
'rope_scale' (float): Set RoPE scaling factor for model
|
||||
(default: 1.0)
|
||||
'rope_alpha' (float): Set RoPE alpha (NTK) factor for model
|
||||
(default: 1.0)
|
||||
'prompt_template' (str): Manually sets the prompt template for
|
||||
this model (default: None)
|
||||
'chunk_size' (int): Sets the maximum chunk size for the model
|
||||
(default: 2048)
|
||||
Inferencing in chunks reduces overall VRAM overhead by
|
||||
processing very long sequences in smaller batches. This
|
||||
limits the size of temporary buffers needed for the hidden
|
||||
state and attention weights.
|
||||
'draft_model_dir' (str): Draft model directory
|
||||
'draft_rope_scale' (float): Set RoPE scaling factor for draft
|
||||
model (default: 1.0)
|
||||
'draft_rope_alpha' (float): RoPE alpha (NTK) factor for draft
|
||||
model. By default, the draft model's alpha value is
|
||||
calculated automatically to scale to the size of the
|
||||
full model.
|
||||
'draft_cache_mode' (str): Sets draft cache mode: "FP16"/"Q8"/"Q6"/"Q4"
|
||||
(default: "FP16")
|
||||
'lora_dir' (str): LoRA directory
|
||||
'loras' (list[dict]): List of loras to be loaded, consisting of
|
||||
'name' and 'scaling'
|
||||
'gpu_split_auto' (bool): Automatically split model across
|
||||
available devices (default: True)
|
||||
'gpu_split' (list[float]): Allocation for weights and (some)
|
||||
tensors, per device
|
||||
Kwargs are located in config_sample.yml
|
||||
"""
|
||||
|
||||
self.quiet = quiet
|
||||
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
||||
|
||||
# Turn off GPU split if the user is using 1 GPU
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
gpu_device_list = list(range(0, gpu_count))
|
||||
|
||||
if gpu_count > 1 and gpu_split_auto:
|
||||
# Auto GPU split parameters
|
||||
self.gpu_split_auto = gpu_split_auto
|
||||
|
||||
autosplit_reserve_megabytes = unwrap(kwargs.get("autosplit_reserve"), [96])
|
||||
self.autosplit_reserve = [
|
||||
int(math.ceil(value * 1024**2)) for value in autosplit_reserve_megabytes
|
||||
]
|
||||
elif gpu_count > 1:
|
||||
# Manual GPU split
|
||||
self.gpu_split = kwargs.get("gpu_split")
|
||||
self.gpu_split_auto = False
|
||||
|
||||
gpu_device_list = [
|
||||
device_idx
|
||||
for device_idx, memory in enumerate(self.gpu_split)
|
||||
if memory > 0
|
||||
]
|
||||
else:
|
||||
# One GPU setup
|
||||
self.gpu_split_auto = False
|
||||
logger.info("Disabling GPU split because one GPU is in use.")
|
||||
|
||||
# Initialize config
|
||||
self.config = ExLlamaV2Config()
|
||||
self.model_dir = model_directory
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
|
||||
# Make the max seq len 4096 before preparing the config
|
||||
# This is a better default than 2048
|
||||
self.config.max_seq_len = 4096
|
||||
|
||||
# Hardcode max output length to 16
|
||||
self.config.max_output_len = 16
|
||||
|
||||
self.config.prepare()
|
||||
|
||||
# Check if the model arch is compatible with various exl2 features
|
||||
try:
|
||||
self.config.arch_compat_overrides()
|
||||
except AttributeError:
|
||||
pass
|
||||
self.config.arch_compat_overrides()
|
||||
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
logger.warning(
|
||||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
enable_draft = False
|
||||
|
||||
if enable_draft:
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
self.draft_config.no_flash_attn = self.config.no_flash_attn
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
)
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_model_dir = draft_model_path
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
# Create the hf_config
|
||||
self.hf_config = HuggingFaceConfig.from_file(model_directory)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Apply a model's config overrides while respecting user settings
|
||||
kwargs = self.set_model_overrides(**kwargs)
|
||||
|
||||
# MARK: User configuration
|
||||
|
||||
# Get cache mode
|
||||
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
||||
|
||||
# Turn off GPU split if the user is using 1 GPU
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
use_tp = unwrap(kwargs.get("tensor_parallel"), False)
|
||||
gpu_split = kwargs.get("gpu_split")
|
||||
gpu_device_list = list(range(0, gpu_count))
|
||||
|
||||
# Set GPU split options
|
||||
if gpu_count == 1:
|
||||
self.gpu_split_auto = False
|
||||
logger.info("Disabling GPU split because one GPU is in use.")
|
||||
else:
|
||||
# Set tensor parallel
|
||||
if use_tp:
|
||||
if has_tp:
|
||||
self.use_tp = True
|
||||
|
||||
# TP has its own autosplit loader
|
||||
self.gpu_split_auto = False
|
||||
else:
|
||||
# TODO: Remove conditional with exl2 v0.1.9 release
|
||||
logger.warning(
|
||||
"Tensor parallelism is not supported in the "
|
||||
"current ExllamaV2 version."
|
||||
)
|
||||
|
||||
# Enable manual GPU split if provided
|
||||
if gpu_split:
|
||||
self.gpu_split_auto = False
|
||||
self.gpu_split = gpu_split
|
||||
|
||||
gpu_device_list = [
|
||||
device_idx
|
||||
for device_idx, memory in enumerate(self.gpu_split)
|
||||
if memory > 0
|
||||
]
|
||||
elif gpu_split_auto and not self.use_tp:
|
||||
# Otherwise fallback to autosplit settings
|
||||
self.gpu_split_auto = gpu_split_auto
|
||||
|
||||
autosplit_reserve_megabytes = unwrap(
|
||||
kwargs.get("autosplit_reserve"), [96]
|
||||
)
|
||||
|
||||
# Reserve VRAM for each GPU
|
||||
self.autosplit_reserve = [
|
||||
int(math.ceil(value * 1024**2))
|
||||
for value in autosplit_reserve_megabytes
|
||||
]
|
||||
|
||||
# Hardcode max output length to 16
|
||||
self.config.max_output_len = 16
|
||||
|
||||
# Then override the base_seq_len if present
|
||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||
if override_base_seq_len:
|
||||
|
|
@ -209,10 +249,13 @@ class ExllamaV2Container:
|
|||
kwargs.get("rope_scale"), self.config.scale_pos_emb
|
||||
)
|
||||
|
||||
# Automatically calculate rope alpha
|
||||
self.config.scale_alpha_value = unwrap(
|
||||
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
|
||||
)
|
||||
# Sets rope alpha value.
|
||||
# Automatically calculate if unset or defined as an "auto" literal.
|
||||
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
|
||||
if rope_alpha == "auto":
|
||||
self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len)
|
||||
else:
|
||||
self.config.scale_alpha_value = rope_alpha
|
||||
|
||||
# Enable fasttensors loading if present
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
|
@ -275,19 +318,6 @@ class ExllamaV2Container:
|
|||
else:
|
||||
self.cache_size = self.config.max_seq_len
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
|
|
@ -315,47 +345,55 @@ class ExllamaV2Container:
|
|||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attention_size = chunk_size**2
|
||||
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
enable_draft = draft_args and draft_model_name
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
logger.warning(
|
||||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
enable_draft = False
|
||||
|
||||
# Set user-configured draft model values
|
||||
if enable_draft:
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
self.draft_config.no_flash_attn = self.config.no_flash_attn
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
)
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
# Fetch from the updated kwargs
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
self.draft_config.scale_pos_emb = unwrap(
|
||||
draft_args.get("draft_rope_scale"), 1.0
|
||||
)
|
||||
|
||||
# Automatically calculate draft rope alpha
|
||||
self.draft_config.scale_alpha_value = unwrap(
|
||||
draft_args.get("draft_rope_alpha"),
|
||||
self.calculate_rope_alpha(self.draft_config.max_seq_len),
|
||||
)
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
# Set draft rope alpha. Follows same behavior as model rope alpha.
|
||||
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
|
||||
if draft_rope_alpha == "auto":
|
||||
self.draft_config.scale_alpha_value = self.calculate_rope_alpha(
|
||||
self.draft_config.max_seq_len
|
||||
)
|
||||
else:
|
||||
self.draft_config.scale_alpha_value = draft_rope_alpha
|
||||
|
||||
# Set draft cache mode
|
||||
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
|
||||
|
||||
if chunk_size:
|
||||
self.draft_config.max_input_len = chunk_size
|
||||
self.draft_config.max_attention_size = chunk_size**2
|
||||
|
||||
def set_model_overrides(self, **kwargs):
|
||||
"""Sets overrides from a model folder's config yaml."""
|
||||
|
||||
override_config_path = self.model_dir / "tabby_config.yml"
|
||||
|
||||
if not override_config_path.exists():
|
||||
return kwargs
|
||||
|
||||
with open(override_config_path, "r", encoding="utf8") as override_config_file:
|
||||
override_args = unwrap(yaml.safe_load(override_config_file), {})
|
||||
|
||||
# Merge draft overrides beforehand
|
||||
draft_override_args = unwrap(override_args.get("draft"), {})
|
||||
if self.draft_config and draft_override_args:
|
||||
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
|
||||
|
||||
# Merge the override and model kwargs
|
||||
merged_kwargs = {**override_args, **kwargs}
|
||||
return merged_kwargs
|
||||
|
||||
def find_prompt_template(self, prompt_template_name, model_directory):
|
||||
"""Tries to find a prompt template using various methods"""
|
||||
"""Tries to find a prompt template using various methods."""
|
||||
|
||||
logger.info("Attempting to load a prompt template if present.")
|
||||
|
||||
|
|
@ -397,6 +435,7 @@ class ExllamaV2Container:
|
|||
|
||||
def calculate_rope_alpha(self, base_seq_len):
|
||||
"""Calculate the rope alpha value for a given sequence length."""
|
||||
|
||||
ratio = self.config.max_seq_len / base_seq_len
|
||||
|
||||
# Default to a 1 alpha if the sequence length is ever less
|
||||
|
|
@ -407,20 +446,9 @@ class ExllamaV2Container:
|
|||
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2
|
||||
return alpha
|
||||
|
||||
def get_model_path(self, is_draft: bool = False):
|
||||
"""Get the path for this model."""
|
||||
|
||||
if is_draft and not self.draft_config:
|
||||
return None
|
||||
|
||||
model_path = pathlib.Path(
|
||||
self.draft_config.model_dir if is_draft else self.config.model_dir
|
||||
)
|
||||
return model_path
|
||||
|
||||
def get_model_parameters(self):
|
||||
model_params = {
|
||||
"name": self.get_model_path().name,
|
||||
"name": self.model_dir.name,
|
||||
"rope_scale": self.config.scale_pos_emb,
|
||||
"rope_alpha": self.config.scale_alpha_value,
|
||||
"max_seq_len": self.config.max_seq_len,
|
||||
|
|
@ -435,7 +463,7 @@ class ExllamaV2Container:
|
|||
|
||||
if self.draft_config:
|
||||
draft_model_params = {
|
||||
"name": self.get_model_path(is_draft=True).name,
|
||||
"name": self.draft_model_dir.name,
|
||||
"rope_scale": self.draft_config.scale_pos_emb,
|
||||
"rope_alpha": self.draft_config.scale_alpha_value,
|
||||
"max_seq_len": self.draft_config.max_seq_len,
|
||||
|
|
@ -473,7 +501,9 @@ class ExllamaV2Container:
|
|||
|
||||
Args:
|
||||
progress_callback (function, optional): A function to call for each
|
||||
module loaded. Prototype:
|
||||
module loaded.
|
||||
|
||||
Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int)
|
||||
"""
|
||||
|
||||
|
|
@ -518,11 +548,13 @@ class ExllamaV2Container:
|
|||
@torch.inference_mode()
|
||||
def load_model_sync(self, progress_callback=None):
|
||||
"""
|
||||
Load model, generator function
|
||||
Synchronous generator for loading.
|
||||
|
||||
Args:
|
||||
progress_callback (function, optional): A function to call for each
|
||||
module loaded. Prototype:
|
||||
module loaded.
|
||||
|
||||
Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int)
|
||||
|
||||
Runs under a shared inference mode context.
|
||||
|
|
@ -548,30 +580,15 @@ class ExllamaV2Container:
|
|||
if not self.quiet:
|
||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||
|
||||
if self.draft_cache_mode == "Q4":
|
||||
self.draft_cache = ExLlamaV2Cache_Q4(
|
||||
self.draft_model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=True,
|
||||
)
|
||||
elif self.draft_cache_mode == "Q6":
|
||||
self.draft_cache = ExLlamaV2Cache_Q6(
|
||||
self.draft_model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=True,
|
||||
)
|
||||
elif self.draft_cache_mode == "Q8":
|
||||
self.draft_cache = ExLlamaV2Cache_Q8(
|
||||
self.draft_model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=True,
|
||||
)
|
||||
else:
|
||||
self.draft_cache = ExLlamaV2Cache(
|
||||
self.draft_model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=True,
|
||||
)
|
||||
# Draft uses the autosplit loader, so create a cache that reflects this
|
||||
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
|
||||
self.draft_cache = self.create_cache(
|
||||
cache_class=draft_cache_class,
|
||||
autosplit=True,
|
||||
use_tp=False,
|
||||
model=self.draft_model,
|
||||
)
|
||||
|
||||
for value in self.draft_model.load_autosplit_gen(
|
||||
self.draft_cache,
|
||||
reserve_vram=autosplit_reserve,
|
||||
|
|
@ -589,9 +606,23 @@ class ExllamaV2Container:
|
|||
if not self.quiet:
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
|
||||
# Get class of the model cache
|
||||
cache_class = self.get_cache_class(self.cache_mode)
|
||||
|
||||
# Load model with manual split
|
||||
# Entrypoint for single GPU users
|
||||
if not self.gpu_split_auto:
|
||||
if self.use_tp:
|
||||
logger.info("Loading with tensor parallel")
|
||||
|
||||
for value in self.model.load_tp_gen(
|
||||
self.gpu_split,
|
||||
callback_gen=progress_callback,
|
||||
expect_cache_base=cache_class,
|
||||
expect_cache_tokens=self.cache_size,
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
elif not self.gpu_split_auto:
|
||||
logger.info("Loading with a manual GPU split (or a one GPU setup)")
|
||||
|
||||
for value in self.model.load_gen(
|
||||
|
|
@ -601,37 +632,16 @@ class ExllamaV2Container:
|
|||
if value:
|
||||
yield value
|
||||
|
||||
if self.cache_mode == "Q4":
|
||||
self.cache = ExLlamaV2Cache_Q4(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
elif self.cache_mode == "Q6":
|
||||
self.cache = ExLlamaV2Cache_Q6(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
elif self.cache_mode == "Q8":
|
||||
self.cache = ExLlamaV2Cache_Q8(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
else:
|
||||
self.cache = ExLlamaV2Cache(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
# Create the model cache
|
||||
self.cache = self.create_cache(
|
||||
cache_class=cache_class,
|
||||
autosplit=self.gpu_split_auto,
|
||||
use_tp=self.use_tp,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
# Load model with autosplit
|
||||
if self.gpu_split_auto:
|
||||
# Load model with autosplit (without TP)
|
||||
if self.gpu_split_auto and not self.use_tp:
|
||||
logger.info("Loading with autosplit")
|
||||
|
||||
for value in self.model.load_autosplit_gen(
|
||||
|
|
@ -647,7 +657,47 @@ class ExllamaV2Container:
|
|||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
# TODO: Maybe make a wrapper class with an ID instead of a utility function
|
||||
def get_cache_class(self, cache_mode: str):
|
||||
"""Utility function to get a cache class based on user preference."""
|
||||
|
||||
match cache_mode:
|
||||
case "Q4":
|
||||
return ExLlamaV2Cache_Q4
|
||||
case "Q6":
|
||||
return ExLlamaV2Cache_Q6
|
||||
case "Q8":
|
||||
return ExLlamaV2Cache_Q8
|
||||
case _:
|
||||
return ExLlamaV2Cache
|
||||
|
||||
def create_cache(
|
||||
self,
|
||||
cache_class: ExLlamaV2CacheBase,
|
||||
autosplit: bool,
|
||||
use_tp: bool,
|
||||
model: ExLlamaV2,
|
||||
):
|
||||
"""Utility function to create a model cache."""
|
||||
|
||||
if has_tp and use_tp:
|
||||
return ExLlamaV2Cache_TP(
|
||||
model,
|
||||
base=cache_class,
|
||||
max_seq_len=self.cache_size,
|
||||
batch_size=1,
|
||||
)
|
||||
else:
|
||||
return cache_class(
|
||||
model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=autosplit,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
async def create_generator(self):
|
||||
"""Create and save a Exllama generator class."""
|
||||
|
||||
try:
|
||||
# Don't acquire locks unless a model is loaded
|
||||
if self.model_loaded:
|
||||
|
|
@ -681,9 +731,7 @@ class ExllamaV2Container:
|
|||
return unwrap(self.generator.generator.current_loras, [])
|
||||
|
||||
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
"""
|
||||
Load loras
|
||||
"""
|
||||
"""Load loras."""
|
||||
|
||||
loras = unwrap(kwargs.get("loras"), [])
|
||||
|
||||
|
|
@ -730,9 +778,7 @@ class ExllamaV2Container:
|
|||
self.load_condition.notify_all()
|
||||
|
||||
async def unload(self, loras_only: bool = False, **kwargs):
|
||||
"""
|
||||
Free all VRAM resources used by this model
|
||||
"""
|
||||
"""Free all VRAM resources used by the model (and loras)."""
|
||||
|
||||
# Shutdown immediately unloads and bypasses all locks
|
||||
do_shutdown = kwargs.get("shutdown")
|
||||
|
|
@ -789,7 +835,7 @@ class ExllamaV2Container:
|
|||
self.load_condition.notify_all()
|
||||
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
"""Wrapper to encode tokens from a text string"""
|
||||
"""Wrapper to encode tokens from a text string."""
|
||||
|
||||
return (
|
||||
self.tokenizer.encode(
|
||||
|
|
@ -824,7 +870,7 @@ class ExllamaV2Container:
|
|||
def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
|
||||
top_tokens = [
|
||||
self.tokenizer.extended_id_to_piece.get(
|
||||
index, self.tokenizer.id_to_piece[index]
|
||||
index, self.tokenizer.get_id_to_piece_list(True)[index]
|
||||
)
|
||||
for index in token_ids.flatten().tolist()
|
||||
]
|
||||
|
|
@ -841,7 +887,7 @@ class ExllamaV2Container:
|
|||
async def generate(
|
||||
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
|
||||
):
|
||||
"""Generate a response to a prompt"""
|
||||
"""Generate a response to a prompt."""
|
||||
generations = []
|
||||
async for generation in self.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
|
|
@ -852,6 +898,7 @@ class ExllamaV2Container:
|
|||
"text": "",
|
||||
"prompt_tokens": 0,
|
||||
"generation_tokens": 0,
|
||||
"tool_calls": None,
|
||||
"offset": [],
|
||||
"token_probs": {},
|
||||
"logprobs": [],
|
||||
|
|
@ -864,6 +911,7 @@ class ExllamaV2Container:
|
|||
joined_generation["finish_reason"] = finish_reason_gen.get(
|
||||
"finish_reason"
|
||||
)
|
||||
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
|
||||
else:
|
||||
joined_generation["finish_reason"] = "stop"
|
||||
|
||||
|
|
@ -890,7 +938,11 @@ class ExllamaV2Container:
|
|||
return joined_generation
|
||||
|
||||
def check_unsupported_settings(self, **kwargs):
|
||||
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
|
||||
"""
|
||||
Check and warn the user if a sampler is unsupported.
|
||||
|
||||
Meant for dev wheels!
|
||||
"""
|
||||
|
||||
return kwargs
|
||||
|
||||
|
|
@ -1068,6 +1120,15 @@ class ExllamaV2Container:
|
|||
gen_settings.top_p = 0
|
||||
gen_settings.typical = 0
|
||||
|
||||
logger.warning(
|
||||
"".join(
|
||||
[
|
||||
"Temperature is set to 0. Overriding temp, ",
|
||||
"top_k, top_p, and typical to 1.0, 1, 0, and 0.",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Store the gen settings for logging purposes
|
||||
gen_settings_log_dict = vars(gen_settings)
|
||||
|
||||
|
|
@ -1076,6 +1137,11 @@ class ExllamaV2Container:
|
|||
if banned_tokens:
|
||||
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
|
||||
|
||||
# Set allowed tokens
|
||||
allowed_tokens = unwrap(kwargs.get("allowed_tokens"), [])
|
||||
if allowed_tokens:
|
||||
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
|
||||
|
||||
# Set logit bias
|
||||
if logit_bias:
|
||||
# Create a vocab tensor if it doesn't exist for token biasing
|
||||
|
|
@ -1088,7 +1154,7 @@ class ExllamaV2Container:
|
|||
|
||||
# Map logits to the tensor with their biases
|
||||
for token_id, bias in logit_bias.items():
|
||||
if 0 <= token_id < len(self.tokenizer.id_to_piece):
|
||||
if 0 <= token_id < len(self.tokenizer.get_id_to_piece_list(True)):
|
||||
gen_settings.token_bias[token_id] = bias
|
||||
else:
|
||||
logger.warning(
|
||||
|
|
@ -1140,8 +1206,12 @@ class ExllamaV2Container:
|
|||
# This is an inverse of skip_special_tokens
|
||||
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
|
||||
|
||||
# Log prompt to console
|
||||
log_prompt(prompt, request_id, negative_prompt)
|
||||
# Log prompt to console. Add the BOS token if specified
|
||||
log_prompt(
|
||||
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
|
||||
request_id,
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
# Create and add a new job
|
||||
# Don't use the request ID here as there can be multiple jobs per request
|
||||
|
|
@ -1227,9 +1297,17 @@ class ExllamaV2Container:
|
|||
log_response(request_id, full_response)
|
||||
|
||||
eos_reason = result.get("eos_reason")
|
||||
finish_reason = (
|
||||
"length" if eos_reason == "max_new_tokens" else "stop"
|
||||
)
|
||||
|
||||
stop_str = None
|
||||
if eos_reason == "max_new_tokens":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
# Grab stop string if stop was the reason
|
||||
if eos_reason == "stop_token":
|
||||
stop_str = result.get("eos_triggering_token_str")
|
||||
elif eos_reason == "stop_string":
|
||||
stop_str = result.get("eos_triggering_string")
|
||||
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = result
|
||||
|
|
@ -1239,6 +1317,7 @@ class ExllamaV2Container:
|
|||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
}
|
||||
|
||||
yield generation
|
||||
|
|
@ -1277,6 +1356,7 @@ class ExllamaV2Container:
|
|||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
banned_tokens=banned_tokens,
|
||||
allowed_tokens=allowed_tokens,
|
||||
banned_strings=banned_strings,
|
||||
logit_bias=logit_bias,
|
||||
filters=grammar_handler.filters,
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from loguru import logger
|
|||
def check_exllama_version():
|
||||
"""Verifies the exllama version"""
|
||||
|
||||
required_version = version.parse("0.1.7")
|
||||
required_version = version.parse("0.1.9")
|
||||
current_version = version.parse(package_version("exllamav2").split("+")[0])
|
||||
|
||||
unsupported_message = (
|
||||
|
|
|
|||
|
|
@ -13,6 +13,24 @@ def str_to_bool(value):
|
|||
raise ValueError(f"{value} is not a valid boolean value")
|
||||
|
||||
|
||||
def argument_with_auto(value):
|
||||
"""
|
||||
Argparse type wrapper for any argument that has an automatic option.
|
||||
|
||||
Ex. rope_alpha
|
||||
"""
|
||||
|
||||
if value == "auto":
|
||||
return "auto"
|
||||
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError as ex:
|
||||
raise argparse.ArgumentTypeError(
|
||||
'This argument only takes a type of float or "auto"'
|
||||
) from ex
|
||||
|
||||
|
||||
def init_argparser():
|
||||
"""Creates an argument parser that any function can use"""
|
||||
|
||||
|
|
@ -107,6 +125,11 @@ def add_model_args(parser: argparse.ArgumentParser):
|
|||
type=str_to_bool,
|
||||
help="Overrides base model context length",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--tensor-parallel",
|
||||
type=str_to_bool,
|
||||
help="Use tensor parallelism to load models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split-auto",
|
||||
type=str_to_bool,
|
||||
|
|
@ -128,7 +151,11 @@ def add_model_args(parser: argparse.ArgumentParser):
|
|||
model_group.add_argument(
|
||||
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
|
||||
)
|
||||
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
|
||||
model_group.add_argument(
|
||||
"--rope-alpha",
|
||||
type=argument_with_auto,
|
||||
help="Sets rope_alpha for NTK",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--cache-mode",
|
||||
type=str,
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
|
||||
# Check if the model is already loaded
|
||||
if container and container.model:
|
||||
loaded_model_name = container.get_model_path().name
|
||||
loaded_model_name = container.model_dir.name
|
||||
|
||||
if loaded_model_name == model_path.name and container.model_loaded:
|
||||
raise ValueError(
|
||||
|
|
@ -149,7 +149,8 @@ async def unload_embedding_model():
|
|||
embeddings_container = None
|
||||
|
||||
|
||||
def get_config_default(key: str, fallback=None, model_type: str = "model"):
|
||||
# FIXME: Maybe make this a one-time function instead of a dynamic default
|
||||
def get_config_default(key: str, model_type: str = "model"):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
model_config = config.model_config()
|
||||
|
|
@ -162,14 +163,12 @@ def get_config_default(key: str, fallback=None, model_type: str = "model"):
|
|||
# Is this a draft model load parameter?
|
||||
if model_type == "draft":
|
||||
draft_config = config.draft_model_config()
|
||||
return unwrap(draft_config.get(key), fallback)
|
||||
return draft_config.get(key)
|
||||
elif model_type == "embedding":
|
||||
embeddings_config = config.embeddings_config()
|
||||
return unwrap(embeddings_config.get(key), fallback)
|
||||
return embeddings_config.get(key)
|
||||
else:
|
||||
return unwrap(model_config.get(key), fallback)
|
||||
else:
|
||||
return fallback
|
||||
return model_config.get(key)
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[512],
|
||||
)
|
||||
|
||||
stop: Optional[Union[str, List[str]]] = Field(
|
||||
stop: Optional[Union[str, List[Union[str, int]]]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("stop", []),
|
||||
validation_alias=AliasChoices("stop", "stop_sequence"),
|
||||
description="Aliases: stop_sequence",
|
||||
|
|
@ -50,6 +50,13 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[[128, 330]],
|
||||
)
|
||||
|
||||
allowed_tokens: Optional[Union[List[int], str]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("allowed_tokens", []),
|
||||
validation_alias=AliasChoices("allowed_tokens", "allowed_token_ids"),
|
||||
description="Aliases: allowed_token_ids",
|
||||
examples=[[128, 330]],
|
||||
)
|
||||
|
||||
token_healing: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
||||
)
|
||||
|
|
@ -287,12 +294,17 @@ class BaseSamplerRequest(BaseModel):
|
|||
if self.banned_strings and isinstance(self.banned_strings, str):
|
||||
self.banned_strings = [self.banned_strings]
|
||||
|
||||
# Convert string banned tokens to an integer list
|
||||
# Convert string banned and allowed tokens to an integer list
|
||||
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
||||
self.banned_tokens = [
|
||||
int(x) for x in self.banned_tokens.split(",") if x.isdigit()
|
||||
]
|
||||
|
||||
if self.allowed_tokens and isinstance(self.allowed_tokens, str):
|
||||
self.allowed_tokens = [
|
||||
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
|
||||
]
|
||||
|
||||
gen_params = {
|
||||
"max_tokens": self.max_tokens,
|
||||
"min_tokens": self.min_tokens,
|
||||
|
|
@ -305,6 +317,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"token_healing": self.token_healing,
|
||||
"logit_bias": self.logit_bias,
|
||||
"banned_tokens": self.banned_tokens,
|
||||
"allowed_tokens": self.allowed_tokens,
|
||||
"temperature": self.temperature,
|
||||
"temperature_last": self.temperature_last,
|
||||
"min_temp": self.min_temp,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import json
|
||||
import pathlib
|
||||
from importlib.metadata import version as package_version
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from jinja2 import Template, TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
|
|
@ -18,6 +18,13 @@ class TemplateLoadError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class TemplateMetadata:
|
||||
"""Represents the parsed metadata from a template."""
|
||||
|
||||
stop_strings: List[str] = []
|
||||
tool_starts: List[str] = []
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
"""A template for chat completion prompts."""
|
||||
|
||||
|
|
@ -25,27 +32,48 @@ class PromptTemplate:
|
|||
raw_template: str
|
||||
template: Template
|
||||
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True, lstrip_blocks=True
|
||||
trim_blocks=True, lstrip_blocks=True, enable_async=True
|
||||
)
|
||||
metadata: Optional[TemplateMetadata] = None
|
||||
|
||||
def stop_strings(self, template_vars: dict):
|
||||
"""Appends extra stop strings if present in a chat template."""
|
||||
async def extract_metadata(self, template_vars: dict):
|
||||
"""
|
||||
Returns deserialized template metadata from a chat template.
|
||||
|
||||
extra_stop_strings = []
|
||||
template_module = self.template.make_module(template_vars)
|
||||
NOTE: Requires all template vars to be passed in since the template
|
||||
is run once to make a module and errors can result.
|
||||
"""
|
||||
|
||||
# No need to extract new metadata if it already exists
|
||||
# This might be removed if stored metadata becomes arbitrary
|
||||
if self.metadata:
|
||||
return self.metadata
|
||||
|
||||
template_metadata = TemplateMetadata()
|
||||
|
||||
template_module = await self.template.make_module_async(template_vars)
|
||||
|
||||
if hasattr(template_module, "stop_strings"):
|
||||
if isinstance(template_module.stop_strings, list):
|
||||
extra_stop_strings += template_module.stop_strings
|
||||
template_metadata.stop_strings += template_module.stop_strings
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping append of stopping strings from chat template "
|
||||
"because stop_strings isn't a list."
|
||||
)
|
||||
|
||||
return extra_stop_strings
|
||||
if hasattr(template_module, "tool_start"):
|
||||
if isinstance(template_module.tool_start, str):
|
||||
template_metadata.tool_starts.append(template_module.tool_start)
|
||||
|
||||
def render(self, template_vars: dict):
|
||||
if hasattr(template_module, "tool_start_token"):
|
||||
if isinstance(template_module.tool_start_token, int):
|
||||
template_metadata.tool_starts.append(template_module.tool_start_token)
|
||||
|
||||
self.metadata = template_metadata
|
||||
return template_metadata
|
||||
|
||||
async def render(self, template_vars: dict):
|
||||
"""Get a prompt from a template and a list of messages."""
|
||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||
raise ImportError(
|
||||
|
|
@ -55,10 +83,9 @@ class PromptTemplate:
|
|||
"pip install --upgrade jinja2"
|
||||
)
|
||||
|
||||
rendered_template = self.template.render(**template_vars)
|
||||
template_stop_strings = self.stop_strings(template_vars)
|
||||
rendered_template = await self.template.render_async(**template_vars)
|
||||
|
||||
return rendered_template, template_stop_strings
|
||||
return rendered_template
|
||||
|
||||
def compile(self, template_str: str):
|
||||
"""Compiles and stores a jinja2 template"""
|
||||
|
|
|
|||
|
|
@ -109,6 +109,11 @@ model:
|
|||
# Only use this if the model's base sequence length in config.json is incorrect (ex. Mistral 7B)
|
||||
#override_base_seq_len:
|
||||
|
||||
# Load model with tensor parallelism
|
||||
# If a GPU split isn't provided, the TP loader will fallback to autosplit
|
||||
# Enabling ignores the gpu_split_auto and autosplit_reserve values
|
||||
#tensor_parallel: False
|
||||
|
||||
# Automatically allocate resources to GPUs (default: True)
|
||||
# NOTE: Not parsed for single GPU users
|
||||
#gpu_split_auto: True
|
||||
|
|
@ -118,6 +123,7 @@ model:
|
|||
#autosplit_reserve: [96]
|
||||
|
||||
# An integer array of GBs of vram to split between GPUs (default: [])
|
||||
# Used with tensor parallelism
|
||||
# NOTE: Not parsed for single GPU users
|
||||
#gpu_split: [20.6, 24]
|
||||
|
||||
|
|
@ -129,7 +135,8 @@ model:
|
|||
|
||||
# Rope alpha (default: 1.0)
|
||||
# Same thing as alpha_value
|
||||
# Leave blank to automatically calculate alpha
|
||||
# Set to "auto" to automatically calculate
|
||||
# Leave blank to pull the value from the model
|
||||
#rope_alpha: 1.0
|
||||
|
||||
# Enable different cache modes for VRAM savings (slight performance hit).
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
models/
|
||||
loras/
|
||||
.ruff_cache/
|
||||
**/__pycache__/
|
||||
**/__pycache__/
|
||||
config.yml
|
||||
api_tokens.yml
|
||||
|
|
@ -1,12 +1,5 @@
|
|||
# Use an official CUDA runtime with Ubuntu as a parent image
|
||||
FROM nvidia/cuda:12.4.1-runtime-ubuntu22.04
|
||||
|
||||
ARG GIT_REPO=https://github.com/theroyallab/tabbyAPI
|
||||
ARG DO_PULL=true
|
||||
ENV DO_PULL $DO_PULL
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /app
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
|
|
@ -15,20 +8,14 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|||
ca-certificates \
|
||||
python3.11 \
|
||||
python3-pip \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Update repo
|
||||
RUN if [ ${DO_PULL} ]; then \
|
||||
git init && \
|
||||
git remote add origin $GIT_REPO && \
|
||||
git fetch origin && \
|
||||
git pull origin main && \
|
||||
echo "Pull finished"; fi
|
||||
|
||||
# Upgrade pip
|
||||
RUN pip3 install --no-cache-dir --upgrade pip
|
||||
|
||||
# Set the working directory in the container
|
||||
WORKDIR /app
|
||||
|
||||
# Get requirements
|
||||
COPY pyproject.toml .
|
||||
|
||||
|
|
@ -38,9 +25,6 @@ RUN pip3 install --no-cache-dir .[cu121]
|
|||
# Copy the current directory contents into the container
|
||||
COPY . .
|
||||
|
||||
# Create a config.yml
|
||||
COPY config_sample.yml config.yml
|
||||
|
||||
# Make port 5000 available to the world outside this container
|
||||
EXPOSE 5000
|
||||
|
||||
|
|
|
|||
|
|
@ -8,11 +8,18 @@ services:
|
|||
- DO_PULL=true
|
||||
ports:
|
||||
- "5000:5000"
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://127.0.0.1:5000/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
environment:
|
||||
- NAME=TabbyAPI
|
||||
- NVIDIA_VISIBLE_DEVICES=all
|
||||
volumes:
|
||||
- ./models:/app/models
|
||||
- ./models:/app/models # Change me
|
||||
# - /path/to/config.yml:/app/config.yml # Change me
|
||||
# - /path/to/api_tokens.yml:/app/api_tokens.yml # Change me
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ async def completion_request(
|
|||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
model_path = model.container.get_model_path()
|
||||
model_path = model.container.model_dir
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
data.prompt = "\n".join(data.prompt)
|
||||
|
|
@ -153,12 +153,12 @@ async def chat_completion_request(
|
|||
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
model_path = model.container.get_model_path()
|
||||
model_path = model.container.model_dir
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
prompt = format_prompt_with_template(data)
|
||||
prompt = await format_prompt_with_template(data)
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
if data.response_format.type == "json":
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic.json_schema import SkipJsonSchema
|
||||
from time import time
|
||||
from typing import Union, List, Optional, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
|
||||
|
||||
|
||||
class ChatCompletionLogprob(BaseModel):
|
||||
|
|
@ -19,12 +21,16 @@ class ChatCompletionLogprobs(BaseModel):
|
|||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionRespChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
# let's us understand why it stopped and if we need to generate a tool_call
|
||||
stop_str: Optional[str] = None
|
||||
message: ChatCompletionMessage
|
||||
logprobs: Optional[ChatCompletionLogprobs] = None
|
||||
|
||||
|
|
@ -42,13 +48,29 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
# Messages
|
||||
# Take in a string as well even though it's not part of the OAI spec
|
||||
# support messages.content as a list of dict
|
||||
messages: Union[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]]
|
||||
|
||||
# WIP this can probably be tightened, or maybe match the OAI lib type
|
||||
# in openai\types\chat\chat_completion_message_param.py
|
||||
messages: Union[str, List[Dict]]
|
||||
prompt_template: Optional[str] = None
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
template_vars: Optional[dict] = {}
|
||||
response_prefix: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
# tools is follows the format OAI schema, functions is more flexible
|
||||
# both are available in the chat template.
|
||||
|
||||
tools: Optional[List[ToolSpec]] = None
|
||||
functions: Optional[List[Dict]] = None
|
||||
|
||||
# Typically collected from Chat Template.
|
||||
# Don't include this in the OpenAPI docs
|
||||
# TODO: Use these custom parameters
|
||||
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None
|
||||
tool_call_end: SkipJsonSchema[Optional[str]] = None
|
||||
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||
|
|
|
|||
58
endpoints/OAI/types/tools.py
Normal file
58
endpoints/OAI/types/tools.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Dict, Literal
|
||||
|
||||
tool_call_schema = {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
# Converted to OAI's string in post process
|
||||
"type": "object"
|
||||
},
|
||||
},
|
||||
"required": ["name", "arguments"],
|
||||
},
|
||||
"type": {"type": "string", "enum": ["function"]},
|
||||
},
|
||||
"required": ["id", "function", "type"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Represents a description of a tool function."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, object]
|
||||
|
||||
|
||||
class ToolSpec(BaseModel):
|
||||
"""Wrapper for an inner tool function."""
|
||||
|
||||
function: Function
|
||||
type: Literal["function"]
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Represents an OAI tool description."""
|
||||
|
||||
name: str
|
||||
|
||||
# Makes more sense to be a dict, but OAI knows best
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Represents an OAI tool description."""
|
||||
|
||||
id: str
|
||||
function: Tool
|
||||
type: Literal["function"]
|
||||
|
|
@ -5,6 +5,7 @@ import pathlib
|
|||
from asyncio import CancelledError
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
|
|
@ -30,6 +31,7 @@ from endpoints.OAI.types.chat_completion import (
|
|||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
from endpoints.OAI.utils.completion import _stream_collector
|
||||
from endpoints.OAI.types.tools import ToolCall
|
||||
|
||||
|
||||
def _create_response(
|
||||
|
|
@ -46,6 +48,10 @@ def _create_response(
|
|||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
tool_calls = generation["tool_calls"]
|
||||
if tool_calls:
|
||||
message.tool_calls = postprocess_tool_call(tool_calls)
|
||||
|
||||
logprob_response = None
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
|
|
@ -72,6 +78,7 @@ def _create_response(
|
|||
choice = ChatCompletionRespChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
stop_str=generation.get("stop_str"),
|
||||
message=message,
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
|
@ -119,7 +126,16 @@ def _create_stream_chunk(
|
|||
finish_reason=generation.get("finish_reason"),
|
||||
)
|
||||
|
||||
# lets check if we have tool calls since we are at the end of the generation
|
||||
if "tool_calls" in generation:
|
||||
tool_calls = generation["tool_calls"]
|
||||
message = ChatCompletionMessage(
|
||||
tool_calls=postprocess_tool_call(tool_calls)
|
||||
)
|
||||
choice.delta = message
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
|
|
@ -162,7 +178,31 @@ def _create_stream_chunk(
|
|||
return chunk
|
||||
|
||||
|
||||
def format_prompt_with_template(data: ChatCompletionRequest):
|
||||
async def _append_template_metadata(data: ChatCompletionRequest):
|
||||
"""Adding metadata is a one-time process."""
|
||||
|
||||
template_metadata = await model.container.prompt_template.extract_metadata(
|
||||
data.template_vars
|
||||
)
|
||||
|
||||
# Stop strings
|
||||
if isinstance(data.stop, str):
|
||||
data.stop = [data.stop] + template_metadata.stop_strings
|
||||
else:
|
||||
data.stop += template_metadata.stop_strings
|
||||
|
||||
# Tool call start strings
|
||||
if template_metadata.tool_starts:
|
||||
if data.tool_call_start is None:
|
||||
data.tool_call_start = template_metadata.tool_starts
|
||||
|
||||
# Append to stop strings to halt for a tool call generation
|
||||
data.stop.extend(template_metadata.tool_starts)
|
||||
|
||||
|
||||
async def format_prompt_with_template(
|
||||
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Compile the prompt and get any additional stop strings from the template.
|
||||
Template stop strings can be overriden by sampler overrides if force is true.
|
||||
|
|
@ -187,18 +227,22 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
"",
|
||||
)
|
||||
|
||||
if "tool_calls" in message:
|
||||
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
|
||||
|
||||
# Overwrite any protected vars with their values
|
||||
data.template_vars.update(
|
||||
{
|
||||
"messages": data.messages,
|
||||
"add_generation_prompt": data.add_generation_prompt,
|
||||
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
|
||||
"functions_json": json.dumps(data.functions, indent=2),
|
||||
"tool_precursor": tool_precursor,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
)
|
||||
|
||||
prompt, template_stop_strings = model.container.prompt_template.render(
|
||||
data.template_vars
|
||||
)
|
||||
prompt = await model.container.prompt_template.render(data.template_vars)
|
||||
|
||||
# Append response prefix if present
|
||||
if data.response_prefix:
|
||||
|
|
@ -216,11 +260,8 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
if bos_token and prompt.startswith(bos_token):
|
||||
prompt = prompt.removeprefix(bos_token)
|
||||
|
||||
# Append template stop strings
|
||||
if isinstance(data.stop, str):
|
||||
data.stop = [data.stop] + template_stop_strings
|
||||
else:
|
||||
data.stop += template_stop_strings
|
||||
# Add template metadata
|
||||
await _append_template_metadata(data)
|
||||
|
||||
return prompt
|
||||
|
||||
|
|
@ -271,6 +312,9 @@ async def stream_generate_chat_completion(
|
|||
|
||||
gen_tasks.append(gen_task)
|
||||
|
||||
# We need to keep track of the text generated so we can resume the tool calls
|
||||
current_generation_text = ""
|
||||
|
||||
# Consumer loop
|
||||
while True:
|
||||
if disconnect_task.done():
|
||||
|
|
@ -280,6 +324,19 @@ async def stream_generate_chat_completion(
|
|||
)
|
||||
|
||||
generation = await gen_queue.get()
|
||||
# lets only append the text if we need it for tool calls later
|
||||
if data.tool_call_start and "text" in generation:
|
||||
current_generation_text += generation["text"]
|
||||
|
||||
# check if we are running a tool model, and that we are at stop
|
||||
if data.tool_call_start and "stop_str" in generation:
|
||||
generations = await generate_tool_calls(
|
||||
data,
|
||||
[generation],
|
||||
request,
|
||||
current_generations=current_generation_text,
|
||||
)
|
||||
generation = generations[0] # We only have one generation in this case
|
||||
|
||||
# Stream collector will push an exception to the queue if it fails
|
||||
if isinstance(generation, Exception):
|
||||
|
|
@ -344,6 +401,11 @@ async def generate_chat_completion(
|
|||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
|
||||
# Let's not waste our time if we arn't running a tool model
|
||||
if data.tool_call_start:
|
||||
generations = await generate_tool_calls(data, generations, request)
|
||||
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
logger.info(f"Finished chat completion request {request.state.id}")
|
||||
|
|
@ -358,3 +420,56 @@ async def generate_chat_completion(
|
|||
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
|
||||
|
||||
async def generate_tool_calls(
|
||||
data: ChatCompletionRequest,
|
||||
generations: List[str],
|
||||
request: Request,
|
||||
current_generations: str = None,
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
tool_idx: List[int] = []
|
||||
|
||||
# Copy to make sure the parent JSON schema doesn't get modified
|
||||
# FIXME: May not be necessary depending on how the codebase evolves
|
||||
tool_data = deepcopy(data)
|
||||
tool_data.json_schema = tool_data.tool_call_schema
|
||||
gen_params = tool_data.to_gen_params()
|
||||
|
||||
for idx, gen in enumerate(generations):
|
||||
if gen["stop_str"] in tool_data.tool_call_start:
|
||||
if "text" in gen:
|
||||
# non streaming, all generations will have the text they generated
|
||||
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
|
||||
elif current_generations is not None:
|
||||
# streaming, we wont have text in the generation,
|
||||
# we'll have to use the current_generations
|
||||
pre_tool_prompt = await format_prompt_with_template(
|
||||
data, current_generations
|
||||
)
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
pre_tool_prompt, request.state.id, **gen_params
|
||||
)
|
||||
)
|
||||
)
|
||||
tool_idx.append(idx)
|
||||
|
||||
tool_calls = await asyncio.gather(*gen_tasks)
|
||||
for outer_idx in range(0, len(tool_idx)):
|
||||
gen_idx = tool_idx[outer_idx]
|
||||
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
|
||||
|
||||
return generations
|
||||
|
||||
|
||||
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
|
||||
tool_calls = json.loads(call_str)
|
||||
for tool_call in tool_calls:
|
||||
tool_call["function"]["arguments"] = json.dumps(
|
||||
tool_call["function"]["arguments"]
|
||||
)
|
||||
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
||||
|
|
|
|||
|
|
@ -44,6 +44,13 @@ from endpoints.core.utils.model import (
|
|||
router = APIRouter()
|
||||
|
||||
|
||||
# Healthcheck endpoint
|
||||
@router.get("/health")
|
||||
async def healthcheck():
|
||||
"""Get the current service health status"""
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
# Model list endpoint
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from time import time
|
||||
from typing import List, Optional
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from common.gen_logging import GenLogPreferences
|
||||
from common.model import get_config_default
|
||||
|
|
@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel):
|
|||
# Config arguments
|
||||
draft_rope_scale: Optional[float] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_scale", 1.0, model_type="draft"
|
||||
"draft_rope_scale", model_type="draft"
|
||||
)
|
||||
)
|
||||
draft_rope_alpha: Optional[float] = Field(
|
||||
description="Automatically calculated if not present",
|
||||
draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||
description='Automatically calculated if set to "auto"',
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_alpha", None, model_type="draft"
|
||||
"draft_rope_alpha", model_type="draft"
|
||||
),
|
||||
examples=[1.0],
|
||||
)
|
||||
draft_cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_cache_mode", "FP16", model_type="draft"
|
||||
"draft_cache_mode", model_type="draft"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -96,14 +96,17 @@ class ModelLoadRequest(BaseModel):
|
|||
default_factory=lambda: get_config_default("cache_size"),
|
||||
examples=[4096],
|
||||
)
|
||||
tensor_parallel: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("tensor_parallel")
|
||||
)
|
||||
gpu_split_auto: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split_auto", True)
|
||||
default_factory=lambda: get_config_default("gpu_split_auto")
|
||||
)
|
||||
autosplit_reserve: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("autosplit_reserve", [96])
|
||||
default_factory=lambda: get_config_default("autosplit_reserve")
|
||||
)
|
||||
gpu_split: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split", []),
|
||||
default_factory=lambda: get_config_default("gpu_split"),
|
||||
examples=[[24.0, 20.0]],
|
||||
)
|
||||
rope_scale: Optional[float] = Field(
|
||||
|
|
@ -111,16 +114,16 @@ class ModelLoadRequest(BaseModel):
|
|||
default_factory=lambda: get_config_default("rope_scale"),
|
||||
examples=[1.0],
|
||||
)
|
||||
rope_alpha: Optional[float] = Field(
|
||||
description="Automatically calculated if not present",
|
||||
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||
description='Automatically calculated if set to "auto"',
|
||||
default_factory=lambda: get_config_default("rope_alpha"),
|
||||
examples=[1.0],
|
||||
)
|
||||
cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("cache_mode", "FP16")
|
||||
default_factory=lambda: get_config_default("cache_mode")
|
||||
)
|
||||
chunk_size: Optional[int] = Field(
|
||||
default_factory=lambda: get_config_default("chunk_size", 2048)
|
||||
default_factory=lambda: get_config_default("chunk_size")
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("prompt_template")
|
||||
|
|
@ -129,7 +132,7 @@ class ModelLoadRequest(BaseModel):
|
|||
default_factory=lambda: get_config_default("num_experts_per_token")
|
||||
)
|
||||
fasttensors: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("fasttensors", False)
|
||||
default_factory=lambda: get_config_default("fasttensors")
|
||||
)
|
||||
|
||||
# Non-config arguments
|
||||
|
|
|
|||
|
|
@ -43,12 +43,16 @@ async def get_current_model_list(model_type: str = "model"):
|
|||
model_path = None
|
||||
|
||||
# Make sure the model container exists
|
||||
if model_type == "model" or model_type == "draft":
|
||||
if model.container:
|
||||
model_path = model.container.get_model_path(model_type == "draft")
|
||||
elif model_type == "embedding":
|
||||
if model.embeddings_container:
|
||||
model_path = model.embeddings_container.model_dir
|
||||
match model_type:
|
||||
case "model":
|
||||
if model.container:
|
||||
model_path = model.container.model_dir
|
||||
case "draft":
|
||||
if model.container:
|
||||
model_path = model.container.draft_model_dir
|
||||
case "embedding":
|
||||
if model.embeddings_container:
|
||||
model_path = model.embeddings_container.model_dir
|
||||
|
||||
if model_path:
|
||||
current_models.append(ModelCard(id=model_path.name))
|
||||
|
|
@ -94,8 +98,10 @@ async def stream_model_load(
|
|||
):
|
||||
"""Request generation wrapper for the loading process."""
|
||||
|
||||
# Get trimmed load data
|
||||
load_data = data.model_dump(exclude_none=True)
|
||||
|
||||
# Set the draft model path if it exists
|
||||
load_data = data.model_dump()
|
||||
if draft_model_path:
|
||||
load_data["draft"]["draft_model_dir"] = draft_model_path
|
||||
|
||||
|
|
|
|||
|
|
@ -77,8 +77,6 @@ async def start_api(host: str, port: int):
|
|||
|
||||
# TODO: Move OAI API to a separate folder
|
||||
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
|
||||
# logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||
# logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||
|
||||
# Setup app
|
||||
app = setup_app(host, port)
|
||||
|
|
|
|||
|
|
@ -68,12 +68,12 @@ cu121 = [
|
|||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
|
||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
|
|
@ -95,12 +95,12 @@ cu118 = [
|
|||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
|
|
@ -119,9 +119,9 @@ amd = [
|
|||
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.1.8/exllamav2-0.1.8+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
]
|
||||
|
||||
# MARK: Ruff options
|
||||
|
|
|
|||
|
|
@ -126,6 +126,10 @@ banned_tokens:
|
|||
override: []
|
||||
force: false
|
||||
additive: false
|
||||
allowed_tokens:
|
||||
override: []
|
||||
force: false
|
||||
additive: false
|
||||
|
||||
# MARK: CFG scale
|
||||
cfg_scale:
|
||||
|
|
|
|||
84
templates/chatml_with_headers_tool_calling.jinja
Normal file
84
templates/chatml_with_headers_tool_calling.jinja
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
{# Metadata #}
|
||||
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
|
||||
{% set message_roles = ['system', 'user', 'assistant', 'tool'] %}
|
||||
{% set tool_start = "<|tool_start|>" %}
|
||||
{% set tool_end = "<|tool_end|>" %}
|
||||
{%- set start_header = "<|start_header_id|>" -%}
|
||||
{%- set end_header = "<|end_header_id|>\n" -%}
|
||||
|
||||
{%- set example_tool_call -%}[
|
||||
{
|
||||
"id": "tool_id_1342",
|
||||
"function": {
|
||||
"arguments": "arg_name": 3,
|
||||
"name": "tool_name"
|
||||
},
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"id": "example_id_13f42",
|
||||
"function": {
|
||||
"arguments": "example_arg": 1.0, "another_example_arg": true,
|
||||
"name": "another_tool_name"
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
{%- endset -%}
|
||||
|
||||
{%- set inital_system_prompt -%}You are an assistant that has access to the following set of tools, to call a tool:
|
||||
1. Prefix calls with '{{ tool_start }}' and end calls with '{{ tool_end }}'
|
||||
2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be.
|
||||
3. Generate all calls using the following json tool call format. Here is a multi tool call example:
|
||||
|
||||
{{ tool_start }}{{ example_tool_call }}{{ tool_end }}
|
||||
|
||||
Here are the tools available for you to call:
|
||||
{{ tools_json }}
|
||||
{%- endset -%}
|
||||
|
||||
{%- set tool_reminder -%}Available Tools:
|
||||
{{ tools_json }}
|
||||
|
||||
Tool Call Format Example:
|
||||
{{ tool_start }}{{ example_tool_call }}
|
||||
|
||||
Prefix & Suffix: Begin tool calls with {{ tool_start }} and end with {{ tool_end }}.
|
||||
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).
|
||||
{%- endset -%}
|
||||
|
||||
{# Template #}
|
||||
|
||||
{%- for message in messages -%}
|
||||
{%- set role = message['role'] | lower -%}
|
||||
{%- if role not in message_roles -%}
|
||||
{{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- set content = message['content'] | default('', true) | trim -%}
|
||||
{%- if loop.first -%}
|
||||
{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }}
|
||||
{{ inital_system_prompt }}
|
||||
|
||||
{{ content }}{{ eos_token }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not loop.first -%}
|
||||
{{ start_header }}{{ role }}{{ end_header }}
|
||||
{{ content }}
|
||||
{%- if 'tool_calls_json' in message and message['tool_calls_json'] -%}
|
||||
{{ tool_start }}{{ message['tool_calls_json']}}{{ tool_end }}
|
||||
{%- endif -%}
|
||||
{{ eos_token }}
|
||||
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if tool_precursor -%}
|
||||
{{ start_header }}system{{ end_header }}
|
||||
{{ tool_reminder }}{{ eos_token }}
|
||||
{{ start_header }}assistant{{ end_header }}
|
||||
{{ tool_precursor }}{{ tool_start }}
|
||||
{%- else -%}
|
||||
{{ start_header }}assistant{{ end_header }}
|
||||
{%- endif -%}
|
||||
Loading…
Add table
Add a link
Reference in a new issue