API: Add route for draft model list

Does the same thing as model list except with draft models.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-19 23:45:53 -05:00
parent ce2602df9a
commit c9e43e51aa
2 changed files with 16 additions and 6 deletions

View file

@ -77,9 +77,9 @@ def create_chat_completion_stream_chunk(const_id: str,
return chunk
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
# Convert the draft model path to a pathlib path for equality comparisons
# Convert the provided draft model path to a pathlib path for equality comparisons
if draft_model_path:
draft_model_path = pathlib.Path(draft_model_path).resolve()

18
main.py
View file

@ -57,10 +57,8 @@ app.add_middleware(
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
model_config = unwrap(config.get("model"), {})
if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"])
else:
model_path = pathlib.Path("models")
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = draft_config.get("draft_model_dir")
@ -102,6 +100,18 @@ async def get_current_model():
return model_card
@app.get("/v1/model/draft/list")
async def list_draft_models():
model_config = unwrap(config.get("model"), {})
draft_config = unwrap(model_config.get("draft"), {})
draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models")
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
print(models)
return models
# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(request: Request, data: ModelLoadRequest):