enabled api routes to work with open-webui
This commit is contained in:
parent
1d3a308709
commit
1fd4a9e119
6 changed files with 333 additions and 1 deletions
1
.webui_secret_key
Normal file
1
.webui_secret_key
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
XkPmjle3dN2r0iZ3
|
||||||
|
|
@ -1,4 +1,8 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
|
from fastapi import Response
|
||||||
|
from datetime import datetime
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sse_starlette import EventSourceResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from sys import maxsize
|
from sys import maxsize
|
||||||
|
|
@ -13,11 +17,13 @@ from endpoints.OAI.types.chat_completion import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
)
|
)
|
||||||
|
import os
|
||||||
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
|
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
|
||||||
from endpoints.OAI.utils.chat_completion import (
|
from endpoints.OAI.utils.chat_completion import (
|
||||||
apply_chat_template,
|
apply_chat_template,
|
||||||
generate_chat_completion,
|
generate_chat_completion,
|
||||||
stream_generate_chat_completion,
|
stream_generate_chat_completion,
|
||||||
|
stream_generate_chat_completion_ollama,
|
||||||
)
|
)
|
||||||
from endpoints.OAI.utils.completion import (
|
from endpoints.OAI.utils.completion import (
|
||||||
generate_completion,
|
generate_completion,
|
||||||
|
|
@ -165,3 +171,129 @@ async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsRes
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
class ModelItem(BaseModel):
|
||||||
|
model: str
|
||||||
|
name: str
|
||||||
|
digest: str
|
||||||
|
urls: List[int]
|
||||||
|
|
||||||
|
class ModelListResponse(BaseModel):
|
||||||
|
object: str = Field("list", description="Type of the response object.")
|
||||||
|
models: List[ModelItem]
|
||||||
|
|
||||||
|
async def fetch_models():
|
||||||
|
models_dir = "models"
|
||||||
|
models = []
|
||||||
|
# Iterate over the files in the models directory
|
||||||
|
if os.path.exists(models_dir):
|
||||||
|
for model in os.listdir(models_dir):
|
||||||
|
model_path = os.path.join(models_dir, model)
|
||||||
|
if os.path.isdir(model_path): # Assuming each model is in its own directory
|
||||||
|
digest = hashlib.md5(model.encode()).hexdigest()
|
||||||
|
models.append({
|
||||||
|
"model":f"{model}:latest",
|
||||||
|
"name":f"{model}:latest",
|
||||||
|
"digest":digest,
|
||||||
|
"urls":[0]
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
print(f"Models directory {models_dir} does not exist.")
|
||||||
|
return ModelListResponse(models=models)
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/ollama/api/version",
|
||||||
|
dependencies=[Depends(check_api_key)]
|
||||||
|
)
|
||||||
|
async def dummy2(request: Request):
|
||||||
|
return {"version": "1.0"}
|
||||||
|
@router.get(
|
||||||
|
"/api/version",
|
||||||
|
dependencies=[Depends(check_api_key)]
|
||||||
|
)
|
||||||
|
async def dummy(request: Request):
|
||||||
|
return {"version": "1.0"}
|
||||||
|
|
||||||
|
# Models endpoint
|
||||||
|
@router.get(
|
||||||
|
"/api/tags",
|
||||||
|
dependencies=[Depends(check_api_key)]
|
||||||
|
)
|
||||||
|
async def get_all_models(request: Request) -> ModelListResponse:
|
||||||
|
print(f"Processing request for models {request.state.id}")
|
||||||
|
|
||||||
|
response = await run_with_request_disconnect(
|
||||||
|
request,
|
||||||
|
asyncio.create_task(fetch_models()),
|
||||||
|
disconnect_message=f"All models fetched",
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/api/chat",
|
||||||
|
dependencies=[Depends(check_api_key)],
|
||||||
|
)
|
||||||
|
async def chat_completion_request_ollama(
|
||||||
|
request: Request, data: ChatCompletionRequest
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generates a chat completion from a prompt.
|
||||||
|
|
||||||
|
If stream = true, this returns an SSE stream.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if data.model:
|
||||||
|
await load_inline_model(data.model, request)
|
||||||
|
else:
|
||||||
|
await check_model_container()
|
||||||
|
|
||||||
|
if model.container.prompt_template is None:
|
||||||
|
error_message = handle_request_error(
|
||||||
|
"Chat completions are disabled because a prompt template is not set.",
|
||||||
|
exc_info=False,
|
||||||
|
).error.message
|
||||||
|
|
||||||
|
raise HTTPException(422, error_message)
|
||||||
|
|
||||||
|
model_path = model.container.model_dir
|
||||||
|
|
||||||
|
if isinstance(data.messages, str):
|
||||||
|
prompt = data.messages
|
||||||
|
else:
|
||||||
|
prompt = await format_prompt_with_template(data)
|
||||||
|
|
||||||
|
# Set an empty JSON schema if the request wants a JSON response
|
||||||
|
if data.response_format.type == "json":
|
||||||
|
data.json_schema = {"type": "object"}
|
||||||
|
|
||||||
|
disable_request_streaming = config.developer.disable_request_streaming
|
||||||
|
|
||||||
|
async def stream_response(request: Request):
|
||||||
|
async for chunk in stream_generate_chat_completion_ollama(prompt, data, request, model_path):
|
||||||
|
yield json.dumps(chunk).encode('utf-8') + b'\n'
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if data.stream and not disable_request_streaming:
|
||||||
|
return StreamingResponse(stream_response(request), media_type="application/x-ndjson")
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
generate_task = asyncio.create_task(
|
||||||
|
generate_chat_completion(prompt, data, request, model_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await run_with_request_disconnect(
|
||||||
|
request,
|
||||||
|
generate_task,
|
||||||
|
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -98,3 +98,20 @@ class ChatCompletionStreamChunk(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
object: str = "chat.completion.chunk"
|
object: str = "chat.completion.chunk"
|
||||||
usage: Optional[UsageStats] = None
|
usage: Optional[UsageStats] = None
|
||||||
|
|
||||||
|
class ChatCompletionResponseOllama(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||||
|
choices: List[ChatCompletionRespChoice]
|
||||||
|
created: int = Field(default_factory=lambda: int(time()))
|
||||||
|
model: str
|
||||||
|
object: str = "chat.completion"
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionStreamChunkOllama(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||||
|
choices: List[ChatCompletionStreamChoice]
|
||||||
|
created: int = Field(default_factory=lambda: int(time()))
|
||||||
|
model: str
|
||||||
|
object: str = "chat.completion.chunk"
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ class CompletionResponseFormat(BaseModel):
|
||||||
|
|
||||||
class ChatCompletionStreamOptions(BaseModel):
|
class ChatCompletionStreamOptions(BaseModel):
|
||||||
include_usage: Optional[bool] = False
|
include_usage: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class CommonCompletionRequest(BaseSamplerRequest):
|
class CommonCompletionRequest(BaseSamplerRequest):
|
||||||
"""Represents a common completion request."""
|
"""Represents a common completion request."""
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ import asyncio
|
||||||
import pathlib
|
import pathlib
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
import json
|
||||||
|
from datetime import datetime
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from jinja2 import TemplateError
|
from jinja2 import TemplateError
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
@ -109,6 +111,82 @@ def _create_response(
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _create_stream_chunk_ollama(
|
||||||
|
request_id: str,
|
||||||
|
generation: Optional[dict] = None,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
is_usage_chunk: bool = False,
|
||||||
|
):
|
||||||
|
"""Create a chat completion stream chunk from the provided text."""
|
||||||
|
|
||||||
|
index = generation.get("index")
|
||||||
|
choices = []
|
||||||
|
usage_stats = None
|
||||||
|
|
||||||
|
if is_usage_chunk:
|
||||||
|
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||||
|
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||||
|
|
||||||
|
usage_stats = UsageStats(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
elif "finish_reason" in generation:
|
||||||
|
choice = ChatCompletionStreamChoice(
|
||||||
|
index=index,
|
||||||
|
finish_reason=generation.get("finish_reason"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# lets check if we have tool calls since we are at the end of the generation
|
||||||
|
if "tool_calls" in generation:
|
||||||
|
tool_calls = generation["tool_calls"]
|
||||||
|
message = ChatCompletionMessage(
|
||||||
|
tool_calls=postprocess_tool_call(tool_calls)
|
||||||
|
)
|
||||||
|
choice.delta = message
|
||||||
|
|
||||||
|
choices.append(choice)
|
||||||
|
|
||||||
|
else:
|
||||||
|
message = ChatCompletionMessage(
|
||||||
|
role="assistant", content=unwrap(generation.get("text"), "")
|
||||||
|
)
|
||||||
|
|
||||||
|
logprob_response = None
|
||||||
|
|
||||||
|
token_probs = unwrap(generation.get("token_probs"), {})
|
||||||
|
if token_probs:
|
||||||
|
logprobs = unwrap(generation.get("logprobs"), {})
|
||||||
|
top_logprobs = [
|
||||||
|
ChatCompletionLogprob(token=token, logprob=logprob)
|
||||||
|
for token, logprob in logprobs.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
generated_token = next(iter(token_probs))
|
||||||
|
token_prob_response = ChatCompletionLogprob(
|
||||||
|
token=generated_token,
|
||||||
|
logprob=token_probs[generated_token],
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
|
||||||
|
|
||||||
|
choice = ChatCompletionStreamChoice(
|
||||||
|
index=index,
|
||||||
|
delta=message,
|
||||||
|
logprobs=logprob_response,
|
||||||
|
)
|
||||||
|
ollama_bit = {
|
||||||
|
"model":model_name,
|
||||||
|
"created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z",
|
||||||
|
"message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none',
|
||||||
|
"content": choice.delta.content if hasattr(choice.delta, 'content') else 'none'},
|
||||||
|
"done_reason": choice.finish_reason,
|
||||||
|
"done": choice.finish_reason=="stop",
|
||||||
|
}
|
||||||
|
return ollama_bit
|
||||||
|
|
||||||
def _create_stream_chunk(
|
def _create_stream_chunk(
|
||||||
request_id: str,
|
request_id: str,
|
||||||
generation: Optional[dict] = None,
|
generation: Optional[dict] = None,
|
||||||
|
|
@ -307,6 +385,101 @@ async def apply_chat_template(data: ChatCompletionRequest):
|
||||||
raise HTTPException(400, error_message) from exc
|
raise HTTPException(400, error_message) from exc
|
||||||
|
|
||||||
|
|
||||||
|
async def stream_generate_chat_completion_ollama(
|
||||||
|
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||||
|
):
|
||||||
|
"""Generator for the generation process."""
|
||||||
|
abort_event = asyncio.Event()
|
||||||
|
gen_queue = asyncio.Queue()
|
||||||
|
gen_tasks: List[asyncio.Task] = []
|
||||||
|
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Received chat completion streaming request {request.state.id}")
|
||||||
|
|
||||||
|
gen_params = data.to_gen_params()
|
||||||
|
|
||||||
|
for n in range(0, data.n):
|
||||||
|
if n > 0:
|
||||||
|
task_gen_params = deepcopy(gen_params)
|
||||||
|
else:
|
||||||
|
task_gen_params = gen_params
|
||||||
|
|
||||||
|
gen_task = asyncio.create_task(
|
||||||
|
_stream_collector(
|
||||||
|
n,
|
||||||
|
gen_queue,
|
||||||
|
prompt,
|
||||||
|
request.state.id,
|
||||||
|
abort_event,
|
||||||
|
**task_gen_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_tasks.append(gen_task)
|
||||||
|
|
||||||
|
# We need to keep track of the text generated so we can resume the tool calls
|
||||||
|
current_generation_text = ""
|
||||||
|
|
||||||
|
# Consumer loop
|
||||||
|
while True:
|
||||||
|
if disconnect_task.done():
|
||||||
|
abort_event.set()
|
||||||
|
handle_request_disconnect(
|
||||||
|
f"Chat completion generation {request.state.id} cancelled by user."
|
||||||
|
)
|
||||||
|
|
||||||
|
generation = await gen_queue.get()
|
||||||
|
# lets only append the text if we need it for tool calls later
|
||||||
|
if data.tool_call_start and "text" in generation:
|
||||||
|
current_generation_text += generation["text"]
|
||||||
|
|
||||||
|
# check if we are running a tool model, and that we are at stop
|
||||||
|
if data.tool_call_start and "stop_str" in generation:
|
||||||
|
generations = await generate_tool_calls(
|
||||||
|
data,
|
||||||
|
[generation],
|
||||||
|
request,
|
||||||
|
current_generations=current_generation_text,
|
||||||
|
)
|
||||||
|
generation = generations[0] # We only have one generation in this case
|
||||||
|
|
||||||
|
# Stream collector will push an exception to the queue if it fails
|
||||||
|
if isinstance(generation, Exception):
|
||||||
|
raise generation
|
||||||
|
|
||||||
|
chunk = _create_stream_chunk_ollama(
|
||||||
|
request.state.id, generation, model_path.name
|
||||||
|
)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Check if all tasks are completed
|
||||||
|
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||||
|
# Send a usage chunk
|
||||||
|
if data.stream_options and data.stream_options.include_usage:
|
||||||
|
usage_chunk = _create_stream_chunk_ollama(
|
||||||
|
request.state.id,
|
||||||
|
generation,
|
||||||
|
model_path.name,
|
||||||
|
is_usage_chunk=True,
|
||||||
|
)
|
||||||
|
yield usage_chunk
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Finished chat completion streaming request {request.state.id}"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except CancelledError:
|
||||||
|
# Get out if the request gets disconnected
|
||||||
|
|
||||||
|
if not disconnect_task.done():
|
||||||
|
abort_event.set()
|
||||||
|
handle_request_disconnect("Chat completion generation cancelled by user.")
|
||||||
|
except Exception:
|
||||||
|
yield get_generator_error(
|
||||||
|
"Chat completion aborted. Please check the server console."
|
||||||
|
)
|
||||||
async def stream_generate_chat_completion(
|
async def stream_generate_chat_completion(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
embeddings: MultimodalEmbeddingWrapper,
|
embeddings: MultimodalEmbeddingWrapper,
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,7 @@ async def _stream_collector(
|
||||||
|
|
||||||
async def load_inline_model(model_name: str, request: Request):
|
async def load_inline_model(model_name: str, request: Request):
|
||||||
"""Load a model from the data.model parameter"""
|
"""Load a model from the data.model parameter"""
|
||||||
|
model_name = model_name.split(":")[0]
|
||||||
|
|
||||||
# Return if the model container already exists and the model is fully loaded
|
# Return if the model container already exists and the model is fully loaded
|
||||||
if (
|
if (
|
||||||
|
|
@ -175,6 +176,14 @@ async def load_inline_model(model_name: str, request: Request):
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if model.container is not None:
|
||||||
|
if model.container.model_dir.name != model_name:
|
||||||
|
logger.info(f"New model requested: {model_name}. Unloading current model.")
|
||||||
|
await model.unload_model()
|
||||||
|
elif model.container.model_loaded:
|
||||||
|
logger.info(f"Model {model_name} is already loaded.")
|
||||||
|
return
|
||||||
|
|
||||||
model_path = pathlib.Path(config.model.model_dir)
|
model_path = pathlib.Path(config.model.model_dir)
|
||||||
model_path = model_path / model_name
|
model_path = model_path / model_name
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue