add legacy config converter

This commit is contained in:
TerminalMan 2024-09-16 14:12:47 +01:00
parent b6dd21f737
commit 564bdcf0a8
4 changed files with 71 additions and 19 deletions

View file

@ -4,7 +4,7 @@ import argparse
from pydantic import BaseModel
from common.config_models import TabbyConfigModel
from common.utils import is_list_type
from common.utils import is_list_type, unwrap_optional
def add_field_to_group(group, field_name, field_type, field) -> None:
@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser:
# Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = field_info.annotation
field_type = unwrap_optional(field_info.annotation)
group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}"
)

View file

@ -439,22 +439,32 @@ class DeveloperConfig(BaseConfigModel):
class TabbyConfigModel(BaseModel):
"""Base model for a TabbyConfig."""
config: ConfigOverrideConfig = Field(
config: Optional[ConfigOverrideConfig] = Field(
default_factory=ConfigOverrideConfig.model_construct
)
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)
draft_model: DraftModelConfig = Field(
network: Optional[NetworkConfig] = Field(
default_factory=NetworkConfig.model_construct
)
logging: Optional[LoggingConfig] = Field(
default_factory=LoggingConfig.model_construct
)
model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct)
draft_model: Optional[DraftModelConfig] = Field(
default_factory=DraftModelConfig.model_construct
)
lora: LoraConfig = Field(default_factory=LoraConfig.model_construct)
embeddings: EmbeddingsConfig = Field(
lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct)
embeddings: Optional[EmbeddingsConfig] = Field(
default_factory=EmbeddingsConfig.model_construct
)
sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct)
developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct)
actions: UtilityActions = Field(default_factory=UtilityActions.model_construct)
sampling: Optional[SamplingConfig] = Field(
default_factory=SamplingConfig.model_construct
)
developer: Optional[DeveloperConfig] = Field(
default_factory=DeveloperConfig.model_construct
)
actions: Optional[UtilityActions] = Field(
default_factory=UtilityActions.model_construct
)
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())

View file

@ -2,10 +2,10 @@ import yaml
import pathlib
from loguru import logger
from typing import Optional
from os import getenv
from os import getenv, replace
from common.utils import unwrap, merge_dicts
from common.config_models import TabbyConfigModel
from common.config_models import TabbyConfigModel, generate_config_file
class TabbyConfig(TabbyConfigModel):
@ -46,10 +46,25 @@ class TabbyConfig(TabbyConfigModel):
def _from_file(self, config_path: pathlib.Path):
"""loads config from a given file path"""
legacy = False
cfg = {}
# try loading from file
try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
return unwrap(yaml.safe_load(config_file), {})
cfg = yaml.safe_load(config_file)
# FIXME: remove legacy config mapper
# load legacy config files
model = cfg.get("model", {})
if model.get("draft"):
legacy = True
cfg["draft"] = model["draft"]
if model.get("lora"):
legacy = True
cfg["lora"] = model["lora"]
except FileNotFoundError:
logger.info(f"The '{config_path.name}' file cannot be found")
except Exception as exc:
@ -58,8 +73,21 @@ class TabbyConfig(TabbyConfigModel):
f"the following error:\n\n{exc}"
)
# if no config file was loaded
return {}
if legacy:
logger.warning(
"legacy config.yml files are deprecated"
"Please upadte to the new version"
"Attempting auto migrationy"
)
new_cfg = TabbyConfigModel.model_validate(cfg)
try:
replace(config_path, f"{config_path}.bak")
generate_config_file(model=new_cfg, filename=config_path)
except Exception as e:
logger.error(f"Auto migration failed: {e}")
return unwrap(cfg, {})
def _from_args(self, args: dict):
"""loads config from the provided arguments"""

View file

@ -1,6 +1,7 @@
"""Common utility functions"""
from typing import get_args, get_origin
from types import NoneType
from typing import Optional, Type, Union, get_args, get_origin
def unwrap(wrapped, default=None):
@ -47,7 +48,7 @@ def flat_map(input_list):
return [item for sublist in input_list for item in sublist]
def is_list_type(type_hint):
def is_list_type(type_hint) -> bool:
"""Checks if a type contains a list."""
if get_origin(type_hint) is list:
@ -59,3 +60,16 @@ def is_list_type(type_hint):
return any(is_list_type(arg) for arg in type_args)
return False
def unwrap_optional(type_hint) -> Type:
"""unwrap Optional[type] annotations"""
if get_origin(type_hint) is Union:
args = get_args(type_hint)
if NoneType in args:
for arg in args:
if arg is not NoneType:
return arg
return type_hint