Merge remote-tracking branch 'upstream/main' into HEAD
This commit is contained in:
commit
e8fcecd56a
28 changed files with 386 additions and 171 deletions
20
.dockerignore
Normal file
20
.dockerignore
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
.ruff_cache/
|
||||
**/__pycache__/
|
||||
venv
|
||||
.git
|
||||
.gitignore
|
||||
.github
|
||||
|
||||
# Ignore specific application files
|
||||
models/
|
||||
loras/
|
||||
config.yml
|
||||
config_sample.yml
|
||||
api_tokens.yml
|
||||
api_tokens_sample.yml
|
||||
*.bat
|
||||
*.sh
|
||||
update_scripts
|
||||
readme.md
|
||||
colab
|
||||
start.py
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -192,7 +192,11 @@ templates/*
|
|||
!templates/place_your_templates_here.txt
|
||||
!templates/alpaca.jinja
|
||||
!templates/chatml.jinja
|
||||
!templates/chatml_with_headers_tool_calling.jinja
|
||||
|
||||
# Tool calling templates folder
|
||||
templates/tool_calls/*
|
||||
!templates/tool_calls
|
||||
!templates/tool_calls/chatml_with_headers.jinja
|
||||
|
||||
# Sampler overrides folder
|
||||
sampler_overrides/*
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
import aiofiles
|
||||
import asyncio
|
||||
import gc
|
||||
import math
|
||||
|
|
@ -7,6 +8,7 @@ import pathlib
|
|||
import traceback
|
||||
import torch
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
|
|
@ -105,13 +107,17 @@ class ExllamaV2Container:
|
|||
load_lock: asyncio.Lock = asyncio.Lock()
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
"""
|
||||
Primary initializer for model container.
|
||||
Primary asynchronous initializer for model container.
|
||||
|
||||
Kwargs are located in config_sample.yml
|
||||
"""
|
||||
|
||||
# Create a new instance as a "fake self"
|
||||
self = cls()
|
||||
|
||||
self.quiet = quiet
|
||||
|
||||
# Initialize config
|
||||
|
|
@ -154,13 +160,13 @@ class ExllamaV2Container:
|
|||
self.draft_config.prepare()
|
||||
|
||||
# Create the hf_config
|
||||
self.hf_config = HuggingFaceConfig.from_file(model_directory)
|
||||
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
self.generation_config = await GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -170,7 +176,7 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Apply a model's config overrides while respecting user settings
|
||||
kwargs = self.set_model_overrides(**kwargs)
|
||||
kwargs = await self.set_model_overrides(**kwargs)
|
||||
|
||||
# MARK: User configuration
|
||||
|
||||
|
|
@ -319,7 +325,7 @@ class ExllamaV2Container:
|
|||
self.cache_size = self.config.max_seq_len
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
self.prompt_template = await self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
|
|
@ -372,7 +378,10 @@ class ExllamaV2Container:
|
|||
self.draft_config.max_input_len = chunk_size
|
||||
self.draft_config.max_attention_size = chunk_size**2
|
||||
|
||||
def set_model_overrides(self, **kwargs):
|
||||
# Return the created instance
|
||||
return self
|
||||
|
||||
async def set_model_overrides(self, **kwargs):
|
||||
"""Sets overrides from a model folder's config yaml."""
|
||||
|
||||
override_config_path = self.model_dir / "tabby_config.yml"
|
||||
|
|
@ -380,8 +389,11 @@ class ExllamaV2Container:
|
|||
if not override_config_path.exists():
|
||||
return kwargs
|
||||
|
||||
with open(override_config_path, "r", encoding="utf8") as override_config_file:
|
||||
override_args = unwrap(yaml.safe_load(override_config_file), {})
|
||||
async with aiofiles.open(
|
||||
override_config_path, "r", encoding="utf8"
|
||||
) as override_config_file:
|
||||
contents = await override_config_file.read()
|
||||
override_args = unwrap(yaml.safe_load(contents), {})
|
||||
|
||||
# Merge draft overrides beforehand
|
||||
draft_override_args = unwrap(override_args.get("draft"), {})
|
||||
|
|
@ -392,7 +404,7 @@ class ExllamaV2Container:
|
|||
merged_kwargs = {**override_args, **kwargs}
|
||||
return merged_kwargs
|
||||
|
||||
def find_prompt_template(self, prompt_template_name, model_directory):
|
||||
async def find_prompt_template(self, prompt_template_name, model_directory):
|
||||
"""Tries to find a prompt template using various methods."""
|
||||
|
||||
logger.info("Attempting to load a prompt template if present.")
|
||||
|
|
@ -400,26 +412,37 @@ class ExllamaV2Container:
|
|||
find_template_functions = [
|
||||
lambda: PromptTemplate.from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
key="chat_template",
|
||||
),
|
||||
lambda: PromptTemplate.from_file(find_template_from_model(model_directory)),
|
||||
]
|
||||
|
||||
# Find the template in the model directory if it exists
|
||||
model_dir_template_path = (
|
||||
pathlib.Path(self.config.model_dir) / "tabby_template.jinja"
|
||||
)
|
||||
if model_dir_template_path.exists():
|
||||
find_template_functions[:0] = [
|
||||
lambda: PromptTemplate.from_file(model_dir_template_path)
|
||||
]
|
||||
|
||||
# Add lookup from prompt template name if provided
|
||||
if prompt_template_name:
|
||||
find_template_functions[:0] = [
|
||||
lambda: PromptTemplate.from_file(prompt_template_name),
|
||||
lambda: PromptTemplate.from_file(
|
||||
pathlib.Path("templates") / prompt_template_name
|
||||
),
|
||||
lambda: PromptTemplate.from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
prompt_template_name,
|
||||
key="chat_template",
|
||||
name=prompt_template_name,
|
||||
),
|
||||
]
|
||||
|
||||
# Continue on exception since functions are tried as they fail
|
||||
for template_func in find_template_functions:
|
||||
try:
|
||||
prompt_template = template_func()
|
||||
prompt_template = await template_func()
|
||||
if prompt_template is not None:
|
||||
return prompt_template
|
||||
except TemplateLoadError as e:
|
||||
|
|
@ -944,6 +967,14 @@ class ExllamaV2Container:
|
|||
Meant for dev wheels!
|
||||
"""
|
||||
|
||||
if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "dry_multiplier"
|
||||
):
|
||||
logger.warning(
|
||||
"DRY sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
async def generate_gen(
|
||||
|
|
@ -1035,6 +1066,7 @@ class ExllamaV2Container:
|
|||
"Please use an ampere (30 series) or higher GPU for CFG support."
|
||||
)
|
||||
|
||||
# Penalties
|
||||
gen_settings.token_repetition_penalty = unwrap(
|
||||
kwargs.get("repetition_penalty"), 1.0
|
||||
)
|
||||
|
|
@ -1070,6 +1102,32 @@ class ExllamaV2Container:
|
|||
kwargs.get("repetition_decay"), fallback_decay, 0
|
||||
)
|
||||
|
||||
# DRY options
|
||||
dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 0.0)
|
||||
|
||||
# < 0 = disabled
|
||||
if dry_multiplier > 0:
|
||||
gen_settings.dry_multiplier = dry_multiplier
|
||||
|
||||
# TODO: Maybe set the "sane" defaults instead?
|
||||
gen_settings.dry_allowed_length = unwrap(
|
||||
kwargs.get("dry_allowed_length"), 0
|
||||
)
|
||||
gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 0.0)
|
||||
|
||||
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
|
||||
# Use max_seq_len as the fallback to stay consistent
|
||||
gen_settings.dry_range = unwrap(
|
||||
kwargs.get("dry_range"), self.config.max_seq_len
|
||||
)
|
||||
|
||||
# Tokenize sequence breakers
|
||||
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers")
|
||||
if dry_sequence_breakers_json:
|
||||
gen_settings.dry_sequence_breakers = {
|
||||
self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json
|
||||
}
|
||||
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
||||
|
|
@ -1130,7 +1188,8 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Store the gen settings for logging purposes
|
||||
gen_settings_log_dict = vars(gen_settings)
|
||||
# Deepcopy to save a snapshot of vars
|
||||
gen_settings_log_dict = deepcopy(vars(gen_settings))
|
||||
|
||||
# Set banned tokens
|
||||
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from loguru import logger
|
|||
def check_exllama_version():
|
||||
"""Verifies the exllama version"""
|
||||
|
||||
required_version = version.parse("0.1.9")
|
||||
required_version = version.parse("0.2.1")
|
||||
current_version = version.parse(package_version("exllamav2").split("+")[0])
|
||||
|
||||
unsupported_message = (
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ from typing import List, Optional
|
|||
from common.utils import unwrap
|
||||
|
||||
# Conditionally import infinity to sidestep its logger
|
||||
# TODO: Make this prettier
|
||||
has_infinity_emb: bool = False
|
||||
try:
|
||||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||
|
||||
has_infinity_emb = True
|
||||
except ImportError:
|
||||
has_infinity_emb = False
|
||||
pass
|
||||
|
||||
|
||||
class InfinityContainer:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
|
|||
application, it should be fine.
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import secrets
|
||||
import yaml
|
||||
from fastapi import Header, HTTPException, Request
|
||||
|
|
@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None
|
|||
DISABLE_AUTH: bool = False
|
||||
|
||||
|
||||
def load_auth_keys(disable_from_config: bool):
|
||||
async def load_auth_keys(disable_from_config: bool):
|
||||
"""Load the authentication keys from api_tokens.yml. If the file does not
|
||||
exist, generate new keys and save them to api_tokens.yml."""
|
||||
global AUTH_KEYS
|
||||
|
|
@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool):
|
|||
return
|
||||
|
||||
try:
|
||||
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
auth_keys_dict = yaml.safe_load(auth_file)
|
||||
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
contents = await auth_file.read()
|
||||
auth_keys_dict = yaml.safe_load(contents)
|
||||
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
|
||||
except FileNotFoundError:
|
||||
new_auth_keys = AuthKeys(
|
||||
|
|
@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool):
|
|||
)
|
||||
AUTH_KEYS = new_auth_keys
|
||||
|
||||
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
|
||||
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
new_auth_yaml = yaml.safe_dump(
|
||||
AUTH_KEYS.model_dump(), default_flow_style=False
|
||||
)
|
||||
await auth_file.write(new_auth_yaml)
|
||||
|
||||
logger.info(
|
||||
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from typing import Optional
|
|||
from common.logger import get_loading_progress_bar
|
||||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.utils import do_export_openapi
|
||||
|
||||
if not do_export_openapi:
|
||||
|
|
@ -67,7 +66,11 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
|
||||
# Merge with config defaults
|
||||
kwargs = {**config.model_defaults, **kwargs}
|
||||
|
||||
# Create a new container
|
||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
load_status = container.load_gen(load_progress, **kwargs)
|
||||
|
|
@ -149,25 +152,6 @@ async def unload_embedding_model():
|
|||
embeddings_container = None
|
||||
|
||||
|
||||
# FIXME: Maybe make this a one-time function instead of a dynamic default
|
||||
def get_config_default(key: str, model_type: str = "model"):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
default_keys = unwrap(config.model.use_as_default, [])
|
||||
|
||||
# Add extra keys to defaults
|
||||
default_keys.append("embeddings_device")
|
||||
|
||||
if key in default_keys:
|
||||
# Is this a draft model load parameter?
|
||||
if model_type == "draft":
|
||||
return config.draft_model.get(key)
|
||||
elif model_type == "embedding":
|
||||
return config.embeddings.get(key)
|
||||
else:
|
||||
return config.model.get(key)
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
"""Common functions for sampling parameters"""
|
||||
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
import yaml
|
||||
from copy import deepcopy
|
||||
|
|
@ -140,6 +142,28 @@ class BaseSamplerRequest(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
|
||||
)
|
||||
|
||||
dry_multiplier: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0)
|
||||
)
|
||||
|
||||
dry_base: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_base", 0.0)
|
||||
)
|
||||
|
||||
dry_allowed_length: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0)
|
||||
)
|
||||
|
||||
dry_range: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_range", 0),
|
||||
alias=AliasChoices("dry_range", "dry_penalty_last_n"),
|
||||
description=("Aliases: dry_penalty_last_n"),
|
||||
)
|
||||
|
||||
dry_sequence_breakers: Optional[Union[str, List[str]]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
|
||||
)
|
||||
|
||||
mirostat_mode: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
|
||||
)
|
||||
|
|
@ -305,6 +329,17 @@ class BaseSamplerRequest(BaseModel):
|
|||
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
|
||||
]
|
||||
|
||||
# Convert sequence breakers into an array of strings
|
||||
# NOTE: This sampler sucks to parse.
|
||||
if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str):
|
||||
if not self.dry_sequence_breakers.startswith("["):
|
||||
self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]"
|
||||
|
||||
try:
|
||||
self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers)
|
||||
except Exception:
|
||||
self.dry_sequence_breakers = []
|
||||
|
||||
gen_params = {
|
||||
"max_tokens": self.max_tokens,
|
||||
"min_tokens": self.min_tokens,
|
||||
|
|
@ -335,6 +370,11 @@ class BaseSamplerRequest(BaseModel):
|
|||
"presence_penalty": self.presence_penalty,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"penalty_range": self.penalty_range,
|
||||
"dry_multiplier": self.dry_multiplier,
|
||||
"dry_base": self.dry_base,
|
||||
"dry_allowed_length": self.dry_allowed_length,
|
||||
"dry_sequence_breakers": self.dry_sequence_breakers,
|
||||
"dry_range": self.dry_range,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": self.mirostat_mode == 2,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
|
|
@ -368,14 +408,15 @@ def overrides_from_dict(new_overrides: dict):
|
|||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
|
||||
def overrides_from_file(preset_name: str):
|
||||
async def overrides_from_file(preset_name: str):
|
||||
"""Fetches an override preset from a file"""
|
||||
|
||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||
if preset_path.exists():
|
||||
overrides_container.selected_preset = preset_path.stem
|
||||
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
preset = yaml.safe_load(raw_preset)
|
||||
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
contents = await raw_preset.read()
|
||||
preset = yaml.safe_load(contents)
|
||||
overrides_from_dict(preset)
|
||||
|
||||
logger.info("Applied sampler overrides from file.")
|
||||
|
|
|
|||
|
|
@ -10,8 +10,13 @@ import common.config_models
|
|||
|
||||
|
||||
class TabbyConfig(tabby_config_model):
|
||||
def load_config(self, arguments: Optional[dict] = None):
|
||||
"""load the global application config"""
|
||||
|
||||
# Persistent defaults
|
||||
# TODO: make this pydantic?
|
||||
model_defaults: dict = {}
|
||||
|
||||
def load(self, arguments: Optional[dict] = None):
|
||||
"""Synchronously loads the global application config"""
|
||||
|
||||
# config is applied in order of items in the list
|
||||
configs = [
|
||||
|
|
@ -28,6 +33,17 @@ class TabbyConfig(tabby_config_model):
|
|||
|
||||
setattr(self, field, model.parse_obj(value))
|
||||
|
||||
# Set model defaults dict once to prevent on-demand reconstruction
|
||||
# TODO: clean this up a bit
|
||||
for field in self.model.use_as_default:
|
||||
if hasattr(self.model, field):
|
||||
self.model_defaults[field] = getattr(config.model, field)
|
||||
elif hasattr(self.draft_model, field):
|
||||
self.model_defaults[field] = getattr(config.draft_model, field)
|
||||
else:
|
||||
# TODO: show an error
|
||||
pass
|
||||
|
||||
def _from_file(self, config_path: pathlib.Path):
|
||||
"""loads config from a given file path"""
|
||||
|
||||
|
|
@ -53,7 +69,7 @@ class TabbyConfig(tabby_config_model):
|
|||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
if config_override:
|
||||
logger.info("Config file override detected in args.")
|
||||
config = self.from_file(pathlib.Path(config_override))
|
||||
config = self._from_file(pathlib.Path(config_override))
|
||||
return config # Return early if loading from file
|
||||
|
||||
for key in tabby_config_model.model_fields.keys():
|
||||
|
|
@ -85,5 +101,5 @@ class TabbyConfig(tabby_config_model):
|
|||
return config
|
||||
|
||||
|
||||
# Create an empty instance of the shared var to make sure nothing breaks
|
||||
# Create an empty instance of the config class
|
||||
config: TabbyConfig = TabbyConfig()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from importlib.metadata import version as package_version
|
||||
from typing import List, Optional
|
||||
from jinja2 import Template, TemplateError
|
||||
from jinja2.ext import loopcontrols
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
|
|
@ -32,7 +34,10 @@ class PromptTemplate:
|
|||
raw_template: str
|
||||
template: Template
|
||||
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True, lstrip_blocks=True, enable_async=True
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
enable_async=True,
|
||||
extensions=[loopcontrols],
|
||||
)
|
||||
metadata: Optional[TemplateMetadata] = None
|
||||
|
||||
|
|
@ -106,32 +111,42 @@ class PromptTemplate:
|
|||
self.template = self.compile(raw_template)
|
||||
|
||||
@classmethod
|
||||
def from_file(self, prompt_template_name: str):
|
||||
async def from_file(cls, template_path: pathlib.Path):
|
||||
"""Get a template from a jinja file."""
|
||||
|
||||
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||
# Add the jinja extension if it isn't provided
|
||||
if template_path.suffix.endswith(".jinja"):
|
||||
template_name = template_path.name.split(".jinja")[0]
|
||||
else:
|
||||
template_name = template_path.name
|
||||
template_path = template_path.with_suffix(".jinja")
|
||||
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf8") as raw_template_stream:
|
||||
return PromptTemplate(
|
||||
name=prompt_template_name,
|
||||
raw_template=raw_template_stream.read(),
|
||||
async with aiofiles.open(
|
||||
template_path, "r", encoding="utf8"
|
||||
) as raw_template_stream:
|
||||
contents = await raw_template_stream.read()
|
||||
return cls(
|
||||
name=template_name,
|
||||
raw_template=contents,
|
||||
)
|
||||
else:
|
||||
# Let the user know if the template file isn't found
|
||||
raise TemplateLoadError(
|
||||
f'Chat template "{prompt_template_name}" not found in files.'
|
||||
f'Chat template "{template_name}" not found in files.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_model_json(
|
||||
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
async def from_model_json(
|
||||
cls, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if not json_path.exists():
|
||||
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
||||
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
|
||||
contents = await config_file.read()
|
||||
model_config = json.loads(contents)
|
||||
chat_template = model_config.get(key)
|
||||
|
||||
if not chat_template:
|
||||
|
|
@ -162,7 +177,7 @@ class PromptTemplate:
|
|||
)
|
||||
else:
|
||||
# Can safely assume the chat template is the old style
|
||||
return PromptTemplate(
|
||||
return cls(
|
||||
name="from_tokenizer_config",
|
||||
raw_template=chat_template,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from typing import List, Optional, Union
|
||||
|
|
@ -15,15 +16,16 @@ class GenerationConfig(BaseModel):
|
|||
bad_words_ids: Optional[List[List[int]]] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
with open(
|
||||
async with aiofiles.open(
|
||||
generation_config_path, "r", encoding="utf8"
|
||||
) as generation_config_json:
|
||||
generation_config_dict = json.load(generation_config_json)
|
||||
return self.model_validate(generation_config_dict)
|
||||
contents = await generation_config_json.read()
|
||||
generation_config_dict = json.loads(contents)
|
||||
return cls.model_validate(generation_config_dict)
|
||||
|
||||
def eos_tokens(self):
|
||||
"""Wrapper method to fetch EOS tokens."""
|
||||
|
|
@ -43,13 +45,16 @@ class HuggingFaceConfig(BaseModel):
|
|||
badwordsids: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(cls, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
hf_config_path = model_directory / "config.json"
|
||||
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
|
||||
hf_config_dict = json.load(hf_config_json)
|
||||
return self.model_validate(hf_config_dict)
|
||||
async with aiofiles.open(
|
||||
hf_config_path, "r", encoding="utf8"
|
||||
) as hf_config_json:
|
||||
contents = await hf_config_json.read()
|
||||
hf_config_dict = json.loads(contents)
|
||||
return cls.model_validate(hf_config_dict)
|
||||
|
||||
def get_badwordsids(self):
|
||||
"""Wrapper method to fetch badwordsids."""
|
||||
|
|
|
|||
|
|
@ -83,6 +83,9 @@ model:
|
|||
# Enable this if the program is looking for a specific OAI model
|
||||
#use_dummy_models: False
|
||||
|
||||
# Allow direct loading of models from a completion or chat completion request
|
||||
inline_model_loading: False
|
||||
|
||||
# An initial model to load. Make sure the model is located in the model directory!
|
||||
# A model can be loaded later via the API.
|
||||
# REQUIRED: This must be filled out to load a model on startup!
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
models/
|
||||
loras/
|
||||
.ruff_cache/
|
||||
**/__pycache__/
|
||||
config.yml
|
||||
api_tokens.yml
|
||||
|
|
@ -22,6 +22,8 @@ COPY pyproject.toml .
|
|||
# Install packages specified in pyproject.toml cu121
|
||||
RUN pip3 install --no-cache-dir .[cu121]
|
||||
|
||||
RUN rm pyproject.toml
|
||||
|
||||
# Copy the current directory contents into the container
|
||||
COPY . .
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
version: '3.8'
|
||||
services:
|
||||
tabbyapi:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: ./docker/Dockerfile
|
||||
args:
|
||||
- DO_PULL=true
|
||||
# Uncomment this to build a docker image from source
|
||||
#build:
|
||||
# context: ..
|
||||
# dockerfile: ./docker/Dockerfile
|
||||
|
||||
# Comment this to build a docker image from source
|
||||
image: ghcr.io/theroyallab/tabbyapi:latest
|
||||
ports:
|
||||
- "5000:5000"
|
||||
healthcheck:
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ async def get_version():
|
|||
async def get_extra_version():
|
||||
"""Impersonate Koboldcpp."""
|
||||
|
||||
return {"result": "KoboldCpp", "version": "1.61"}
|
||||
return {"result": "KoboldCpp", "version": "1.71"}
|
||||
|
||||
|
||||
@kai_router.get("/config/soft_prompts_list")
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ from endpoints.OAI.utils.chat_completion import (
|
|||
)
|
||||
from endpoints.OAI.utils.completion import (
|
||||
generate_completion,
|
||||
load_inline_model,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.embeddings import get_embeddings
|
||||
|
|
@ -42,7 +43,7 @@ def setup():
|
|||
# Completions endpoint
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def completion_request(
|
||||
request: Request, data: CompletionRequest
|
||||
|
|
@ -53,6 +54,18 @@ async def completion_request(
|
|||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if data.model:
|
||||
inline_load_task = asyncio.create_task(load_inline_model(data.model, request))
|
||||
|
||||
await run_with_request_disconnect(
|
||||
request,
|
||||
inline_load_task,
|
||||
disconnect_message=f"Model switch for generation {request.state.id} "
|
||||
+ "cancelled by user.",
|
||||
)
|
||||
else:
|
||||
await check_model_container()
|
||||
|
||||
model_path = model.container.model_dir
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
|
|
@ -85,7 +98,7 @@ async def completion_request(
|
|||
# Chat completions endpoint
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def chat_completion_request(
|
||||
request: Request, data: ChatCompletionRequest
|
||||
|
|
@ -96,6 +109,11 @@ async def chat_completion_request(
|
|||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if data.model:
|
||||
await load_inline_model(data.model, request)
|
||||
else:
|
||||
await check_model_container()
|
||||
|
||||
if model.container.prompt_template is None:
|
||||
error_message = handle_request_error(
|
||||
"Chat completions are disabled because a prompt template is not set.",
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
add_generation_prompt: Optional[bool] = True
|
||||
template_vars: Optional[dict] = {}
|
||||
response_prefix: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
# tools is follows the format OAI schema, functions is more flexible
|
||||
# both are available in the chat template.
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
"""Completion utilities for OAI server."""
|
||||
"""
|
||||
Completion utilities for OAI server.
|
||||
|
||||
Also serves as a common module for completions and chat completions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
|
|
@ -10,12 +14,14 @@ from typing import List, Union
|
|||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.auth import get_key_permission
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
request_disconnect_loop,
|
||||
)
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.completion import (
|
||||
CompletionRequest,
|
||||
|
|
@ -103,6 +109,50 @@ async def _stream_collector(
|
|||
await gen_queue.put(e)
|
||||
|
||||
|
||||
async def load_inline_model(model_name: str, request: Request):
|
||||
"""Load a model from the data.model parameter"""
|
||||
|
||||
# Return if the model container already exists and the model is fully loaded
|
||||
if (
|
||||
model.container
|
||||
and model.container.model_dir.name == model_name
|
||||
and model.container.model_loaded
|
||||
):
|
||||
return
|
||||
|
||||
# Inline model loading isn't enabled or the user isn't an admin
|
||||
if not get_key_permission(request) == "admin":
|
||||
error_message = handle_request_error(
|
||||
f"Unable to switch model to {model_name} because "
|
||||
+ "an admin key isn't provided",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(401, error_message)
|
||||
|
||||
if not unwrap(config.model.get("inline_model_loading"), False):
|
||||
logger.warning(
|
||||
f"Unable to switch model to {model_name} because "
|
||||
'"inline_model_loading" is not True in config.yml.'
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
|
||||
model_path = model_path / model_name
|
||||
|
||||
# Model path doesn't exist
|
||||
if not model_path.exists():
|
||||
logger.warning(
|
||||
f"Could not find model path {str(model_path)}. Skipping inline model load."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Load the model
|
||||
await model.load_model(model_path)
|
||||
|
||||
|
||||
async def stream_generate_completion(
|
||||
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ async def list_draft_models(request: Request) -> ModelList:
|
|||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
else:
|
||||
models = await get_current_model_list(is_draft=True)
|
||||
models = await get_current_model_list(model_type="draft")
|
||||
|
||||
return models
|
||||
|
||||
|
|
@ -441,7 +441,8 @@ async def switch_template(data: TemplateSwitchRequest):
|
|||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
model.container.prompt_template = PromptTemplate.from_file(data.name)
|
||||
template_path = pathlib.Path("templates") / data.name
|
||||
model.container.prompt_template = await PromptTemplate.from_file(template_path)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
|
|
@ -490,7 +491,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||
|
||||
if data.preset:
|
||||
try:
|
||||
sampling.overrides_from_file(data.preset)
|
||||
await sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"Sampler override preset with name {data.preset} does not exist. "
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@ from time import time
|
|||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from common.config_models import logging_config_model
|
||||
from common.model import get_config_default
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
class ModelCardParameters(BaseModel):
|
||||
|
|
@ -51,23 +52,13 @@ class DraftModelLoadRequest(BaseModel):
|
|||
draft_model_name: str
|
||||
|
||||
# Config arguments
|
||||
draft_rope_scale: Optional[float] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_scale", model_type="draft"
|
||||
)
|
||||
)
|
||||
draft_rope_scale: Optional[float] = None
|
||||
draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||
description='Automatically calculated if set to "auto"',
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_rope_alpha", model_type="draft"
|
||||
),
|
||||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
draft_cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"draft_cache_mode", model_type="draft"
|
||||
)
|
||||
)
|
||||
draft_cache_mode: Optional[str] = None
|
||||
|
||||
|
||||
class ModelLoadRequest(BaseModel):
|
||||
|
|
@ -78,62 +69,45 @@ class ModelLoadRequest(BaseModel):
|
|||
|
||||
# Config arguments
|
||||
|
||||
# Max seq len is fetched from config.json of the model by default
|
||||
max_seq_len: Optional[int] = Field(
|
||||
description="Leave this blank to use the model's base sequence length",
|
||||
default_factory=lambda: get_config_default("max_seq_len"),
|
||||
default=None,
|
||||
examples=[4096],
|
||||
)
|
||||
override_base_seq_len: Optional[int] = Field(
|
||||
description=(
|
||||
"Overrides the model's base sequence length. " "Leave blank if unsure"
|
||||
),
|
||||
default_factory=lambda: get_config_default("override_base_seq_len"),
|
||||
default=None,
|
||||
examples=[4096],
|
||||
)
|
||||
cache_size: Optional[int] = Field(
|
||||
description=("Number in tokens, must be greater than or equal to max_seq_len"),
|
||||
default_factory=lambda: get_config_default("cache_size"),
|
||||
default=None,
|
||||
examples=[4096],
|
||||
)
|
||||
tensor_parallel: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("tensor_parallel")
|
||||
)
|
||||
gpu_split_auto: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split_auto")
|
||||
)
|
||||
autosplit_reserve: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("autosplit_reserve")
|
||||
)
|
||||
tensor_parallel: Optional[bool] = None
|
||||
gpu_split_auto: Optional[bool] = None
|
||||
autosplit_reserve: Optional[List[float]] = None
|
||||
gpu_split: Optional[List[float]] = Field(
|
||||
default_factory=lambda: get_config_default("gpu_split"),
|
||||
default=None,
|
||||
examples=[[24.0, 20.0]],
|
||||
)
|
||||
rope_scale: Optional[float] = Field(
|
||||
description="Automatically pulled from the model's config if not present",
|
||||
default_factory=lambda: get_config_default("rope_scale"),
|
||||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
rope_alpha: Optional[Union[float, Literal["auto"]]] = Field(
|
||||
description='Automatically calculated if set to "auto"',
|
||||
default_factory=lambda: get_config_default("rope_alpha"),
|
||||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
cache_mode: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("cache_mode")
|
||||
)
|
||||
chunk_size: Optional[int] = Field(
|
||||
default_factory=lambda: get_config_default("chunk_size")
|
||||
)
|
||||
prompt_template: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default("prompt_template")
|
||||
)
|
||||
num_experts_per_token: Optional[int] = Field(
|
||||
default_factory=lambda: get_config_default("num_experts_per_token")
|
||||
)
|
||||
fasttensors: Optional[bool] = Field(
|
||||
default_factory=lambda: get_config_default("fasttensors")
|
||||
)
|
||||
cache_mode: Optional[str] = None
|
||||
chunk_size: Optional[int] = None
|
||||
prompt_template: Optional[str] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
fasttensors: Optional[bool] = None
|
||||
|
||||
# Non-config arguments
|
||||
draft: Optional[DraftModelLoadRequest] = None
|
||||
|
|
@ -142,9 +116,11 @@ class ModelLoadRequest(BaseModel):
|
|||
|
||||
class EmbeddingModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
# Set default from the config
|
||||
embeddings_device: Optional[str] = Field(
|
||||
default_factory=lambda: get_config_default(
|
||||
"embeddings_device", model_type="embedding"
|
||||
default_factory=lambda: unwrap(
|
||||
config.embeddings.get("embeddings_device"), "cpu"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
|||
6
main.py
6
main.py
|
|
@ -50,7 +50,7 @@ async def entrypoint_async():
|
|||
port = fallback_port
|
||||
|
||||
# Initialize auth keys
|
||||
load_auth_keys(config.network.disable_auth)
|
||||
await load_auth_keys(config.network.disable_auth)
|
||||
|
||||
gen_logging.broadcast_status()
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ async def entrypoint_async():
|
|||
sampling_override_preset = config.sampling.override_preset
|
||||
if sampling_override_preset:
|
||||
try:
|
||||
sampling.overrides_from_file(sampling_override_preset)
|
||||
await sampling.overrides_from_file(sampling_override_preset)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ def entrypoint(arguments: Optional[dict] = None):
|
|||
arguments = convert_args_to_dict(parser.parse_args(), parser)
|
||||
|
||||
# load config
|
||||
config.load_config(arguments)
|
||||
config.load(arguments)
|
||||
|
||||
if do_export_openapi:
|
||||
openapi_json = export_openapi()
|
||||
|
|
|
|||
|
|
@ -68,12 +68,12 @@ cu121 = [
|
|||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
|
||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
|
|
@ -95,12 +95,12 @@ cu118 = [
|
|||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
|
|
@ -119,9 +119,9 @@ amd = [
|
|||
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
]
|
||||
|
||||
# MARK: Ruff options
|
||||
|
|
|
|||
|
|
@ -97,6 +97,24 @@ penalty_range:
|
|||
override: -1
|
||||
force: false
|
||||
|
||||
# MARK: DRY
|
||||
dry_multiplier:
|
||||
override: 0.0
|
||||
force: false
|
||||
dry_base:
|
||||
override: 0.0
|
||||
force: false
|
||||
dry_allowed_length:
|
||||
override: 0
|
||||
force: false
|
||||
dry_range:
|
||||
override: 0
|
||||
force: false
|
||||
dry_sequence_breakers:
|
||||
override: []
|
||||
force: false
|
||||
additive: false
|
||||
|
||||
# MARK: Mirostat
|
||||
mirostat_mode:
|
||||
override: 0
|
||||
|
|
|
|||
17
start.py
17
start.py
|
|
@ -234,15 +234,16 @@ if __name__ == "__main__":
|
|||
if first_run:
|
||||
start_options["first_run_done"] = True
|
||||
|
||||
# Save start options
|
||||
with open("start_options.json", "w") as start_file:
|
||||
start_file.write(json.dumps(start_options))
|
||||
# Save start options
|
||||
with open("start_options.json", "w") as start_file:
|
||||
start_file.write(json.dumps(start_options))
|
||||
|
||||
print(
|
||||
"Successfully wrote your start script options to `start_options.json`. \n"
|
||||
"If something goes wrong, editing or deleting the file "
|
||||
"will reinstall TabbyAPI as a first-time user."
|
||||
)
|
||||
print(
|
||||
"Successfully wrote your start script options to "
|
||||
"`start_options.json`. \n"
|
||||
"If something goes wrong, editing or deleting the file "
|
||||
"will reinstall TabbyAPI as a first-time user."
|
||||
)
|
||||
|
||||
# Import entrypoint after installing all requirements
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
{# Metadata #}
|
||||
{% set stop_strings = ["### Instruction:", "### Input:", "### Response:"] %}
|
||||
{%- set stop_strings = ["### Instruction:", "### Input:", "### Response:"] -%}
|
||||
|
||||
{# Template #}
|
||||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
{# Metadata #}
|
||||
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
|
||||
{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%}
|
||||
|
||||
{# Template #}
|
||||
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
{# Metadata #}
|
||||
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
|
||||
{% set message_roles = ['system', 'user', 'assistant', 'tool'] %}
|
||||
{% set tool_start = "<|tool_start|>" %}
|
||||
{% set tool_end = "<|tool_end|>" %}
|
||||
{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%}
|
||||
{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%}
|
||||
{%- set tool_start = "<|tool_start|>" -%}
|
||||
{%- set tool_end = "<|tool_end|>" -%}
|
||||
{%- set start_header = "<|start_header_id|>" -%}
|
||||
{%- set end_header = "<|end_header_id|>\n" -%}
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue