Tree: Basic formatting and comments

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-16 11:48:30 -05:00
parent 5defb1b0b4
commit 60eb076b43
2 changed files with 21 additions and 3 deletions

View file

@ -1,7 +1,7 @@
from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Union
from typing import List, Optional, Union
from OAI.types.common import LogProbs, UsageStats, CommonCompletionRequest
class CompletionRespChoice(BaseModel):

22
main.py
View file

@ -9,7 +9,12 @@ from sse_starlette import EventSourceResponse
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse
from OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
TokenDecodeRequest,
TokenDecodeResponse
)
from OAI.utils import (
create_completion_response,
get_model_list,
@ -27,6 +32,7 @@ app = FastAPI()
model_container: Optional[ModelContainer] = None
config: dict = {}
# Model list endpoint
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
@ -40,6 +46,7 @@ async def list_models():
return models.json()
# Currently loaded model endpoint
@app.get("/v1/model", dependencies=[Depends(check_api_key)])
async def get_current_model():
if model_container is None or model_container.model is None:
@ -48,6 +55,7 @@ async def get_current_model():
model_card = ModelCard(id=model_container.get_model_path().name)
return model_card.json()
# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest):
if model_container and model_container.model:
@ -89,6 +97,7 @@ async def load_model(data: ModelLoadRequest):
return EventSourceResponse(generator())
# Unload model endpoint
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)])
async def unload_model():
global model_container
@ -99,6 +108,7 @@ async def unload_model():
model_container.unload()
model_container = None
# Encode tokens endpoint
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)])
async def encode_tokens(data: TokenEncodeRequest):
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
@ -106,6 +116,7 @@ async def encode_tokens(data: TokenEncodeRequest):
return response.json()
# Decode tokens endpoint
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)])
async def decode_tokens(data: TokenDecodeRequest):
message = model_container.get_tokens(None, data.tokens, **data.get_params())
@ -113,6 +124,7 @@ async def decode_tokens(data: TokenDecodeRequest):
return response.json()
# Completions endpoint
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
async def generate_completion(request: Request, data: CompletionRequest):
model_path = model_container.get_model_path()
@ -138,6 +150,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
return response.json()
# Chat completions endpoint
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
model_path = model_container.get_model_path()
@ -200,4 +213,9 @@ if __name__ == "__main__":
loading_bar.next()
network_config = config.get("network", {})
uvicorn.run(app, host=network_config.get("host", "127.0.0.1"), port=network_config.get("port", 5000), log_level="debug")
uvicorn.run(
app,
host=network_config.get("host", "127.0.0.1"),
port=network_config.get("port", 5000),
log_level="debug"
)