Tree: Switch to asynchronous file handling

Using aiofiles, there's no longer a possiblity of blocking file operations
that can hang up the event loop. In addition, partially migrate
classes to use asynchronous init instead of the normal python magic method.

The only exception is config, since that's handled in the synchonous
init before the event loop starts.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-09-10 16:45:14 -04:00
parent 54bfb770af
commit 2c3bc71afa
9 changed files with 63 additions and 36 deletions

View file

@ -1,5 +1,6 @@
"""The model container class for ExLlamaV2 models."""
import aiofiles
import asyncio
import gc
import math
@ -106,13 +107,17 @@ class ExllamaV2Container:
load_lock: asyncio.Lock = asyncio.Lock()
load_condition: asyncio.Condition = asyncio.Condition()
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
@classmethod
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
"""
Primary initializer for model container.
Primary asynchronous initializer for model container.
Kwargs are located in config_sample.yml
"""
# Create a new instance as a "fake self"
self = cls()
self.quiet = quiet
# Initialize config
@ -155,13 +160,13 @@ class ExllamaV2Container:
self.draft_config.prepare()
# Create the hf_config
self.hf_config = HuggingFaceConfig.from_file(model_directory)
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
self.generation_config = await GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
@ -171,7 +176,7 @@ class ExllamaV2Container:
)
# Apply a model's config overrides while respecting user settings
kwargs = self.set_model_overrides(**kwargs)
kwargs = await self.set_model_overrides(**kwargs)
# MARK: User configuration
@ -320,7 +325,7 @@ class ExllamaV2Container:
self.cache_size = self.config.max_seq_len
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
self.prompt_template = await self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
)
@ -373,7 +378,10 @@ class ExllamaV2Container:
self.draft_config.max_input_len = chunk_size
self.draft_config.max_attention_size = chunk_size**2
def set_model_overrides(self, **kwargs):
# Return the created instance
return self
async def set_model_overrides(self, **kwargs):
"""Sets overrides from a model folder's config yaml."""
override_config_path = self.model_dir / "tabby_config.yml"
@ -381,8 +389,11 @@ class ExllamaV2Container:
if not override_config_path.exists():
return kwargs
with open(override_config_path, "r", encoding="utf8") as override_config_file:
override_args = unwrap(yaml.safe_load(override_config_file), {})
async with aiofiles.open(
override_config_path, "r", encoding="utf8"
) as override_config_file:
contents = await override_config_file.read()
override_args = unwrap(yaml.safe_load(contents), {})
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
@ -393,7 +404,7 @@ class ExllamaV2Container:
merged_kwargs = {**override_args, **kwargs}
return merged_kwargs
def find_prompt_template(self, prompt_template_name, model_directory):
async def find_prompt_template(self, prompt_template_name, model_directory):
"""Tries to find a prompt template using various methods."""
logger.info("Attempting to load a prompt template if present.")
@ -431,7 +442,7 @@ class ExllamaV2Container:
# Continue on exception since functions are tried as they fail
for template_func in find_template_functions:
try:
prompt_template = template_func()
prompt_template = await template_func()
if prompt_template is not None:
return prompt_template
except TemplateLoadError as e:

View file

@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
application, it should be fine.
"""
import aiofiles
import secrets
import yaml
from fastapi import Header, HTTPException, Request
@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None
DISABLE_AUTH: bool = False
def load_auth_keys(disable_from_config: bool):
async def load_auth_keys(disable_from_config: bool):
"""Load the authentication keys from api_tokens.yml. If the file does not
exist, generate new keys and save them to api_tokens.yml."""
global AUTH_KEYS
@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool):
return
try:
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
auth_keys_dict = yaml.safe_load(auth_file)
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
contents = await auth_file.read()
auth_keys_dict = yaml.safe_load(contents)
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
except FileNotFoundError:
new_auth_keys = AuthKeys(
@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool):
)
AUTH_KEYS = new_auth_keys
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
new_auth_yaml = yaml.safe_dump(
AUTH_KEYS.model_dump(), default_flow_style=False
)
await auth_file.write(new_auth_yaml)
logger.info(
f"Your API key is: {AUTH_KEYS.api_key}\n"

View file

@ -67,7 +67,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
logger.info("Unloading existing model.")
await unload_model()
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
load_status = container.load_gen(load_progress, **kwargs)

View file

@ -1,5 +1,6 @@
"""Common functions for sampling parameters"""
import aiofiles
import json
import pathlib
import yaml
@ -407,14 +408,15 @@ def overrides_from_dict(new_overrides: dict):
raise TypeError("New sampler overrides must be a dict!")
def overrides_from_file(preset_name: str):
async def overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
if preset_path.exists():
overrides_container.selected_preset = preset_path.stem
with open(preset_path, "r", encoding="utf8") as raw_preset:
preset = yaml.safe_load(raw_preset)
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
contents = await raw_preset.read()
preset = yaml.safe_load(contents)
overrides_from_dict(preset)
logger.info("Applied sampler overrides from file.")

View file

@ -17,7 +17,7 @@ class TabbyConfig:
embeddings: dict = {}
def load(self, arguments: Optional[dict] = None):
"""load the global application config"""
"""Synchronously loads the global application config"""
# config is applied in order of items in the list
configs = [

View file

@ -1,5 +1,6 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
import aiofiles
import json
import pathlib
from importlib.metadata import version as package_version
@ -110,7 +111,7 @@ class PromptTemplate:
self.template = self.compile(raw_template)
@classmethod
def from_file(self, template_path: pathlib.Path):
async def from_file(self, template_path: pathlib.Path):
"""Get a template from a jinja file."""
# Add the jinja extension if it isn't provided
@ -121,10 +122,13 @@ class PromptTemplate:
template_path = template_path.with_suffix(".jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template_stream:
async with aiofiles.open(
template_path, "r", encoding="utf8"
) as raw_template_stream:
contents = await raw_template_stream.read()
return PromptTemplate(
name=template_name,
raw_template=raw_template_stream.read(),
raw_template=contents,
)
else:
# Let the user know if the template file isn't found
@ -133,15 +137,16 @@ class PromptTemplate:
)
@classmethod
def from_model_json(
async def from_model_json(
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
contents = await config_file.read()
model_config = json.loads(contents)
chat_template = model_config.get(key)
if not chat_template:

View file

@ -1,3 +1,4 @@
import aiofiles
import json
import pathlib
from typing import List, Optional, Union
@ -15,11 +16,11 @@ class GenerationConfig(BaseModel):
bad_words_ids: Optional[List[List[int]]] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
async def from_file(self, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
generation_config_path = model_directory / "generation_config.json"
with open(
async with aiofiles.open(
generation_config_path, "r", encoding="utf8"
) as generation_config_json:
generation_config_dict = json.load(generation_config_json)
@ -43,12 +44,15 @@ class HuggingFaceConfig(BaseModel):
badwordsids: Optional[str] = None
@classmethod
def from_file(self, model_directory: pathlib.Path):
async def from_file(self, model_directory: pathlib.Path):
"""Create an instance from a generation config file."""
hf_config_path = model_directory / "config.json"
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
hf_config_dict = json.load(hf_config_json)
async with aiofiles.open(
hf_config_path, "r", encoding="utf8"
) as hf_config_json:
contents = await hf_config_json.read()
hf_config_dict = json.loads(contents)
return self.model_validate(hf_config_dict)
def get_badwordsids(self):

View file

@ -446,7 +446,7 @@ async def switch_template(data: TemplateSwitchRequest):
try:
template_path = pathlib.Path("templates") / data.name
model.container.prompt_template = PromptTemplate.from_file(template_path)
model.container.prompt_template = await PromptTemplate.from_file(template_path)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
@ -495,7 +495,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
if data.preset:
try:
sampling.overrides_from_file(data.preset)
await sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "

View file

@ -50,7 +50,7 @@ async def entrypoint_async():
port = fallback_port
# Initialize auth keys
load_auth_keys(unwrap(config.network.get("disable_auth"), False))
await load_auth_keys(unwrap(config.network.get("disable_auth"), False))
# Override the generation log options if given
if config.logging:
@ -62,7 +62,7 @@ async def entrypoint_async():
sampling_override_preset = config.sampling.get("override_preset")
if sampling_override_preset:
try:
sampling.overrides_from_file(sampling_override_preset)
await sampling.overrides_from_file(sampling_override_preset)
except FileNotFoundError as e:
logger.warning(str(e))