OAI: Add models support
The models endpoint fetches all the models that OAI has to offer. However, since this is an OAI clone, just list the models inside the user's configured model directory instead. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
eee8b642bd
commit
47343e2f1a
4 changed files with 47 additions and 12 deletions
13
OAI/models/models.py
Normal file
13
OAI/models/models.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time()))
|
||||
owned_by: str = "tabbyAPI"
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
11
OAI/utils.py
11
OAI/utils.py
|
|
@ -1,5 +1,7 @@
|
|||
import pathlib
|
||||
from OAI.models.completions import CompletionResponse, CompletionRespChoice
|
||||
from OAI.models.common import UsageStats
|
||||
from OAI.models.models import ModelList, ModelCard
|
||||
from typing import Optional
|
||||
|
||||
def create_completion_response(text: str, index: int, model_name: Optional[str]):
|
||||
|
|
@ -17,3 +19,12 @@ def create_completion_response(text: str, index: int, model_name: Optional[str])
|
|||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_model_list(model_path: pathlib.Path):
|
||||
model_card_list = ModelList()
|
||||
for path in model_path.parent.iterdir():
|
||||
if path.is_dir():
|
||||
model_card = ModelCard(id = path.name)
|
||||
model_card_list.data.append(model_card)
|
||||
|
||||
return model_card_list
|
||||
|
|
|
|||
22
main.py
22
main.py
|
|
@ -4,15 +4,27 @@ from fastapi import FastAPI, Request
|
|||
from model import ModelContainer
|
||||
from progress.bar import IncrementalBar
|
||||
from sse_starlette import EventSourceResponse
|
||||
from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice
|
||||
from OAI.utils import create_completion_response
|
||||
from OAI.models.completions import CompletionRequest, CompletionResponse
|
||||
from OAI.models.models import ModelCard, ModelList
|
||||
from OAI.utils import create_completion_response, get_model_list
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Initialize a model container. This can be undefined at any period of time
|
||||
model_container: ModelContainer = None
|
||||
|
||||
@app.post("/v1/completions")
|
||||
@app.get("/v1/models")
|
||||
@app.get("/v1/model/list")
|
||||
async def list_models():
|
||||
models = get_model_list(model_container.get_model_path())
|
||||
|
||||
return models.model_dump_json()
|
||||
|
||||
@app.get("/v1/model")
|
||||
async def get_current_model():
|
||||
return ModelCard(id = model_container.get_model_path().name)
|
||||
|
||||
@app.post("/v1/completions", response_class=CompletionResponse)
|
||||
async def generate_completion(request: Request, data: CompletionRequest):
|
||||
if data.stream:
|
||||
async def generator():
|
||||
|
|
@ -21,14 +33,14 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_completion_response(part, index, model_container.get_model_name())
|
||||
response = create_completion_response(part, index, model_container.get_model_path().name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
response_text = model_container.generate(**data.to_gen_params())
|
||||
response = create_completion_response(response_text, 0, model_container.get_model_name())
|
||||
response = create_completion_response(response_text, 0, model_container.get_model_path().name)
|
||||
|
||||
return response.model_dump_json()
|
||||
|
||||
|
|
|
|||
13
model.py
13
model.py
|
|
@ -1,4 +1,4 @@
|
|||
import gc, time
|
||||
import gc, time, pathlib
|
||||
import torch
|
||||
from exllamav2 import(
|
||||
ExLlamaV2,
|
||||
|
|
@ -11,7 +11,6 @@ from exllamav2.generator import(
|
|||
ExLlamaV2StreamingGenerator,
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
from os import path
|
||||
from typing import Optional
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
|
|
@ -102,11 +101,11 @@ class ModelContainer:
|
|||
self.draft_config.max_input_len = kwargs["chunk_size"]
|
||||
self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2
|
||||
|
||||
def get_model_name(self):
|
||||
if self.draft_enabled:
|
||||
return path.basename(path.normpath(self.draft_config.model_dir))
|
||||
else:
|
||||
return path.basename(path.normpath(self.config.model_dir))
|
||||
|
||||
def get_model_path(self):
|
||||
model_path = pathlib.Path(self.draft_config.model_dir if self.draft_enabled else self.config.model_dir)
|
||||
return model_path
|
||||
|
||||
|
||||
def load(self, progress_callback = None):
|
||||
"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue