fix model names

This commit is contained in:
TerminalMan 2024-09-12 17:00:07 +01:00
parent 05f1c3e293
commit 8b48f00271
2 changed files with 38 additions and 35 deletions

View file

@ -4,13 +4,13 @@ from typing import List, Optional, Union
from common.utils import unwrap
class config_config_model(BaseModel):
class ConfigConfig(BaseModel):
config: Optional[str] = Field(
None, description=("Path to an overriding config.yml file")
)
class network_config_model(BaseModel):
class NetworkConfig(BaseModel):
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
port: Optional[int] = Field(5000, description=("The port to host on"))
disable_auth: Optional[bool] = Field(
@ -28,7 +28,7 @@ class network_config_model(BaseModel):
)
class logging_config_model(BaseModel):
class LoggingConfig(BaseModel):
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
log_generation_params: Optional[bool] = Field(
False, description=("Enable generation parameter logging")
@ -36,7 +36,7 @@ class logging_config_model(BaseModel):
log_requests: Optional[bool] = Field(False, description=("Enable request logging"))
class model_config_model(BaseModel):
class ModelConfig(BaseModel):
model_dir: str = Field(
"models",
description=(
@ -171,8 +171,10 @@ class model_config_model(BaseModel):
),
)
model_config = ConfigDict(protected_namespaces=())
class draft_model_config_model(BaseModel):
class DraftModelConfig(BaseModel):
draft_model_dir: Optional[str] = Field(
"models",
description=(
@ -209,18 +211,18 @@ class draft_model_config_model(BaseModel):
)
class lora_instance_model(BaseModel):
class LoraInstanceModel(BaseModel):
name: str = Field(..., description=("Name of the LoRA model"))
scaling: float = Field(
1.0, description=("Scaling factor for the LoRA model (default: 1.0)")
)
class lora_config_model(BaseModel):
class LoraConfig(BaseModel):
lora_dir: Optional[str] = Field(
"loras", description=("Directory to look for LoRAs (default: 'loras')")
)
loras: Optional[List[lora_instance_model]] = Field(
loras: Optional[List[LoraInstanceModel]] = Field(
None,
description=(
"List of LoRAs to load and associated scaling factors (default scaling:"
@ -229,13 +231,13 @@ class lora_config_model(BaseModel):
)
class sampling_config_model(BaseModel):
class SamplingConfig(BaseModel):
override_preset: Optional[str] = Field(
None, description=("Select a sampler override preset")
)
class developer_config_model(BaseModel):
class DeveloperConfig(BaseModel):
unsafe_launch: Optional[bool] = Field(
False, description=("Skip Exllamav2 version check")
)
@ -257,7 +259,7 @@ class developer_config_model(BaseModel):
)
class embeddings_config_model(BaseModel):
class EmbeddingsConfig(BaseModel):
embedding_model_dir: Optional[str] = Field(
"models",
description=(
@ -276,18 +278,20 @@ class embeddings_config_model(BaseModel):
)
class tabby_config_model(BaseModel):
config: config_config_model = Field(default_factory=config_config_model)
network: network_config_model = Field(default_factory=network_config_model)
logging: logging_config_model = Field(default_factory=logging_config_model)
model: model_config_model = Field(default_factory=model_config_model)
draft_model: draft_model_config_model = Field(
default_factory=draft_model_config_model
class TabbyConfigModel(BaseModel):
config: ConfigConfig = Field(default_factory=ConfigConfig.model_construct)
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)
draft_model: DraftModelConfig = Field(
default_factory=DraftModelConfig.model_construct
)
lora: LoraConfig = Field(default_factory=LoraConfig.model_construct)
sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct)
developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct)
embeddings: EmbeddingsConfig = Field(
default_factory=EmbeddingsConfig.model_construct
)
lora: lora_config_model = Field(default_factory=lora_config_model)
sampling: sampling_config_model = Field(default_factory=sampling_config_model)
developer: developer_config_model = Field(default_factory=developer_config_model)
embeddings: embeddings_config_model = Field(default_factory=embeddings_config_model)
@model_validator(mode="before")
def set_defaults(cls, values):
@ -297,11 +301,11 @@ class tabby_config_model(BaseModel):
values[field_name] = cls.__annotations__[field_name](**default_instance)
return values
model_config = ConfigDict(validate_assignment=True)
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
def generate_config_file(filename="config_sample.yml", indentation=2):
schema = tabby_config_model.model_json_schema()
schema = TabbyConfigModel.model_json_schema()
def dump_def(id: str, indent=2):
yaml = ""

View file

@ -5,11 +5,10 @@ from typing import Optional
from os import getenv
from common.utils import unwrap, merge_dicts
from common.config_models import tabby_config_model
import common.config_models
from common.config_models import TabbyConfigModel
class TabbyConfig(tabby_config_model):
class TabbyConfig(TabbyConfigModel):
# Persistent defaults
# TODO: make this pydantic?
model_defaults: dict = {}
@ -26,11 +25,11 @@ class TabbyConfig(tabby_config_model):
merged_config = merge_dicts(*configs)
for field in tabby_config_model.model_fields.keys():
value = unwrap(merged_config.get(field), {})
model = getattr(common.config_models, f"{field}_config_model")
setattr(self, field, model.parse_obj(value))
# validate and update config
merged_config_model = TabbyConfigModel.model_validate(merged_config)
for field in TabbyConfigModel.model_fields.keys():
value = getattr(merged_config_model, field)
setattr(self, field, value)
# Set model defaults dict once to prevent on-demand reconstruction
# TODO: clean this up a bit
@ -71,7 +70,7 @@ class TabbyConfig(tabby_config_model):
config = self._from_file(pathlib.Path(config_override))
return config # Return early if loading from file
for key in tabby_config_model.model_fields.keys():
for key in TabbyConfigModel.model_fields.keys():
override = args.get(key)
if override:
if key == "logging":
@ -86,10 +85,10 @@ class TabbyConfig(tabby_config_model):
config = {}
for field_name in tabby_config_model.model_fields.keys():
for field_name in TabbyConfigModel.model_fields.keys():
section_config = {}
for sub_field_name in getattr(
tabby_config_model(), field_name
TabbyConfigModel(), field_name
).model_fields.keys():
setting = getenv(f"TABBY_{field_name}_{sub_field_name}".upper(), None)
if setting is not None: