tabbyAPI-ollama/endpoints/OAI/router.py
2025-08-27 16:47:55 +02:00

282 lines
No EOL
8.2 KiB
Python

import asyncio
import json
from fastapi import Response
from datetime import datetime
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from common import model
from common.auth import check_api_key
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.common import (
ModelItem, ModelListResponse
)
import hashlib
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
import os
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
apply_chat_template,
generate_chat_completion,
stream_generate_chat_completion,
stream_generate_chat_completion_ollama,
)
from endpoints.OAI.utils.completion import (
generate_completion,
load_inline_model,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import get_embeddings
api_name = "OAI"
router = APIRouter()
urls = {
"Completions": "http://{host}:{port}/v1/completions",
"Chat completions": "http://{host}:{port}/v1/chat/completions",
}
def setup():
return router
# Completions endpoint
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key)],
)
async def completion_request(
request: Request, data: CompletionRequest
) -> CompletionResponse:
"""
Generates a completion from a prompt.
If stream = true, this returns an SSE stream.
"""
if data.model:
inline_load_task = asyncio.create_task(load_inline_model(data.model, request))
await run_with_request_disconnect(
request,
inline_load_task,
disconnect_message=f"Model switch for generation {request.state.id} "
+ "cancelled by user.",
)
else:
await check_model_container()
model_path = model.container.model_dir
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
disable_request_streaming = config.developer.disable_request_streaming
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_completion(data, request, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_completion(data, request, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message=f"Completion {request.state.id} cancelled by user.",
)
return response
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key)],
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
) -> ChatCompletionResponse:
"""
Generates a chat completion from a prompt.
If stream = true, this returns an SSE stream.
"""
if data.model:
await load_inline_model(data.model, request)
else:
await check_model_container()
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
model_path = model.container.model_dir
prompt, embeddings = await apply_chat_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
disable_request_streaming = config.developer.disable_request_streaming
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(
prompt, embeddings, data, request, model_path
),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, embeddings, data, request, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response
# Embeddings endpoint
@router.post(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
embeddings_task = asyncio.create_task(get_embeddings(data, request))
response = await run_with_request_disconnect(
request,
embeddings_task,
f"Embeddings request {request.state.id} cancelled by user.",
)
return response
async def fetch_models():
models_dir = "models"
models = []
if os.path.exists(models_dir):
for model in os.listdir(models_dir):
model_path = os.path.join(models_dir, model)
if os.path.isdir(model_path):
digest = hashlib.md5(model.encode()).hexdigest()
models.append({
"model":f"{model}:latest",
"name":f"{model}:latest",
"digest":digest,
"urls":[0]
})
else:
print(f"Models directory {models_dir} does not exist.")
return ModelListResponse(models=models)
@router.get(
"/ollama/api/version",
dependencies=[Depends(check_api_key)]
)
async def dummy2(request: Request):
return {"version": "1.0"}
@router.get(
"/api/version",
dependencies=[Depends(check_api_key)]
)
async def dummy(request: Request):
return {"version": "1.0"}
# Models endpoint
@router.get(
"/api/tags",
dependencies=[Depends(check_api_key)]
)
async def get_all_models(request: Request) -> ModelListResponse:
response = await run_with_request_disconnect(
request,
asyncio.create_task(fetch_models()),
disconnect_message="All models fetched",
)
return response
@router.post(
"/api/chat",
dependencies=[Depends(check_api_key)],
)
async def chat_completion_request_ollama(
request: Request, data: ChatCompletionRequest
):
"""
Generates a chat completion from a prompt.
If stream = true, this returns an SSE stream.
"""
if data.model:
await load_inline_model(data.model, request)
else:
await check_model_container()
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
model_path = model.container.model_dir
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = await format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
disable_request_streaming = config.developer.disable_request_streaming
async def stream_response(request: Request):
async for chunk in stream_generate_chat_completion_ollama(prompt, data, request, model_path):
yield json.dumps(chunk).encode('utf-8') + b'\n'
if data.stream and not disable_request_streaming:
return StreamingResponse(stream_response(request), media_type="application/x-ndjson")
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response