Tree: Use safe loader for YAML
Loaders that read use a safe type while loaders that write use both round-trip and safe options. Also don't create module-level parsers where they're not needed. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
6c7542de9f
commit
24ea85b3c5
4 changed files with 18 additions and 12 deletions
|
|
@ -56,8 +56,6 @@ from common.templating import (
|
|||
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
|
||||
yaml = YAML()
|
||||
|
||||
|
||||
class ExllamaV2Container:
|
||||
"""The model container class for ExLlamaV2 models."""
|
||||
|
|
@ -381,7 +379,10 @@ class ExllamaV2Container:
|
|||
override_config_path, "r", encoding="utf8"
|
||||
) as override_config_file:
|
||||
contents = await override_config_file.read()
|
||||
override_args = unwrap(yaml.safe_load(contents), {})
|
||||
|
||||
# Create a temporary YAML parser
|
||||
yaml = YAML(typ="safe")
|
||||
override_args = unwrap(yaml.load(contents), {})
|
||||
|
||||
# Merge draft overrides beforehand
|
||||
draft_override_args = unwrap(override_args.get("draft"), {})
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ application, it should be fine.
|
|||
"""
|
||||
|
||||
import aiofiles
|
||||
import io
|
||||
import secrets
|
||||
from ruamel.yaml import YAML
|
||||
from fastapi import Header, HTTPException, Request
|
||||
|
|
@ -13,8 +14,6 @@ from typing import Optional
|
|||
|
||||
from common.utils import coalesce
|
||||
|
||||
yaml = YAML()
|
||||
|
||||
|
||||
class AuthKeys(BaseModel):
|
||||
"""
|
||||
|
|
@ -59,6 +58,9 @@ async def load_auth_keys(disable_from_config: bool):
|
|||
|
||||
return
|
||||
|
||||
# Create a temporary YAML parser
|
||||
yaml = YAML(typ=["rt", "safe"])
|
||||
|
||||
try:
|
||||
async with aiofiles.open("api_tokens.yml", "r", encoding="utf8") as auth_file:
|
||||
contents = await auth_file.read()
|
||||
|
|
@ -71,10 +73,12 @@ async def load_auth_keys(disable_from_config: bool):
|
|||
AUTH_KEYS = new_auth_keys
|
||||
|
||||
async with aiofiles.open("api_tokens.yml", "w", encoding="utf8") as auth_file:
|
||||
new_auth_yaml = yaml.safe_dump(
|
||||
AUTH_KEYS.model_dump(), default_flow_style=False
|
||||
string_stream = io.StringIO()
|
||||
yaml.dump(
|
||||
AUTH_KEYS.model_dump(), string_stream
|
||||
)
|
||||
await auth_file.write(new_auth_yaml)
|
||||
|
||||
await auth_file.write(string_stream.getvalue())
|
||||
|
||||
logger.info(
|
||||
f"Your API key is: {AUTH_KEYS.api_key}\n"
|
||||
|
|
|
|||
|
|
@ -11,8 +11,6 @@ from typing import Dict, List, Optional, Union
|
|||
|
||||
from common.utils import unwrap, prune_dict
|
||||
|
||||
yaml = YAML()
|
||||
|
||||
|
||||
# Common class for sampler params
|
||||
class BaseSamplerRequest(BaseModel):
|
||||
|
|
@ -418,7 +416,10 @@ async def overrides_from_file(preset_name: str):
|
|||
overrides_container.selected_preset = preset_path.stem
|
||||
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
contents = await raw_preset.read()
|
||||
preset = yaml.safe_load(contents)
|
||||
|
||||
# Create a temporary YAML parser
|
||||
yaml = YAML(typ="safe")
|
||||
preset = yaml.load(contents)
|
||||
overrides_from_dict(preset)
|
||||
|
||||
logger.info("Applied sampler overrides from file.")
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from ruamel.yaml.scalarstring import PreservedScalarString
|
|||
from common.config_models import BaseConfigModel, TabbyConfigModel
|
||||
from common.utils import merge_dicts, unwrap
|
||||
|
||||
yaml = YAML()
|
||||
yaml = YAML(typ=["rt", "safe"])
|
||||
|
||||
|
||||
class TabbyConfig(TabbyConfigModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue