From 2c3bc71afaf94fecc80680ddbdf4a289ae9b0944 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 10 Sep 2024 16:45:14 -0400 Subject: [PATCH] Tree: Switch to asynchronous file handling Using aiofiles, there's no longer a possiblity of blocking file operations that can hang up the event loop. In addition, partially migrate classes to use asynchronous init instead of the normal python magic method. The only exception is config, since that's handled in the synchonous init before the event loop starts. Signed-off-by: kingbri --- backends/exllamav2/model.py | 33 ++++++++++++++++++++++----------- common/auth.py | 15 ++++++++++----- common/model.py | 2 +- common/sampling.py | 8 +++++--- common/tabby_config.py | 2 +- common/templating.py | 17 +++++++++++------ common/transformers_utils.py | 14 +++++++++----- endpoints/core/router.py | 4 ++-- main.py | 4 ++-- 9 files changed, 63 insertions(+), 36 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6e0a8cc..4aedf75 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" +import aiofiles import asyncio import gc import math @@ -106,13 +107,17 @@ class ExllamaV2Container: load_lock: asyncio.Lock = asyncio.Lock() load_condition: asyncio.Condition = asyncio.Condition() - def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): + @classmethod + async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): """ - Primary initializer for model container. + Primary asynchronous initializer for model container. Kwargs are located in config_sample.yml """ + # Create a new instance as a "fake self" + self = cls() + self.quiet = quiet # Initialize config @@ -155,13 +160,13 @@ class ExllamaV2Container: self.draft_config.prepare() # Create the hf_config - self.hf_config = HuggingFaceConfig.from_file(model_directory) + self.hf_config = await HuggingFaceConfig.from_file(model_directory) # Load generation config overrides generation_config_path = model_directory / "generation_config.json" if generation_config_path.exists(): try: - self.generation_config = GenerationConfig.from_file( + self.generation_config = await GenerationConfig.from_file( generation_config_path.parent ) except Exception: @@ -171,7 +176,7 @@ class ExllamaV2Container: ) # Apply a model's config overrides while respecting user settings - kwargs = self.set_model_overrides(**kwargs) + kwargs = await self.set_model_overrides(**kwargs) # MARK: User configuration @@ -320,7 +325,7 @@ class ExllamaV2Container: self.cache_size = self.config.max_seq_len # Try to set prompt template - self.prompt_template = self.find_prompt_template( + self.prompt_template = await self.find_prompt_template( kwargs.get("prompt_template"), model_directory ) @@ -373,7 +378,10 @@ class ExllamaV2Container: self.draft_config.max_input_len = chunk_size self.draft_config.max_attention_size = chunk_size**2 - def set_model_overrides(self, **kwargs): + # Return the created instance + return self + + async def set_model_overrides(self, **kwargs): """Sets overrides from a model folder's config yaml.""" override_config_path = self.model_dir / "tabby_config.yml" @@ -381,8 +389,11 @@ class ExllamaV2Container: if not override_config_path.exists(): return kwargs - with open(override_config_path, "r", encoding="utf8") as override_config_file: - override_args = unwrap(yaml.safe_load(override_config_file), {}) + async with aiofiles.open( + override_config_path, "r", encoding="utf8" + ) as override_config_file: + contents = await override_config_file.read() + override_args = unwrap(yaml.safe_load(contents), {}) # Merge draft overrides beforehand draft_override_args = unwrap(override_args.get("draft"), {}) @@ -393,7 +404,7 @@ class ExllamaV2Container: merged_kwargs = {**override_args, **kwargs} return merged_kwargs - def find_prompt_template(self, prompt_template_name, model_directory): + async def find_prompt_template(self, prompt_template_name, model_directory): """Tries to find a prompt template using various methods.""" logger.info("Attempting to load a prompt template if present.") @@ -431,7 +442,7 @@ class ExllamaV2Container: # Continue on exception since functions are tried as they fail for template_func in find_template_functions: try: - prompt_template = template_func() + prompt_template = await template_func() if prompt_template is not None: return prompt_template except TemplateLoadError as e: diff --git a/common/auth.py b/common/auth.py index 174208d..6fcfec9 100644 --- a/common/auth.py +++ b/common/auth.py @@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local application, it should be fine. """ +import aiofiles import secrets import yaml from fastapi import Header, HTTPException, Request @@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None DISABLE_AUTH: bool = False -def load_auth_keys(disable_from_config: bool): +async def load_auth_keys(disable_from_config: bool): """Load the authentication keys from api_tokens.yml. If the file does not exist, generate new keys and save them to api_tokens.yml.""" global AUTH_KEYS @@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool): return try: - with open("api_tokens.yml", "r", encoding="utf8") as auth_file: - auth_keys_dict = yaml.safe_load(auth_file) + async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file: + contents = await auth_file.read() + auth_keys_dict = yaml.safe_load(contents) AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) except FileNotFoundError: new_auth_keys = AuthKeys( @@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool): ) AUTH_KEYS = new_auth_keys - with open("api_tokens.yml", "w", encoding="utf8") as auth_file: - yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False) + async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file: + new_auth_yaml = yaml.safe_dump( + AUTH_KEYS.model_dump(), default_flow_style=False + ) + await auth_file.write(new_auth_yaml) logger.info( f"Your API key is: {AUTH_KEYS.api_key}\n" diff --git a/common/model.py b/common/model.py index a9ddfff..a1f29b5 100644 --- a/common/model.py +++ b/common/model.py @@ -67,7 +67,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): logger.info("Unloading existing model.") await unload_model() - container = ExllamaV2Container(model_path.resolve(), False, **kwargs) + container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) model_type = "draft" if container.draft_config else "model" load_status = container.load_gen(load_progress, **kwargs) diff --git a/common/sampling.py b/common/sampling.py index eab2a4c..a7da3ca 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -1,5 +1,6 @@ """Common functions for sampling parameters""" +import aiofiles import json import pathlib import yaml @@ -407,14 +408,15 @@ def overrides_from_dict(new_overrides: dict): raise TypeError("New sampler overrides must be a dict!") -def overrides_from_file(preset_name: str): +async def overrides_from_file(preset_name: str): """Fetches an override preset from a file""" preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") if preset_path.exists(): overrides_container.selected_preset = preset_path.stem - with open(preset_path, "r", encoding="utf8") as raw_preset: - preset = yaml.safe_load(raw_preset) + async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset: + contents = await raw_preset.read() + preset = yaml.safe_load(contents) overrides_from_dict(preset) logger.info("Applied sampler overrides from file.") diff --git a/common/tabby_config.py b/common/tabby_config.py index efde051..c49df91 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -17,7 +17,7 @@ class TabbyConfig: embeddings: dict = {} def load(self, arguments: Optional[dict] = None): - """load the global application config""" + """Synchronously loads the global application config""" # config is applied in order of items in the list configs = [ diff --git a/common/templating.py b/common/templating.py index d515cf8..2c0e5e2 100644 --- a/common/templating.py +++ b/common/templating.py @@ -1,5 +1,6 @@ """Small replication of AutoTokenizer's chat template system for efficiency""" +import aiofiles import json import pathlib from importlib.metadata import version as package_version @@ -110,7 +111,7 @@ class PromptTemplate: self.template = self.compile(raw_template) @classmethod - def from_file(self, template_path: pathlib.Path): + async def from_file(self, template_path: pathlib.Path): """Get a template from a jinja file.""" # Add the jinja extension if it isn't provided @@ -121,10 +122,13 @@ class PromptTemplate: template_path = template_path.with_suffix(".jinja") if template_path.exists(): - with open(template_path, "r", encoding="utf8") as raw_template_stream: + async with aiofiles.open( + template_path, "r", encoding="utf8" + ) as raw_template_stream: + contents = await raw_template_stream.read() return PromptTemplate( name=template_name, - raw_template=raw_template_stream.read(), + raw_template=contents, ) else: # Let the user know if the template file isn't found @@ -133,15 +137,16 @@ class PromptTemplate: ) @classmethod - def from_model_json( + async def from_model_json( self, json_path: pathlib.Path, key: str, name: Optional[str] = None ): """Get a template from a JSON file. Requires a key and template name""" if not json_path.exists(): raise TemplateLoadError(f'Model JSON path "{json_path}" not found.') - with open(json_path, "r", encoding="utf8") as config_file: - model_config = json.load(config_file) + async with aiofiles.open(json_path, "r", encoding="utf8") as config_file: + contents = await config_file.read() + model_config = json.loads(contents) chat_template = model_config.get(key) if not chat_template: diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 9db8ad2..386f543 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,3 +1,4 @@ +import aiofiles import json import pathlib from typing import List, Optional, Union @@ -15,11 +16,11 @@ class GenerationConfig(BaseModel): bad_words_ids: Optional[List[List[int]]] = None @classmethod - def from_file(self, model_directory: pathlib.Path): + async def from_file(self, model_directory: pathlib.Path): """Create an instance from a generation config file.""" generation_config_path = model_directory / "generation_config.json" - with open( + async with aiofiles.open( generation_config_path, "r", encoding="utf8" ) as generation_config_json: generation_config_dict = json.load(generation_config_json) @@ -43,12 +44,15 @@ class HuggingFaceConfig(BaseModel): badwordsids: Optional[str] = None @classmethod - def from_file(self, model_directory: pathlib.Path): + async def from_file(self, model_directory: pathlib.Path): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json" - with open(hf_config_path, "r", encoding="utf8") as hf_config_json: - hf_config_dict = json.load(hf_config_json) + async with aiofiles.open( + hf_config_path, "r", encoding="utf8" + ) as hf_config_json: + contents = await hf_config_json.read() + hf_config_dict = json.loads(contents) return self.model_validate(hf_config_dict) def get_badwordsids(self): diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 4f6b441..2d7a139 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -446,7 +446,7 @@ async def switch_template(data: TemplateSwitchRequest): try: template_path = pathlib.Path("templates") / data.name - model.container.prompt_template = PromptTemplate.from_file(template_path) + model.container.prompt_template = await PromptTemplate.from_file(template_path) except FileNotFoundError as e: error_message = handle_request_error( f"The template name {data.name} doesn't exist. Check the spelling?", @@ -495,7 +495,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): if data.preset: try: - sampling.overrides_from_file(data.preset) + await sampling.overrides_from_file(data.preset) except FileNotFoundError as e: error_message = handle_request_error( f"Sampler override preset with name {data.preset} does not exist. " diff --git a/main.py b/main.py index 740e1d0..7c20910 100644 --- a/main.py +++ b/main.py @@ -50,7 +50,7 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - load_auth_keys(unwrap(config.network.get("disable_auth"), False)) + await load_auth_keys(unwrap(config.network.get("disable_auth"), False)) # Override the generation log options if given if config.logging: @@ -62,7 +62,7 @@ async def entrypoint_async(): sampling_override_preset = config.sampling.get("override_preset") if sampling_override_preset: try: - sampling.overrides_from_file(sampling_override_preset) + await sampling.overrides_from_file(sampling_override_preset) except FileNotFoundError as e: logger.warning(str(e))