OAI: Add models support

The models endpoint fetches all the models that OAI has to offer.
However, since this is an OAI clone, just list the models inside
the user's configured model directory instead.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-13 21:38:34 -05:00
parent eee8b642bd
commit 47343e2f1a
4 changed files with 47 additions and 12 deletions

13
OAI/models/models.py Normal file
View file

@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from time import time
from typing import List
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time()))
owned_by: str = "tabbyAPI"
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)

View file

@ -1,5 +1,7 @@
import pathlib
from OAI.models.completions import CompletionResponse, CompletionRespChoice
from OAI.models.common import UsageStats
from OAI.models.models import ModelList, ModelCard
from typing import Optional
def create_completion_response(text: str, index: int, model_name: Optional[str]):
@ -17,3 +19,12 @@ def create_completion_response(text: str, index: int, model_name: Optional[str])
)
return response
def get_model_list(model_path: pathlib.Path):
model_card_list = ModelList()
for path in model_path.parent.iterdir():
if path.is_dir():
model_card = ModelCard(id = path.name)
model_card_list.data.append(model_card)
return model_card_list

22
main.py
View file

@ -4,15 +4,27 @@ from fastapi import FastAPI, Request
from model import ModelContainer
from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice
from OAI.utils import create_completion_response
from OAI.models.completions import CompletionRequest, CompletionResponse
from OAI.models.models import ModelCard, ModelList
from OAI.utils import create_completion_response, get_model_list
app = FastAPI()
# Initialize a model container. This can be undefined at any period of time
model_container: ModelContainer = None
@app.post("/v1/completions")
@app.get("/v1/models")
@app.get("/v1/model/list")
async def list_models():
models = get_model_list(model_container.get_model_path())
return models.model_dump_json()
@app.get("/v1/model")
async def get_current_model():
return ModelCard(id = model_container.get_model_path().name)
@app.post("/v1/completions", response_class=CompletionResponse)
async def generate_completion(request: Request, data: CompletionRequest):
if data.stream:
async def generator():
@ -21,14 +33,14 @@ async def generate_completion(request: Request, data: CompletionRequest):
if await request.is_disconnected():
break
response = create_completion_response(part, index, model_container.get_model_name())
response = create_completion_response(part, index, model_container.get_model_path().name)
yield response.model_dump_json()
return EventSourceResponse(generator())
else:
response_text = model_container.generate(**data.to_gen_params())
response = create_completion_response(response_text, 0, model_container.get_model_name())
response = create_completion_response(response_text, 0, model_container.get_model_path().name)
return response.model_dump_json()

View file

@ -1,4 +1,4 @@
import gc, time
import gc, time, pathlib
import torch
from exllamav2 import(
ExLlamaV2,
@ -11,7 +11,6 @@ from exllamav2.generator import(
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
)
from os import path
from typing import Optional
# Bytes to reserve on first device when loading with auto split
@ -102,11 +101,11 @@ class ModelContainer:
self.draft_config.max_input_len = kwargs["chunk_size"]
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
def get_model_name(self):
if self.draft_enabled:
return path.basename(path.normpath(self.draft_config.model_dir))
else:
return path.basename(path.normpath(self.config.model_dir))
def get_model_path(self):
model_path = pathlib.Path(self.draft_config.model_dir if self.draft_enabled else self.config.model_dir)
return model_path
def load(self, progress_callback = None):
"""