API: Don't include draft directory in response
The draft directory should be returned for a draft model request (TBD). Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
13c9c09398
commit
d47c39da54
3 changed files with 19 additions and 7 deletions
13
OAI/utils.py
13
OAI/utils.py
|
|
@ -1,4 +1,4 @@
|
|||
import pathlib
|
||||
import os, pathlib
|
||||
from OAI.types.completion import CompletionResponse, CompletionRespChoice
|
||||
from OAI.types.chat_completion import (
|
||||
ChatCompletionMessage,
|
||||
|
|
@ -76,10 +76,17 @@ def create_chat_completion_stream_chunk(const_id: str, text: str, model_name: Op
|
|||
|
||||
return chunk
|
||||
|
||||
def get_model_list(model_path: pathlib.Path):
|
||||
def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str]):
|
||||
|
||||
# Convert the draft model path to a pathlib path for equality comparisons
|
||||
if draft_model_path:
|
||||
draft_model_path = pathlib.Path(draft_model_path).resolve()
|
||||
|
||||
model_card_list = ModelList()
|
||||
for path in model_path.iterdir():
|
||||
if path.is_dir():
|
||||
|
||||
# Don't include the draft models path
|
||||
if path.is_dir() and path != draft_model_path:
|
||||
model_card = ModelCard(id = path.name)
|
||||
model_card_list.data.append(model_card)
|
||||
|
||||
|
|
|
|||
11
main.py
11
main.py
|
|
@ -50,13 +50,16 @@ app.add_middleware(
|
|||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
model_config = config.get("model", {})
|
||||
model_config = config.get("model") or {}
|
||||
if "model_dir" in model_config:
|
||||
model_path = pathlib.Path(model_config["model_dir"])
|
||||
else:
|
||||
model_path = pathlib.Path("models")
|
||||
|
||||
models = get_model_list(model_path.resolve())
|
||||
draft_config = model_config.get("draft") or {}
|
||||
draft_model_dir = draft_config.get("draft_model_dir")
|
||||
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
|
||||
return models
|
||||
|
||||
|
|
@ -64,7 +67,9 @@ async def list_models():
|
|||
@app.get("/v1/model", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_current_model():
|
||||
model_card = ModelCard(id=model_container.get_model_path().name)
|
||||
model_name = model_container.get_model_path().name
|
||||
model_card = ModelCard(id = model_name)
|
||||
|
||||
return model_card
|
||||
|
||||
# Load model endpoint
|
||||
|
|
|
|||
2
model.py
2
model.py
|
|
@ -82,7 +82,7 @@ class ModelContainer:
|
|||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attn_size = chunk_size ** 2
|
||||
|
||||
draft_config = kwargs.get("draft") or {}
|
||||
draft_config = kwargs.get("draft") or {}
|
||||
draft_model_name = draft_config.get("draft_model_name")
|
||||
enable_draft = bool(draft_config) and draft_model_name is not None
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue