144 lines
5.5 KiB
Python
144 lines
5.5 KiB
Python
import uvicorn
|
|
import yaml
|
|
import pathlib
|
|
from auth import check_admin_key, check_api_key, load_auth_keys
|
|
from fastapi import FastAPI, Request, HTTPException, Depends
|
|
from model import ModelContainer
|
|
from progress.bar import IncrementalBar
|
|
from sse_starlette import EventSourceResponse
|
|
from OAI.types.completion import CompletionRequest
|
|
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
|
|
from OAI.types.token import TokenEncodeRequest, TokenEncodeResponse, TokenDecodeRequest, TokenDecodeResponse
|
|
from OAI.utils import create_completion_response, get_model_list
|
|
from typing import Optional
|
|
from utils import load_progress
|
|
|
|
app = FastAPI()
|
|
|
|
# Globally scoped variables. Undefined until initalized in main
|
|
model_container: Optional[ModelContainer] = None
|
|
config: Optional[dict] = None
|
|
|
|
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
|
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
|
async def list_models():
|
|
model_config = config["model"]
|
|
models = get_model_list(pathlib.Path(model_config["model_dir"] or "models"))
|
|
|
|
return models.model_dump_json()
|
|
|
|
@app.get("/v1/model", dependencies=[Depends(check_api_key)])
|
|
async def get_current_model():
|
|
if model_container is None or model_container.model is None:
|
|
return HTTPException(400, "No models are loaded.")
|
|
|
|
model_card = ModelCard(id=model_container.get_model_path().name)
|
|
return model_card.model_dump_json()
|
|
|
|
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
|
async def load_model(data: ModelLoadRequest):
|
|
if model_container and model_container.model:
|
|
raise HTTPException(400, "A model is already loaded! Please unload it first.")
|
|
|
|
def generator():
|
|
global model_container
|
|
model_config = config["model"]
|
|
model_path = pathlib.Path(model_config["model_dir"] or "models")
|
|
model_path = model_path / data.name
|
|
|
|
model_container = ModelContainer(model_path, False, **data.model_dump())
|
|
load_status = model_container.load_gen(load_progress)
|
|
for (module, modules) in load_status:
|
|
if module == 0:
|
|
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
|
elif module == modules:
|
|
loading_bar.next()
|
|
loading_bar.finish()
|
|
else:
|
|
loading_bar.next()
|
|
|
|
yield ModelLoadResponse(
|
|
module=module,
|
|
modules=modules,
|
|
status="processing"
|
|
).model_dump_json()
|
|
|
|
yield ModelLoadResponse(
|
|
module=module,
|
|
modules=modules,
|
|
status="finished"
|
|
).model_dump_json()
|
|
|
|
return EventSourceResponse(generator())
|
|
|
|
@app.get("/v1/model/unload", dependencies=[Depends(check_admin_key)])
|
|
async def unload_model():
|
|
global model_container
|
|
|
|
if model_container is None:
|
|
raise HTTPException(400, "No models are loaded.")
|
|
|
|
model_container.unload()
|
|
model_container = None
|
|
|
|
@app.post("/v1/token/encode", dependencies=[Depends(check_api_key)])
|
|
async def encode_tokens(data: TokenEncodeRequest):
|
|
tokens = model_container.get_tokens(data.text, None, **data.get_params())[0].tolist()
|
|
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
|
|
|
return response.model_dump_json()
|
|
|
|
@app.post("/v1/token/decode", dependencies=[Depends(check_api_key)])
|
|
async def decode_tokens(data: TokenDecodeRequest):
|
|
message = model_container.get_tokens(None, data.tokens, **data.get_params())
|
|
response = TokenDecodeResponse(text=message)
|
|
|
|
return response.model_dump_json()
|
|
|
|
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
|
|
async def generate_completion(request: Request, data: CompletionRequest):
|
|
if data.stream:
|
|
async def generator():
|
|
new_generation = model_container.generate_gen(**data.to_gen_params())
|
|
for index, part in enumerate(new_generation):
|
|
if await request.is_disconnected():
|
|
break
|
|
|
|
response = create_completion_response(part, index, model_container.get_model_path().name)
|
|
|
|
yield response.model_dump_json()
|
|
|
|
return EventSourceResponse(generator())
|
|
else:
|
|
response_text = model_container.generate(**data.to_gen_params())
|
|
response = create_completion_response(response_text, 0, model_container.get_model_path().name)
|
|
|
|
return response.model_dump_json()
|
|
|
|
if __name__ == "__main__":
|
|
# Initialize auth keys
|
|
load_auth_keys()
|
|
|
|
# Load from YAML config. Possibly add a config -> kwargs conversion function
|
|
with open('config.yml', 'r') as config_file:
|
|
config = yaml.safe_load(config_file)
|
|
|
|
# If an initial model name is specified, create a container and load the model
|
|
model_config = config["model"]
|
|
if model_config["model_name"]:
|
|
model_path = pathlib.Path(model_config["model_dir"] or "models")
|
|
model_path = model_path / model_config["model_name"]
|
|
|
|
model_container = ModelContainer(model_path, False, **model_config)
|
|
load_status = model_container.load_gen(load_progress)
|
|
for (module, modules) in load_status:
|
|
if module == 0:
|
|
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
|
elif module == modules:
|
|
loading_bar.next()
|
|
loading_bar.finish()
|
|
else:
|
|
loading_bar.next()
|
|
|
|
network_config = config["network"]
|
|
uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug")
|