fix model names
This commit is contained in:
parent
05f1c3e293
commit
8b48f00271
2 changed files with 38 additions and 35 deletions
|
|
@ -4,13 +4,13 @@ from typing import List, Optional, Union
|
|||
from common.utils import unwrap
|
||||
|
||||
|
||||
class config_config_model(BaseModel):
|
||||
class ConfigConfig(BaseModel):
|
||||
config: Optional[str] = Field(
|
||||
None, description=("Path to an overriding config.yml file")
|
||||
)
|
||||
|
||||
|
||||
class network_config_model(BaseModel):
|
||||
class NetworkConfig(BaseModel):
|
||||
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
|
||||
port: Optional[int] = Field(5000, description=("The port to host on"))
|
||||
disable_auth: Optional[bool] = Field(
|
||||
|
|
@ -28,7 +28,7 @@ class network_config_model(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class logging_config_model(BaseModel):
|
||||
class LoggingConfig(BaseModel):
|
||||
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
|
||||
log_generation_params: Optional[bool] = Field(
|
||||
False, description=("Enable generation parameter logging")
|
||||
|
|
@ -36,7 +36,7 @@ class logging_config_model(BaseModel):
|
|||
log_requests: Optional[bool] = Field(False, description=("Enable request logging"))
|
||||
|
||||
|
||||
class model_config_model(BaseModel):
|
||||
class ModelConfig(BaseModel):
|
||||
model_dir: str = Field(
|
||||
"models",
|
||||
description=(
|
||||
|
|
@ -171,8 +171,10 @@ class model_config_model(BaseModel):
|
|||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
class draft_model_config_model(BaseModel):
|
||||
|
||||
class DraftModelConfig(BaseModel):
|
||||
draft_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=(
|
||||
|
|
@ -209,18 +211,18 @@ class draft_model_config_model(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class lora_instance_model(BaseModel):
|
||||
class LoraInstanceModel(BaseModel):
|
||||
name: str = Field(..., description=("Name of the LoRA model"))
|
||||
scaling: float = Field(
|
||||
1.0, description=("Scaling factor for the LoRA model (default: 1.0)")
|
||||
)
|
||||
|
||||
|
||||
class lora_config_model(BaseModel):
|
||||
class LoraConfig(BaseModel):
|
||||
lora_dir: Optional[str] = Field(
|
||||
"loras", description=("Directory to look for LoRAs (default: 'loras')")
|
||||
)
|
||||
loras: Optional[List[lora_instance_model]] = Field(
|
||||
loras: Optional[List[LoraInstanceModel]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"List of LoRAs to load and associated scaling factors (default scaling:"
|
||||
|
|
@ -229,13 +231,13 @@ class lora_config_model(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class sampling_config_model(BaseModel):
|
||||
class SamplingConfig(BaseModel):
|
||||
override_preset: Optional[str] = Field(
|
||||
None, description=("Select a sampler override preset")
|
||||
)
|
||||
|
||||
|
||||
class developer_config_model(BaseModel):
|
||||
class DeveloperConfig(BaseModel):
|
||||
unsafe_launch: Optional[bool] = Field(
|
||||
False, description=("Skip Exllamav2 version check")
|
||||
)
|
||||
|
|
@ -257,7 +259,7 @@ class developer_config_model(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class embeddings_config_model(BaseModel):
|
||||
class EmbeddingsConfig(BaseModel):
|
||||
embedding_model_dir: Optional[str] = Field(
|
||||
"models",
|
||||
description=(
|
||||
|
|
@ -276,18 +278,20 @@ class embeddings_config_model(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class tabby_config_model(BaseModel):
|
||||
config: config_config_model = Field(default_factory=config_config_model)
|
||||
network: network_config_model = Field(default_factory=network_config_model)
|
||||
logging: logging_config_model = Field(default_factory=logging_config_model)
|
||||
model: model_config_model = Field(default_factory=model_config_model)
|
||||
draft_model: draft_model_config_model = Field(
|
||||
default_factory=draft_model_config_model
|
||||
class TabbyConfigModel(BaseModel):
|
||||
config: ConfigConfig = Field(default_factory=ConfigConfig.model_construct)
|
||||
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
|
||||
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)
|
||||
draft_model: DraftModelConfig = Field(
|
||||
default_factory=DraftModelConfig.model_construct
|
||||
)
|
||||
lora: LoraConfig = Field(default_factory=LoraConfig.model_construct)
|
||||
sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct)
|
||||
developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct)
|
||||
embeddings: EmbeddingsConfig = Field(
|
||||
default_factory=EmbeddingsConfig.model_construct
|
||||
)
|
||||
lora: lora_config_model = Field(default_factory=lora_config_model)
|
||||
sampling: sampling_config_model = Field(default_factory=sampling_config_model)
|
||||
developer: developer_config_model = Field(default_factory=developer_config_model)
|
||||
embeddings: embeddings_config_model = Field(default_factory=embeddings_config_model)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def set_defaults(cls, values):
|
||||
|
|
@ -297,11 +301,11 @@ class tabby_config_model(BaseModel):
|
|||
values[field_name] = cls.__annotations__[field_name](**default_instance)
|
||||
return values
|
||||
|
||||
model_config = ConfigDict(validate_assignment=True)
|
||||
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
|
||||
|
||||
|
||||
def generate_config_file(filename="config_sample.yml", indentation=2):
|
||||
schema = tabby_config_model.model_json_schema()
|
||||
schema = TabbyConfigModel.model_json_schema()
|
||||
|
||||
def dump_def(id: str, indent=2):
|
||||
yaml = ""
|
||||
|
|
|
|||
|
|
@ -5,11 +5,10 @@ from typing import Optional
|
|||
from os import getenv
|
||||
|
||||
from common.utils import unwrap, merge_dicts
|
||||
from common.config_models import tabby_config_model
|
||||
import common.config_models
|
||||
from common.config_models import TabbyConfigModel
|
||||
|
||||
|
||||
class TabbyConfig(tabby_config_model):
|
||||
class TabbyConfig(TabbyConfigModel):
|
||||
# Persistent defaults
|
||||
# TODO: make this pydantic?
|
||||
model_defaults: dict = {}
|
||||
|
|
@ -26,11 +25,11 @@ class TabbyConfig(tabby_config_model):
|
|||
|
||||
merged_config = merge_dicts(*configs)
|
||||
|
||||
for field in tabby_config_model.model_fields.keys():
|
||||
value = unwrap(merged_config.get(field), {})
|
||||
model = getattr(common.config_models, f"{field}_config_model")
|
||||
|
||||
setattr(self, field, model.parse_obj(value))
|
||||
# validate and update config
|
||||
merged_config_model = TabbyConfigModel.model_validate(merged_config)
|
||||
for field in TabbyConfigModel.model_fields.keys():
|
||||
value = getattr(merged_config_model, field)
|
||||
setattr(self, field, value)
|
||||
|
||||
# Set model defaults dict once to prevent on-demand reconstruction
|
||||
# TODO: clean this up a bit
|
||||
|
|
@ -71,7 +70,7 @@ class TabbyConfig(tabby_config_model):
|
|||
config = self._from_file(pathlib.Path(config_override))
|
||||
return config # Return early if loading from file
|
||||
|
||||
for key in tabby_config_model.model_fields.keys():
|
||||
for key in TabbyConfigModel.model_fields.keys():
|
||||
override = args.get(key)
|
||||
if override:
|
||||
if key == "logging":
|
||||
|
|
@ -86,10 +85,10 @@ class TabbyConfig(tabby_config_model):
|
|||
|
||||
config = {}
|
||||
|
||||
for field_name in tabby_config_model.model_fields.keys():
|
||||
for field_name in TabbyConfigModel.model_fields.keys():
|
||||
section_config = {}
|
||||
for sub_field_name in getattr(
|
||||
tabby_config_model(), field_name
|
||||
TabbyConfigModel(), field_name
|
||||
).model_fields.keys():
|
||||
setting = getenv(f"TABBY_{field_name}_{sub_field_name}".upper(), None)
|
||||
if setting is not None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue