Merge remote-tracking branch 'upstream/main' into HEAD

This commit is contained in:
TerminalMan 2024-09-11 15:57:18 +01:00
commit e8fcecd56a
28 changed files with 386 additions and 171 deletions

20
.dockerignore Normal file
View file

@ -0,0 +1,20 @@
.ruff_cache/
**/__pycache__/
venv
.git
.gitignore
.github
# Ignore specific application files
models/
loras/
config.yml
config_sample.yml
api_tokens.yml
api_tokens_sample.yml
*.bat
*.sh
update_scripts
readme.md
colab
start.py

6
.gitignore vendored
View file

@ -192,7 +192,11 @@ templates/*
!templates/place_your_templates_here.txt
!templates/alpaca.jinja
!templates/chatml.jinja
!templates/chatml_with_headers_tool_calling.jinja
# Tool calling templates folder
templates/tool_calls/*
!templates/tool_calls
!templates/tool_calls/chatml_with_headers.jinja
# Sampler overrides folder
sampler_overrides/*

View file

@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models."""
import aiofiles
import asyncio
import gc
import math
@ -7,6 +8,7 @@ import pathlib
import traceback
import torch
import uuid
from copy import deepcopy
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Config,
@ -105,13 +107,17 @@ class ExllamaV2Container:
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
@classmethod
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
"""
Primary initializer for model container.
Primary asynchronous initializer for model container.
Kwargs are located in config_sample.yml
"""
# Create a new instance as a "fake self"
self = cls()
self.quiet = quiet
# Initialize config
@ -154,13 +160,13 @@ class ExllamaV2Container:
self.draft_config.prepare()
# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
@ -170,7 +176,7 @@ class ExllamaV2Container:
)
# Apply a model's config overrides while respecting user settings
kwargs = self.set_model_overrides(**kwargs)
kwargs = await self.set_model_overrides(**kwargs)
# MARK: User configuration
@ -319,7 +325,7 @@ class ExllamaV2Container:
self.cache_size = self.config.max_seq_len
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
self.prompt_template = await self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
@ -372,7 +378,10 @@ class ExllamaV2Container:
self.draft_config.max_input_len = chunk_size
self.draft_config.max_attention_size = chunk_size**2
def set_model_overrides(self, **kwargs):
# Return the created instance
return self
async def set_model_overrides(self, **kwargs):
"""Sets overrides from a model folder's config yaml."""
override_config_path = self.model_dir / "tabby_config.yml"
@ -380,8 +389,11 @@ class ExllamaV2Container:
if not override_config_path.exists():
return kwargs
with open(override_config_path, "r", encoding="utf8") as override_config_file:
override_args = unwrap(yaml.safe_load(override_config_file), {})
async with aiofiles.open(
override_config_path, "r", encoding="utf8"
) as override_config_file:
contents = await override_config_file.read()
override_args = unwrap(yaml.safe_load(contents), {})
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
@ -392,7 +404,7 @@ class ExllamaV2Container:
merged_kwargs = {**override_args, **kwargs}
return merged_kwargs
def find_prompt_template(self, prompt_template_name, model_directory):
async def find_prompt_template(self, prompt_template_name, model_directory):
"""Tries to find a prompt template using various methods."""
logger.info("Attempting to load a prompt template if present.")
@ -400,26 +412,37 @@ class ExllamaV2Container:
find_template_functions = [
lambda: PromptTemplate.from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
key="chat_template",
),
lambda: PromptTemplate.from_file(find_template_from_model(model_directory)),
]
# Find the template in the model directory if it exists
model_dir_template_path = (
pathlib.Path(self.config.model_dir) / "tabby_template.jinja"
)
if model_dir_template_path.exists():
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(model_dir_template_path)
]
# Add lookup from prompt template name if provided
if prompt_template_name:
find_template_functions[:0] = [
lambda: PromptTemplate.from_file(prompt_template_name),
lambda: PromptTemplate.from_file(
pathlib.Path("templates") / prompt_template_name
),
lambda: PromptTemplate.from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
prompt_template_name,
key="chat_template",
name=prompt_template_name,
),
]
# Continue on exception since functions are tried as they fail
for template_func in find_template_functions:
try:
prompt_template = template_func()
prompt_template = await template_func()
if prompt_template is not None:
return prompt_template
except TemplateLoadError as e:
@ -944,6 +967,14 @@ class ExllamaV2Container:
Meant for dev wheels!
"""
if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr(
ExLlamaV2Sampler.Settings, "dry_multiplier"
):
logger.warning(
"DRY sampling is not supported by the currently "
"installed ExLlamaV2 version."
)
return kwargs
async def generate_gen(
@ -1035,6 +1066,7 @@ class ExllamaV2Container:
"Please use an ampere (30 series) or higher GPU for CFG support."
)
# Penalties
gen_settings.token_repetition_penalty = unwrap(
kwargs.get("repetition_penalty"), 1.0
)
@ -1070,6 +1102,32 @@ class ExllamaV2Container:
kwargs.get("repetition_decay"), fallback_decay, 0
)
# DRY options
dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 0.0)
# < 0 = disabled
if dry_multiplier > 0:
gen_settings.dry_multiplier = dry_multiplier
# TODO: Maybe set the "sane" defaults instead?
gen_settings.dry_allowed_length = unwrap(
kwargs.get("dry_allowed_length"), 0
)
gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 0.0)
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
# Use max_seq_len as the fallback to stay consistent
gen_settings.dry_range = unwrap(
kwargs.get("dry_range"), self.config.max_seq_len
)
# Tokenize sequence breakers
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers")
if dry_sequence_breakers_json:
gen_settings.dry_sequence_breakers = {
self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json
}
# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
@ -1130,7 +1188,8 @@ class ExllamaV2Container:
)
# Store the gen settings for logging purposes
gen_settings_log_dict = vars(gen_settings)
# Deepcopy to save a snapshot of vars
gen_settings_log_dict = deepcopy(vars(gen_settings))
# Set banned tokens
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])

View file

@ -8,7 +8,7 @@ from loguru import logger
def check_exllama_version():
"""Verifies the exllama version"""
required_version = version.parse("0.1.9")
required_version = version.parse("0.2.1")
current_version = version.parse(package_version("exllamav2").split("+")[0])
unsupported_message = (

View file

@ -7,13 +7,13 @@ from typing import List, Optional
from common.utils import unwrap
# Conditionally import infinity to sidestep its logger
# TODO: Make this prettier
has_infinity_emb: bool = False
try:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
has_infinity_emb = True
except ImportError:
has_infinity_emb = False
pass
class InfinityContainer:

View file

@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
application, it should be fine.
"""
import aiofiles
import secrets
import yaml
from fastapi import Header, HTTPException, Request
@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False
def load_auth_keys(disable_from_config: bool):
async def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool):
return
try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
contents = await auth_file.read()
auth_keys_dict = yaml.safe_load(contents)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool):
)
AUTH_KEYS = new_auth_keys
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
new_auth_yaml = yaml.safe_dump(
AUTH_KEYS.model_dump(), default_flow_style=False
)
await auth_file.write(new_auth_yaml)
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"

View file

@ -13,7 +13,6 @@ from typing import Optional
from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.utils import unwrap
from endpoints.utils import do_export_openapi
if not do_export_openapi:
@ -67,7 +66,11 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
logger.info("Unloading existing model.")
await unload_model()
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
# Merge with config defaults
kwargs = {**config.model_defaults, **kwargs}
# Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
load_status = container.load_gen(load_progress, **kwargs)
@ -149,25 +152,6 @@ async def unload_embedding_model():
embeddings_container = None
# FIXME: Maybe make this a one-time function instead of a dynamic default
def get_config_default(key: str, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user."""
default_keys = unwrap(config.model.use_as_default, [])
# Add extra keys to defaults
default_keys.append("embeddings_device")
if key in default_keys:
# Is this a draft model load parameter?
if model_type == "draft":
return config.draft_model.get(key)
elif model_type == "embedding":
return config.embeddings.get(key)
else:
return config.model.get(key)
async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""

View file

@ -1,5 +1,7 @@
"""Common functions for sampling parameters"""
import aiofiles
import json
import pathlib
import yaml
from copy import deepcopy
@ -140,6 +142,28 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
dry_multiplier: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0)
)
dry_base: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("dry_base", 0.0)
)
dry_allowed_length: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0)
)
dry_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_range", 0),
alias=AliasChoices("dry_range", "dry_penalty_last_n"),
description=("Aliases: dry_penalty_last_n"),
)
dry_sequence_breakers: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
)
mirostat_mode: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
)
@ -305,6 +329,17 @@ class BaseSamplerRequest(BaseModel):
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
]
# Convert sequence breakers into an array of strings
# NOTE: This sampler sucks to parse.
if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str):
if not self.dry_sequence_breakers.startswith("["):
self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]"
try:
self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers)
except Exception:
self.dry_sequence_breakers = []
gen_params = {
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
@ -335,6 +370,11 @@ class BaseSamplerRequest(BaseModel):
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"penalty_range": self.penalty_range,
"dry_multiplier": self.dry_multiplier,
"dry_base": self.dry_base,
"dry_allowed_length": self.dry_allowed_length,
"dry_sequence_breakers": self.dry_sequence_breakers,
"dry_range": self.dry_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
@ -368,14 +408,15 @@ def overrides_from_dict(new_overrides: dict):
raise TypeError("New sampler overrides must be a dict!")
def overrides_from_file(preset_name: str):
async def overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
if preset_path.exists():
overrides_container.selected_preset = preset_path.stem
with open(preset_path, "r", encoding="utf8") as raw_preset:
preset = yaml.safe_load(raw_preset)
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
contents = await raw_preset.read()
preset = yaml.safe_load(contents)
overrides_from_dict(preset)
logger.info("Applied sampler overrides from file.")

View file

@ -10,8 +10,13 @@ import common.config_models
class TabbyConfig(tabby_config_model):
def load_config(self, arguments: Optional[dict] = None):
"""load the global application config"""
# Persistent defaults
# TODO: make this pydantic?
model_defaults: dict = {}
def load(self, arguments: Optional[dict] = None):
"""Synchronously loads the global application config"""
# config is applied in order of items in the list
configs = [
@ -28,6 +33,17 @@ class TabbyConfig(tabby_config_model):
setattr(self, field, model.parse_obj(value))
# Set model defaults dict once to prevent on-demand reconstruction
# TODO: clean this up a bit
for field in self.model.use_as_default:
if hasattr(self.model, field):
self.model_defaults[field] = getattr(config.model, field)
elif hasattr(self.draft_model, field):
self.model_defaults[field] = getattr(config.draft_model, field)
else:
# TODO: show an error
pass
def _from_file(self, config_path: pathlib.Path):
"""loads config from a given file path"""
@ -53,7 +69,7 @@ class TabbyConfig(tabby_config_model):
config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Config file override detected in args.")
config = self.from_file(pathlib.Path(config_override))
config = self._from_file(pathlib.Path(config_override))
return config # Return early if loading from file
for key in tabby_config_model.model_fields.keys():
@ -85,5 +101,5 @@ class TabbyConfig(tabby_config_model):
return config
# Create an empty instance of the shared var to make sure nothing breaks
# Create an empty instance of the config class
config: TabbyConfig = TabbyConfig()

View file

@ -1,10 +1,12 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import aiofiles
import json
import pathlib
from importlib.metadata import version as package_version
from typing import List, Optional
from jinja2 import Template, TemplateError
from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from packaging import version
@ -32,7 +34,10 @@ class PromptTemplate:
raw_template: str
template: Template
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, enable_async=True
trim_blocks=True,
lstrip_blocks=True,
enable_async=True,
extensions=[loopcontrols],
)
metadata: Optional[TemplateMetadata] = None
@ -106,32 +111,42 @@ class PromptTemplate:
self.template = self.compile(raw_template)
@classmethod
def from_file(self, prompt_template_name: str):
async def from_file(cls, template_path: pathlib.Path):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
# Add the jinja extension if it isn't provided
if template_path.suffix.endswith(".jinja"):
template_name = template_path.name.split(".jinja")[0]
else:
template_name = template_path.name
template_path = template_path.with_suffix(".jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template_stream:
return PromptTemplate(
name=prompt_template_name,
raw_template=raw_template_stream.read(),
async with aiofiles.open(
template_path, "r", encoding="utf8"
) as raw_template_stream:
contents = await raw_template_stream.read()
return cls(
name=template_name,
raw_template=contents,
)
else:
# Let the user know if the template file isn't found
raise TemplateLoadError(
f'Chat template "{prompt_template_name}" not found in files.'
f'Chat template "{template_name}" not found in files.'
)
@classmethod
def from_model_json(
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
async def from_model_json(
cls, json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
contents = await config_file.read()
model_config = json.loads(contents)
chat_template = model_config.get(key)
if not chat_template:
@ -162,7 +177,7 @@ class PromptTemplate:
)
else:
# Can safely assume the chat template is the old style
return PromptTemplate(
return cls(
name="from_tokenizer_config",
raw_template=chat_template,
)

View file

@ -1,3 +1,4 @@
import aiofiles
import json
import pathlib
from typing import List, Optional, Union
@ -15,15 +16,16 @@ class GenerationConfig(BaseModel):
bad_words_ids: Optional[List[List[int]]] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
generation_config_path = model_directory / "generation_config.json"
with open(
async with aiofiles.open(
generation_config_path, "r", encoding="utf8"
) as generation_config_json:
generation_config_dict = json.load(generation_config_json)
return self.model_validate(generation_config_dict)
contents = await generation_config_json.read()
generation_config_dict = json.loads(contents)
return cls.model_validate(generation_config_dict)
def eos_tokens(self):
"""Wrapper method to fetch EOS tokens."""
@ -43,13 +45,16 @@ class HuggingFaceConfig(BaseModel):
badwordsids: Optional[str] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
async def from_file(cls, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
hf_config_dict = json.load(hf_config_json)
return self.model_validate(hf_config_dict)
async with aiofiles.open(
hf_config_path, "r", encoding="utf8"
) as hf_config_json:
contents = await hf_config_json.read()
hf_config_dict = json.loads(contents)
return cls.model_validate(hf_config_dict)
def get_badwordsids(self):
"""Wrapper method to fetch badwordsids."""

View file

@ -83,6 +83,9 @@ model:
# Enable this if the program is looking for a specific OAI model
#use_dummy_models: False
# Allow direct loading of models from a completion or chat completion request
inline_model_loading: False
# An initial model to load. Make sure the model is located in the model directory!
# A model can be loaded later via the API.
# REQUIRED: This must be filled out to load a model on startup!

View file

@ -1,6 +0,0 @@
models/
loras/
.ruff_cache/
**/__pycache__/
config.yml
api_tokens.yml

View file

@ -22,6 +22,8 @@ COPY pyproject.toml .
# Install packages specified in pyproject.toml cu121
RUN pip3 install --no-cache-dir .[cu121]
RUN rm pyproject.toml
# Copy the current directory contents into the container
COPY . .

View file

@ -1,11 +1,13 @@
version: '3.8'
services:
tabbyapi:
build:
context: ..
dockerfile: ./docker/Dockerfile
args:
- DO_PULL=true
# Uncomment this to build a docker image from source
#build:
# context: ..
# dockerfile: ./docker/Dockerfile
# Comment this to build a docker image from source
image: ghcr.io/theroyallab/tabbyapi:latest
ports:
- "5000:5000"
healthcheck:

View file

@ -137,7 +137,7 @@ async def get_version():
async def get_extra_version():
"""Impersonate Koboldcpp."""
return {"result": "KoboldCpp", "version": "1.61"}
return {"result": "KoboldCpp", "version": "1.71"}
@kai_router.get("/config/soft_prompts_list")

View file

@ -22,6 +22,7 @@ from endpoints.OAI.utils.chat_completion import (
)
from endpoints.OAI.utils.completion import (
generate_completion,
load_inline_model,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import get_embeddings
@ -42,7 +43,7 @@ def setup():
# Completions endpoint
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
dependencies=[Depends(check_api_key)],
)
async def completion_request(
request: Request, data: CompletionRequest
@ -53,6 +54,18 @@ async def completion_request(
If stream = true, this returns an SSE stream.
"""
if data.model:
inline_load_task = asyncio.create_task(load_inline_model(data.model, request))
await run_with_request_disconnect(
request,
inline_load_task,
disconnect_message=f"Model switch for generation {request.state.id} "
+ "cancelled by user.",
)
else:
await check_model_container()
model_path = model.container.model_dir
if isinstance(data.prompt, list):
@ -85,7 +98,7 @@ async def completion_request(
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
dependencies=[Depends(check_api_key)],
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
@ -96,6 +109,11 @@ async def chat_completion_request(
If stream = true, this returns an SSE stream.
"""
if data.model:
await load_inline_model(data.model, request)
else:
await check_model_container()
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",

View file

@ -56,6 +56,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}
response_prefix: Optional[str] = None
model: Optional[str] = None
# tools is follows the format OAI schema, functions is more flexible
# both are available in the chat template.

View file

@ -1,4 +1,8 @@
"""Completion utilities for OAI server."""
"""
Completion utilities for OAI server.
Also serves as a common module for completions and chat completions.
"""
import asyncio
import pathlib
@ -10,12 +14,14 @@ from typing import List, Union
from loguru import logger
from common import model
from common.auth import get_key_permission
from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
request_disconnect_loop,
)
from common.tabby_config import config
from common.utils import unwrap
from endpoints.OAI.types.completion import (
CompletionRequest,
@ -103,6 +109,50 @@ async def _stream_collector(
await gen_queue.put(e)
async def load_inline_model(model_name: str, request: Request):
"""Load a model from the data.model parameter"""
# Return if the model container already exists and the model is fully loaded
if (
model.container
and model.container.model_dir.name == model_name
and model.container.model_loaded
):
return
# Inline model loading isn't enabled or the user isn't an admin
if not get_key_permission(request) == "admin":
error_message = handle_request_error(
f"Unable to switch model to {model_name} because "
+ "an admin key isn't provided",
exc_info=False,
).error.message
raise HTTPException(401, error_message)
if not unwrap(config.model.get("inline_model_loading"), False):
logger.warning(
f"Unable to switch model to {model_name} because "
'"inline_model_loading" is not True in config.yml.'
)
return
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
model_path = model_path / model_name
# Model path doesn't exist
if not model_path.exists():
logger.warning(
f"Could not find model path {str(model_path)}. Skipping inline model load."
)
return
# Load the model
await model.load_model(model_path)
async def stream_generate_completion(
data: CompletionRequest, request: Request, model_path: pathlib.Path
):

View file

@ -103,7 +103,7 @@ async def list_draft_models(request: Request) -> ModelList:
models = get_model_list(draft_model_path.resolve())
else:
models = await get_current_model_list(is_draft=True)
models = await get_current_model_list(model_type="draft")
return models
@ -441,7 +441,8 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message)
try:
model.container.prompt_template = PromptTemplate.from_file(data.name)
template_path = pathlib.Path("templates") / data.name
model.container.prompt_template = await PromptTemplate.from_file(template_path)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
@ -490,7 +491,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
if data.preset:
try:
sampling.overrides_from_file(data.preset)
await sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "

View file

@ -5,7 +5,8 @@ from time import time
from typing import List, Literal, Optional, Union
from common.config_models import logging_config_model
from common.model import get_config_default
from common.tabby_config import config
from common.utils import unwrap
class ModelCardParameters(BaseModel):
@ -51,23 +52,13 @@ class DraftModelLoadRequest(BaseModel):
draft_model_name: str
# Config arguments
draft_rope_scale: Optional[float] = Field(
default_factory=lambda: get_config_default(
"draft_rope_scale", model_type="draft"
)
)
draft_rope_scale: Optional[float] = None
draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
description='Automatically calculated if set to "auto"',
default_factory=lambda: get_config_default(
"draft_rope_alpha", model_type="draft"
),
default=None,
examples=[1.0],
)
draft_cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default(
"draft_cache_mode", model_type="draft"
)
)
draft_cache_mode: Optional[str] = None
class ModelLoadRequest(BaseModel):
@ -78,62 +69,45 @@ class ModelLoadRequest(BaseModel):
# Config arguments
# Max seq len is fetched from config.json of the model by default
max_seq_len: Optional[int] = Field(
description="Leave this blank to use the model's base sequence length",
default_factory=lambda: get_config_default("max_seq_len"),
default=None,
examples=[4096],
)
override_base_seq_len: Optional[int] = Field(
description=(
"Overrides the model's base sequence length. " "Leave blank if unsure"
),
default_factory=lambda: get_config_default("override_base_seq_len"),
default=None,
examples=[4096],
)
cache_size: Optional[int] = Field(
description=("Number in tokens, must be greater than or equal to max_seq_len"),
default_factory=lambda: get_config_default("cache_size"),
default=None,
examples=[4096],
)
tensor_parallel: Optional[bool] = Field(
default_factory=lambda: get_config_default("tensor_parallel")
)
gpu_split_auto: Optional[bool] = Field(
default_factory=lambda: get_config_default("gpu_split_auto")
)
autosplit_reserve: Optional[List[float]] = Field(
default_factory=lambda: get_config_default("autosplit_reserve")
)
tensor_parallel: Optional[bool] = None
gpu_split_auto: Optional[bool] = None
autosplit_reserve: Optional[List[float]] = None
gpu_split: Optional[List[float]] = Field(
default_factory=lambda: get_config_default("gpu_split"),
default=None,
examples=[[24.0, 20.0]],
)
rope_scale: Optional[float] = Field(
description="Automatically pulled from the model's config if not present",
default_factory=lambda: get_config_default("rope_scale"),
default=None,
examples=[1.0],
)
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
description='Automatically calculated if set to "auto"',
default_factory=lambda: get_config_default("rope_alpha"),
default=None,
examples=[1.0],
)
cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default("cache_mode")
)
chunk_size: Optional[int] = Field(
default_factory=lambda: get_config_default("chunk_size")
)
prompt_template: Optional[str] = Field(
default_factory=lambda: get_config_default("prompt_template")
)
num_experts_per_token: Optional[int] = Field(
default_factory=lambda: get_config_default("num_experts_per_token")
)
fasttensors: Optional[bool] = Field(
default_factory=lambda: get_config_default("fasttensors")
)
cache_mode: Optional[str] = None
chunk_size: Optional[int] = None
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
fasttensors: Optional[bool] = None
# Non-config arguments
draft: Optional[DraftModelLoadRequest] = None
@ -142,9 +116,11 @@ class ModelLoadRequest(BaseModel):
class EmbeddingModelLoadRequest(BaseModel):
name: str
# Set default from the config
embeddings_device: Optional[str] = Field(
default_factory=lambda: get_config_default(
"embeddings_device", model_type="embedding"
default_factory=lambda: unwrap(
config.embeddings.get("embeddings_device"), "cpu"
)
)

View file

@ -50,7 +50,7 @@ async def entrypoint_async():
port = fallback_port
# Initialize auth keys
load_auth_keys(config.network.disable_auth)
await load_auth_keys(config.network.disable_auth)
gen_logging.broadcast_status()
@ -58,7 +58,7 @@ async def entrypoint_async():
sampling_override_preset = config.sampling.override_preset
if sampling_override_preset:
try:
sampling.overrides_from_file(sampling_override_preset)
await sampling.overrides_from_file(sampling_override_preset)
except FileNotFoundError as e:
logger.warning(str(e))
@ -111,7 +111,7 @@ def entrypoint(arguments: Optional[dict] = None):
arguments = convert_args_to_dict(parser.parse_args(), parser)
# load config
config.load_config(arguments)
config.load(arguments)
if do_export_openapi:
openapi_json = export_openapi()

View file

@ -68,12 +68,12 @@ cu121 = [
"torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
@ -95,12 +95,12 @@ cu118 = [
"torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
@ -119,9 +119,9 @@ amd = [
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
# Exl2
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
]
# MARK: Ruff options

View file

@ -97,6 +97,24 @@ penalty_range:
override: -1
force: false
# MARK: DRY
dry_multiplier:
override: 0.0
force: false
dry_base:
override: 0.0
force: false
dry_allowed_length:
override: 0
force: false
dry_range:
override: 0
force: false
dry_sequence_breakers:
override: []
force: false
additive: false
# MARK: Mirostat
mirostat_mode:
override: 0

View file

@ -234,15 +234,16 @@ if __name__ == "__main__":
if first_run:
start_options["first_run_done"] = True
# Save start options
with open("start_options.json", "w") as start_file:
start_file.write(json.dumps(start_options))
# Save start options
with open("start_options.json", "w") as start_file:
start_file.write(json.dumps(start_options))
print(
"Successfully wrote your start script options to `start_options.json`. \n"
"If something goes wrong, editing or deleting the file "
"will reinstall TabbyAPI as a first-time user."
)
print(
"Successfully wrote your start script options to "
"`start_options.json`. \n"
"If something goes wrong, editing or deleting the file "
"will reinstall TabbyAPI as a first-time user."
)
# Import entrypoint after installing all requirements
try:

View file

@ -1,5 +1,5 @@
{# Metadata #}
{% set stop_strings = ["### Instruction:", "### Input:", "### Response:"] %}
{%- set stop_strings = ["### Instruction:", "### Input:", "### Response:"] -%}
{# Template #}
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}

View file

@ -1,5 +1,5 @@
{# Metadata #}
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%}
{# Template #}
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}

View file

@ -1,8 +1,8 @@
{# Metadata #}
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
{% set message_roles = ['system', 'user', 'assistant', 'tool'] %}
{% set tool_start = "<|tool_start|>" %}
{% set tool_end = "<|tool_end|>" %}
{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%}
{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%}
{%- set tool_start = "<|tool_start|>" -%}
{%- set tool_end = "<|tool_end|>" -%}
{%- set start_header = "<|start_header_id|>" -%}
{%- set end_header = "<|end_header_id|>\n" -%}