API: Add route for draft model list
Does the same thing as model list except with draft models. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ce2602df9a
commit
c9e43e51aa
2 changed files with 16 additions and 6 deletions
|
|
@ -77,9 +77,9 @@ def create_chat_completion_stream_chunk(const_id: str,
|
|||
|
||||
return chunk
|
||||
|
||||
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):
|
||||
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = None):
|
||||
|
||||
# Convert the draft model path to a pathlib path for equality comparisons
|
||||
# Convert the provided draft model path to a pathlib path for equality comparisons
|
||||
if draft_model_path:
|
||||
draft_model_path = pathlib.Path(draft_model_path).resolve()
|
||||
|
||||
|
|
|
|||
18
main.py
18
main.py
|
|
@ -57,10 +57,8 @@ app.add_middleware(
|
|||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
if "model_dir" in model_config:
|
||||
model_path = pathlib.Path(model_config["model_dir"])
|
||||
else:
|
||||
model_path = pathlib.Path("models")
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
model_path = pathlib.Path(model_dir)
|
||||
|
||||
draft_config = unwrap(model_config.get("draft"), {})
|
||||
draft_model_dir = draft_config.get("draft_model_dir")
|
||||
|
|
@ -102,6 +100,18 @@ async def get_current_model():
|
|||
|
||||
return model_card
|
||||
|
||||
@app.get("/v1/model/draft/list")
|
||||
async def list_draft_models():
|
||||
model_config = unwrap(config.get("model"), {})
|
||||
draft_config = unwrap(model_config.get("draft"), {})
|
||||
draft_model_dir = unwrap(draft_config.get("draft_model_dir"), "models")
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
print(models)
|
||||
|
||||
return models
|
||||
|
||||
# Load model endpoint
|
||||
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_model(request: Request, data: ModelLoadRequest):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue