OAI: Add models support
The models endpoint fetches all the models that OAI has to offer. However, since this is an OAI clone, just list the models inside the user's configured model directory instead. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
eee8b642bd
commit
47343e2f1a
4 changed files with 47 additions and 12 deletions
22
main.py
22
main.py
|
|
@ -4,15 +4,27 @@ from fastapi import FastAPI, Request
|
|||
from model import ModelContainer
|
||||
from progress.bar import IncrementalBar
|
||||
from sse_starlette import EventSourceResponse
|
||||
from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice
|
||||
from OAI.utils import create_completion_response
|
||||
from OAI.models.completions import CompletionRequest, CompletionResponse
|
||||
from OAI.models.models import ModelCard, ModelList
|
||||
from OAI.utils import create_completion_response, get_model_list
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize a model container. This can be undefined at any period of time
|
||||
model_container: ModelContainer = None
|
||||
|
||||
@app.post("/v1/completions")
|
||||
@app.get("/v1/models")
|
||||
@app.get("/v1/model/list")
|
||||
async def list_models():
|
||||
models = get_model_list(model_container.get_model_path())
|
||||
|
||||
return models.model_dump_json()
|
||||
|
||||
@app.get("/v1/model")
|
||||
async def get_current_model():
|
||||
return ModelCard(id = model_container.get_model_path().name)
|
||||
|
||||
@app.post("/v1/completions", response_class=CompletionResponse)
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
if data.stream:
|
||||
async def generator():
|
||||
|
|
@ -21,14 +33,14 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_completion_response(part, index, model_container.get_model_name())
|
||||
response = create_completion_response(part, index, model_container.get_model_path().name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(**data.to_gen_params())
|
||||
response = create_completion_response(response_text, 0, model_container.get_model_name())
|
||||
response = create_completion_response(response_text, 0, model_container.get_model_path().name)
|
||||
|
||||
return response.model_dump_json()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue