add legacy config converter
This commit is contained in:
parent
b6dd21f737
commit
564bdcf0a8
4 changed files with 71 additions and 19 deletions
|
|
@ -4,7 +4,7 @@ import argparse
|
|||
from pydantic import BaseModel
|
||||
|
||||
from common.config_models import TabbyConfigModel
|
||||
from common.utils import is_list_type
|
||||
from common.utils import is_list_type, unwrap_optional
|
||||
|
||||
|
||||
def add_field_to_group(group, field_name, field_type, field) -> None:
|
||||
|
|
@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser:
|
|||
|
||||
# Loop through each top-level field in the config
|
||||
for field_name, field_info in TabbyConfigModel.model_fields.items():
|
||||
field_type = field_info.annotation
|
||||
field_type = unwrap_optional(field_info.annotation)
|
||||
group = parser.add_argument_group(
|
||||
field_name, description=f"Arguments for {field_name}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -439,22 +439,32 @@ class DeveloperConfig(BaseConfigModel):
|
|||
class TabbyConfigModel(BaseModel):
|
||||
"""Base model for a TabbyConfig."""
|
||||
|
||||
config: ConfigOverrideConfig = Field(
|
||||
config: Optional[ConfigOverrideConfig] = Field(
|
||||
default_factory=ConfigOverrideConfig.model_construct
|
||||
)
|
||||
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
|
||||
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)
|
||||
draft_model: DraftModelConfig = Field(
|
||||
network: Optional[NetworkConfig] = Field(
|
||||
default_factory=NetworkConfig.model_construct
|
||||
)
|
||||
logging: Optional[LoggingConfig] = Field(
|
||||
default_factory=LoggingConfig.model_construct
|
||||
)
|
||||
model: Optional[ModelConfig] = Field(default_factory=ModelConfig.model_construct)
|
||||
draft_model: Optional[DraftModelConfig] = Field(
|
||||
default_factory=DraftModelConfig.model_construct
|
||||
)
|
||||
lora: LoraConfig = Field(default_factory=LoraConfig.model_construct)
|
||||
embeddings: EmbeddingsConfig = Field(
|
||||
lora: Optional[LoraConfig] = Field(default_factory=LoraConfig.model_construct)
|
||||
embeddings: Optional[EmbeddingsConfig] = Field(
|
||||
default_factory=EmbeddingsConfig.model_construct
|
||||
)
|
||||
sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct)
|
||||
developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct)
|
||||
actions: UtilityActions = Field(default_factory=UtilityActions.model_construct)
|
||||
sampling: Optional[SamplingConfig] = Field(
|
||||
default_factory=SamplingConfig.model_construct
|
||||
)
|
||||
developer: Optional[DeveloperConfig] = Field(
|
||||
default_factory=DeveloperConfig.model_construct
|
||||
)
|
||||
actions: Optional[UtilityActions] = Field(
|
||||
default_factory=UtilityActions.model_construct
|
||||
)
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
|
||||
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ import yaml
|
|||
import pathlib
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
from os import getenv
|
||||
from os import getenv, replace
|
||||
|
||||
from common.utils import unwrap, merge_dicts
|
||||
from common.config_models import TabbyConfigModel
|
||||
from common.config_models import TabbyConfigModel, generate_config_file
|
||||
|
||||
|
||||
class TabbyConfig(TabbyConfigModel):
|
||||
|
|
@ -46,10 +46,25 @@ class TabbyConfig(TabbyConfigModel):
|
|||
def _from_file(self, config_path: pathlib.Path):
|
||||
"""loads config from a given file path"""
|
||||
|
||||
legacy = False
|
||||
cfg = {}
|
||||
|
||||
# try loading from file
|
||||
try:
|
||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||
return unwrap(yaml.safe_load(config_file), {})
|
||||
cfg = yaml.safe_load(config_file)
|
||||
|
||||
# FIXME: remove legacy config mapper
|
||||
# load legacy config files
|
||||
model = cfg.get("model", {})
|
||||
|
||||
if model.get("draft"):
|
||||
legacy = True
|
||||
cfg["draft"] = model["draft"]
|
||||
if model.get("lora"):
|
||||
legacy = True
|
||||
cfg["lora"] = model["lora"]
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.info(f"The '{config_path.name}' file cannot be found")
|
||||
except Exception as exc:
|
||||
|
|
@ -58,8 +73,21 @@ class TabbyConfig(TabbyConfigModel):
|
|||
f"the following error:\n\n{exc}"
|
||||
)
|
||||
|
||||
# if no config file was loaded
|
||||
return {}
|
||||
if legacy:
|
||||
logger.warning(
|
||||
"legacy config.yml files are deprecated"
|
||||
"Please upadte to the new version"
|
||||
"Attempting auto migrationy"
|
||||
)
|
||||
new_cfg = TabbyConfigModel.model_validate(cfg)
|
||||
|
||||
try:
|
||||
replace(config_path, f"{config_path}.bak")
|
||||
generate_config_file(model=new_cfg, filename=config_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Auto migration failed: {e}")
|
||||
|
||||
return unwrap(cfg, {})
|
||||
|
||||
def _from_args(self, args: dict):
|
||||
"""loads config from the provided arguments"""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Common utility functions"""
|
||||
|
||||
from typing import get_args, get_origin
|
||||
from types import NoneType
|
||||
from typing import Optional, Type, Union, get_args, get_origin
|
||||
|
||||
|
||||
def unwrap(wrapped, default=None):
|
||||
|
|
@ -47,7 +48,7 @@ def flat_map(input_list):
|
|||
return [item for sublist in input_list for item in sublist]
|
||||
|
||||
|
||||
def is_list_type(type_hint):
|
||||
def is_list_type(type_hint) -> bool:
|
||||
"""Checks if a type contains a list."""
|
||||
|
||||
if get_origin(type_hint) is list:
|
||||
|
|
@ -59,3 +60,16 @@ def is_list_type(type_hint):
|
|||
return any(is_list_type(arg) for arg in type_args)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def unwrap_optional(type_hint) -> Type:
|
||||
"""unwrap Optional[type] annotations"""
|
||||
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
if NoneType in args:
|
||||
for arg in args:
|
||||
if arg is not NoneType:
|
||||
return arg
|
||||
|
||||
return type_hint
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue