diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6e0a8cc..4aedf75 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1,5 +1,6 @@ """The model container class for ExLlamaV2 models.""" +import aiofiles import asyncio import gc import math @@ -106,13 +107,17 @@ class ExllamaV2Container: load_lock: asyncio.Lock = asyncio.Lock() load_condition: asyncio.Condition = asyncio.Condition() - def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): + @classmethod + async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): """ - Primary initializer for model container. + Primary asynchronous initializer for model container. Kwargs are located in config_sample.yml """ + # Create a new instance as a "fake self" + self = cls() + self.quiet = quiet # Initialize config @@ -155,13 +160,13 @@ class ExllamaV2Container: self.draft_config.prepare() # Create the hf_config - self.hf_config = HuggingFaceConfig.from_file(model_directory) + self.hf_config = await HuggingFaceConfig.from_file(model_directory) # Load generation config overrides generation_config_path = model_directory / "generation_config.json" if generation_config_path.exists(): try: - self.generation_config = GenerationConfig.from_file( + self.generation_config = await GenerationConfig.from_file( generation_config_path.parent ) except Exception: @@ -171,7 +176,7 @@ class ExllamaV2Container: ) # Apply a model's config overrides while respecting user settings - kwargs = self.set_model_overrides(**kwargs) + kwargs = await self.set_model_overrides(**kwargs) # MARK: User configuration @@ -320,7 +325,7 @@ class ExllamaV2Container: self.cache_size = self.config.max_seq_len # Try to set prompt template - self.prompt_template = self.find_prompt_template( + self.prompt_template = await self.find_prompt_template( kwargs.get("prompt_template"), model_directory ) @@ -373,7 +378,10 @@ class ExllamaV2Container: self.draft_config.max_input_len = chunk_size self.draft_config.max_attention_size = chunk_size**2 - def set_model_overrides(self, **kwargs): + # Return the created instance + return self + + async def set_model_overrides(self, **kwargs): """Sets overrides from a model folder's config yaml.""" override_config_path = self.model_dir / "tabby_config.yml" @@ -381,8 +389,11 @@ class ExllamaV2Container: if not override_config_path.exists(): return kwargs - with open(override_config_path, "r", encoding="utf8") as override_config_file: - override_args = unwrap(yaml.safe_load(override_config_file), {}) + async with aiofiles.open( + override_config_path, "r", encoding="utf8" + ) as override_config_file: + contents = await override_config_file.read() + override_args = unwrap(yaml.safe_load(contents), {}) # Merge draft overrides beforehand draft_override_args = unwrap(override_args.get("draft"), {}) @@ -393,7 +404,7 @@ class ExllamaV2Container: merged_kwargs = {**override_args, **kwargs} return merged_kwargs - def find_prompt_template(self, prompt_template_name, model_directory): + async def find_prompt_template(self, prompt_template_name, model_directory): """Tries to find a prompt template using various methods.""" logger.info("Attempting to load a prompt template if present.") @@ -431,7 +442,7 @@ class ExllamaV2Container: # Continue on exception since functions are tried as they fail for template_func in find_template_functions: try: - prompt_template = template_func() + prompt_template = await template_func() if prompt_template is not None: return prompt_template except TemplateLoadError as e: diff --git a/common/auth.py b/common/auth.py index 174208d..6fcfec9 100644 --- a/common/auth.py +++ b/common/auth.py @@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local application, it should be fine. """ +import aiofiles import secrets import yaml from fastapi import Header, HTTPException, Request @@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None DISABLE_AUTH: bool = False -def load_auth_keys(disable_from_config: bool): +async def load_auth_keys(disable_from_config: bool): """Load the authentication keys from api_tokens.yml. If the file does not exist, generate new keys and save them to api_tokens.yml.""" global AUTH_KEYS @@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool): return try: - with open("api_tokens.yml", "r", encoding="utf8") as auth_file: - auth_keys_dict = yaml.safe_load(auth_file) + async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file: + contents = await auth_file.read() + auth_keys_dict = yaml.safe_load(contents) AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict) except FileNotFoundError: new_auth_keys = AuthKeys( @@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool): ) AUTH_KEYS = new_auth_keys - with open("api_tokens.yml", "w", encoding="utf8") as auth_file: - yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False) + async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file: + new_auth_yaml = yaml.safe_dump( + AUTH_KEYS.model_dump(), default_flow_style=False + ) + await auth_file.write(new_auth_yaml) logger.info( f"Your API key is: {AUTH_KEYS.api_key}\n" diff --git a/common/model.py b/common/model.py index a9ddfff..a1f29b5 100644 --- a/common/model.py +++ b/common/model.py @@ -67,7 +67,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): logger.info("Unloading existing model.") await unload_model() - container = ExllamaV2Container(model_path.resolve(), False, **kwargs) + container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) model_type = "draft" if container.draft_config else "model" load_status = container.load_gen(load_progress, **kwargs) diff --git a/common/sampling.py b/common/sampling.py index eab2a4c..a7da3ca 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -1,5 +1,6 @@ """Common functions for sampling parameters""" +import aiofiles import json import pathlib import yaml @@ -407,14 +408,15 @@ def overrides_from_dict(new_overrides: dict): raise TypeError("New sampler overrides must be a dict!") -def overrides_from_file(preset_name: str): +async def overrides_from_file(preset_name: str): """Fetches an override preset from a file""" preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") if preset_path.exists(): overrides_container.selected_preset = preset_path.stem - with open(preset_path, "r", encoding="utf8") as raw_preset: - preset = yaml.safe_load(raw_preset) + async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset: + contents = await raw_preset.read() + preset = yaml.safe_load(contents) overrides_from_dict(preset) logger.info("Applied sampler overrides from file.") diff --git a/common/tabby_config.py b/common/tabby_config.py index efde051..c49df91 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -17,7 +17,7 @@ class TabbyConfig: embeddings: dict = {} def load(self, arguments: Optional[dict] = None): - """load the global application config""" + """Synchronously loads the global application config""" # config is applied in order of items in the list configs = [ diff --git a/common/templating.py b/common/templating.py index d515cf8..2c0e5e2 100644 --- a/common/templating.py +++ b/common/templating.py @@ -1,5 +1,6 @@ """Small replication of AutoTokenizer's chat template system for efficiency""" +import aiofiles import json import pathlib from importlib.metadata import version as package_version @@ -110,7 +111,7 @@ class PromptTemplate: self.template = self.compile(raw_template) @classmethod - def from_file(self, template_path: pathlib.Path): + async def from_file(self, template_path: pathlib.Path): """Get a template from a jinja file.""" # Add the jinja extension if it isn't provided @@ -121,10 +122,13 @@ class PromptTemplate: template_path = template_path.with_suffix(".jinja") if template_path.exists(): - with open(template_path, "r", encoding="utf8") as raw_template_stream: + async with aiofiles.open( + template_path, "r", encoding="utf8" + ) as raw_template_stream: + contents = await raw_template_stream.read() return PromptTemplate( name=template_name, - raw_template=raw_template_stream.read(), + raw_template=contents, ) else: # Let the user know if the template file isn't found @@ -133,15 +137,16 @@ class PromptTemplate: ) @classmethod - def from_model_json( + async def from_model_json( self, json_path: pathlib.Path, key: str, name: Optional[str] = None ): """Get a template from a JSON file. Requires a key and template name""" if not json_path.exists(): raise TemplateLoadError(f'Model JSON path "{json_path}" not found.') - with open(json_path, "r", encoding="utf8") as config_file: - model_config = json.load(config_file) + async with aiofiles.open(json_path, "r", encoding="utf8") as config_file: + contents = await config_file.read() + model_config = json.loads(contents) chat_template = model_config.get(key) if not chat_template: diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 9db8ad2..386f543 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -1,3 +1,4 @@ +import aiofiles import json import pathlib from typing import List, Optional, Union @@ -15,11 +16,11 @@ class GenerationConfig(BaseModel): bad_words_ids: Optional[List[List[int]]] = None @classmethod - def from_file(self, model_directory: pathlib.Path): + async def from_file(self, model_directory: pathlib.Path): """Create an instance from a generation config file.""" generation_config_path = model_directory / "generation_config.json" - with open( + async with aiofiles.open( generation_config_path, "r", encoding="utf8" ) as generation_config_json: generation_config_dict = json.load(generation_config_json) @@ -43,12 +44,15 @@ class HuggingFaceConfig(BaseModel): badwordsids: Optional[str] = None @classmethod - def from_file(self, model_directory: pathlib.Path): + async def from_file(self, model_directory: pathlib.Path): """Create an instance from a generation config file.""" hf_config_path = model_directory / "config.json" - with open(hf_config_path, "r", encoding="utf8") as hf_config_json: - hf_config_dict = json.load(hf_config_json) + async with aiofiles.open( + hf_config_path, "r", encoding="utf8" + ) as hf_config_json: + contents = await hf_config_json.read() + hf_config_dict = json.loads(contents) return self.model_validate(hf_config_dict) def get_badwordsids(self): diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 4f6b441..2d7a139 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -446,7 +446,7 @@ async def switch_template(data: TemplateSwitchRequest): try: template_path = pathlib.Path("templates") / data.name - model.container.prompt_template = PromptTemplate.from_file(template_path) + model.container.prompt_template = await PromptTemplate.from_file(template_path) except FileNotFoundError as e: error_message = handle_request_error( f"The template name {data.name} doesn't exist. Check the spelling?", @@ -495,7 +495,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): if data.preset: try: - sampling.overrides_from_file(data.preset) + await sampling.overrides_from_file(data.preset) except FileNotFoundError as e: error_message = handle_request_error( f"Sampler override preset with name {data.preset} does not exist. " diff --git a/main.py b/main.py index 740e1d0..7c20910 100644 --- a/main.py +++ b/main.py @@ -50,7 +50,7 @@ async def entrypoint_async(): port = fallback_port # Initialize auth keys - load_auth_keys(unwrap(config.network.get("disable_auth"), False)) + await load_auth_keys(unwrap(config.network.get("disable_auth"), False)) # Override the generation log options if given if config.logging: @@ -62,7 +62,7 @@ async def entrypoint_async(): sampling_override_preset = config.sampling.get("override_preset") if sampling_override_preset: try: - sampling.overrides_from_file(sampling_override_preset) + await sampling.overrides_from_file(sampling_override_preset) except FileNotFoundError as e: logger.warning(str(e))