API: Add template list endpoint

Fetches all template names that a user has in the templates directory
for chat completions.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-29 22:56:47 -05:00
parent dce8c74edc
commit 79a57588d5
3 changed files with 28 additions and 3 deletions

9
OAI/types/template.py Normal file
View file

@ -0,0 +1,9 @@
from pydantic import BaseModel, Field
from typing import List
class TemplateList(BaseModel):
"""Represents a list of templates."""
object: str = "list"
data: List[str] = Field(default_factory=list)

11
main.py
View file

@ -32,6 +32,7 @@ from OAI.types.model import (
ModelLoadResponse,
ModelCardParameters,
)
from OAI.types.template import TemplateList
from OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
@ -45,7 +46,7 @@ from OAI.utils_oai import (
create_chat_completion_response,
create_chat_completion_stream_chunk,
)
from templating import get_prompt_from_template
from templating import get_all_templates, get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
from logger import init_logger
@ -244,6 +245,14 @@ async def unload_model():
MODEL_CONTAINER = None
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates():
templates = get_all_templates()
template_strings = list(map(lambda template: template.stem, templates))
return TemplateList(data=template_strings)
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])

View file

@ -57,11 +57,18 @@ def _compile_template(template: str):
return jinja_template
def get_all_templates():
"""Fetches all templates from the templates directory"""
template_directory = pathlib.Path("templates")
return template_directory.glob("*.jinja")
def find_template_from_model(model_path: pathlib.Path):
"""Find a matching template name from a model path."""
model_name = model_path.name
template_directory = pathlib.Path("templates")
for filepath in template_directory.glob("*.jinja"):
template_files = get_all_templates()
for filepath in template_files:
template_name = filepath.stem.lower()
# Check if the template name is present in the model name