From 279e900ea5ec095eff97d2996a966a1ea0aa663f Mon Sep 17 00:00:00 2001 From: Colin Kealty Date: Tue, 4 Jun 2024 13:35:48 -0400 Subject: [PATCH 01/34] Add on the fly model loading to requests --- endpoints/OAI/router.py | 49 +++++++++++++++++++++++++- endpoints/OAI/types/chat_completion.py | 1 + 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 0e4f27b..f4cc516 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,5 +1,6 @@ import asyncio import pathlib +from loguru import logger from fastapi import APIRouter, Depends, HTTPException, Header, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -118,7 +119,7 @@ async def completion_request( # Chat completions endpoint @router.post( "/v1/chat/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key)], ) async def chat_completion_request( request: Request, data: ChatCompletionRequest @@ -129,6 +130,52 @@ async def chat_completion_request( If stream = true, this returns an SSE stream. """ + if data.model is not None and ( + model.container is None or model.container.get_model_path().name != data.model + ): + adminValid = False + if "x_admin_key" in request.headers.keys(): + try: + await check_admin_key( + x_admin_key=request.headers.get("x_admin_key"), authorization=None + ) + adminValid = True + except HTTPException: + pass + + if not adminValid and "authorization" in request.headers.keys(): + try: + await check_admin_key( + x_admin_key=None, authorization=request.headers.get("authorization") + ) + adminValid = True + except HTTPException: + pass + + if adminValid: + logger.info( + f"New request for {data.model} which is not loaded, proper admin key provided, loading new model" + ) + + model_path = pathlib.Path( + unwrap(config.model_config().get("model_dir"), "models") + ) + model_path = model_path / data.model + + if not model_path.exists(): + error_message = handle_request_error( + "Could not find the model path for load. Check model name.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + await model.load_model(model_path) + else: + logger.info(f"No valid admin key found to change loaded model, ignoring") + else: + await check_model_container() + if model.container.prompt_template is None: error_message = handle_request_error( "Chat completions are disabled because a prompt template is not set.", diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index be5cfea..b66277b 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -47,6 +47,7 @@ class ChatCompletionRequest(CommonCompletionRequest): add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} response_prefix: Optional[str] = None + model: Optional[str] = None class ChatCompletionResponse(BaseModel): From 48d7674316129b4dbf3385501652c2285e349b3c Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Thu, 29 Aug 2024 00:50:01 +0100 Subject: [PATCH 02/34] make docker-compose use prebuilt images - Docker compose uses the prebuilt images produced by the GitHub action added in 872eeed581380ae032dfc039479934ce9a55e6f3 --- docker/docker-compose.yml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index d55d857..8a46737 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,11 +1,7 @@ version: '3.8' services: tabbyapi: - build: - context: .. - dockerfile: ./docker/Dockerfile - args: - - DO_PULL=true + image: ghcr.io/theroyallab/tabbyapi:latest ports: - "5000:5000" healthcheck: From 43104e0d19c1978ec642f4e502c02ce536635f45 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sat, 31 Aug 2024 21:48:43 +0100 Subject: [PATCH 03/34] Complete conditional infinity import TODO - add logging - change declaration order --- backends/infinity/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 35a4df4..3704901 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -7,13 +7,13 @@ from typing import List, Optional from common.utils import unwrap # Conditionally import infinity to sidestep its logger -# TODO: Make this prettier +has_infinity_emb: bool = False try: from infinity_emb import EngineArgs, AsyncEmbeddingEngine - has_infinity_emb = True + logger.debug("Successfully imported infinity.") except ImportError: - has_infinity_emb = False + logger.debug("Failed to import infinity.") class InfinityContainer: From 21f14d431883009263bac40c17e4e634aff591e7 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 3 Sep 2024 23:37:28 -0400 Subject: [PATCH 04/34] API: Update inline load - Add a config flag - Migrate support to /v1/completions - Unify the load function Signed-off-by: kingbri --- config_sample.yml | 3 ++ endpoints/OAI/router.py | 55 +++++-------------------------- endpoints/OAI/utils/completion.py | 50 ++++++++++++++++++++++++++-- 3 files changed, 60 insertions(+), 48 deletions(-) diff --git a/config_sample.yml b/config_sample.yml index 85bb1df..3b4f247 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -83,6 +83,9 @@ model: # Enable this if the program is looking for a specific OAI model #use_dummy_models: False + # Allow direct loading of models from a completion or chat completion request + inline_model_loading: False + # An initial model to load. Make sure the model is located in the model directory! # A model can be loaded later via the API. # REQUIRED: This must be filled out to load a model on startup! diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index ca91b0a..1b98f41 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,6 +1,4 @@ import asyncio -import pathlib -from loguru import logger from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -23,6 +21,7 @@ from endpoints.OAI.utils.chat_completion import ( ) from endpoints.OAI.utils.completion import ( generate_completion, + load_inline_model, stream_generate_completion, ) from endpoints.OAI.utils.embeddings import get_embeddings @@ -43,7 +42,7 @@ def setup(): # Completions endpoint @router.post( "/v1/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key)], ) async def completion_request( request: Request, data: CompletionRequest @@ -54,6 +53,11 @@ async def completion_request( If stream = true, this returns an SSE stream. """ + if data.model: + await load_inline_model(data.model, request) + else: + await check_model_container() + model_path = model.container.model_dir if isinstance(data.prompt, list): @@ -99,49 +103,8 @@ async def chat_completion_request( If stream = true, this returns an SSE stream. """ - if data.model is not None and ( - model.container is None or model.container.get_model_path().name != data.model - ): - adminValid = False - if "x_admin_key" in request.headers.keys(): - try: - await check_admin_key( - x_admin_key=request.headers.get("x_admin_key"), authorization=None - ) - adminValid = True - except HTTPException: - pass - - if not adminValid and "authorization" in request.headers.keys(): - try: - await check_admin_key( - x_admin_key=None, authorization=request.headers.get("authorization") - ) - adminValid = True - except HTTPException: - pass - - if adminValid: - logger.info( - f"New request for {data.model} which is not loaded, proper admin key provided, loading new model" - ) - - model_path = pathlib.Path( - unwrap(config.model_config().get("model_dir"), "models") - ) - model_path = model_path / data.model - - if not model_path.exists(): - error_message = handle_request_error( - "Could not find the model path for load. Check model name.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - await model.load_model(model_path) - else: - logger.info(f"No valid admin key found to change loaded model, ignoring") + if data.model: + await load_inline_model(data.model, request) else: await check_model_container() diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 52c2bb4..5fdf81f 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -1,4 +1,8 @@ -"""Completion utilities for OAI server.""" +""" +Completion utilities for OAI server. + +Also serves as a common module for completions and chat completions. +""" import asyncio import pathlib @@ -9,7 +13,8 @@ from typing import List, Union from loguru import logger -from common import model +from common import config, model +from common.auth import get_key_permission from common.networking import ( get_generator_error, handle_request_disconnect, @@ -173,6 +178,47 @@ async def stream_generate_completion( ) +async def load_inline_model(model_name: str, request: Request): + """Load a model from the data.model parameter""" + + # Return if the model container already exists + if model.container and model.container.model_dir.name == model_name: + return + + model_config = config.model_config() + + # Inline model loading isn't enabled or the user isn't an admin + if not get_key_permission(request) == "admin": + logger.warning( + f"Unable to switch model to {model_name} " + "because an admin key isn't provided." + ) + + return + + if not unwrap(model_config.get("inline_model_loading"), False): + logger.warning( + f"Unable to switch model to {model_name} because " + '"inline_model_load" is not True in config.yml.' + ) + + return + + model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) + model_path = model_path / model_name + + # Model path doesn't exist + if not model_path.exists(): + logger.warning( + f"Could not find model path {str(model_path)}. Skipping inline model load." + ) + + return + + # Load the model + await model.load_model(model_path) + + async def generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ): From 42a42caf430fb28423dd7df31a2d070eba16892c Mon Sep 17 00:00:00 2001 From: Jake <84923604+SecretiveShell@users.noreply.github.com> Date: Wed, 4 Sep 2024 16:14:09 +0100 Subject: [PATCH 05/34] remove logging - remove logging statements - format code with ruff --- backends/infinity/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 3704901..c48a42c 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -10,10 +10,10 @@ from common.utils import unwrap has_infinity_emb: bool = False try: from infinity_emb import EngineArgs, AsyncEmbeddingEngine + has_infinity_emb = True - logger.debug("Successfully imported infinity.") except ImportError: - logger.debug("Failed to import infinity.") + pass class InfinityContainer: From 9c10789ca1095e0571c534e5e5535f3b337ec604 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 4 Sep 2024 21:44:14 -0400 Subject: [PATCH 06/34] API: Error on invalid key permissions and cleanup format If a user requesting a model change isn't admin, error. Better to place the load function before the generate functions. Signed-off-by: kingbri --- endpoints/OAI/utils/completion.py | 83 ++++++++++++++++--------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 5fdf81f..cc752c5 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -108,6 +108,48 @@ async def _stream_collector( await gen_queue.put(e) +async def load_inline_model(model_name: str, request: Request): + """Load a model from the data.model parameter""" + + # Return if the model container already exists + if model.container and model.container.model_dir.name == model_name: + return + + model_config = config.model_config() + + # Inline model loading isn't enabled or the user isn't an admin + if not get_key_permission(request) == "admin": + error_message = handle_request_error( + f"Unable to switch model to {model_name} because " + + "an admin key isn't provided", + exc_info=False, + ).error.message + + raise HTTPException(401, error_message) + + if not unwrap(model_config.get("inline_model_loading"), False): + logger.warning( + f"Unable to switch model to {model_name} because " + '"inline_model_load" is not True in config.yml.' + ) + + return + + model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) + model_path = model_path / model_name + + # Model path doesn't exist + if not model_path.exists(): + logger.warning( + f"Could not find model path {str(model_path)}. Skipping inline model load." + ) + + return + + # Load the model + await model.load_model(model_path) + + async def stream_generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ): @@ -178,47 +220,6 @@ async def stream_generate_completion( ) -async def load_inline_model(model_name: str, request: Request): - """Load a model from the data.model parameter""" - - # Return if the model container already exists - if model.container and model.container.model_dir.name == model_name: - return - - model_config = config.model_config() - - # Inline model loading isn't enabled or the user isn't an admin - if not get_key_permission(request) == "admin": - logger.warning( - f"Unable to switch model to {model_name} " - "because an admin key isn't provided." - ) - - return - - if not unwrap(model_config.get("inline_model_loading"), False): - logger.warning( - f"Unable to switch model to {model_name} because " - '"inline_model_load" is not True in config.yml.' - ) - - return - - model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) - model_path = model_path / model_name - - # Model path doesn't exist - if not model_path.exists(): - logger.warning( - f"Could not find model path {str(model_path)}. Skipping inline model load." - ) - - return - - # Load the model - await model.load_model(model_path) - - async def generate_completion( data: CompletionRequest, request: Request, model_path: pathlib.Path ): From 98768bfa3076f4d9b9590974643d41a83aae3339 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 4 Sep 2024 23:39:06 -0400 Subject: [PATCH 07/34] Docker: Re-add build block If a user wants to build from source, let them. But the default should fetch from the package registry. Signed-off-by: kingbri --- docker/docker-compose.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 8a46737..d27bf47 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -1,6 +1,12 @@ version: '3.8' services: tabbyapi: + # Uncomment this to build a docker image from source + #build: + # context: .. + # dockerfile: ./docker/Dockerfile + + # Comment this to build a docker image from source image: ghcr.io/theroyallab/tabbyapi:latest ports: - "5000:5000" From 1c9991f79ef69d87ac4a3324c97da13d3e0073a6 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 5 Sep 2024 17:59:18 -0400 Subject: [PATCH 08/34] Config: Format and organize Rename some methods and change comments. Signed-off-by: kingbri --- common/tabby_config.py | 7 ++----- main.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/common/tabby_config.py b/common/tabby_config.py index c0c9e58..f3a189f 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -16,10 +16,7 @@ class TabbyConfig: developer: dict = {} embeddings: dict = {} - def __init__(self): - pass - - def load_config(self, arguments: Optional[dict] = None): + def load(self, arguments: Optional[dict] = None): """load the global application config""" # config is applied in order of items in the list @@ -87,5 +84,5 @@ class TabbyConfig: pass -# Create an empty instance of the shared var to make sure nothing breaks +# Create an empty instance of the config class config: TabbyConfig = TabbyConfig() diff --git a/main.py b/main.py index f017ecc..587cd4b 100644 --- a/main.py +++ b/main.py @@ -110,7 +110,7 @@ def entrypoint(arguments: Optional[dict] = None): arguments = convert_args_to_dict(parser.parse_args(), parser) # load config - config.load_config(arguments) + config.load(arguments) if do_export_openapi: openapi_json = export_openapi() From 2f45e978c56a5bc3c60a10394a60ade74a04f546 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 5 Sep 2024 18:04:53 -0400 Subject: [PATCH 09/34] API: Fix merge overwrite The completions utils did not take the new imports. Signed-off-by: kingbri --- endpoints/OAI/utils/completion.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index cc752c5..d279545 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -13,7 +13,7 @@ from typing import List, Union from loguru import logger -from common import config, model +from common import model from common.auth import get_key_permission from common.networking import ( get_generator_error, @@ -21,6 +21,7 @@ from common.networking import ( handle_request_error, request_disconnect_loop, ) +from common.tabby_config import config from common.utils import unwrap from endpoints.OAI.types.completion import ( CompletionRequest, @@ -115,8 +116,6 @@ async def load_inline_model(model_name: str, request: Request): if model.container and model.container.model_dir.name == model_name: return - model_config = config.model_config() - # Inline model loading isn't enabled or the user isn't an admin if not get_key_permission(request) == "admin": error_message = handle_request_error( @@ -127,7 +126,7 @@ async def load_inline_model(model_name: str, request: Request): raise HTTPException(401, error_message) - if not unwrap(model_config.get("inline_model_loading"), False): + if not unwrap(config.model.get("inline_model_loading"), False): logger.warning( f"Unable to switch model to {model_name} because " '"inline_model_load" is not True in config.yml.' @@ -135,7 +134,7 @@ async def load_inline_model(model_name: str, request: Request): return - model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models")) + model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models")) model_path = model_path / model_name # Model path doesn't exist From d34756dc98191ff898eb4662578a8e0a9215aeb2 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 5 Sep 2024 18:05:59 -0400 Subject: [PATCH 10/34] Tree: Format Signed-off-by: kingbri --- main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/main.py b/main.py index 587cd4b..740e1d0 100644 --- a/main.py +++ b/main.py @@ -123,7 +123,6 @@ def entrypoint(arguments: Optional[dict] = None): # Check exllamav2 version and give a descriptive error if it's too old # Skip if launching unsafely - print(f"MAIN.PY {config=}") if unwrap(config.developer.get("unsafe_launch"), False): logger.warning( "UNSAFE: Skipping ExllamaV2 version check.\n" From 05c3f1194f11606776a6d130cea652f40319c2f6 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 7 Sep 2024 00:48:42 -0400 Subject: [PATCH 11/34] Sampling: Add rudimentary DRY support Adds DRY support based on the current exl2 dev API. Only change for optimization is dry_max_ngram instead of using a closed range. Currently, DRY range is aliased to dry_max_ngram. Signed-off-by: kingbri --- backends/exllamav2/model.py | 30 +++++++++++++++++++++++++++- common/sampling.py | 39 +++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 7fe08db..1d80062 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -7,6 +7,7 @@ import pathlib import traceback import torch import uuid +from copy import deepcopy from exllamav2 import ( ExLlamaV2, ExLlamaV2Config, @@ -944,6 +945,14 @@ class ExllamaV2Container: Meant for dev wheels! """ + if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr( + ExLlamaV2Sampler.Settings, "dry_multiplier" + ): + logger.warning( + "DRY sampling is not supported by the currently " + "installed ExLlamaV2 version." + ) + return kwargs async def generate_gen( @@ -1035,6 +1044,7 @@ class ExllamaV2Container: "Please use an ampere (30 series) or higher GPU for CFG support." ) + # Penalties gen_settings.token_repetition_penalty = unwrap( kwargs.get("repetition_penalty"), 1.0 ) @@ -1070,6 +1080,23 @@ class ExllamaV2Container: kwargs.get("repetition_decay"), fallback_decay, 0 ) + # DRY options + dry_allowed_length = unwrap(kwargs.get("dry_allowed_length"), 0) + + # 0 = disabled + if dry_allowed_length: + gen_settings.dry_allowed_length = dry_allowed_length + gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 2.0) + gen_settings.dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 2.0) + gen_settings.dry_max_ngram = unwrap(kwargs.get("dry_max_ngram"), 20) + + # Tokenize sequence breakers + dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers") + if dry_sequence_breakers_json: + gen_settings.dry_sequence_breakers = { + self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json + } + # Initialize grammar handler grammar_handler = ExLlamaV2Grammar() @@ -1130,7 +1157,8 @@ class ExllamaV2Container: ) # Store the gen settings for logging purposes - gen_settings_log_dict = vars(gen_settings) + # Deepcopy to save a snapshot of vars + gen_settings_log_dict = deepcopy(vars(gen_settings)) # Set banned tokens banned_tokens = unwrap(kwargs.get("banned_tokens"), []) diff --git a/common/sampling.py b/common/sampling.py index 56c5b34..a3bccb3 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -1,5 +1,6 @@ """Common functions for sampling parameters""" +import json import pathlib import yaml from copy import deepcopy @@ -140,6 +141,28 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("repetition_decay", 0) ) + dry_allowed_length: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0) + ) + + dry_base: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("dry_base", 2.0) + ) + + dry_multiplier: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("dry_multiplier", 2.0) + ) + + # TODO: Remove these aliases + dry_max_ngram: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("dry_max_ngram", 20), + alias=AliasChoices("dry_max_ngram", "dry_penalty_last_n"), + ) + + dry_sequence_breakers: Optional[str] = Field( + default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", []) + ) + mirostat_mode: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) ) @@ -305,6 +328,17 @@ class BaseSamplerRequest(BaseModel): int(x) for x in self.allowed_tokens.split(",") if x.isdigit() ] + # Convert sequence breakers into an array of strings + # NOTE: This sampler sucks to parse. + if self.dry_sequence_breakers: + if not self.dry_sequence_breakers.startswith("["): + self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]" + + try: + self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers) + except Exception: + self.dry_sequence_breakers = [] + gen_params = { "max_tokens": self.max_tokens, "min_tokens": self.min_tokens, @@ -335,6 +369,11 @@ class BaseSamplerRequest(BaseModel): "presence_penalty": self.presence_penalty, "repetition_penalty": self.repetition_penalty, "penalty_range": self.penalty_range, + "dry_allowed_length": self.dry_allowed_length, + "dry_base": self.dry_base, + "dry_max_ngram": self.dry_max_ngram, + "dry_multiplier": self.dry_multiplier, + "dry_sequence_breakers": self.dry_sequence_breakers, "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, From ae37f3f3326652e91363d4b809cae119a111e4d5 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 7 Sep 2024 12:39:14 -0400 Subject: [PATCH 12/34] Sampling: Update DRY Switch to new parameters and remove dry_max_ngram as that's not supposed to be changed. Signed-off-by: kingbri --- backends/exllamav2/model.py | 17 ++++++++++++----- common/sampling.py | 9 ++++----- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 1d80062..fdb85a7 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1081,14 +1081,21 @@ class ExllamaV2Container: ) # DRY options - dry_allowed_length = unwrap(kwargs.get("dry_allowed_length"), 0) + dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 0.0) - # 0 = disabled - if dry_allowed_length: - gen_settings.dry_allowed_length = dry_allowed_length + # < 0 = disabled + if dry_multiplier > 0: + gen_settings.dry_allowed_length = unwrap( + kwargs.get("dry_allowed_length"), 0 + ) gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 2.0) gen_settings.dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 2.0) - gen_settings.dry_max_ngram = unwrap(kwargs.get("dry_max_ngram"), 20) + + # Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range + # Use max_seq_len as the fallback to stay consistent + gen_settings.dry_range = unwrap( + kwargs.get("dry_range"), self.config.max_seq_len + ) # Tokenize sequence breakers dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers") diff --git a/common/sampling.py b/common/sampling.py index a3bccb3..de5b7dc 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -153,10 +153,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("dry_multiplier", 2.0) ) - # TODO: Remove these aliases - dry_max_ngram: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("dry_max_ngram", 20), - alias=AliasChoices("dry_max_ngram", "dry_penalty_last_n"), + dry_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("dry_range", 0), + alias=AliasChoices("dry_range", "dry_penalty_last_n"), + description=("Aliases: dry_penalty_last_n"), ) dry_sequence_breakers: Optional[str] = Field( @@ -371,7 +371,6 @@ class BaseSamplerRequest(BaseModel): "penalty_range": self.penalty_range, "dry_allowed_length": self.dry_allowed_length, "dry_base": self.dry_base, - "dry_max_ngram": self.dry_max_ngram, "dry_multiplier": self.dry_multiplier, "dry_sequence_breakers": self.dry_sequence_breakers, "repetition_decay": self.repetition_decay, From 4f5ca7a4c7972ea63528b557710377cd18eebac6 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 7 Sep 2024 12:48:59 -0400 Subject: [PATCH 13/34] Sampling: Update overrides and params Re-order to make more sense. Signed-off-by: kingbri --- backends/exllamav2/model.py | 6 ++++-- common/sampling.py | 14 +++++++------- sampler_overrides/sample_preset.yml | 14 ++++++++++++++ 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index fdb85a7..044bce0 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1085,11 +1085,13 @@ class ExllamaV2Container: # < 0 = disabled if dry_multiplier > 0: + gen_settings.dry_multiplier = dry_multiplier + + # TODO: Maybe set the "sane" defaults instead? gen_settings.dry_allowed_length = unwrap( kwargs.get("dry_allowed_length"), 0 ) - gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 2.0) - gen_settings.dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 2.0) + gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 0.0) # Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range # Use max_seq_len as the fallback to stay consistent diff --git a/common/sampling.py b/common/sampling.py index de5b7dc..67b8925 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -141,16 +141,16 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("repetition_decay", 0) ) - dry_allowed_length: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0) + dry_multiplier: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0) ) dry_base: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("dry_base", 2.0) + default_factory=lambda: get_default_sampler_value("dry_base", 0.0) ) - dry_multiplier: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("dry_multiplier", 2.0) + dry_allowed_length: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0) ) dry_range: Optional[int] = Field( @@ -369,9 +369,9 @@ class BaseSamplerRequest(BaseModel): "presence_penalty": self.presence_penalty, "repetition_penalty": self.repetition_penalty, "penalty_range": self.penalty_range, - "dry_allowed_length": self.dry_allowed_length, - "dry_base": self.dry_base, "dry_multiplier": self.dry_multiplier, + "dry_base": self.dry_base, + "dry_allowed_length": self.dry_allowed_length, "dry_sequence_breakers": self.dry_sequence_breakers, "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index b20b042..ed2edba 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -97,6 +97,20 @@ penalty_range: override: -1 force: false +# MARK: DRY +dry_multiplier: + override: 0.0 + force: false +dry_base: + override: 0.0 + force: false +dry_allowed_length: + override: 0 + force: false +dry_range: + override: 0 + force: false + # MARK: Mirostat mirostat_mode: override: 0 From d57a3b459c15434a58c782a7eed6f13998b7eac6 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sat, 7 Sep 2024 18:27:10 +0100 Subject: [PATCH 14/34] fix function arguments for get_model_list --- endpoints/core/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 1f9d194..cc9af24 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -103,7 +103,7 @@ async def list_draft_models(request: Request) -> ModelList: models = get_model_list(draft_model_path.resolve()) else: - models = await get_current_model_list(is_draft=True) + models = await get_current_model_list(model_type="draft") return models From 4b11cabbec70f27b695a66ab6d5dccb0ad763171 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Sun, 8 Sep 2024 00:02:00 +0100 Subject: [PATCH 15/34] debloat docker build --- .dockerignore | 20 ++++++++++++++++++++ docker/.dockerignore | 6 ------ docker/Dockerfile | 2 ++ 3 files changed, 22 insertions(+), 6 deletions(-) create mode 100644 .dockerignore delete mode 100644 docker/.dockerignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..e8118ac --- /dev/null +++ b/.dockerignore @@ -0,0 +1,20 @@ +.ruff_cache/ +**/__pycache__/ +venv +.git +.gitignore +.github + +# Ignore specific application files +models/ +loras/ +config.yml +config_sample.yml +api_tokens.yml +api_tokens_sample.yml +*.bat +*.sh +update_scripts +readme.md +colab +start.py \ No newline at end of file diff --git a/docker/.dockerignore b/docker/.dockerignore deleted file mode 100644 index ae8a12c..0000000 --- a/docker/.dockerignore +++ /dev/null @@ -1,6 +0,0 @@ -models/ -loras/ -.ruff_cache/ -**/__pycache__/ -config.yml -api_tokens.yml \ No newline at end of file diff --git a/docker/Dockerfile b/docker/Dockerfile index 0b709b5..f3587cc 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,6 +22,8 @@ COPY pyproject.toml . # Install packages specified in pyproject.toml cu121 RUN pip3 install --no-cache-dir .[cu121] +RUN rm pyproject.toml + # Copy the current directory contents into the container COPY . . From 9c4a0e650f6c7499f559c4624bea228e325caf7c Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 7 Sep 2024 21:32:46 -0400 Subject: [PATCH 16/34] Sampling: Fix override for DRY sequence breakers The common type should be an array of strings. Signed-off-by: kingbri --- common/sampling.py | 4 ++-- sampler_overrides/sample_preset.yml | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/common/sampling.py b/common/sampling.py index 67b8925..c8366b3 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -159,7 +159,7 @@ class BaseSamplerRequest(BaseModel): description=("Aliases: dry_penalty_last_n"), ) - dry_sequence_breakers: Optional[str] = Field( + dry_sequence_breakers: Optional[Union[str, List[str]]] = Field( default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", []) ) @@ -330,7 +330,7 @@ class BaseSamplerRequest(BaseModel): # Convert sequence breakers into an array of strings # NOTE: This sampler sucks to parse. - if self.dry_sequence_breakers: + if self.dry_sequence_breakers and isinstance(self.dry_sequence_breakers, str): if not self.dry_sequence_breakers.startswith("["): self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]" diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index ed2edba..e01c2b7 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -110,6 +110,10 @@ dry_allowed_length: dry_range: override: 0 force: false +dry_sequence_breakers: + override: [] + force: false + additive: false # MARK: Mirostat mirostat_mode: From b576a2f11609bea85bedc385403518cb2877e5c7 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 7 Sep 2024 21:45:51 -0400 Subject: [PATCH 17/34] API: Bump sent koboldcpp version Unlock DRY on lite UI. Signed-off-by: kingbri --- endpoints/Kobold/router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index 334bae2..310a380 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -137,7 +137,7 @@ async def get_version(): async def get_extra_version(): """Impersonate Koboldcpp.""" - return {"result": "KoboldCpp", "version": "1.61"} + return {"result": "KoboldCpp", "version": "1.71"} @kai_router.get("/config/soft_prompts_list") From acd3eb1140f0280b7630cd40dc1524c6dcb68b18 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 7 Sep 2024 22:15:42 -0400 Subject: [PATCH 18/34] Model: Add model folder template support Like tabby_config.yml in the model's folder, a custom template can also be provided via tabby_template.yml in addition to the existing templates folder. The config.yml always takes priority. Signed-off-by: kingbri --- backends/exllamav2/model.py | 19 +++++++++++++++---- common/templating.py | 14 ++++++++++---- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 044bce0..6e0a8cc 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -401,19 +401,30 @@ class ExllamaV2Container: find_template_functions = [ lambda: PromptTemplate.from_model_json( pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - "chat_template", + key="chat_template", ), lambda: PromptTemplate.from_file(find_template_from_model(model_directory)), ] + # Find the template in the model directory if it exists + model_dir_template_path = ( + pathlib.Path(self.config.model_dir) / "tabby_template.jinja" + ) + if model_dir_template_path.exists(): + find_template_functions[:0] = [ + lambda: PromptTemplate.from_file(model_dir_template_path) + ] + # Add lookup from prompt template name if provided if prompt_template_name: find_template_functions[:0] = [ - lambda: PromptTemplate.from_file(prompt_template_name), + lambda: PromptTemplate.from_file( + pathlib.Path("templates") / prompt_template_name + ), lambda: PromptTemplate.from_model_json( pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - "chat_template", - prompt_template_name, + key="chat_template", + name=prompt_template_name, ), ] diff --git a/common/templating.py b/common/templating.py index 47299ff..30abb38 100644 --- a/common/templating.py +++ b/common/templating.py @@ -106,20 +106,26 @@ class PromptTemplate: self.template = self.compile(raw_template) @classmethod - def from_file(self, prompt_template_name: str): + def from_file(self, template_path: pathlib.Path): """Get a template from a jinja file.""" - template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") + # Add the jinja extension if it isn't provided + if template_path.suffix.endswith(".jinja"): + template_name = template_path.name.split(".jinja")[0] + else: + template_name = template_path.name + template_path = template_path.with_suffix(".jinja") + if template_path.exists(): with open(template_path, "r", encoding="utf8") as raw_template_stream: return PromptTemplate( - name=prompt_template_name, + name=template_name, raw_template=raw_template_stream.read(), ) else: # Let the user know if the template file isn't found raise TemplateLoadError( - f'Chat template "{prompt_template_name}" not found in files.' + f'Chat template "{template_name}" not found in files.' ) @classmethod From dffceab7776f59bed21c0c44782d2b0af9a37180 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 8 Sep 2024 01:55:00 -0400 Subject: [PATCH 19/34] Sampling: Link dry_range Was not linked in the gen params dict. Signed-off-by: kingbri --- common/sampling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/common/sampling.py b/common/sampling.py index c8366b3..eab2a4c 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -373,6 +373,7 @@ class BaseSamplerRequest(BaseModel): "dry_base": self.dry_base, "dry_allowed_length": self.dry_allowed_length, "dry_sequence_breakers": self.dry_sequence_breakers, + "dry_range": self.dry_range, "repetition_decay": self.repetition_decay, "mirostat": self.mirostat_mode == 2, "mirostat_tau": self.mirostat_tau, From df118908510f8c835128f0d3488ead2de02b160c Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 8 Sep 2024 12:21:42 -0400 Subject: [PATCH 20/34] Templating: Add loopcontrols extension Inbuilt jinja extension to allow for break and continue in loops. Signed-off-by: kingbri --- common/templating.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/common/templating.py b/common/templating.py index 30abb38..d515cf8 100644 --- a/common/templating.py +++ b/common/templating.py @@ -5,6 +5,7 @@ import pathlib from importlib.metadata import version as package_version from typing import List, Optional from jinja2 import Template, TemplateError +from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger from packaging import version @@ -32,7 +33,10 @@ class PromptTemplate: raw_template: str template: Template environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment( - trim_blocks=True, lstrip_blocks=True, enable_async=True + trim_blocks=True, + lstrip_blocks=True, + enable_async=True, + extensions=[loopcontrols], ) metadata: Optional[TemplateMetadata] = None From 776bfd817df7850bdd366f0a1d93c4d733630212 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 8 Sep 2024 12:28:38 -0400 Subject: [PATCH 21/34] Templates: Migrate tool calling templates to folder Mirrors the llm-prompt-templates repo Signed-off-by: kingbri --- .gitignore | 6 +++++- .../chatml_with_headers.jinja} | 0 2 files changed, 5 insertions(+), 1 deletion(-) rename templates/{chatml_with_headers_tool_calling.jinja => tool_calls/chatml_with_headers.jinja} (100%) diff --git a/.gitignore b/.gitignore index 49aa517..5e5c9ee 100644 --- a/.gitignore +++ b/.gitignore @@ -192,7 +192,11 @@ templates/* !templates/place_your_templates_here.txt !templates/alpaca.jinja !templates/chatml.jinja -!templates/chatml_with_headers_tool_calling.jinja + +# Tool calling templates folder +templates/tool_calls/* +!templates/tool_calls +!templates/tool_calls/chatml_with_headers.jinja # Sampler overrides folder sampler_overrides/* diff --git a/templates/chatml_with_headers_tool_calling.jinja b/templates/tool_calls/chatml_with_headers.jinja similarity index 100% rename from templates/chatml_with_headers_tool_calling.jinja rename to templates/tool_calls/chatml_with_headers.jinja From d6ad17097cbc85eef385cd9ada314a2b3e1a5b21 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 8 Sep 2024 12:36:36 -0400 Subject: [PATCH 22/34] Templates: Remove whitespace from metadata Apparently setting variables also adds extraneous whitespace before the template itself. Doing {%- set stop_strings = ["string1"] -%} fixes this issue. Signed-off-by: kingbri --- templates/alpaca.jinja | 2 +- templates/chatml.jinja | 2 +- templates/tool_calls/chatml_with_headers.jinja | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/templates/alpaca.jinja b/templates/alpaca.jinja index 5c7aa7c..41eb098 100644 --- a/templates/alpaca.jinja +++ b/templates/alpaca.jinja @@ -1,5 +1,5 @@ {# Metadata #} -{% set stop_strings = ["### Instruction:", "### Input:", "### Response:"] %} +{%- set stop_strings = ["### Instruction:", "### Input:", "### Response:"] -%} {# Template #} {{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} diff --git a/templates/chatml.jinja b/templates/chatml.jinja index 750b82a..f7dd6f5 100644 --- a/templates/chatml.jinja +++ b/templates/chatml.jinja @@ -1,5 +1,5 @@ {# Metadata #} -{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %} +{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%} {# Template #} {% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %} diff --git a/templates/tool_calls/chatml_with_headers.jinja b/templates/tool_calls/chatml_with_headers.jinja index ecd1d2c..db0c0ac 100644 --- a/templates/tool_calls/chatml_with_headers.jinja +++ b/templates/tool_calls/chatml_with_headers.jinja @@ -1,8 +1,8 @@ {# Metadata #} -{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %} -{% set message_roles = ['system', 'user', 'assistant', 'tool'] %} -{% set tool_start = "<|tool_start|>" %} -{% set tool_end = "<|tool_end|>" %} +{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%} +{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%} +{%- set tool_start = "<|tool_start|>" -%} +{%- set tool_end = "<|tool_end|>" -%} {%- set start_header = "<|start_header_id|>" -%} {%- set end_header = "<|end_header_id|>\n" -%} From 63476041d1c9c7a6a673f3a375e6773cde52b427 Mon Sep 17 00:00:00 2001 From: Cohee <18619528+Cohee1207@users.noreply.github.com> Date: Sun, 8 Sep 2024 22:02:49 +0300 Subject: [PATCH 23/34] Properly specify config value in the error message --- endpoints/OAI/utils/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index d279545..2f51175 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -129,7 +129,7 @@ async def load_inline_model(model_name: str, request: Request): if not unwrap(config.model.get("inline_model_loading"), False): logger.warning( f"Unable to switch model to {model_name} because " - '"inline_model_load" is not True in config.yml.' + '"inline_model_loading" is not True in config.yml.' ) return From cf97113868f78bb557860f115ab8dbe6e0cb08f8 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 8 Sep 2024 21:12:31 -0400 Subject: [PATCH 24/34] Dependencies: Update Exllamav2 v0.2.1 Signed-off-by: kingbri --- backends/exllamav2/utils.py | 2 +- pyproject.toml | 30 +++++++++++++++--------------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 5b1d042..4c192b2 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -8,7 +8,7 @@ from loguru import logger def check_exllama_version(): """Verifies the exllama version""" - required_version = version.parse("0.1.9") + required_version = version.parse("0.2.1") current_version = version.parse(package_version("exllamav2").split("+")[0]) unsupported_message = ( diff --git a/pyproject.toml b/pyproject.toml index b9e80fe..19fcbce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,12 +68,12 @@ cu121 = [ "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", @@ -95,12 +95,12 @@ cu118 = [ "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", @@ -119,9 +119,9 @@ amd = [ "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.0/exllamav2-0.2.0+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options From a370aeb15fd82ea9ffe3096106e4d2470285f375 Mon Sep 17 00:00:00 2001 From: Ati Sharma Date: Mon, 9 Sep 2024 09:19:12 +0100 Subject: [PATCH 25/34] Fix tabby_config.py _from_file Update tabby_config.py to fix issue #196 --- common/tabby_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/common/tabby_config.py b/common/tabby_config.py index f3a189f..efde051 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -61,7 +61,7 @@ class TabbyConfig: config_override = unwrap(args.get("options", {}).get("config")) if config_override: logger.info("Config file override detected in args.") - config = self.from_file(pathlib.Path(config_override)) + config = self._from_file(pathlib.Path(config_override)) return config # Return early if loading from file for key in ["network", "model", "logging", "developer", "embeddings"]: From 810cd400166c0d8d28c3a2178e8c3d42dd7b7107 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 8 Sep 2024 21:31:18 -0400 Subject: [PATCH 26/34] Start: Broadcast start_options only on first-time run Prevents the save from occurring multiple times for no reason. Signed-off-by: kingbri --- start.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/start.py b/start.py index 490570e..844976f 100644 --- a/start.py +++ b/start.py @@ -234,15 +234,15 @@ if __name__ == "__main__": if first_run: start_options["first_run_done"] = True - # Save start options - with open("start_options.json", "w") as start_file: - start_file.write(json.dumps(start_options)) + # Save start options + with open("start_options.json", "w") as start_file: + start_file.write(json.dumps(start_options)) - print( - "Successfully wrote your start script options to `start_options.json`. \n" - "If something goes wrong, editing or deleting the file " - "will reinstall TabbyAPI as a first-time user." - ) + print( + "Successfully wrote your start script options to `start_options.json`. \n" + "If something goes wrong, editing or deleting the file " + "will reinstall TabbyAPI as a first-time user." + ) # Import entrypoint after installing all requirements try: From 54bfb770af25ec65dc8160e3ce17e4cba0d7e92e Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 10 Sep 2024 12:22:07 -0400 Subject: [PATCH 27/34] API: Fix template switch endpoint Forwards a Path instead of a string and adheres to the new pathfinding system. Signed-off-by: kingbri --- endpoints/core/router.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/endpoints/core/router.py b/endpoints/core/router.py index cc9af24..4f6b441 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -445,7 +445,8 @@ async def switch_template(data: TemplateSwitchRequest): raise HTTPException(400, error_message) try: - model.container.prompt_template = PromptTemplate.from_file(data.name) + template_path = pathlib.Path("templates") / data.name + model.container.prompt_template = PromptTemplate.from_file(template_path) except FileNotFoundError as e: error_message = handle_request_error( f"The template name {data.name} doesn't exist. Check the spelling?", From 2c3bc71afaf94fecc80680ddbdf4a289ae9b0944 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 10 Sep 2024 16:45:14 -0400 Subject: [PATCH 28/34] Tree: Switch to asynchronous file handling Using aiofiles, there's no longer a possiblity of blocking file operations that can hang up the event loop. In addition, partially migrate classes to use asynchronous init instead of the normal python magic method. The only exception is config, since that's handled in the synchonous init before the event loop starts. Signed-off-by: kingbri --- backends/exllamav2/model.py | 33 ++++++++++++++++++++++----------- common/auth.py | 15 ++++++++++----- common/model.py | 2 +- common/sampling.py | 8 +++++--- common/tabby_config.py | 2 +- common/templating.py | 17 +++++++++++------ common/transformers_utils.py | 14 +++++++++----- endpoints/core/router.py | 4 ++-- main.py | 4 ++-- 9 files changed, 63 insertions(+), 36 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6e0a8cc..4aedf75 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" +import aiofiles import asyncio import gc import math @@ -106,13 +107,17 @@ class ExllamaV2Container: load_lock: asyncio.Lock = asyncio.Lock() load_condition: asyncio.Condition = asyncio.Condition() - def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): + @classmethod + async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): """ - Primary initializer for model container. + Primary asynchronous initializer for model container. Kwargs are located in config_sample.yml """ + # Create a new instance as a "fake self" + self = cls() + self.quiet = quiet # Initialize config @@ -155,13 +160,13 @@ class ExllamaV2Container: self.draft_config.prepare() # Create the hf_config - self.hf_config = HuggingFaceConfig.from_file(model_directory) + self.hf_config = await HuggingFaceConfig.from_file(model_directory) # Load generation config overrides generation_config_path = model_directory / "generation_config.json" if generation_config_path.exists(): try: - self.generation_config = GenerationConfig.from_file( + self.generation_config = await GenerationConfig.from_file( generation_config_path.parent ) except Exception: @@ -171,7 +176,7 @@ class ExllamaV2Container: ) # Apply a model's config overrides while respecting user settings - kwargs = self.set_model_overrides(**kwargs) + kwargs = await self.set_model_overrides(**kwargs) # MARK: User configuration @@ -320,7 +325,7 @@ class ExllamaV2Container: self.cache_size = self.config.max_seq_len # Try to set prompt template - self.prompt_template = self.find_prompt_template( + self.prompt_template = await self.find_prompt_template( kwargs.get("prompt_template"), model_directory ) @@ -373,7 +378,10 @@ class ExllamaV2Container: self.draft_config.max_input_len = chunk_size self.draft_config.max_attention_size = chunk_size**2 - def set_model_overrides(self, **kwargs): + # Return the created instance + return self + + async def set_model_overrides(self, **kwargs): """Sets overrides from a model folder's config yaml.""" override_config_path = self.model_dir / "tabby_config.yml" @@ -381,8 +389,11 @@ class ExllamaV2Container: if not override_config_path.exists(): return kwargs - with open(override_config_path, "r", encoding="utf8") as override_config_file: - override_args = unwrap(yaml.safe_load(override_config_file), {}) + async with aiofiles.open( + override_config_path, "r", encoding="utf8" + ) as override_config_file: + contents = await override_config_file.read() + override_args = unwrap(yaml.safe_load(contents), {}) # Merge draft overrides beforehand draft_override_args = unwrap(override_args.get("draft"), {}) @@ -393,7 +404,7 @@ class ExllamaV2Container: merged_kwargs = {**override_args, **kwargs} return merged_kwargs - def find_prompt_template(self, prompt_template_name, model_directory): + async def find_prompt_template(self, prompt_template_name, model_directory): """Tries to find a prompt template using various methods.""" logger.info("Attempting to load a prompt template if present.") @@ -431,7 +442,7 @@ class ExllamaV2Container: # Continue on exception since functions are tried as they fail for template_func in find_template_functions: try: - prompt_template = template_func() + prompt_template = await template_func() if prompt_template is not None: return prompt_template except TemplateLoadError as e: diff --git a/common/auth.py b/common/auth.py index 174208d..6fcfec9 100644 --- a/common/auth.py +++ b/common/auth.py @@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local application, it should be fine. """ +import aiofiles import secrets import yaml from fastapi import Header, HTTPException, Request @@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None DISABLE_AUTH: bool = False -def load_auth_keys(disable_from_config: bool): +async def load_auth_keys(disable_from_config: bool): """Load the authentication keys from api_tokens.yml. If the file does not exist, generate new keys and save them to api_tokens.yml.""" global AUTH_KEYS @@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool): return try: - with open("api_tokens.yml", "r", encoding="utf8") as auth_file: - auth_keys_dict = yaml.safe_load(auth_file) + async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file: + contents = await auth_file.read() + auth_keys_dict = yaml.safe_load(contents) AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) except FileNotFoundError: new_auth_keys = AuthKeys( @@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool): ) AUTH_KEYS = new_auth_keys - with open("api_tokens.yml", "w", encoding="utf8") as auth_file: - yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False) + async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file: + new_auth_yaml = yaml.safe_dump( + AUTH_KEYS.model_dump(), default_flow_style=False + ) + await auth_file.write(new_auth_yaml) logger.info( f"Your API key is: {AUTH_KEYS.api_key}\n" diff --git a/common/model.py b/common/model.py index a9ddfff..a1f29b5 100644 --- a/common/model.py +++ b/common/model.py @@ -67,7 +67,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): logger.info("Unloading existing model.") await unload_model() - container = ExllamaV2Container(model_path.resolve(), False, **kwargs) + container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) model_type = "draft" if container.draft_config else "model" load_status = container.load_gen(load_progress, **kwargs) diff --git a/common/sampling.py b/common/sampling.py index eab2a4c..a7da3ca 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -1,5 +1,6 @@ """Common functions for sampling parameters""" +import aiofiles import json import pathlib import yaml @@ -407,14 +408,15 @@ def overrides_from_dict(new_overrides: dict): raise TypeError("New sampler overrides must be a dict!") -def overrides_from_file(preset_name: str): +async def overrides_from_file(preset_name: str): """Fetches an override preset from a file""" preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") if preset_path.exists(): overrides_container.selected_preset = preset_path.stem - with open(preset_path, "r", encoding="utf8") as raw_preset: - preset = yaml.safe_load(raw_preset) + async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset: + contents = await raw_preset.read() + preset = yaml.safe_load(contents) overrides_from_dict(preset) logger.info("Applied sampler overrides from file.") diff --git a/common/tabby_config.py b/common/tabby_config.py index efde051..c49df91 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -17,7 +17,7 @@ class TabbyConfig: embeddings: dict = {} def load(self, arguments: Optional[dict] = None): - """load the global application config""" + """Synchronously loads the global application config""" # config is applied in order of items in the list configs = [ diff --git a/common/templating.py b/common/templating.py index d515cf8..2c0e5e2 100644 --- a/common/templating.py +++ b/common/templating.py @@ -1,5 +1,6 @@ """Small replication of AutoTokenizer's chat template system for efficiency""" +import aiofiles import json import pathlib from importlib.metadata import version as package_version @@ -110,7 +111,7 @@ class PromptTemplate: self.template = self.compile(raw_template) @classmethod - def from_file(self, template_path: pathlib.Path): + async def from_file(self, template_path: pathlib.Path): """Get a template from a jinja file.""" # Add the jinja extension if it isn't provided @@ -121,10 +122,13 @@ class PromptTemplate: template_path = template_path.with_suffix(".jinja") if template_path.exists(): - with open(template_path, "r", encoding="utf8") as raw_template_stream: + async with aiofiles.open( + template_path, "r", encoding="utf8" + ) as raw_template_stream: + contents = await raw_template_stream.read() return PromptTemplate( name=template_name, - raw_template=raw_template_stream.read(), + raw_template=contents, ) else: # Let the user know if the template file isn't found @@ -133,15 +137,16 @@ class PromptTemplate: ) @classmethod - def from_model_json( + async def from_model_json( self, json_path: pathlib.Path, key: str, name: Optional[str] = None ): """Get a template from a JSON file. Requires a key and template name""" if not json_path.exists(): raise TemplateLoadError(f'Model JSON path "{json_path}" not found.') - with open(json_path, "r", encoding="utf8") as config_file: - model_config = json.load(config_file) + async with aiofiles.open(json_path, "r", encoding="utf8") as config_file: + contents = await config_file.read() + model_config = json.loads(contents) chat_template = model_config.get(key) if not chat_template: diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 9db8ad2..386f543 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,3 +1,4 @@ +import aiofiles import json import pathlib from typing import List, Optional, Union @@ -15,11 +16,11 @@ class GenerationConfig(BaseModel): bad_words_ids: Optional[List[List[int]]] = None @classmethod - def from_file(self, model_directory: pathlib.Path): + async def from_file(self, model_directory: pathlib.Path): """Create an instance from a generation config file.""" generation_config_path = model_directory / "generation_config.json" - with open( + async with aiofiles.open( generation_config_path, "r", encoding="utf8" ) as generation_config_json: generation_config_dict = json.load(generation_config_json) @@ -43,12 +44,15 @@ class HuggingFaceConfig(BaseModel): badwordsids: Optional[str] = None @classmethod - def from_file(self, model_directory: pathlib.Path): + async def from_file(self, model_directory: pathlib.Path): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json" - with open(hf_config_path, "r", encoding="utf8") as hf_config_json: - hf_config_dict = json.load(hf_config_json) + async with aiofiles.open( + hf_config_path, "r", encoding="utf8" + ) as hf_config_json: + contents = await hf_config_json.read() + hf_config_dict = json.loads(contents) return self.model_validate(hf_config_dict) def get_badwordsids(self): diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 4f6b441..2d7a139 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -446,7 +446,7 @@ async def switch_template(data: TemplateSwitchRequest): try: template_path = pathlib.Path("templates") / data.name - model.container.prompt_template = PromptTemplate.from_file(template_path) + model.container.prompt_template = await PromptTemplate.from_file(template_path) except FileNotFoundError as e: error_message = handle_request_error( f"The template name {data.name} doesn't exist. Check the spelling?", @@ -495,7 +495,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): if data.preset: try: - sampling.overrides_from_file(data.preset) + await sampling.overrides_from_file(data.preset) except FileNotFoundError as e: error_message = handle_request_error( f"Sampler override preset with name {data.preset} does not exist. " diff --git a/main.py b/main.py index 740e1d0..7c20910 100644 --- a/main.py +++ b/main.py @@ -50,7 +50,7 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - load_auth_keys(unwrap(config.network.get("disable_auth"), False)) + await load_auth_keys(unwrap(config.network.get("disable_auth"), False)) # Override the generation log options if given if config.logging: @@ -62,7 +62,7 @@ async def entrypoint_async(): sampling_override_preset = config.sampling.get("override_preset") if sampling_override_preset: try: - sampling.overrides_from_file(sampling_override_preset) + await sampling.overrides_from_file(sampling_override_preset) except FileNotFoundError as e: logger.warning(str(e)) From 5e8ff9a00435c76a5c742c590b04c1ad7a76706e Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 10 Sep 2024 20:52:29 -0400 Subject: [PATCH 29/34] Tree: Fix classmethod usage Instead of self, use cls which passes a type of the class. Signed-off-by: kingbri --- common/templating.py | 8 ++++---- common/transformers_utils.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/common/templating.py b/common/templating.py index 2c0e5e2..1200e5c 100644 --- a/common/templating.py +++ b/common/templating.py @@ -111,7 +111,7 @@ class PromptTemplate: self.template = self.compile(raw_template) @classmethod - async def from_file(self, template_path: pathlib.Path): + async def from_file(cls, template_path: pathlib.Path): """Get a template from a jinja file.""" # Add the jinja extension if it isn't provided @@ -126,7 +126,7 @@ class PromptTemplate: template_path, "r", encoding="utf8" ) as raw_template_stream: contents = await raw_template_stream.read() - return PromptTemplate( + return cls( name=template_name, raw_template=contents, ) @@ -138,7 +138,7 @@ class PromptTemplate: @classmethod async def from_model_json( - self, json_path: pathlib.Path, key: str, name: Optional[str] = None + cls, json_path: pathlib.Path, key: str, name: Optional[str] = None ): """Get a template from a JSON file. Requires a key and template name""" if not json_path.exists(): @@ -177,7 +177,7 @@ class PromptTemplate: ) else: # Can safely assume the chat template is the old style - return PromptTemplate( + return cls( name="from_tokenizer_config", raw_template=chat_template, ) diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 386f543..4fd848d 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -16,7 +16,7 @@ class GenerationConfig(BaseModel): bad_words_ids: Optional[List[List[int]]] = None @classmethod - async def from_file(self, model_directory: pathlib.Path): + async def from_file(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" generation_config_path = model_directory / "generation_config.json" @@ -24,7 +24,7 @@ class GenerationConfig(BaseModel): generation_config_path, "r", encoding="utf8" ) as generation_config_json: generation_config_dict = json.load(generation_config_json) - return self.model_validate(generation_config_dict) + return cls.model_validate(generation_config_dict) def eos_tokens(self): """Wrapper method to fetch EOS tokens.""" @@ -44,7 +44,7 @@ class HuggingFaceConfig(BaseModel): badwordsids: Optional[str] = None @classmethod - async def from_file(self, model_directory: pathlib.Path): + async def from_file(cls, model_directory: pathlib.Path): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json" @@ -53,7 +53,7 @@ class HuggingFaceConfig(BaseModel): ) as hf_config_json: contents = await hf_config_json.read() hf_config_dict = json.loads(contents) - return self.model_validate(hf_config_dict) + return cls.model_validate(hf_config_dict) def get_badwordsids(self): """Wrapper method to fetch badwordsids.""" From aa832b86276558e627a7feb2297ae9e01344b9ec Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 10 Sep 2024 20:57:13 -0400 Subject: [PATCH 30/34] Tree: Format Signed-off-by: kingbri --- start.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/start.py b/start.py index 844976f..7e7776d 100644 --- a/start.py +++ b/start.py @@ -239,7 +239,8 @@ if __name__ == "__main__": start_file.write(json.dumps(start_options)) print( - "Successfully wrote your start script options to `start_options.json`. \n" + "Successfully wrote your start script options to " + "`start_options.json`. \n" "If something goes wrong, editing or deleting the file " "will reinstall TabbyAPI as a first-time user." ) From 62beb2b1c87bf1745d184a2532dafebd285034c7 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 10 Sep 2024 21:30:53 -0400 Subject: [PATCH 31/34] Config: Fetch the correct dict for draft_model and lora Fixed fetching from the merged config instead of the sub-config Signed-off-by: kingbri --- common/tabby_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/tabby_config.py b/common/tabby_config.py index c49df91..215f9d0 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -30,8 +30,8 @@ class TabbyConfig: self.network = unwrap(merged_config.get("network"), {}) self.logging = unwrap(merged_config.get("logging"), {}) self.model = unwrap(merged_config.get("model"), {}) - self.draft_model = unwrap(merged_config.get("draft"), {}) - self.lora = unwrap(merged_config.get("draft"), {}) + self.draft_model = unwrap(self.model.get("draft"), {}) + self.lora = unwrap(self.model.get("lora"), {}) self.sampling = unwrap(merged_config.get("sampling"), {}) self.developer = unwrap(merged_config.get("developer"), {}) self.embeddings = unwrap(merged_config.get("embeddings"), {}) From 7baef05b491cd85e2faf4b06faae52d7c757d2c0 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 10 Sep 2024 22:41:39 -0400 Subject: [PATCH 32/34] Transformers Utils: Fix file read Use asynchronous JSON reading Signed-off-by: kingbri --- common/transformers_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 4fd848d..c00fef4 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -23,7 +23,8 @@ class GenerationConfig(BaseModel): async with aiofiles.open( generation_config_path, "r", encoding="utf8" ) as generation_config_json: - generation_config_dict = json.load(generation_config_json) + contents = await generation_config_json.read() + generation_config_dict = json.loads(contents) return cls.model_validate(generation_config_dict) def eos_tokens(self): From b9e5693c1b53dce3043654cf20fff807475747eb Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 10 Sep 2024 23:35:35 -0400 Subject: [PATCH 33/34] API + Model: Apply config.yml defaults for all load paths There are two ways to load a model: 1. Via the load endpoint 2. Inline with a completion The defaults were not applying on the inline load, so rewrite to fix that. However, while doing this, set up a defaults dictionary rather than comparing it at runtime and remove the pydantic default lambda on all the model load fields. This makes the code cleaner and establishes a clear config tree for loading models. Signed-off-by: kingbri --- common/model.py | 24 ++---------- common/tabby_config.py | 14 +++++++ endpoints/core/types/model.py | 70 ++++++++++++----------------------- 3 files changed, 41 insertions(+), 67 deletions(-) diff --git a/common/model.py b/common/model.py index a1f29b5..5fdfc5b 100644 --- a/common/model.py +++ b/common/model.py @@ -13,7 +13,6 @@ from typing import Optional from common.logger import get_loading_progress_bar from common.networking import handle_request_error from common.tabby_config import config -from common.utils import unwrap from endpoints.utils import do_export_openapi if not do_export_openapi: @@ -67,6 +66,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): logger.info("Unloading existing model.") await unload_model() + # Merge with config defaults + kwargs = {**config.model_defaults, **kwargs} + + # Create a new container container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) model_type = "draft" if container.draft_config else "model" @@ -149,25 +152,6 @@ async def unload_embedding_model(): embeddings_container = None -# FIXME: Maybe make this a one-time function instead of a dynamic default -def get_config_default(key: str, model_type: str = "model"): - """Fetches a default value from model config if allowed by the user.""" - - default_keys = unwrap(config.model.get("use_as_default"), []) - - # Add extra keys to defaults - default_keys.append("embeddings_device") - - if key in default_keys: - # Is this a draft model load parameter? - if model_type == "draft": - return config.draft_model.get(key) - elif model_type == "embedding": - return config.embeddings.get(key) - else: - return config.model.get(key) - - async def check_model_container(): """FastAPI depends that checks if a model isn't loaded or currently loading.""" diff --git a/common/tabby_config.py b/common/tabby_config.py index 215f9d0..704e3ba 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -7,6 +7,9 @@ from common.utils import unwrap, merge_dicts class TabbyConfig: + """Common config class for TabbyAPI. Loaded into sub-dictionaries from YAML file.""" + + # Sub-blocks of yaml network: dict = {} logging: dict = {} model: dict = {} @@ -16,6 +19,9 @@ class TabbyConfig: developer: dict = {} embeddings: dict = {} + # Persistent defaults + model_defaults: dict = {} + def load(self, arguments: Optional[dict] = None): """Synchronously loads the global application config""" @@ -36,6 +42,14 @@ class TabbyConfig: self.developer = unwrap(merged_config.get("developer"), {}) self.embeddings = unwrap(merged_config.get("embeddings"), {}) + # Set model defaults dict once to prevent on-demand reconstruction + default_keys = unwrap(self.model.get("use_as_default"), []) + for key in default_keys: + if key in self.model: + self.model_defaults[key] = config.model[key] + elif key in self.draft_model: + self.model_defaults[key] = config.draft_model[key] + def _from_file(self, config_path: pathlib.Path): """loads config from a given file path""" diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 154a906..e8bd882 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -5,7 +5,8 @@ from time import time from typing import List, Literal, Optional, Union from common.gen_logging import GenLogPreferences -from common.model import get_config_default +from common.tabby_config import config +from common.utils import unwrap class ModelCardParameters(BaseModel): @@ -51,23 +52,13 @@ class DraftModelLoadRequest(BaseModel): draft_model_name: str # Config arguments - draft_rope_scale: Optional[float] = Field( - default_factory=lambda: get_config_default( - "draft_rope_scale", model_type="draft" - ) - ) + draft_rope_scale: Optional[float] = None draft_rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( description='Automatically calculated if set to "auto"', - default_factory=lambda: get_config_default( - "draft_rope_alpha", model_type="draft" - ), + default=None, examples=[1.0], ) - draft_cache_mode: Optional[str] = Field( - default_factory=lambda: get_config_default( - "draft_cache_mode", model_type="draft" - ) - ) + draft_cache_mode: Optional[str] = None class ModelLoadRequest(BaseModel): @@ -78,62 +69,45 @@ class ModelLoadRequest(BaseModel): # Config arguments - # Max seq len is fetched from config.json of the model by default max_seq_len: Optional[int] = Field( description="Leave this blank to use the model's base sequence length", - default_factory=lambda: get_config_default("max_seq_len"), + default=None, examples=[4096], ) override_base_seq_len: Optional[int] = Field( description=( "Overrides the model's base sequence length. " "Leave blank if unsure" ), - default_factory=lambda: get_config_default("override_base_seq_len"), + default=None, examples=[4096], ) cache_size: Optional[int] = Field( description=("Number in tokens, must be greater than or equal to max_seq_len"), - default_factory=lambda: get_config_default("cache_size"), + default=None, examples=[4096], ) - tensor_parallel: Optional[bool] = Field( - default_factory=lambda: get_config_default("tensor_parallel") - ) - gpu_split_auto: Optional[bool] = Field( - default_factory=lambda: get_config_default("gpu_split_auto") - ) - autosplit_reserve: Optional[List[float]] = Field( - default_factory=lambda: get_config_default("autosplit_reserve") - ) + tensor_parallel: Optional[bool] = None + gpu_split_auto: Optional[bool] = None + autosplit_reserve: Optional[List[float]] = None gpu_split: Optional[List[float]] = Field( - default_factory=lambda: get_config_default("gpu_split"), + default=None, examples=[[24.0, 20.0]], ) rope_scale: Optional[float] = Field( description="Automatically pulled from the model's config if not present", - default_factory=lambda: get_config_default("rope_scale"), + default=None, examples=[1.0], ) rope_alpha: Optional[Union[float, Literal["auto"]]] = Field( description='Automatically calculated if set to "auto"', - default_factory=lambda: get_config_default("rope_alpha"), + default=None, examples=[1.0], ) - cache_mode: Optional[str] = Field( - default_factory=lambda: get_config_default("cache_mode") - ) - chunk_size: Optional[int] = Field( - default_factory=lambda: get_config_default("chunk_size") - ) - prompt_template: Optional[str] = Field( - default_factory=lambda: get_config_default("prompt_template") - ) - num_experts_per_token: Optional[int] = Field( - default_factory=lambda: get_config_default("num_experts_per_token") - ) - fasttensors: Optional[bool] = Field( - default_factory=lambda: get_config_default("fasttensors") - ) + cache_mode: Optional[str] = None + chunk_size: Optional[int] = None + prompt_template: Optional[str] = None + num_experts_per_token: Optional[int] = None + fasttensors: Optional[bool] = None # Non-config arguments draft: Optional[DraftModelLoadRequest] = None @@ -142,9 +116,11 @@ class ModelLoadRequest(BaseModel): class EmbeddingModelLoadRequest(BaseModel): name: str + + # Set default from the config embeddings_device: Optional[str] = Field( - default_factory=lambda: get_config_default( - "embeddings_device", model_type="embedding" + default_factory=lambda: unwrap( + config.embeddings.get("embeddings_device"), "cpu" ) ) From e00eb09ef397c253286201da7ff6d1ef0fb431c4 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 11 Sep 2024 00:08:55 -0400 Subject: [PATCH 34/34] OAI: Add cancellation with inline load When the request is cancelled, cancel the load task. In addition, when checking if a model container exists, also check if the model is fully loaded. Signed-off-by: kingbri --- endpoints/OAI/router.py | 9 ++++++++- endpoints/OAI/utils/completion.py | 8 ++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 12c95a2..0b7c1a6 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -55,7 +55,14 @@ async def completion_request( """ if data.model: - await load_inline_model(data.model, request) + inline_load_task = asyncio.create_task(load_inline_model(data.model, request)) + + await run_with_request_disconnect( + request, + inline_load_task, + disconnect_message=f"Model switch for generation {request.state.id} " + + "cancelled by user.", + ) else: await check_model_container() diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 2f51175..df4bf19 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -112,8 +112,12 @@ async def _stream_collector( async def load_inline_model(model_name: str, request: Request): """Load a model from the data.model parameter""" - # Return if the model container already exists - if model.container and model.container.model_dir.name == model_name: + # Return if the model container already exists and the model is fully loaded + if ( + model.container + and model.container.model_dir.name == model_name + and model.container.model_loaded + ): return # Inline model loading isn't enabled or the user isn't an admin