improve config generation action

This commit is contained in:
TerminalMan 2024-09-15 17:50:37 +01:00
parent f05229bce4
commit 92af656705
2 changed files with 52 additions and 34 deletions

View file

@ -20,7 +20,7 @@ def branch_to_actions() -> bool:
)
elif config.actions.export_config:
generate_config_file(config.actions.config_export_path)
generate_config_file(filename=config.actions.config_export_path)
else:
# did not branch

View file

@ -1,11 +1,25 @@
from pydantic import AliasChoices, BaseModel, ConfigDict, Field
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, PrivateAttr
from typing import List, Literal, Optional, Union
from pathlib import Path
from pydantic_core import PydanticUndefined
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
class ConfigOverrideConfig(BaseModel):
class Metadata(BaseModel):
"""metadata model for config options"""
include_in_config: Optional[bool] = Field(True)
class BaseConfigModel(BaseModel):
"""Base model for config models with added metadata"""
_metadata: Metadata = PrivateAttr(Metadata())
class ConfigOverrideConfig(BaseConfigModel):
"""Model for overriding a provided config file."""
# TODO: convert this to a pathlib.path?
@ -13,8 +27,10 @@ class ConfigOverrideConfig(BaseModel):
None, description=("Path to an overriding config.yml file")
)
_metadata: Metadata = PrivateAttr(Metadata(include_in_config=False))
class UtilityActions(BaseModel):
class UtilityActions(BaseConfigModel):
"""Model used for arg actions."""
# YAML export options
@ -33,8 +49,10 @@ class UtilityActions(BaseModel):
"openapi.json", description="path to export openapi schema to"
)
_metadata: Metadata = PrivateAttr(Metadata(include_in_config=False))
class NetworkConfig(BaseModel):
class NetworkConfig(BaseConfigModel):
"""Model for network configuration."""
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
@ -54,7 +72,7 @@ class NetworkConfig(BaseModel):
# TODO: Migrate config.yml to have the log_ prefix
# This is a breaking change.
class LoggingConfig(BaseModel):
class LoggingConfig(BaseConfigModel):
"""Model for logging configuration."""
log_prompt: Optional[bool] = Field(
@ -74,7 +92,7 @@ class LoggingConfig(BaseModel):
)
class ModelConfig(BaseModel):
class ModelConfig(BaseConfigModel):
"""Model for LLM configuration."""
# TODO: convert this to a pathlib.path?
@ -219,10 +237,11 @@ class ModelConfig(BaseModel):
),
)
_metadata: Metadata = PrivateAttr(Metadata())
model_config = ConfigDict(protected_namespaces=())
class DraftModelConfig(BaseModel):
class DraftModelConfig(BaseConfigModel):
"""Model for draft LLM model configuration."""
# TODO: convert this to a pathlib.path?
@ -262,7 +281,7 @@ class DraftModelConfig(BaseModel):
)
class LoraInstanceModel(BaseModel):
class LoraInstanceModel(BaseConfigModel):
"""Model representing an instance of a Lora."""
name: str = Field(..., description=("Name of the LoRA model"))
@ -273,7 +292,7 @@ class LoraInstanceModel(BaseModel):
)
class LoraConfig(BaseModel):
class LoraConfig(BaseConfigModel):
"""Model for lora configuration."""
# TODO: convert this to a pathlib.path?
@ -289,7 +308,7 @@ class LoraConfig(BaseModel):
)
class SamplingConfig(BaseModel):
class SamplingConfig(BaseConfigModel):
"""Model for sampling (overrides) config."""
override_preset: Optional[str] = Field(
@ -297,7 +316,7 @@ class SamplingConfig(BaseModel):
)
class DeveloperConfig(BaseModel):
class DeveloperConfig(BaseConfigModel):
"""Model for developer settings configuration."""
unsafe_launch: Optional[bool] = Field(
@ -321,7 +340,7 @@ class DeveloperConfig(BaseModel):
)
class EmbeddingsConfig(BaseModel):
class EmbeddingsConfig(BaseConfigModel):
"""Model for embeddings configuration."""
# TODO: convert this to a pathlib.path?
@ -366,30 +385,29 @@ class TabbyConfigModel(BaseModel):
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
def generate_config_file(filename="config_sample.yml", indentation=2):
def generate_config_file(
model: BaseConfigModel = None,
filename: str = "config_sample.yml",
indentation: int = 2,
) -> None:
"""Creates a config.yml file from Pydantic models."""
schema = TabbyConfigModel.model_json_schema()
def dump_def(id: str, indent=2):
yaml = ""
indent = " " * indentation * indent
id = id.split("/")[-1]
section = schema["$defs"][id]["properties"]
for property in section.keys(): # get type
comment = section[property]["description"]
yaml += f"{indent}# {comment}\n"
value = section[property].get("default", "")
yaml += f"{indent}{property}: {value}\n\n"
return yaml + "\n"
schema = model if model else TabbyConfigModel()
yaml = ""
for section in schema["properties"].keys():
yaml += f"{section}:\n"
yaml += dump_def(schema["properties"][section]["$ref"])
for field, field_data in schema.model_fields.items():
subfield_model = field_data.default_factory()
if not subfield_model._metadata.include_in_config:
continue
yaml += f"# {subfield_model.__doc__}\n"
yaml += f"{field}:\n"
for subfield, subfield_data in subfield_model.model_fields.items():
value = subfield_data.default
value = value if value is not None else ""
value = value if value is not PydanticUndefined else ""
yaml += f"{' ' * indentation}# {subfield_data.description}\n"
yaml += f"{' ' * indentation}{subfield}: {value}\n"
yaml += "\n"
with open(filename, "w") as f: