fix model names
This commit is contained in:
parent
05f1c3e293
commit
8b48f00271
2 changed files with 38 additions and 35 deletions
|
|
@ -4,13 +4,13 @@ from typing import List, Optional, Union
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
|
|
||||||
|
|
||||||
class config_config_model(BaseModel):
|
class ConfigConfig(BaseModel):
|
||||||
config: Optional[str] = Field(
|
config: Optional[str] = Field(
|
||||||
None, description=("Path to an overriding config.yml file")
|
None, description=("Path to an overriding config.yml file")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class network_config_model(BaseModel):
|
class NetworkConfig(BaseModel):
|
||||||
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
|
host: Optional[str] = Field("127.0.0.1", description=("The IP to host on"))
|
||||||
port: Optional[int] = Field(5000, description=("The port to host on"))
|
port: Optional[int] = Field(5000, description=("The port to host on"))
|
||||||
disable_auth: Optional[bool] = Field(
|
disable_auth: Optional[bool] = Field(
|
||||||
|
|
@ -28,7 +28,7 @@ class network_config_model(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class logging_config_model(BaseModel):
|
class LoggingConfig(BaseModel):
|
||||||
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
|
log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging"))
|
||||||
log_generation_params: Optional[bool] = Field(
|
log_generation_params: Optional[bool] = Field(
|
||||||
False, description=("Enable generation parameter logging")
|
False, description=("Enable generation parameter logging")
|
||||||
|
|
@ -36,7 +36,7 @@ class logging_config_model(BaseModel):
|
||||||
log_requests: Optional[bool] = Field(False, description=("Enable request logging"))
|
log_requests: Optional[bool] = Field(False, description=("Enable request logging"))
|
||||||
|
|
||||||
|
|
||||||
class model_config_model(BaseModel):
|
class ModelConfig(BaseModel):
|
||||||
model_dir: str = Field(
|
model_dir: str = Field(
|
||||||
"models",
|
"models",
|
||||||
description=(
|
description=(
|
||||||
|
|
@ -171,8 +171,10 @@ class model_config_model(BaseModel):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
class draft_model_config_model(BaseModel):
|
|
||||||
|
class DraftModelConfig(BaseModel):
|
||||||
draft_model_dir: Optional[str] = Field(
|
draft_model_dir: Optional[str] = Field(
|
||||||
"models",
|
"models",
|
||||||
description=(
|
description=(
|
||||||
|
|
@ -209,18 +211,18 @@ class draft_model_config_model(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class lora_instance_model(BaseModel):
|
class LoraInstanceModel(BaseModel):
|
||||||
name: str = Field(..., description=("Name of the LoRA model"))
|
name: str = Field(..., description=("Name of the LoRA model"))
|
||||||
scaling: float = Field(
|
scaling: float = Field(
|
||||||
1.0, description=("Scaling factor for the LoRA model (default: 1.0)")
|
1.0, description=("Scaling factor for the LoRA model (default: 1.0)")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class lora_config_model(BaseModel):
|
class LoraConfig(BaseModel):
|
||||||
lora_dir: Optional[str] = Field(
|
lora_dir: Optional[str] = Field(
|
||||||
"loras", description=("Directory to look for LoRAs (default: 'loras')")
|
"loras", description=("Directory to look for LoRAs (default: 'loras')")
|
||||||
)
|
)
|
||||||
loras: Optional[List[lora_instance_model]] = Field(
|
loras: Optional[List[LoraInstanceModel]] = Field(
|
||||||
None,
|
None,
|
||||||
description=(
|
description=(
|
||||||
"List of LoRAs to load and associated scaling factors (default scaling:"
|
"List of LoRAs to load and associated scaling factors (default scaling:"
|
||||||
|
|
@ -229,13 +231,13 @@ class lora_config_model(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class sampling_config_model(BaseModel):
|
class SamplingConfig(BaseModel):
|
||||||
override_preset: Optional[str] = Field(
|
override_preset: Optional[str] = Field(
|
||||||
None, description=("Select a sampler override preset")
|
None, description=("Select a sampler override preset")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class developer_config_model(BaseModel):
|
class DeveloperConfig(BaseModel):
|
||||||
unsafe_launch: Optional[bool] = Field(
|
unsafe_launch: Optional[bool] = Field(
|
||||||
False, description=("Skip Exllamav2 version check")
|
False, description=("Skip Exllamav2 version check")
|
||||||
)
|
)
|
||||||
|
|
@ -257,7 +259,7 @@ class developer_config_model(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class embeddings_config_model(BaseModel):
|
class EmbeddingsConfig(BaseModel):
|
||||||
embedding_model_dir: Optional[str] = Field(
|
embedding_model_dir: Optional[str] = Field(
|
||||||
"models",
|
"models",
|
||||||
description=(
|
description=(
|
||||||
|
|
@ -276,18 +278,20 @@ class embeddings_config_model(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class tabby_config_model(BaseModel):
|
class TabbyConfigModel(BaseModel):
|
||||||
config: config_config_model = Field(default_factory=config_config_model)
|
config: ConfigConfig = Field(default_factory=ConfigConfig.model_construct)
|
||||||
network: network_config_model = Field(default_factory=network_config_model)
|
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
|
||||||
logging: logging_config_model = Field(default_factory=logging_config_model)
|
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
|
||||||
model: model_config_model = Field(default_factory=model_config_model)
|
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)
|
||||||
draft_model: draft_model_config_model = Field(
|
draft_model: DraftModelConfig = Field(
|
||||||
default_factory=draft_model_config_model
|
default_factory=DraftModelConfig.model_construct
|
||||||
|
)
|
||||||
|
lora: LoraConfig = Field(default_factory=LoraConfig.model_construct)
|
||||||
|
sampling: SamplingConfig = Field(default_factory=SamplingConfig.model_construct)
|
||||||
|
developer: DeveloperConfig = Field(default_factory=DeveloperConfig.model_construct)
|
||||||
|
embeddings: EmbeddingsConfig = Field(
|
||||||
|
default_factory=EmbeddingsConfig.model_construct
|
||||||
)
|
)
|
||||||
lora: lora_config_model = Field(default_factory=lora_config_model)
|
|
||||||
sampling: sampling_config_model = Field(default_factory=sampling_config_model)
|
|
||||||
developer: developer_config_model = Field(default_factory=developer_config_model)
|
|
||||||
embeddings: embeddings_config_model = Field(default_factory=embeddings_config_model)
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
def set_defaults(cls, values):
|
def set_defaults(cls, values):
|
||||||
|
|
@ -297,11 +301,11 @@ class tabby_config_model(BaseModel):
|
||||||
values[field_name] = cls.__annotations__[field_name](**default_instance)
|
values[field_name] = cls.__annotations__[field_name](**default_instance)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
model_config = ConfigDict(validate_assignment=True)
|
model_config = ConfigDict(validate_assignment=True, protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
def generate_config_file(filename="config_sample.yml", indentation=2):
|
def generate_config_file(filename="config_sample.yml", indentation=2):
|
||||||
schema = tabby_config_model.model_json_schema()
|
schema = TabbyConfigModel.model_json_schema()
|
||||||
|
|
||||||
def dump_def(id: str, indent=2):
|
def dump_def(id: str, indent=2):
|
||||||
yaml = ""
|
yaml = ""
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,10 @@ from typing import Optional
|
||||||
from os import getenv
|
from os import getenv
|
||||||
|
|
||||||
from common.utils import unwrap, merge_dicts
|
from common.utils import unwrap, merge_dicts
|
||||||
from common.config_models import tabby_config_model
|
from common.config_models import TabbyConfigModel
|
||||||
import common.config_models
|
|
||||||
|
|
||||||
|
|
||||||
class TabbyConfig(tabby_config_model):
|
class TabbyConfig(TabbyConfigModel):
|
||||||
# Persistent defaults
|
# Persistent defaults
|
||||||
# TODO: make this pydantic?
|
# TODO: make this pydantic?
|
||||||
model_defaults: dict = {}
|
model_defaults: dict = {}
|
||||||
|
|
@ -26,11 +25,11 @@ class TabbyConfig(tabby_config_model):
|
||||||
|
|
||||||
merged_config = merge_dicts(*configs)
|
merged_config = merge_dicts(*configs)
|
||||||
|
|
||||||
for field in tabby_config_model.model_fields.keys():
|
# validate and update config
|
||||||
value = unwrap(merged_config.get(field), {})
|
merged_config_model = TabbyConfigModel.model_validate(merged_config)
|
||||||
model = getattr(common.config_models, f"{field}_config_model")
|
for field in TabbyConfigModel.model_fields.keys():
|
||||||
|
value = getattr(merged_config_model, field)
|
||||||
setattr(self, field, model.parse_obj(value))
|
setattr(self, field, value)
|
||||||
|
|
||||||
# Set model defaults dict once to prevent on-demand reconstruction
|
# Set model defaults dict once to prevent on-demand reconstruction
|
||||||
# TODO: clean this up a bit
|
# TODO: clean this up a bit
|
||||||
|
|
@ -71,7 +70,7 @@ class TabbyConfig(tabby_config_model):
|
||||||
config = self._from_file(pathlib.Path(config_override))
|
config = self._from_file(pathlib.Path(config_override))
|
||||||
return config # Return early if loading from file
|
return config # Return early if loading from file
|
||||||
|
|
||||||
for key in tabby_config_model.model_fields.keys():
|
for key in TabbyConfigModel.model_fields.keys():
|
||||||
override = args.get(key)
|
override = args.get(key)
|
||||||
if override:
|
if override:
|
||||||
if key == "logging":
|
if key == "logging":
|
||||||
|
|
@ -86,10 +85,10 @@ class TabbyConfig(tabby_config_model):
|
||||||
|
|
||||||
config = {}
|
config = {}
|
||||||
|
|
||||||
for field_name in tabby_config_model.model_fields.keys():
|
for field_name in TabbyConfigModel.model_fields.keys():
|
||||||
section_config = {}
|
section_config = {}
|
||||||
for sub_field_name in getattr(
|
for sub_field_name in getattr(
|
||||||
tabby_config_model(), field_name
|
TabbyConfigModel(), field_name
|
||||||
).model_fields.keys():
|
).model_fields.keys():
|
||||||
setting = getenv(f"TABBY_{field_name}_{sub_field_name}".upper(), None)
|
setting = getenv(f"TABBY_{field_name}_{sub_field_name}".upper(), None)
|
||||||
if setting is not None:
|
if setting is not None:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue