improve config generation action

This commit is contained in:
TerminalMan 2024-09-15 17:50:37 +01:00
parent f05229bce4
commit 92af656705
2 changed files with 52 additions and 34 deletions

View file

@ -20,7 +20,7 @@ def branch_to_actions() -> bool:
) )
elif config.actions.export_config: elif config.actions.export_config:
generate_config_file(config.actions.config_export_path) generate_config_file(filename=config.actions.config_export_path)
else: else:
# did not branch # did not branch

View file

@ -1,11 +1,25 @@
from pydantic import AliasChoices, BaseModel, ConfigDict, Field from pydantic import AliasChoices, BaseModel, ConfigDict, Field, PrivateAttr
from typing import List, Literal, Optional, Union from typing import List, Literal, Optional, Union
from pathlib import Path from pathlib import Path
from pydantic_core import PydanticUndefined
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
class ConfigOverrideConfig(BaseModel): class Metadata(BaseModel):
"""metadata model for config options"""
include_in_config: Optional[bool] = Field(True)
class BaseConfigModel(BaseModel):
"""Base model for config models with added metadata"""
_metadata: Metadata = PrivateAttr(Metadata())
class ConfigOverrideConfig(BaseConfigModel):
"""Model for overriding a provided config file.""" """Model for overriding a provided config file."""
# TODO: convert this to a pathlib.path? # TODO: convert this to a pathlib.path?
@ -13,8 +27,10 @@ class ConfigOverrideConfig(BaseModel):
None, description=("Path to an overriding config.yml file") None, description=("Path to an overriding config.yml file")
) )
_metadata: Metadata = PrivateAttr(Metadata(include_in_config=False))
class UtilityActions(BaseModel):
class UtilityActions(BaseConfigModel):
"""Model used for arg actions.""" """Model used for arg actions."""
# YAML export options # YAML export options
@ -33,8 +49,10 @@ class UtilityActions(BaseModel):
"openapi.json", description="path to export openapi schema to" "openapi.json", description="path to export openapi schema to"
) )
_metadata: Metadata = PrivateAttr(Metadata(include_in_config=False))
class NetworkConfig(BaseModel):
class NetworkConfig(BaseConfigModel):
"""Model for network configuration.""" """Model for network configuration."""
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on")) host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
@ -54,7 +72,7 @@ class NetworkConfig(BaseModel):
# TODO: Migrate config.yml to have the log_ prefix # TODO: Migrate config.yml to have the log_ prefix
# This is a breaking change. # This is a breaking change.
class LoggingConfig(BaseModel): class LoggingConfig(BaseConfigModel):
"""Model for logging configuration.""" """Model for logging configuration."""
log_prompt: Optional[bool] = Field( log_prompt: Optional[bool] = Field(
@ -74,7 +92,7 @@ class LoggingConfig(BaseModel):
) )
class ModelConfig(BaseModel): class ModelConfig(BaseConfigModel):
"""Model for LLM configuration.""" """Model for LLM configuration."""
# TODO: convert this to a pathlib.path? # TODO: convert this to a pathlib.path?
@ -219,10 +237,11 @@ class ModelConfig(BaseModel):
), ),
) )
_metadata: Metadata = PrivateAttr(Metadata())
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
class DraftModelConfig(BaseModel): class DraftModelConfig(BaseConfigModel):
"""Model for draft LLM model configuration.""" """Model for draft LLM model configuration."""
# TODO: convert this to a pathlib.path? # TODO: convert this to a pathlib.path?
@ -262,7 +281,7 @@ class DraftModelConfig(BaseModel):
) )
class LoraInstanceModel(BaseModel): class LoraInstanceModel(BaseConfigModel):
"""Model representing an instance of a Lora.""" """Model representing an instance of a Lora."""
name: str = Field(..., description=("Name of the LoRA model")) name: str = Field(..., description=("Name of the LoRA model"))
@ -273,7 +292,7 @@ class LoraInstanceModel(BaseModel):
) )
class LoraConfig(BaseModel): class LoraConfig(BaseConfigModel):
"""Model for lora configuration.""" """Model for lora configuration."""
# TODO: convert this to a pathlib.path? # TODO: convert this to a pathlib.path?
@ -289,7 +308,7 @@ class LoraConfig(BaseModel):
) )
class SamplingConfig(BaseModel): class SamplingConfig(BaseConfigModel):
"""Model for sampling (overrides) config.""" """Model for sampling (overrides) config."""
override_preset: Optional[str] = Field( override_preset: Optional[str] = Field(
@ -297,7 +316,7 @@ class SamplingConfig(BaseModel):
) )
class DeveloperConfig(BaseModel): class DeveloperConfig(BaseConfigModel):
"""Model for developer settings configuration.""" """Model for developer settings configuration."""
unsafe_launch: Optional[bool] = Field( unsafe_launch: Optional[bool] = Field(
@ -321,7 +340,7 @@ class DeveloperConfig(BaseModel):
) )
class EmbeddingsConfig(BaseModel): class EmbeddingsConfig(BaseConfigModel):
"""Model for embeddings configuration.""" """Model for embeddings configuration."""
# TODO: convert this to a pathlib.path? # TODO: convert this to a pathlib.path?
@ -366,30 +385,29 @@ class TabbyConfigModel(BaseModel):
model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
def generate_config_file(filename="config_sample.yml", indentation=2): def generate_config_file(
model: BaseConfigModel = None,
filename: str = "config_sample.yml",
indentation: int = 2,
) -> None:
"""Creates a config.yml file from Pydantic models.""" """Creates a config.yml file from Pydantic models."""
schema = TabbyConfigModel.model_json_schema() schema = model if model else TabbyConfigModel()
def dump_def(id: str, indent=2):
yaml = ""
indent = " " * indentation * indent
id = id.split("/")[-1]
section = schema["$defs"][id]["properties"]
for property in section.keys(): # get type
comment = section[property]["description"]
yaml += f"{indent}# {comment}\n"
value = section[property].get("default", "")
yaml += f"{indent}{property}: {value}\n\n"
return yaml + "\n"
yaml = "" yaml = ""
for section in schema["properties"].keys():
yaml += f"{section}:\n" for field, field_data in schema.model_fields.items():
yaml += dump_def(schema["properties"][section]["$ref"]) subfield_model = field_data.default_factory()
if not subfield_model._metadata.include_in_config:
continue
yaml += f"# {subfield_model.__doc__}\n"
yaml += f"{field}:\n"
for subfield, subfield_data in subfield_model.model_fields.items():
value = subfield_data.default
value = value if value is not None else ""
value = value if value is not PydanticUndefined else ""
yaml += f"{' ' * indentation}# {subfield_data.description}\n"
yaml += f"{' ' * indentation}{subfield}: {value}\n"
yaml += "\n" yaml += "\n"
with open(filename, "w") as f: with open(filename, "w") as f: