config is now backed by pydantic (WIP)

- add models for config options
- add function to regenerate config.yml
- replace references to config with pydantic compatible references
- remove unnecessary unwrap() statements

TODO:

- auto generate env vars
- auto generate argparse
- test loading a model
This commit is contained in:
Jake 2024-09-05 18:04:56 +01:00
parent cb91670c7a
commit 362b8d5818
11 changed files with 297 additions and 94 deletions

View file

@ -58,9 +58,7 @@ async def completion_request(
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
config.developer.get("disable_request_streaming"), False
)
disable_request_streaming = config.developer.disable_request_streaming
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
@ -117,9 +115,7 @@ async def chat_completion_request(
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
disable_request_streaming = unwrap(
config.developer.get("disable_request_streaming"), False
)
disable_request_streaming = config.developer.disable_request_streaming
if data.stream and not disable_request_streaming:
return EventSourceResponse(

View file

@ -62,17 +62,17 @@ async def list_models(request: Request) -> ModelList:
Requires an admin key to see all models.
"""
model_dir = unwrap(config.model.get("model_dir"), "models")
model_dir = config.model.model_dir
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model.get("draft_model_dir")
draft_model_dir = config.draft_model.draft_model_dir
if get_key_permission(request) == "admin":
models = get_model_list(model_path.resolve(), draft_model_dir)
else:
models = await get_current_model_list()
if unwrap(config.model.get("use_dummy_models"), False):
if config.model.use_dummy_models:
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
@ -98,7 +98,7 @@ async def list_draft_models(request: Request) -> ModelList:
"""
if get_key_permission(request) == "admin":
draft_model_dir = unwrap(config.draft_model.get("draft_model_dir"), "models")
draft_model_dir = config.draft_model.draft_model_dir
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
@ -122,7 +122,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model.get("model_dir"), "models"))
model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / data.name
draft_model_path = None
@ -135,7 +135,7 @@ async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
raise HTTPException(400, error_message)
draft_model_path = unwrap(config.draft_model.get("draft_model_dir"), "models")
draft_model_path = config.draft_model.draft_model_dir
if not model_path.exists():
error_message = handle_request_error(
@ -192,7 +192,7 @@ async def list_all_loras(request: Request) -> LoraList:
"""
if get_key_permission(request) == "admin":
lora_path = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
lora_path = pathlib.Path(config.lora.lora_dir)
loras = get_lora_list(lora_path.resolve())
else:
loras = get_active_loras()
@ -227,7 +227,7 @@ async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora.get("lora_dir"), "loras"))
lora_dir = pathlib.Path(config.lora.lora_dir)
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
@ -266,9 +266,7 @@ async def list_embedding_models(request: Request) -> ModelList:
"""
if get_key_permission(request) == "admin":
embedding_model_dir = unwrap(
config.embeddings.get("embedding_model_dir"), "models"
)
embedding_model_dir = config.embeddings.embedding_model_dir
embedding_model_path = pathlib.Path(embedding_model_dir)
models = get_model_list(embedding_model_path.resolve())
@ -302,9 +300,7 @@ async def load_embedding_model(
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(
unwrap(config.embeddings.get("embedding_model_dir"), "models")
)
embedding_model_dir = pathlib.Path(config.embeddings.embedding_model_dir)
embedding_model_path = embedding_model_dir / data.name
if not embedding_model_path.exists():

View file

@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, ConfigDict
from time import time
from typing import List, Literal, Optional, Union
from common.gen_logging import GenLogPreferences
from common.config_models import logging_config_model
from common.model import get_config_default
@ -33,7 +33,7 @@ class ModelCard(BaseModel):
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
logging: Optional[GenLogPreferences] = None
logging: Optional[logging_config_model] = None
parameters: Optional[ModelCardParameters] = None

View file

@ -36,7 +36,7 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
allow_headers=["*"],
)
api_servers = unwrap(config.network.get("api_servers"), [])
api_servers = config.network.api_servers
# Map for API id to server router
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}