tabbyAPI-ollama/main.py
kingbri eee8b642bd OAI: Implement completion API endpoint
Add support for /v1/completions with the option to use streaming
if needed. Also rewrite API endpoints to use async when possible
since that improves request performance.

Model container parameter names also needed rewrites as well and
set fallback cases to their disabled values.

Signed-off-by: kingbri <bdashore3@proton.me>
2023-11-13 18:31:26 -05:00

62 lines
2.3 KiB
Python

import uvicorn
import yaml
from fastapi import FastAPI, Request
from model import ModelContainer
from progress.bar import IncrementalBar
from sse_starlette import EventSourceResponse
from OAI.models.completions import CompletionRequest, CompletionResponse, CompletionRespChoice
from OAI.utils import create_completion_response
app = FastAPI()
# Initialize a model container. This can be undefined at any period of time
model_container: ModelContainer = None
@app.post("/v1/completions")
async def generate_completion(request: Request, data: CompletionRequest):
if data.stream:
async def generator():
new_generation = model_container.generate_gen(**data.to_gen_params())
for index, part in enumerate(new_generation):
if await request.is_disconnected():
break
response = create_completion_response(part, index, model_container.get_model_name())
yield response.model_dump_json()
return EventSourceResponse(generator())
else:
response_text = model_container.generate(**data.to_gen_params())
response = create_completion_response(response_text, 0, model_container.get_model_name())
return response.model_dump_json()
# Wrapper callback for load progress
def load_progress(module, modules):
yield module, modules
if __name__ == "__main__":
# Load from YAML config. Possibly add a config -> kwargs conversion function
with open('config.yml', 'r') as config_file:
config = yaml.safe_load(config_file)
# If an initial model name is specified, create a container and load the model
if config["model_name"]:
model_path = f"{config['model_dir']}/{config['model_name']}" if config['model_dir'] else f"models/{config['model_name']}"
model_container = ModelContainer(model_path, False, **config)
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
else:
loading_bar.next()
if module == modules:
loading_bar.finish()
print("Model successfully loaded.")
uvicorn.run(app, host="0.0.0.0", port=8012, log_level="debug")