diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..e8118ac --- /dev/null +++ b/.dockerignore @@ -0,0 +1,20 @@ +.ruff_cache/ +**/__pycache__/ +venv +.git +.gitignore +.github + +# Ignore specific application files +models/ +loras/ +config.yml +config_sample.yml +api_tokens.yml +api_tokens_sample.yml +*.bat +*.sh +update_scripts +readme.md +colab +start.py \ No newline at end of file diff --git a/.gitignore b/.gitignore index 49aa517..5e5c9ee 100644 --- a/.gitignore +++ b/.gitignore @@ -192,7 +192,11 @@ templates/* !templates/place_your_templates_here.txt !templates/alpaca.jinja !templates/chatml.jinja -!templates/chatml_with_headers_tool_calling.jinja + +# Tool calling templates folder +templates/tool_calls/* +!templates/tool_calls +!templates/tool_calls/chatml_with_headers.jinja # Sampler overrides folder sampler_overrides/* diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 7fe08db..4aedf75 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" +import aiofiles import asyncio import gc import math @@ -7,6 +8,7 @@ import pathlib import traceback import torch import uuid +from copy import deepcopy from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, @@ -105,13 +107,17 @@ class ExllamaV2Container: load_lock: asyncio.Lock = asyncio.Lock() load_condition: asyncio.Condition = asyncio.Condition() - def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): + @classmethod + async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): """ - Primary initializer for model container. + Primary asynchronous initializer for model container. Kwargs are located in config_sample.yml """ + # Create a new instance as a "fake self" + self = cls() + self.quiet = quiet # Initialize config @@ -154,13 +160,13 @@ class ExllamaV2Container: self.draft_config.prepare() # Create the hf_config - self.hf_config = HuggingFaceConfig.from_file(model_directory) + self.hf_config = await HuggingFaceConfig.from_file(model_directory) # Load generation config overrides generation_config_path = model_directory / "generation_config.json" if generation_config_path.exists(): try: - self.generation_config = GenerationConfig.from_file( + self.generation_config = await GenerationConfig.from_file( generation_config_path.parent ) except Exception: @@ -170,7 +176,7 @@ class ExllamaV2Container: ) # Apply a model's config overrides while respecting user settings - kwargs = self.set_model_overrides(**kwargs) + kwargs = await self.set_model_overrides(**kwargs) # MARK: User configuration @@ -319,7 +325,7 @@ class ExllamaV2Container: self.cache_size = self.config.max_seq_len # Try to set prompt template - self.prompt_template = self.find_prompt_template( + self.prompt_template = await self.find_prompt_template( kwargs.get("prompt_template"), model_directory ) @@ -372,7 +378,10 @@ class ExllamaV2Container: self.draft_config.max_input_len = chunk_size self.draft_config.max_attention_size = chunk_size**2 - def set_model_overrides(self, **kwargs): + # Return the created instance + return self + + async def set_model_overrides(self, **kwargs): """Sets overrides from a model folder's config yaml.""" override_config_path = self.model_dir / "tabby_config.yml" @@ -380,8 +389,11 @@ class ExllamaV2Container: if not override_config_path.exists(): return kwargs - with open(override_config_path, "r", encoding="utf8") as override_config_file: - override_args = unwrap(yaml.safe_load(override_config_file), {}) + async with aiofiles.open( + override_config_path, "r", encoding="utf8" + ) as override_config_file: + contents = await override_config_file.read() + override_args = unwrap(yaml.safe_load(contents), {}) # Merge draft overrides beforehand draft_override_args = unwrap(override_args.get("draft"), {}) @@ -392,7 +404,7 @@ class ExllamaV2Container: merged_kwargs = {**override_args, **kwargs} return merged_kwargs - def find_prompt_template(self, prompt_template_name, model_directory): + async def find_prompt_template(self, prompt_template_name, model_directory): """Tries to find a prompt template using various methods.""" logger.info("Attempting to load a prompt template if present.") @@ -400,26 +412,37 @@ class ExllamaV2Container: find_template_functions = [ lambda: PromptTemplate.from_model_json( pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - "chat_template", + key="chat_template", ), lambda: PromptTemplate.from_file(find_template_from_model(model_directory)), ] + # Find the template in the model directory if it exists + model_dir_template_path = ( + pathlib.Path(self.config.model_dir) / "tabby_template.jinja" + ) + if model_dir_template_path.exists(): + find_template_functions[:0] = [ + lambda: PromptTemplate.from_file(model_dir_template_path) + ] + # Add lookup from prompt template name if provided if prompt_template_name: find_template_functions[:0] = [ - lambda: PromptTemplate.from_file(prompt_template_name), + lambda: PromptTemplate.from_file( + pathlib.Path("templates") / prompt_template_name + ), lambda: PromptTemplate.from_model_json( pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - "chat_template", - prompt_template_name, + key="chat_template", + name=prompt_template_name, ), ] # Continue on exception since functions are tried as they fail for template_func in find_template_functions: try: - prompt_template = template_func() + prompt_template = await template_func() if prompt_template is not None: return prompt_template except TemplateLoadError as e: @@ -944,6 +967,14 @@ class ExllamaV2Container: Meant for dev wheels! """ + if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr( + ExLlamaV2Sampler.Settings, "dry_multiplier" + ): + logger.warning( + "DRY sampling is not supported by the currently " + "installed ExLlamaV2 version." + ) + return kwargs async def generate_gen( @@ -1035,6 +1066,7 @@ class ExllamaV2Container: "Please use an ampere (30 series) or higher GPU for CFG support." ) + # Penalties gen_settings.token_repetition_penalty = unwrap( kwargs.get("repetition_penalty"), 1.0 ) @@ -1070,6 +1102,32 @@ class ExllamaV2Container: kwargs.get("repetition_decay"), fallback_decay, 0 ) + # DRY options + dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 0.0) + + # < 0 = disabled + if dry_multiplier > 0: + gen_settings.dry_multiplier = dry_multiplier + + # TODO: Maybe set the "sane" defaults instead? + gen_settings.dry_allowed_length = unwrap( + kwargs.get("dry_allowed_length"), 0 + ) + gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 0.0) + + # Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range + # Use max_seq_len as the fallback to stay consistent + gen_settings.dry_range = unwrap( + kwargs.get("dry_range"), self.config.max_seq_len + ) + + # Tokenize sequence breakers + dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers") + if dry_sequence_breakers_json: + gen_settings.dry_sequence_breakers = { + self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json + } + # Initialize grammar handler grammar_handler = ExLlamaV2Grammar() @@ -1130,7 +1188,8 @@ class ExllamaV2Container: ) # Store the gen settings for logging purposes - gen_settings_log_dict = vars(gen_settings) + # Deepcopy to save a snapshot of vars + gen_settings_log_dict = deepcopy(vars(gen_settings)) # Set banned tokens banned_tokens = unwrap(kwargs.get("banned_tokens"), []) diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 5b1d042..4c192b2 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -8,7 +8,7 @@ from loguru import logger def check_exllama_version(): """Verifies the exllama version""" - required_version = version.parse("0.1.9") + required_version = version.parse("0.2.1") current_version = version.parse(package_version("exllamav2").split("+")[0]) unsupported_message = ( diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 35a4df4..c48a42c 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -7,13 +7,13 @@ from typing import List, Optional from common.utils import unwrap # Conditionally import infinity to sidestep its logger -# TODO: Make this prettier +has_infinity_emb: bool = False try: from infinity_emb import EngineArgs, AsyncEmbeddingEngine has_infinity_emb = True except ImportError: - has_infinity_emb = False + pass class InfinityContainer: diff --git a/common/auth.py b/common/auth.py index 174208d..6fcfec9 100644 --- a/common/auth.py +++ b/common/auth.py @@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local application, it should be fine. """ +import aiofiles import secrets import yaml from fastapi import Header, HTTPException, Request @@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None DISABLE_AUTH: bool = False -def load_auth_keys(disable_from_config: bool): +async def load_auth_keys(disable_from_config: bool): """Load the authentication keys from api_tokens.yml. If the file does not exist, generate new keys and save them to api_tokens.yml.""" global AUTH_KEYS @@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool): return try: - with open("api_tokens.yml", "r", encoding="utf8") as auth_file: - auth_keys_dict = yaml.safe_load(auth_file) + async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file: + contents = await auth_file.read() + auth_keys_dict = yaml.safe_load(contents) AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) except FileNotFoundError: new_auth_keys = AuthKeys( @@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool): ) AUTH_KEYS = new_auth_keys - with open("api_tokens.yml", "w", encoding="utf8") as auth_file: - yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False) + async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file: + new_auth_yaml = yaml.safe_dump( + AUTH_KEYS.model_dump(), default_flow_style=False + ) + await auth_file.write(new_auth_yaml) logger.info( f"Your API key is: {AUTH_KEYS.api_key}\n" diff --git a/common/model.py b/common/model.py index 4da1d90..5fdfc5b 100644 --- a/common/model.py +++ b/common/model.py @@ -13,7 +13,6 @@ from typing import Optional from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.tabby_config import config -from common.utils import unwrap from endpoints.utils import do_export_openapi if not do_export_openapi: @@ -67,7 +66,11 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): logger.info("Unloading existing model.") await unload_model() - container = ExllamaV2Container(model_path.resolve(), False, **kwargs) + # Merge with config defaults + kwargs = {**config.model_defaults, **kwargs} + + # Create a new container + container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) model_type = "draft" if container.draft_config else "model" load_status = container.load_gen(load_progress, **kwargs) @@ -149,25 +152,6 @@ async def unload_embedding_model(): embeddings_container = None -# FIXME: Maybe make this a one-time function instead of a dynamic default -def get_config_default(key: str, model_type: str = "model"): - """Fetches a default value from model config if allowed by the user.""" - - default_keys = unwrap(config.model.use_as_default, []) - - # Add extra keys to defaults - default_keys.append("embeddings_device") - - if key in default_keys: - # Is this a draft model load parameter? - if model_type == "draft": - return config.draft_model.get(key) - elif model_type == "embedding": - return config.embeddings.get(key) - else: - return config.model.get(key) - - async def check_model_container(): """FastAPI depends that checks if a model isn't loaded or currently loading.""" diff --git a/common/sampling.py b/common/sampling.py index 56c5b34..a7da3ca 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -1,5 +1,7 @@ """Common functions for sampling parameters""" +import aiofiles +import json import pathlib import yaml from copy import deepcopy @@ -140,6 +142,28 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("repetition_decay", 0) ) + dry_multiplier: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0) + ) + + dry_base: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("dry_base", 0.0) + ) + + dry_allowed_length: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0) + ) + + dry_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("dry_range", 0), + alias=AliasChoices("dry_range", "dry_penalty_last_n"), + description=("Aliases: dry_penalty_last_n"), + ) + + dry_sequence_breakers: Optional[Union[str, List[str]]] = Field( + default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", []) + ) + mirostat_mode: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) ) @@ -305,6 +329,17 @@ class BaseSamplerRequest(BaseModel): int(x) for x in self.allowed_tokens.split(",") if x.isdigit() ] + # Convert sequence breakers into an array of strings + # NOTE: This sampler sucks to parse. + if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str): + if not self.dry_sequence_breakers.startswith("["): + self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]" + + try: + self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers) + except Exception: + self.dry_sequence_breakers = [] + gen_params = { "max_tokens": self.max_tokens, "min_tokens": self.min_tokens, @@ -335,6 +370,11 @@ class BaseSamplerRequest(BaseModel): "presence_penalty": self.presence_penalty, "repetition_penalty": self.repetition_penalty, "penalty_range": self.penalty_range, + "dry_multiplier": self.dry_multiplier, + "dry_base": self.dry_base, + "dry_allowed_length": self.dry_allowed_length, + "dry_sequence_breakers": self.dry_sequence_breakers, + "dry_range": self.dry_range, "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, @@ -368,14 +408,15 @@ def overrides_from_dict(new_overrides: dict): raise TypeError("New sampler overrides must be a dict!") -def overrides_from_file(preset_name: str): +async def overrides_from_file(preset_name: str): """Fetches an override preset from a file""" preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") if preset_path.exists(): overrides_container.selected_preset = preset_path.stem - with open(preset_path, "r", encoding="utf8") as raw_preset: - preset = yaml.safe_load(raw_preset) + async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset: + contents = await raw_preset.read() + preset = yaml.safe_load(contents) overrides_from_dict(preset) logger.info("Applied sampler overrides from file.") diff --git a/common/tabby_config.py b/common/tabby_config.py index a379ebb..cd7cb14 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -10,8 +10,13 @@ import common.config_models class TabbyConfig(tabby_config_model): - def load_config(self, arguments: Optional[dict] = None): - """load the global application config""" + + # Persistent defaults + # TODO: make this pydantic? + model_defaults: dict = {} + + def load(self, arguments: Optional[dict] = None): + """Synchronously loads the global application config""" # config is applied in order of items in the list configs = [ @@ -28,6 +33,17 @@ class TabbyConfig(tabby_config_model): setattr(self, field, model.parse_obj(value)) + # Set model defaults dict once to prevent on-demand reconstruction + # TODO: clean this up a bit + for field in self.model.use_as_default: + if hasattr(self.model, field): + self.model_defaults[field] = getattr(config.model, field) + elif hasattr(self.draft_model, field): + self.model_defaults[field] = getattr(config.draft_model, field) + else: + # TODO: show an error + pass + def _from_file(self, config_path: pathlib.Path): """loads config from a given file path""" @@ -53,7 +69,7 @@ class TabbyConfig(tabby_config_model): config_override = unwrap(args.get("options", {}).get("config")) if config_override: logger.info("Config file override detected in args.") - config = self.from_file(pathlib.Path(config_override)) + config = self._from_file(pathlib.Path(config_override)) return config # Return early if loading from file for key in tabby_config_model.model_fields.keys(): @@ -85,5 +101,5 @@ class TabbyConfig(tabby_config_model): return config -# Create an empty instance of the shared var to make sure nothing breaks +# Create an empty instance of the config class config: TabbyConfig = TabbyConfig() diff --git a/common/templating.py b/common/templating.py index 47299ff..1200e5c 100644 --- a/common/templating.py +++ b/common/templating.py @@ -1,10 +1,12 @@ """Small replication of AutoTokenizer's chat template system for efficiency""" +import aiofiles import json import pathlib from importlib.metadata import version as package_version from typing import List, Optional from jinja2 import Template, TemplateError +from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger from packaging import version @@ -32,7 +34,10 @@ class PromptTemplate: raw_template: str template: Template environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment( - trim_blocks=True, lstrip_blocks=True, enable_async=True + trim_blocks=True, + lstrip_blocks=True, + enable_async=True, + extensions=[loopcontrols], ) metadata: Optional[TemplateMetadata] = None @@ -106,32 +111,42 @@ class PromptTemplate: self.template = self.compile(raw_template) @classmethod - def from_file(self, prompt_template_name: str): + async def from_file(cls, template_path: pathlib.Path): """Get a template from a jinja file.""" - template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") + # Add the jinja extension if it isn't provided + if template_path.suffix.endswith(".jinja"): + template_name = template_path.name.split(".jinja")[0] + else: + template_name = template_path.name + template_path = template_path.with_suffix(".jinja") + if template_path.exists(): - with open(template_path, "r", encoding="utf8") as raw_template_stream: - return PromptTemplate( - name=prompt_template_name, - raw_template=raw_template_stream.read(), + async with aiofiles.open( + template_path, "r", encoding="utf8" + ) as raw_template_stream: + contents = await raw_template_stream.read() + return cls( + name=template_name, + raw_template=contents, ) else: # Let the user know if the template file isn't found raise TemplateLoadError( - f'Chat template "{prompt_template_name}" not found in files.' + f'Chat template "{template_name}" not found in files.' ) @classmethod - def from_model_json( - self, json_path: pathlib.Path, key: str, name: Optional[str] = None + async def from_model_json( + cls, json_path: pathlib.Path, key: str, name: Optional[str] = None ): """Get a template from a JSON file. Requires a key and template name""" if not json_path.exists(): raise TemplateLoadError(f'Model JSON path "{json_path}" not found.') - with open(json_path, "r", encoding="utf8") as config_file: - model_config = json.load(config_file) + async with aiofiles.open(json_path, "r", encoding="utf8") as config_file: + contents = await config_file.read() + model_config = json.loads(contents) chat_template = model_config.get(key) if not chat_template: @@ -162,7 +177,7 @@ class PromptTemplate: ) else: # Can safely assume the chat template is the old style - return PromptTemplate( + return cls( name="from_tokenizer_config", raw_template=chat_template, ) diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 9db8ad2..c00fef4 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,3 +1,4 @@ +import aiofiles import json import pathlib from typing import List, Optional, Union @@ -15,15 +16,16 @@ class GenerationConfig(BaseModel): bad_words_ids: Optional[List[List[int]]] = None @classmethod - def from_file(self, model_directory: pathlib.Path): + async def from_file(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" generation_config_path = model_directory / "generation_config.json" - with open( + async with aiofiles.open( generation_config_path, "r", encoding="utf8" ) as generation_config_json: - generation_config_dict = json.load(generation_config_json) - return self.model_validate(generation_config_dict) + contents = await generation_config_json.read() + generation_config_dict = json.loads(contents) + return cls.model_validate(generation_config_dict) def eos_tokens(self): """Wrapper method to fetch EOS tokens.""" @@ -43,13 +45,16 @@ class HuggingFaceConfig(BaseModel): badwordsids: Optional[str] = None @classmethod - def from_file(self, model_directory: pathlib.Path): + async def from_file(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json" - with open(hf_config_path, "r", encoding="utf8") as hf_config_json: - hf_config_dict = json.load(hf_config_json) - return self.model_validate(hf_config_dict) + async with aiofiles.open( + hf_config_path, "r", encoding="utf8" + ) as hf_config_json: + contents = await hf_config_json.read() + hf_config_dict = json.loads(contents) + return cls.model_validate(hf_config_dict) def get_badwordsids(self): """Wrapper method to fetch badwordsids.""" diff --git a/config_sample.yml b/config_sample.yml index 85bb1df..3b4f247 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -83,6 +83,9 @@ model: # Enable this if the program is looking for a specific OAI model #use_dummy_models: False + # Allow direct loading of models from a completion or chat completion request + inline_model_loading: False + # An initial model to load. Make sure the model is located in the model directory! # A model can be loaded later via the API. # REQUIRED: This must be filled out to load a model on startup! diff --git a/docker/.dockerignore b/docker/.dockerignore deleted file mode 100644 index ae8a12c..0000000 --- a/docker/.dockerignore +++ /dev/null @@ -1,6 +0,0 @@ -models/ -loras/ -.ruff_cache/ -**/__pycache__/ -config.yml -api_tokens.yml \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 0b709b5..f3587cc 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,6 +22,8 @@ COPY pyproject.toml . # Install packages specified in pyproject.toml cu121 RUN pip3 install --no-cache-dir .[cu121] +RUN rm pyproject.toml + # Copy the current directory contents into the container COPY . . diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index d55d857..d27bf47 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,11 +1,13 @@ version: '3.8' services: tabbyapi: - build: - context: .. - dockerfile: ./docker/Dockerfile - args: - - DO_PULL=true + # Uncomment this to build a docker image from source + #build: + # context: .. + # dockerfile: ./docker/Dockerfile + + # Comment this to build a docker image from source + image: ghcr.io/theroyallab/tabbyapi:latest ports: - "5000:5000" healthcheck: diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index 334bae2..310a380 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -137,7 +137,7 @@ async def get_version(): async def get_extra_version(): """Impersonate Koboldcpp.""" - return {"result": "KoboldCpp", "version": "1.61"} + return {"result": "KoboldCpp", "version": "1.71"} @kai_router.get("/config/soft_prompts_list") diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 7cf08d7..f120e4d 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -22,6 +22,7 @@ from endpoints.OAI.utils.chat_completion import ( ) from endpoints.OAI.utils.completion import ( generate_completion, + load_inline_model, stream_generate_completion, ) from endpoints.OAI.utils.embeddings import get_embeddings @@ -42,7 +43,7 @@ def setup(): # Completions endpoint @router.post( "/v1/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key)], ) async def completion_request( request: Request, data: CompletionRequest @@ -53,6 +54,18 @@ async def completion_request( If stream = true, this returns an SSE stream. """ + if data.model: + inline_load_task = asyncio.create_task(load_inline_model(data.model, request)) + + await run_with_request_disconnect( + request, + inline_load_task, + disconnect_message=f"Model switch for generation {request.state.id} " + + "cancelled by user.", + ) + else: + await check_model_container() + model_path = model.container.model_dir if isinstance(data.prompt, list): @@ -85,7 +98,7 @@ async def completion_request( # Chat completions endpoint @router.post( "/v1/chat/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key)], ) async def chat_completion_request( request: Request, data: ChatCompletionRequest @@ -96,6 +109,11 @@ async def chat_completion_request( If stream = true, this returns an SSE stream. """ + if data.model: + await load_inline_model(data.model, request) + else: + await check_model_container() + if model.container.prompt_template is None: error_message = handle_request_error( "Chat completions are disabled because a prompt template is not set.", diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 8977792..30ec769 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -56,6 +56,7 @@ class ChatCompletionRequest(CommonCompletionRequest): add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} response_prefix: Optional[str] = None + model: Optional[str] = None # tools is follows the format OAI schema, functions is more flexible # both are available in the chat template. diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 52c2bb4..df4bf19 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -1,4 +1,8 @@ -"""Completion utilities for OAI server.""" +""" +Completion utilities for OAI server. + +Also serves as a common module for completions and chat completions. +""" import asyncio import pathlib @@ -10,12 +14,14 @@ from typing import List, Union from loguru import logger from common import model +from common.auth import get_key_permission from common.networking import ( get_generator_error, handle_request_disconnect, handle_request_error, request_disconnect_loop, ) +from common.tabby_config import config from common.utils import unwrap from endpoints.OAI.types.completion import ( CompletionRequest, @@ -103,6 +109,50 @@ async def _stream_collector( await gen_queue.put(e) +async def load_inline_model(model_name: str, request: Request): + """Load a model from the data.model parameter""" + + # Return if the model container already exists and the model is fully loaded + if ( + model.container + and model.container.model_dir.name == model_name + and model.container.model_loaded + ): + return + + # Inline model loading isn't enabled or the user isn't an admin + if not get_key_permission(request) == "admin": + error_message = handle_request_error( + f"Unable to switch model to {model_name} because " + + "an admin key isn't provided", + exc_info=False, + ).error.message + + raise HTTPException(401, error_message) + + if not unwrap(config.model.get("inline_model_loading"), False): + logger.warning( + f"Unable to switch model to {model_name} because " + '"inline_model_loading" is not True in config.yml.' + ) + + return + + model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models")) + model_path = model_path / model_name + + # Model path doesn't exist + if not model_path.exists(): + logger.warning( + f"Could not find model path {str(model_path)}. Skipping inline model load." + ) + + return + + # Load the model + await model.load_model(model_path) + + async def stream_generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ): diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 29a615c..325fbad 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -103,7 +103,7 @@ async def list_draft_models(request: Request) -> ModelList: models = get_model_list(draft_model_path.resolve()) else: - models = await get_current_model_list(is_draft=True) + models = await get_current_model_list(model_type="draft") return models @@ -441,7 +441,8 @@ async def switch_template(data: TemplateSwitchRequest): raise HTTPException(400, error_message) try: - model.container.prompt_template = PromptTemplate.from_file(data.name) + template_path = pathlib.Path("templates") / data.name + model.container.prompt_template = await PromptTemplate.from_file(template_path) except FileNotFoundError as e: error_message = handle_request_error( f"The template name {data.name} doesn't exist. Check the spelling?", @@ -490,7 +491,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): if data.preset: try: - sampling.overrides_from_file(data.preset) + await sampling.overrides_from_file(data.preset) except FileNotFoundError as e: error_message = handle_request_error( f"Sampler override preset with name {data.preset} does not exist. " diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 6966359..f2560b3 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -5,7 +5,8 @@ from time import time from typing import List, Literal, Optional, Union from common.config_models import logging_config_model -from common.model import get_config_default +from common.tabby_config import config +from common.utils import unwrap class ModelCardParameters(BaseModel): @@ -51,23 +52,13 @@ class DraftModelLoadRequest(BaseModel): draft_model_name: str # Config arguments - draft_rope_scale: Optional[float] = Field( - default_factory=lambda: get_config_default( - "draft_rope_scale", model_type="draft" - ) - ) + draft_rope_scale: Optional[float] = None draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( description='Automatically calculated if set to "auto"', - default_factory=lambda: get_config_default( - "draft_rope_alpha", model_type="draft" - ), + default=None, examples=[1.0], ) - draft_cache_mode: Optional[str] = Field( - default_factory=lambda: get_config_default( - "draft_cache_mode", model_type="draft" - ) - ) + draft_cache_mode: Optional[str] = None class ModelLoadRequest(BaseModel): @@ -78,62 +69,45 @@ class ModelLoadRequest(BaseModel): # Config arguments - # Max seq len is fetched from config.json of the model by default max_seq_len: Optional[int] = Field( description="Leave this blank to use the model's base sequence length", - default_factory=lambda: get_config_default("max_seq_len"), + default=None, examples=[4096], ) override_base_seq_len: Optional[int] = Field( description=( "Overrides the model's base sequence length. " "Leave blank if unsure" ), - default_factory=lambda: get_config_default("override_base_seq_len"), + default=None, examples=[4096], ) cache_size: Optional[int] = Field( description=("Number in tokens, must be greater than or equal to max_seq_len"), - default_factory=lambda: get_config_default("cache_size"), + default=None, examples=[4096], ) - tensor_parallel: Optional[bool] = Field( - default_factory=lambda: get_config_default("tensor_parallel") - ) - gpu_split_auto: Optional[bool] = Field( - default_factory=lambda: get_config_default("gpu_split_auto") - ) - autosplit_reserve: Optional[List[float]] = Field( - default_factory=lambda: get_config_default("autosplit_reserve") - ) + tensor_parallel: Optional[bool] = None + gpu_split_auto: Optional[bool] = None + autosplit_reserve: Optional[List[float]] = None gpu_split: Optional[List[float]] = Field( - default_factory=lambda: get_config_default("gpu_split"), + default=None, examples=[[24.0, 20.0]], ) rope_scale: Optional[float] = Field( description="Automatically pulled from the model's config if not present", - default_factory=lambda: get_config_default("rope_scale"), + default=None, examples=[1.0], ) rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( description='Automatically calculated if set to "auto"', - default_factory=lambda: get_config_default("rope_alpha"), + default=None, examples=[1.0], ) - cache_mode: Optional[str] = Field( - default_factory=lambda: get_config_default("cache_mode") - ) - chunk_size: Optional[int] = Field( - default_factory=lambda: get_config_default("chunk_size") - ) - prompt_template: Optional[str] = Field( - default_factory=lambda: get_config_default("prompt_template") - ) - num_experts_per_token: Optional[int] = Field( - default_factory=lambda: get_config_default("num_experts_per_token") - ) - fasttensors: Optional[bool] = Field( - default_factory=lambda: get_config_default("fasttensors") - ) + cache_mode: Optional[str] = None + chunk_size: Optional[int] = None + prompt_template: Optional[str] = None + num_experts_per_token: Optional[int] = None + fasttensors: Optional[bool] = None # Non-config arguments draft: Optional[DraftModelLoadRequest] = None @@ -142,9 +116,11 @@ class ModelLoadRequest(BaseModel): class EmbeddingModelLoadRequest(BaseModel): name: str + + # Set default from the config embeddings_device: Optional[str] = Field( - default_factory=lambda: get_config_default( - "embeddings_device", model_type="embedding" + default_factory=lambda: unwrap( + config.embeddings.get("embeddings_device"), "cpu" ) ) diff --git a/main.py b/main.py index 6254bf2..6e7943b 100644 --- a/main.py +++ b/main.py @@ -50,7 +50,7 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - load_auth_keys(config.network.disable_auth) + await load_auth_keys(config.network.disable_auth) gen_logging.broadcast_status() @@ -58,7 +58,7 @@ async def entrypoint_async(): sampling_override_preset = config.sampling.override_preset if sampling_override_preset: try: - sampling.overrides_from_file(sampling_override_preset) + await sampling.overrides_from_file(sampling_override_preset) except FileNotFoundError as e: logger.warning(str(e)) @@ -111,7 +111,7 @@ def entrypoint(arguments: Optional[dict] = None): arguments = convert_args_to_dict(parser.parse_args(), parser) # load config - config.load_config(arguments) + config.load(arguments) if do_export_openapi: openapi_json = export_openapi() diff --git a/pyproject.toml b/pyproject.toml index b9e80fe..19fcbce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,12 +68,12 @@ cu121 = [ "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", @@ -95,12 +95,12 @@ cu118 = [ "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", @@ -119,9 +119,9 @@ amd = [ "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index b20b042..e01c2b7 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -97,6 +97,24 @@ penalty_range: override: -1 force: false +# MARK: DRY +dry_multiplier: + override: 0.0 + force: false +dry_base: + override: 0.0 + force: false +dry_allowed_length: + override: 0 + force: false +dry_range: + override: 0 + force: false +dry_sequence_breakers: + override: [] + force: false + additive: false + # MARK: Mirostat mirostat_mode: override: 0 diff --git a/start.py b/start.py index 490570e..7e7776d 100644 --- a/start.py +++ b/start.py @@ -234,15 +234,16 @@ if __name__ == "__main__": if first_run: start_options["first_run_done"] = True - # Save start options - with open("start_options.json", "w") as start_file: - start_file.write(json.dumps(start_options)) + # Save start options + with open("start_options.json", "w") as start_file: + start_file.write(json.dumps(start_options)) - print( - "Successfully wrote your start script options to `start_options.json`. \n" - "If something goes wrong, editing or deleting the file " - "will reinstall TabbyAPI as a first-time user." - ) + print( + "Successfully wrote your start script options to " + "`start_options.json`. \n" + "If something goes wrong, editing or deleting the file " + "will reinstall TabbyAPI as a first-time user." + ) # Import entrypoint after installing all requirements try: diff --git a/templates/alpaca.jinja b/templates/alpaca.jinja index 5c7aa7c..41eb098 100644 --- a/templates/alpaca.jinja +++ b/templates/alpaca.jinja @@ -1,5 +1,5 @@ {# Metadata #} -{% set stop_strings = ["### Instruction:", "### Input:", "### Response:"] %} +{%- set stop_strings = ["### Instruction:", "### Input:", "### Response:"] -%} {# Template #} {{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} diff --git a/templates/chatml.jinja b/templates/chatml.jinja index 750b82a..f7dd6f5 100644 --- a/templates/chatml.jinja +++ b/templates/chatml.jinja @@ -1,5 +1,5 @@ {# Metadata #} -{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %} +{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%} {# Template #} {% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} diff --git a/templates/chatml_with_headers_tool_calling.jinja b/templates/tool_calls/chatml_with_headers.jinja similarity index 92% rename from templates/chatml_with_headers_tool_calling.jinja rename to templates/tool_calls/chatml_with_headers.jinja index ecd1d2c..db0c0ac 100644 --- a/templates/chatml_with_headers_tool_calling.jinja +++ b/templates/tool_calls/chatml_with_headers.jinja @@ -1,8 +1,8 @@ {# Metadata #} -{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %} -{% set message_roles = ['system', 'user', 'assistant', 'tool'] %} -{% set tool_start = "<|tool_start|>" %} -{% set tool_end = "<|tool_end|>" %} +{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%} +{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%} +{%- set tool_start = "<|tool_start|>" -%} +{%- set tool_end = "<|tool_end|>" -%} {%- set start_header = "<|start_header_id|>" -%} {%- set end_header = "<|end_header_id|>\n" -%}