From d47c39da5493eeb54b403a5a3b78db76088f99ec Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 23 Nov 2023 00:07:56 -0500 Subject: [PATCH] API: Don't include draft directory in response The draft directory should be returned for a draft model request (TBD). Signed-off-by: kingbri --- OAI/utils.py | 13 ++++++++++--- main.py | 11 ++++++++--- model.py | 2 +- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/OAI/utils.py b/OAI/utils.py index 769991b..e97f1b7 100644 --- a/OAI/utils.py +++ b/OAI/utils.py @@ -1,4 +1,4 @@ -import pathlib +import os, pathlib from OAI.types.completion import CompletionResponse, CompletionRespChoice from OAI.types.chat_completion import ( ChatCompletionMessage, @@ -76,10 +76,17 @@ def create_chat_completion_stream_chunk(const_id: str, text: str, model_name: Op return chunk -def get_model_list(model_path: pathlib.Path): +def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]): + + # Convert the draft model path to a pathlib path for equality comparisons + if draft_model_path: + draft_model_path = pathlib.Path(draft_model_path).resolve() + model_card_list = ModelList() for path in model_path.iterdir(): - if path.is_dir(): + + # Don't include the draft models path + if path.is_dir() and path != draft_model_path: model_card = ModelCard(id = path.name) model_card_list.data.append(model_card) diff --git a/main.py b/main.py index 5634561..b04d684 100644 --- a/main.py +++ b/main.py @@ -50,13 +50,16 @@ app.add_middleware( @app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): - model_config = config.get("model", {}) + model_config = config.get("model") or {} if "model_dir" in model_config: model_path = pathlib.Path(model_config["model_dir"]) else: model_path = pathlib.Path("models") - models = get_model_list(model_path.resolve()) + draft_config = model_config.get("draft") or {} + draft_model_dir = draft_config.get("draft_model_dir") + + models = get_model_list(model_path.resolve(), draft_model_dir) return models @@ -64,7 +67,9 @@ async def list_models(): @app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) @app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)]) async def get_current_model(): - model_card = ModelCard(id=model_container.get_model_path().name) + model_name = model_container.get_model_path().name + model_card = ModelCard(id = model_name) + return model_card # Load model endpoint diff --git a/model.py b/model.py index 604a314..1ebbdf1 100644 --- a/model.py +++ b/model.py @@ -82,7 +82,7 @@ class ModelContainer: self.config.max_input_len = chunk_size self.config.max_attn_size = chunk_size ** 2 - draft_config = kwargs.get("draft") or {} + draft_config = kwargs.get("draft") or {} draft_model_name = draft_config.get("draft_model_name") enable_draft = bool(draft_config) and draft_model_name is not None