Tree: Switch to asynchronous file handling
Using aiofiles, there's no longer a possiblity of blocking file operations that can hang up the event loop. In addition, partially migrate classes to use asynchronous init instead of the normal python magic method. The only exception is config, since that's handled in the synchonous init before the event loop starts. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
54bfb770af
commit
2c3bc71afa
9 changed files with 63 additions and 36 deletions
|
|
@ -1,5 +1,6 @@
|
|||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
import aiofiles
|
||||
import asyncio
|
||||
import gc
|
||||
import math
|
||||
|
|
@ -106,13 +107,17 @@ class ExllamaV2Container:
|
|||
load_lock: asyncio.Lock = asyncio.Lock()
|
||||
load_condition: asyncio.Condition = asyncio.Condition()
|
||||
|
||||
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
@classmethod
|
||||
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
"""
|
||||
Primary initializer for model container.
|
||||
Primary asynchronous initializer for model container.
|
||||
|
||||
Kwargs are located in config_sample.yml
|
||||
"""
|
||||
|
||||
# Create a new instance as a "fake self"
|
||||
self = cls()
|
||||
|
||||
self.quiet = quiet
|
||||
|
||||
# Initialize config
|
||||
|
|
@ -155,13 +160,13 @@ class ExllamaV2Container:
|
|||
self.draft_config.prepare()
|
||||
|
||||
# Create the hf_config
|
||||
self.hf_config = HuggingFaceConfig.from_file(model_directory)
|
||||
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
self.generation_config = await GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -171,7 +176,7 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Apply a model's config overrides while respecting user settings
|
||||
kwargs = self.set_model_overrides(**kwargs)
|
||||
kwargs = await self.set_model_overrides(**kwargs)
|
||||
|
||||
# MARK: User configuration
|
||||
|
||||
|
|
@ -320,7 +325,7 @@ class ExllamaV2Container:
|
|||
self.cache_size = self.config.max_seq_len
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
self.prompt_template = await self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
|
|
@ -373,7 +378,10 @@ class ExllamaV2Container:
|
|||
self.draft_config.max_input_len = chunk_size
|
||||
self.draft_config.max_attention_size = chunk_size**2
|
||||
|
||||
def set_model_overrides(self, **kwargs):
|
||||
# Return the created instance
|
||||
return self
|
||||
|
||||
async def set_model_overrides(self, **kwargs):
|
||||
"""Sets overrides from a model folder's config yaml."""
|
||||
|
||||
override_config_path = self.model_dir / "tabby_config.yml"
|
||||
|
|
@ -381,8 +389,11 @@ class ExllamaV2Container:
|
|||
if not override_config_path.exists():
|
||||
return kwargs
|
||||
|
||||
with open(override_config_path, "r", encoding="utf8") as override_config_file:
|
||||
override_args = unwrap(yaml.safe_load(override_config_file), {})
|
||||
async with aiofiles.open(
|
||||
override_config_path, "r", encoding="utf8"
|
||||
) as override_config_file:
|
||||
contents = await override_config_file.read()
|
||||
override_args = unwrap(yaml.safe_load(contents), {})
|
||||
|
||||
# Merge draft overrides beforehand
|
||||
draft_override_args = unwrap(override_args.get("draft"), {})
|
||||
|
|
@ -393,7 +404,7 @@ class ExllamaV2Container:
|
|||
merged_kwargs = {**override_args, **kwargs}
|
||||
return merged_kwargs
|
||||
|
||||
def find_prompt_template(self, prompt_template_name, model_directory):
|
||||
async def find_prompt_template(self, prompt_template_name, model_directory):
|
||||
"""Tries to find a prompt template using various methods."""
|
||||
|
||||
logger.info("Attempting to load a prompt template if present.")
|
||||
|
|
@ -431,7 +442,7 @@ class ExllamaV2Container:
|
|||
# Continue on exception since functions are tried as they fail
|
||||
for template_func in find_template_functions:
|
||||
try:
|
||||
prompt_template = template_func()
|
||||
prompt_template = await template_func()
|
||||
if prompt_template is not None:
|
||||
return prompt_template
|
||||
except TemplateLoadError as e:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
|
|||
application, it should be fine.
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import secrets
|
||||
import yaml
|
||||
from fastapi import Header, HTTPException, Request
|
||||
|
|
@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None
|
|||
DISABLE_AUTH: bool = False
|
||||
|
||||
|
||||
def load_auth_keys(disable_from_config: bool):
|
||||
async def load_auth_keys(disable_from_config: bool):
|
||||
"""Load the authentication keys from api_tokens.yml. If the file does not
|
||||
exist, generate new keys and save them to api_tokens.yml."""
|
||||
global AUTH_KEYS
|
||||
|
|
@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool):
|
|||
return
|
||||
|
||||
try:
|
||||
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
auth_keys_dict = yaml.safe_load(auth_file)
|
||||
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
contents = await auth_file.read()
|
||||
auth_keys_dict = yaml.safe_load(contents)
|
||||
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
|
||||
except FileNotFoundError:
|
||||
new_auth_keys = AuthKeys(
|
||||
|
|
@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool):
|
|||
)
|
||||
AUTH_KEYS = new_auth_keys
|
||||
|
||||
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
|
||||
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
new_auth_yaml = yaml.safe_dump(
|
||||
AUTH_KEYS.model_dump(), default_flow_style=False
|
||||
)
|
||||
await auth_file.write(new_auth_yaml)
|
||||
|
||||
logger.info(
|
||||
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
|
||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
load_status = container.load_gen(load_progress, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Common functions for sampling parameters"""
|
||||
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
import yaml
|
||||
|
|
@ -407,14 +408,15 @@ def overrides_from_dict(new_overrides: dict):
|
|||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
|
||||
def overrides_from_file(preset_name: str):
|
||||
async def overrides_from_file(preset_name: str):
|
||||
"""Fetches an override preset from a file"""
|
||||
|
||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||
if preset_path.exists():
|
||||
overrides_container.selected_preset = preset_path.stem
|
||||
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
preset = yaml.safe_load(raw_preset)
|
||||
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
contents = await raw_preset.read()
|
||||
preset = yaml.safe_load(contents)
|
||||
overrides_from_dict(preset)
|
||||
|
||||
logger.info("Applied sampler overrides from file.")
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class TabbyConfig:
|
|||
embeddings: dict = {}
|
||||
|
||||
def load(self, arguments: Optional[dict] = None):
|
||||
"""load the global application config"""
|
||||
"""Synchronously loads the global application config"""
|
||||
|
||||
# config is applied in order of items in the list
|
||||
configs = [
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||
|
||||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from importlib.metadata import version as package_version
|
||||
|
|
@ -110,7 +111,7 @@ class PromptTemplate:
|
|||
self.template = self.compile(raw_template)
|
||||
|
||||
@classmethod
|
||||
def from_file(self, template_path: pathlib.Path):
|
||||
async def from_file(self, template_path: pathlib.Path):
|
||||
"""Get a template from a jinja file."""
|
||||
|
||||
# Add the jinja extension if it isn't provided
|
||||
|
|
@ -121,10 +122,13 @@ class PromptTemplate:
|
|||
template_path = template_path.with_suffix(".jinja")
|
||||
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf8") as raw_template_stream:
|
||||
async with aiofiles.open(
|
||||
template_path, "r", encoding="utf8"
|
||||
) as raw_template_stream:
|
||||
contents = await raw_template_stream.read()
|
||||
return PromptTemplate(
|
||||
name=template_name,
|
||||
raw_template=raw_template_stream.read(),
|
||||
raw_template=contents,
|
||||
)
|
||||
else:
|
||||
# Let the user know if the template file isn't found
|
||||
|
|
@ -133,15 +137,16 @@ class PromptTemplate:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def from_model_json(
|
||||
async def from_model_json(
|
||||
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if not json_path.exists():
|
||||
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
||||
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
|
||||
contents = await config_file.read()
|
||||
model_config = json.loads(contents)
|
||||
chat_template = model_config.get(key)
|
||||
|
||||
if not chat_template:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from typing import List, Optional, Union
|
||||
|
|
@ -15,11 +16,11 @@ class GenerationConfig(BaseModel):
|
|||
bad_words_ids: Optional[List[List[int]]] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(self, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
generation_config_path = model_directory / "generation_config.json"
|
||||
with open(
|
||||
async with aiofiles.open(
|
||||
generation_config_path, "r", encoding="utf8"
|
||||
) as generation_config_json:
|
||||
generation_config_dict = json.load(generation_config_json)
|
||||
|
|
@ -43,12 +44,15 @@ class HuggingFaceConfig(BaseModel):
|
|||
badwordsids: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_file(self, model_directory: pathlib.Path):
|
||||
async def from_file(self, model_directory: pathlib.Path):
|
||||
"""Create an instance from a generation config file."""
|
||||
|
||||
hf_config_path = model_directory / "config.json"
|
||||
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
|
||||
hf_config_dict = json.load(hf_config_json)
|
||||
async with aiofiles.open(
|
||||
hf_config_path, "r", encoding="utf8"
|
||||
) as hf_config_json:
|
||||
contents = await hf_config_json.read()
|
||||
hf_config_dict = json.loads(contents)
|
||||
return self.model_validate(hf_config_dict)
|
||||
|
||||
def get_badwordsids(self):
|
||||
|
|
|
|||
|
|
@ -446,7 +446,7 @@ async def switch_template(data: TemplateSwitchRequest):
|
|||
|
||||
try:
|
||||
template_path = pathlib.Path("templates") / data.name
|
||||
model.container.prompt_template = PromptTemplate.from_file(template_path)
|
||||
model.container.prompt_template = await PromptTemplate.from_file(template_path)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
|
|
@ -495,7 +495,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||
|
||||
if data.preset:
|
||||
try:
|
||||
sampling.overrides_from_file(data.preset)
|
||||
await sampling.overrides_from_file(data.preset)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"Sampler override preset with name {data.preset} does not exist. "
|
||||
|
|
|
|||
4
main.py
4
main.py
|
|
@ -50,7 +50,7 @@ async def entrypoint_async():
|
|||
port = fallback_port
|
||||
|
||||
# Initialize auth keys
|
||||
load_auth_keys(unwrap(config.network.get("disable_auth"), False))
|
||||
await load_auth_keys(unwrap(config.network.get("disable_auth"), False))
|
||||
|
||||
# Override the generation log options if given
|
||||
if config.logging:
|
||||
|
|
@ -62,7 +62,7 @@ async def entrypoint_async():
|
|||
sampling_override_preset = config.sampling.get("override_preset")
|
||||
if sampling_override_preset:
|
||||
try:
|
||||
sampling.overrides_from_file(sampling_override_preset)
|
||||
await sampling.overrides_from_file(sampling_override_preset)
|
||||
except FileNotFoundError as e:
|
||||
logger.warning(str(e))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue