From 79a57588d5080da97960469f7442e2aff28ffed5 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 29 Dec 2023 22:56:47 -0500 Subject: [PATCH] API: Add template list endpoint Fetches all template names that a user has in the templates directory for chat completions. Signed-off-by: kingbri --- OAI/types/template.py | 9 +++++++++ main.py | 11 ++++++++++- templating.py | 11 +++++++++-- 3 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 OAI/types/template.py diff --git a/OAI/types/template.py b/OAI/types/template.py new file mode 100644 index 0000000..0374547 --- /dev/null +++ b/OAI/types/template.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel, Field +from typing import List + + +class TemplateList(BaseModel): + """Represents a list of templates.""" + + object: str = "list" + data: List[str] = Field(default_factory=list) diff --git a/main.py b/main.py index 2f8484e..fc7c5d3 100644 --- a/main.py +++ b/main.py @@ -32,6 +32,7 @@ from OAI.types.model import ( ModelLoadResponse, ModelCardParameters, ) +from OAI.types.template import TemplateList from OAI.types.token import ( TokenEncodeRequest, TokenEncodeResponse, @@ -45,7 +46,7 @@ from OAI.utils_oai import ( create_chat_completion_response, create_chat_completion_stream_chunk, ) -from templating import get_prompt_from_template +from templating import get_all_templates, get_prompt_from_template from utils import get_generator_error, get_sse_packet, load_progress, unwrap from logger import init_logger @@ -244,6 +245,14 @@ async def unload_model(): MODEL_CONTAINER = None +@app.get("/v1/templates", dependencies=[Depends(check_api_key)]) +@app.get("/v1/template/list", dependencies=[Depends(check_api_key)]) +async def get_templates(): + templates = get_all_templates() + template_strings = list(map(lambda template: template.stem, templates)) + return TemplateList(data=template_strings) + + # Lora list endpoint @app.get("/v1/loras", dependencies=[Depends(check_api_key)]) @app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) diff --git a/templating.py b/templating.py index fb4f030..ddc0ca1 100644 --- a/templating.py +++ b/templating.py @@ -57,11 +57,18 @@ def _compile_template(template: str): return jinja_template +def get_all_templates(): + """Fetches all templates from the templates directory""" + + template_directory = pathlib.Path("templates") + return template_directory.glob("*.jinja") + + def find_template_from_model(model_path: pathlib.Path): """Find a matching template name from a model path.""" model_name = model_path.name - template_directory = pathlib.Path("templates") - for filepath in template_directory.glob("*.jinja"): + template_files = get_all_templates() + for filepath in template_files: template_name = filepath.stem.lower() # Check if the template name is present in the model name