Tree: Switch to asynchronous file handling
Using aiofiles, there's no longer a possiblity of blocking file operations that can hang up the event loop. In addition, partially migrate classes to use asynchronous init instead of the normal python magic method. The only exception is config, since that's handled in the synchonous init before the event loop starts. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
54bfb770af
commit
2c3bc71afa
9 changed files with 63 additions and 36 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
"""The model container class for ExLlamaV2 models."""
|
"""The model container class for ExLlamaV2 models."""
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import math
|
import math
|
||||||
|
|
@ -106,13 +107,17 @@ class ExllamaV2Container:
|
||||||
load_lock: asyncio.Lock = asyncio.Lock()
|
load_lock: asyncio.Lock = asyncio.Lock()
|
||||||
load_condition: asyncio.Condition = asyncio.Condition()
|
load_condition: asyncio.Condition = asyncio.Condition()
|
||||||
|
|
||||||
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
@classmethod
|
||||||
|
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Primary initializer for model container.
|
Primary asynchronous initializer for model container.
|
||||||
|
|
||||||
Kwargs are located in config_sample.yml
|
Kwargs are located in config_sample.yml
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Create a new instance as a "fake self"
|
||||||
|
self = cls()
|
||||||
|
|
||||||
self.quiet = quiet
|
self.quiet = quiet
|
||||||
|
|
||||||
# Initialize config
|
# Initialize config
|
||||||
|
|
@ -155,13 +160,13 @@ class ExllamaV2Container:
|
||||||
self.draft_config.prepare()
|
self.draft_config.prepare()
|
||||||
|
|
||||||
# Create the hf_config
|
# Create the hf_config
|
||||||
self.hf_config = HuggingFaceConfig.from_file(model_directory)
|
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
|
||||||
|
|
||||||
# Load generation config overrides
|
# Load generation config overrides
|
||||||
generation_config_path = model_directory / "generation_config.json"
|
generation_config_path = model_directory / "generation_config.json"
|
||||||
if generation_config_path.exists():
|
if generation_config_path.exists():
|
||||||
try:
|
try:
|
||||||
self.generation_config = GenerationConfig.from_file(
|
self.generation_config = await GenerationConfig.from_file(
|
||||||
generation_config_path.parent
|
generation_config_path.parent
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -171,7 +176,7 @@ class ExllamaV2Container:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply a model's config overrides while respecting user settings
|
# Apply a model's config overrides while respecting user settings
|
||||||
kwargs = self.set_model_overrides(**kwargs)
|
kwargs = await self.set_model_overrides(**kwargs)
|
||||||
|
|
||||||
# MARK: User configuration
|
# MARK: User configuration
|
||||||
|
|
||||||
|
|
@ -320,7 +325,7 @@ class ExllamaV2Container:
|
||||||
self.cache_size = self.config.max_seq_len
|
self.cache_size = self.config.max_seq_len
|
||||||
|
|
||||||
# Try to set prompt template
|
# Try to set prompt template
|
||||||
self.prompt_template = self.find_prompt_template(
|
self.prompt_template = await self.find_prompt_template(
|
||||||
kwargs.get("prompt_template"), model_directory
|
kwargs.get("prompt_template"), model_directory
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -373,7 +378,10 @@ class ExllamaV2Container:
|
||||||
self.draft_config.max_input_len = chunk_size
|
self.draft_config.max_input_len = chunk_size
|
||||||
self.draft_config.max_attention_size = chunk_size**2
|
self.draft_config.max_attention_size = chunk_size**2
|
||||||
|
|
||||||
def set_model_overrides(self, **kwargs):
|
# Return the created instance
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def set_model_overrides(self, **kwargs):
|
||||||
"""Sets overrides from a model folder's config yaml."""
|
"""Sets overrides from a model folder's config yaml."""
|
||||||
|
|
||||||
override_config_path = self.model_dir / "tabby_config.yml"
|
override_config_path = self.model_dir / "tabby_config.yml"
|
||||||
|
|
@ -381,8 +389,11 @@ class ExllamaV2Container:
|
||||||
if not override_config_path.exists():
|
if not override_config_path.exists():
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
with open(override_config_path, "r", encoding="utf8") as override_config_file:
|
async with aiofiles.open(
|
||||||
override_args = unwrap(yaml.safe_load(override_config_file), {})
|
override_config_path, "r", encoding="utf8"
|
||||||
|
) as override_config_file:
|
||||||
|
contents = await override_config_file.read()
|
||||||
|
override_args = unwrap(yaml.safe_load(contents), {})
|
||||||
|
|
||||||
# Merge draft overrides beforehand
|
# Merge draft overrides beforehand
|
||||||
draft_override_args = unwrap(override_args.get("draft"), {})
|
draft_override_args = unwrap(override_args.get("draft"), {})
|
||||||
|
|
@ -393,7 +404,7 @@ class ExllamaV2Container:
|
||||||
merged_kwargs = {**override_args, **kwargs}
|
merged_kwargs = {**override_args, **kwargs}
|
||||||
return merged_kwargs
|
return merged_kwargs
|
||||||
|
|
||||||
def find_prompt_template(self, prompt_template_name, model_directory):
|
async def find_prompt_template(self, prompt_template_name, model_directory):
|
||||||
"""Tries to find a prompt template using various methods."""
|
"""Tries to find a prompt template using various methods."""
|
||||||
|
|
||||||
logger.info("Attempting to load a prompt template if present.")
|
logger.info("Attempting to load a prompt template if present.")
|
||||||
|
|
@ -431,7 +442,7 @@ class ExllamaV2Container:
|
||||||
# Continue on exception since functions are tried as they fail
|
# Continue on exception since functions are tried as they fail
|
||||||
for template_func in find_template_functions:
|
for template_func in find_template_functions:
|
||||||
try:
|
try:
|
||||||
prompt_template = template_func()
|
prompt_template = await template_func()
|
||||||
if prompt_template is not None:
|
if prompt_template is not None:
|
||||||
return prompt_template
|
return prompt_template
|
||||||
except TemplateLoadError as e:
|
except TemplateLoadError as e:
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ This method of authorization is pretty insecure, but since TabbyAPI is a local
|
||||||
application, it should be fine.
|
application, it should be fine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import secrets
|
import secrets
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import Header, HTTPException, Request
|
from fastapi import Header, HTTPException, Request
|
||||||
|
|
@ -40,7 +41,7 @@ AUTH_KEYS: Optional[AuthKeys] = None
|
||||||
DISABLE_AUTH: bool = False
|
DISABLE_AUTH: bool = False
|
||||||
|
|
||||||
|
|
||||||
def load_auth_keys(disable_from_config: bool):
|
async def load_auth_keys(disable_from_config: bool):
|
||||||
"""Load the authentication keys from api_tokens.yml. If the file does not
|
"""Load the authentication keys from api_tokens.yml. If the file does not
|
||||||
exist, generate new keys and save them to api_tokens.yml."""
|
exist, generate new keys and save them to api_tokens.yml."""
|
||||||
global AUTH_KEYS
|
global AUTH_KEYS
|
||||||
|
|
@ -57,8 +58,9 @@ def load_auth_keys(disable_from_config: bool):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||||
auth_keys_dict = yaml.safe_load(auth_file)
|
contents = await auth_file.read()
|
||||||
|
auth_keys_dict = yaml.safe_load(contents)
|
||||||
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
|
AUTH_KEYS = AuthKeys.model_validate(auth_keys_dict)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
new_auth_keys = AuthKeys(
|
new_auth_keys = AuthKeys(
|
||||||
|
|
@ -66,8 +68,11 @@ def load_auth_keys(disable_from_config: bool):
|
||||||
)
|
)
|
||||||
AUTH_KEYS = new_auth_keys
|
AUTH_KEYS = new_auth_keys
|
||||||
|
|
||||||
with open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||||
yaml.safe_dump(AUTH_KEYS.model_dump(), auth_file, default_flow_style=False)
|
new_auth_yaml = yaml.safe_dump(
|
||||||
|
AUTH_KEYS.model_dump(), default_flow_style=False
|
||||||
|
)
|
||||||
|
await auth_file.write(new_auth_yaml)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||||
logger.info("Unloading existing model.")
|
logger.info("Unloading existing model.")
|
||||||
await unload_model()
|
await unload_model()
|
||||||
|
|
||||||
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
|
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||||
|
|
||||||
model_type = "draft" if container.draft_config else "model"
|
model_type = "draft" if container.draft_config else "model"
|
||||||
load_status = container.load_gen(load_progress, **kwargs)
|
load_status = container.load_gen(load_progress, **kwargs)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Common functions for sampling parameters"""
|
"""Common functions for sampling parameters"""
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -407,14 +408,15 @@ def overrides_from_dict(new_overrides: dict):
|
||||||
raise TypeError("New sampler overrides must be a dict!")
|
raise TypeError("New sampler overrides must be a dict!")
|
||||||
|
|
||||||
|
|
||||||
def overrides_from_file(preset_name: str):
|
async def overrides_from_file(preset_name: str):
|
||||||
"""Fetches an override preset from a file"""
|
"""Fetches an override preset from a file"""
|
||||||
|
|
||||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||||
if preset_path.exists():
|
if preset_path.exists():
|
||||||
overrides_container.selected_preset = preset_path.stem
|
overrides_container.selected_preset = preset_path.stem
|
||||||
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||||
preset = yaml.safe_load(raw_preset)
|
contents = await raw_preset.read()
|
||||||
|
preset = yaml.safe_load(contents)
|
||||||
overrides_from_dict(preset)
|
overrides_from_dict(preset)
|
||||||
|
|
||||||
logger.info("Applied sampler overrides from file.")
|
logger.info("Applied sampler overrides from file.")
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ class TabbyConfig:
|
||||||
embeddings: dict = {}
|
embeddings: dict = {}
|
||||||
|
|
||||||
def load(self, arguments: Optional[dict] = None):
|
def load(self, arguments: Optional[dict] = None):
|
||||||
"""load the global application config"""
|
"""Synchronously loads the global application config"""
|
||||||
|
|
||||||
# config is applied in order of items in the list
|
# config is applied in order of items in the list
|
||||||
configs = [
|
configs = [
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
from importlib.metadata import version as package_version
|
from importlib.metadata import version as package_version
|
||||||
|
|
@ -110,7 +111,7 @@ class PromptTemplate:
|
||||||
self.template = self.compile(raw_template)
|
self.template = self.compile(raw_template)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(self, template_path: pathlib.Path):
|
async def from_file(self, template_path: pathlib.Path):
|
||||||
"""Get a template from a jinja file."""
|
"""Get a template from a jinja file."""
|
||||||
|
|
||||||
# Add the jinja extension if it isn't provided
|
# Add the jinja extension if it isn't provided
|
||||||
|
|
@ -121,10 +122,13 @@ class PromptTemplate:
|
||||||
template_path = template_path.with_suffix(".jinja")
|
template_path = template_path.with_suffix(".jinja")
|
||||||
|
|
||||||
if template_path.exists():
|
if template_path.exists():
|
||||||
with open(template_path, "r", encoding="utf8") as raw_template_stream:
|
async with aiofiles.open(
|
||||||
|
template_path, "r", encoding="utf8"
|
||||||
|
) as raw_template_stream:
|
||||||
|
contents = await raw_template_stream.read()
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
name=template_name,
|
name=template_name,
|
||||||
raw_template=raw_template_stream.read(),
|
raw_template=contents,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Let the user know if the template file isn't found
|
# Let the user know if the template file isn't found
|
||||||
|
|
@ -133,15 +137,16 @@ class PromptTemplate:
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model_json(
|
async def from_model_json(
|
||||||
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||||
):
|
):
|
||||||
"""Get a template from a JSON file. Requires a key and template name"""
|
"""Get a template from a JSON file. Requires a key and template name"""
|
||||||
if not json_path.exists():
|
if not json_path.exists():
|
||||||
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
||||||
|
|
||||||
with open(json_path, "r", encoding="utf8") as config_file:
|
async with aiofiles.open(json_path, "r", encoding="utf8") as config_file:
|
||||||
model_config = json.load(config_file)
|
contents = await config_file.read()
|
||||||
|
model_config = json.loads(contents)
|
||||||
chat_template = model_config.get(key)
|
chat_template = model_config.get(key)
|
||||||
|
|
||||||
if not chat_template:
|
if not chat_template:
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import aiofiles
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
@ -15,11 +16,11 @@ class GenerationConfig(BaseModel):
|
||||||
bad_words_ids: Optional[List[List[int]]] = None
|
bad_words_ids: Optional[List[List[int]]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(self, model_directory: pathlib.Path):
|
async def from_file(self, model_directory: pathlib.Path):
|
||||||
"""Create an instance from a generation config file."""
|
"""Create an instance from a generation config file."""
|
||||||
|
|
||||||
generation_config_path = model_directory / "generation_config.json"
|
generation_config_path = model_directory / "generation_config.json"
|
||||||
with open(
|
async with aiofiles.open(
|
||||||
generation_config_path, "r", encoding="utf8"
|
generation_config_path, "r", encoding="utf8"
|
||||||
) as generation_config_json:
|
) as generation_config_json:
|
||||||
generation_config_dict = json.load(generation_config_json)
|
generation_config_dict = json.load(generation_config_json)
|
||||||
|
|
@ -43,12 +44,15 @@ class HuggingFaceConfig(BaseModel):
|
||||||
badwordsids: Optional[str] = None
|
badwordsids: Optional[str] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(self, model_directory: pathlib.Path):
|
async def from_file(self, model_directory: pathlib.Path):
|
||||||
"""Create an instance from a generation config file."""
|
"""Create an instance from a generation config file."""
|
||||||
|
|
||||||
hf_config_path = model_directory / "config.json"
|
hf_config_path = model_directory / "config.json"
|
||||||
with open(hf_config_path, "r", encoding="utf8") as hf_config_json:
|
async with aiofiles.open(
|
||||||
hf_config_dict = json.load(hf_config_json)
|
hf_config_path, "r", encoding="utf8"
|
||||||
|
) as hf_config_json:
|
||||||
|
contents = await hf_config_json.read()
|
||||||
|
hf_config_dict = json.loads(contents)
|
||||||
return self.model_validate(hf_config_dict)
|
return self.model_validate(hf_config_dict)
|
||||||
|
|
||||||
def get_badwordsids(self):
|
def get_badwordsids(self):
|
||||||
|
|
|
||||||
|
|
@ -446,7 +446,7 @@ async def switch_template(data: TemplateSwitchRequest):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
template_path = pathlib.Path("templates") / data.name
|
template_path = pathlib.Path("templates") / data.name
|
||||||
model.container.prompt_template = PromptTemplate.from_file(template_path)
|
model.container.prompt_template = await PromptTemplate.from_file(template_path)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||||
|
|
@ -495,7 +495,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
||||||
|
|
||||||
if data.preset:
|
if data.preset:
|
||||||
try:
|
try:
|
||||||
sampling.overrides_from_file(data.preset)
|
await sampling.overrides_from_file(data.preset)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
f"Sampler override preset with name {data.preset} does not exist. "
|
f"Sampler override preset with name {data.preset} does not exist. "
|
||||||
|
|
|
||||||
4
main.py
4
main.py
|
|
@ -50,7 +50,7 @@ async def entrypoint_async():
|
||||||
port = fallback_port
|
port = fallback_port
|
||||||
|
|
||||||
# Initialize auth keys
|
# Initialize auth keys
|
||||||
load_auth_keys(unwrap(config.network.get("disable_auth"), False))
|
await load_auth_keys(unwrap(config.network.get("disable_auth"), False))
|
||||||
|
|
||||||
# Override the generation log options if given
|
# Override the generation log options if given
|
||||||
if config.logging:
|
if config.logging:
|
||||||
|
|
@ -62,7 +62,7 @@ async def entrypoint_async():
|
||||||
sampling_override_preset = config.sampling.get("override_preset")
|
sampling_override_preset = config.sampling.get("override_preset")
|
||||||
if sampling_override_preset:
|
if sampling_override_preset:
|
||||||
try:
|
try:
|
||||||
sampling.overrides_from_file(sampling_override_preset)
|
await sampling.overrides_from_file(sampling_override_preset)
|
||||||
except FileNotFoundError as e:
|
except FileNotFoundError as e:
|
||||||
logger.warning(str(e))
|
logger.warning(str(e))
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue