Tree: Switch to asynchronous file handling
Using aiofiles, there's no longer a possiblity of blocking file operations that can hang up the event loop. In addition, partially migrate classes to use asynchronous init instead of the normal python magic method. The only exception is config, since that's handled in the synchonous init before the event loop starts. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
54bfb770af
commit
2c3bc71afa
9 changed files with 63 additions and 36 deletions
|
|
@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
|
|||
application, it should be fine.
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import secrets
|
||||
import yaml
|
||||
from fastapi import Header, HTTPException, Request
|
||||
|
|
@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None
|
|||
DISABLE_AUTH: bool = False
|
||||
|
||||
|
||||
def load_auth_keys(disable_from_config: bool):
|
||||
async def load_auth_keys(disable_from_config: bool):
|
||||
"""Load the authentication keys from api_tokens.yml. If the file does not
|
||||
exist, generate new keys and save them to api_tokens.yml."""
|
||||
global AUTH_KEYS
|
||||
|
|
@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool):
|
|||
return
|
||||
|
||||
try:
|
||||
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
auth_keys_dict = yaml.safe_load(auth_file)
|
||||
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
contents = await auth_file.read()
|
||||
auth_keys_dict = yaml.safe_load(contents)
|
||||
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
|
||||
except FileNotFoundError:
|
||||
new_auth_keys = AuthKeys(
|
||||
|
|
@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool):
|
|||
)
|
||||
AUTH_KEYS = new_auth_keys
|
||||
|
||||
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
|
||||
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
new_auth_yaml = yaml.safe_dump(
|
||||
AUTH_KEYS.model_dump(), default_flow_style=False
|
||||
)
|
||||
await auth_file.write(new_auth_yaml)
|
||||
|
||||
logger.info(
|
||||
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
|
||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
load_status = container.load_gen(load_progress, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Common functions for sampling parameters"""
|
||||
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
import yaml
|
||||
|
|
@ -407,14 +408,15 @@ def overrides_from_dict(new_overrides: dict):
|
|||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
|
||||
def overrides_from_file(preset_name: str):
|
||||
async def overrides_from_file(preset_name: str):
|
||||
"""Fetches an override preset from a file"""
|
||||
|
||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||
if preset_path.exists():
|
||||
overrides_container.selected_preset = preset_path.stem
|
||||
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
preset = yaml.safe_load(raw_preset)
|
||||
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
contents = await raw_preset.read()
|
||||
preset = yaml.safe_load(contents)
|
||||
overrides_from_dict(preset)
|
||||
|
||||
logger.info("Applied sampler overrides from file.")
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class TabbyConfig:
|
|||
embeddings: dict = {}
|
||||
|
||||
def load(self, arguments: Optional[dict] = None):
|
||||
"""load the global application config"""
|
||||
"""Synchronously loads the global application config"""
|
||||
|
||||
# config is applied in order of items in the list
|
||||
configs = [
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from importlib.metadata import version as package_version
|
||||
|
|
@ -110,7 +111,7 @@ class PromptTemplate:
|
|||
self.template = self.compile(raw_template)
|
||||
|
||||
@classmethod
|
||||
def from_file(self, template_path: pathlib.Path):
|
||||
async def from_file(self, template_path: pathlib.Path):
|
||||
"""Get a template from a jinja file."""
|
||||
|
||||
# Add the jinja extension if it isn't provided
|
||||
|
|
@ -121,10 +122,13 @@ class PromptTemplate:
|
|||
template_path = template_path.with_suffix(".jinja")
|
||||
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf8") as raw_template_stream:
|
||||
async with aiofiles.open(
|
||||
template_path, "r", encoding="utf8"
|
||||
) as raw_template_stream:
|
||||
contents = await raw_template_stream.read()
|
||||
return PromptTemplate(
|
||||
name=template_name,
|
||||
raw_template=raw_template_stream.read(),
|
||||
raw_template=contents,
|
||||
)
|
||||
else:
|
||||
# Let the user know if the template file isn't found
|
||||
|
|
@ -133,15 +137,16 @@ class PromptTemplate:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def from_model_json(
|
||||
async def from_model_json(
|
||||
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if not json_path.exists():
|
||||
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
||||
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
|
||||
contents = await config_file.read()
|
||||
model_config = json.loads(contents)
|
||||
chat_template = model_config.get(key)
|
||||
|
||||
if not chat_template:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from typing import List, Optional, Union
|
||||
|
|
@ -15,11 +16,11 @@ class GenerationConfig(BaseModel):
|
|||
bad_words_ids: Optional[List[List[int]]] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(self, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
with open(
|
||||
async with aiofiles.open(
|
||||
generation_config_path, "r", encoding="utf8"
|
||||
) as generation_config_json:
|
||||
generation_config_dict = json.load(generation_config_json)
|
||||
|
|
@ -43,12 +44,15 @@ class HuggingFaceConfig(BaseModel):
|
|||
badwordsids: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(self, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
hf_config_path = model_directory / "config.json"
|
||||
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
|
||||
hf_config_dict = json.load(hf_config_json)
|
||||
async with aiofiles.open(
|
||||
hf_config_path, "r", encoding="utf8"
|
||||
) as hf_config_json:
|
||||
contents = await hf_config_json.read()
|
||||
hf_config_dict = json.loads(contents)
|
||||
return self.model_validate(hf_config_dict)
|
||||
|
||||
def get_badwordsids(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue