API: Add template list endpoint
Fetches all template names that a user has in the templates directory for chat completions. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
dce8c74edc
commit
79a57588d5
3 changed files with 28 additions and 3 deletions
9
OAI/types/template.py
Normal file
9
OAI/types/template.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List
|
||||
|
||||
|
||||
class TemplateList(BaseModel):
|
||||
"""Represents a list of templates."""
|
||||
|
||||
object: str = "list"
|
||||
data: List[str] = Field(default_factory=list)
|
||||
11
main.py
11
main.py
|
|
@ -32,6 +32,7 @@ from OAI.types.model import (
|
|||
ModelLoadResponse,
|
||||
ModelCardParameters,
|
||||
)
|
||||
from OAI.types.template import TemplateList
|
||||
from OAI.types.token import (
|
||||
TokenEncodeRequest,
|
||||
TokenEncodeResponse,
|
||||
|
|
@ -45,7 +46,7 @@ from OAI.utils_oai import (
|
|||
create_chat_completion_response,
|
||||
create_chat_completion_stream_chunk,
|
||||
)
|
||||
from templating import get_prompt_from_template
|
||||
from templating import get_all_templates, get_prompt_from_template
|
||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||
from logger import init_logger
|
||||
|
||||
|
|
@ -244,6 +245,14 @@ async def unload_model():
|
|||
MODEL_CONTAINER = None
|
||||
|
||||
|
||||
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_templates():
|
||||
templates = get_all_templates()
|
||||
template_strings = list(map(lambda template: template.stem, templates))
|
||||
return TemplateList(data=template_strings)
|
||||
|
||||
|
||||
# Lora list endpoint
|
||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
|
|
|
|||
|
|
@ -57,11 +57,18 @@ def _compile_template(template: str):
|
|||
return jinja_template
|
||||
|
||||
|
||||
def get_all_templates():
|
||||
"""Fetches all templates from the templates directory"""
|
||||
|
||||
template_directory = pathlib.Path("templates")
|
||||
return template_directory.glob("*.jinja")
|
||||
|
||||
|
||||
def find_template_from_model(model_path: pathlib.Path):
|
||||
"""Find a matching template name from a model path."""
|
||||
model_name = model_path.name
|
||||
template_directory = pathlib.Path("templates")
|
||||
for filepath in template_directory.glob("*.jinja"):
|
||||
template_files = get_all_templates()
|
||||
for filepath in template_files:
|
||||
template_name = filepath.stem.lower()
|
||||
|
||||
# Check if the template name is present in the model name
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue