enabled api routes to work with open-webui

This commit is contained in:
John 2024-10-12 20:14:57 -04:00 committed by Jakob Lechner
parent 1d3a308709
commit 1fd4a9e119
6 changed files with 333 additions and 1 deletions

1
.webui_secret_key Normal file
View file

@ -0,0 +1 @@
XkPmjle3dN2r0iZ3

View file

@ -1,4 +1,8 @@
import asyncio import asyncio
import json
from fastapi import Response
from datetime import datetime
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from sys import maxsize from sys import maxsize
@ -13,11 +17,13 @@ from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest, ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
) )
import os
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import ( from endpoints.OAI.utils.chat_completion import (
apply_chat_template, apply_chat_template,
generate_chat_completion, generate_chat_completion,
stream_generate_chat_completion, stream_generate_chat_completion,
stream_generate_chat_completion_ollama,
) )
from endpoints.OAI.utils.completion import ( from endpoints.OAI.utils.completion import (
generate_completion, generate_completion,
@ -165,3 +171,129 @@ async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsRes
) )
return response return response
from pydantic import BaseModel, Field
from typing import List, Optional
import hashlib
class ModelItem(BaseModel):
model: str
name: str
digest: str
urls: List[int]
class ModelListResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
models: List[ModelItem]
async def fetch_models():
models_dir = "models"
models = []
# Iterate over the files in the models directory
if os.path.exists(models_dir):
for model in os.listdir(models_dir):
model_path = os.path.join(models_dir, model)
if os.path.isdir(model_path): # Assuming each model is in its own directory
digest = hashlib.md5(model.encode()).hexdigest()
models.append({
"model":f"{model}:latest",
"name":f"{model}:latest",
"digest":digest,
"urls":[0]
})
else:
print(f"Models directory {models_dir} does not exist.")
return ModelListResponse(models=models)
@router.get(
"/ollama/api/version",
dependencies=[Depends(check_api_key)]
)
async def dummy2(request: Request):
return {"version": "1.0"}
@router.get(
"/api/version",
dependencies=[Depends(check_api_key)]
)
async def dummy(request: Request):
return {"version": "1.0"}
# Models endpoint
@router.get(
"/api/tags",
dependencies=[Depends(check_api_key)]
)
async def get_all_models(request: Request) -> ModelListResponse:
print(f"Processing request for models {request.state.id}")
response = await run_with_request_disconnect(
request,
asyncio.create_task(fetch_models()),
disconnect_message=f"All models fetched",
)
return response
@router.post(
"/api/chat",
dependencies=[Depends(check_api_key)],
)
async def chat_completion_request_ollama(
request: Request, data: ChatCompletionRequest
):
"""
Generates a chat completion from a prompt.
If stream = true, this returns an SSE stream.
"""
if data.model:
await load_inline_model(data.model, request)
else:
await check_model_container()
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
model_path = model.container.model_dir
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = await format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
disable_request_streaming = config.developer.disable_request_streaming
async def stream_response(request: Request):
async for chunk in stream_generate_chat_completion_ollama(prompt, data, request, model_path):
yield json.dumps(chunk).encode('utf-8') + b'\n'
if data.stream and not disable_request_streaming:
return StreamingResponse(stream_response(request), media_type="application/x-ndjson")
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response

View file

@ -98,3 +98,20 @@ class ChatCompletionStreamChunk(BaseModel):
model: str model: str
object: str = "chat.completion.chunk" object: str = "chat.completion.chunk"
usage: Optional[UsageStats] = None usage: Optional[UsageStats] = None
class ChatCompletionResponseOllama(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion"
usage: Optional[UsageStats] = None
class ChatCompletionStreamChunkOllama(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
choices: List[ChatCompletionStreamChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion.chunk"
usage: Optional[UsageStats] = None

View file

@ -4,6 +4,8 @@ import asyncio
import pathlib import pathlib
from asyncio import CancelledError from asyncio import CancelledError
from typing import List, Optional from typing import List, Optional
import json
from datetime import datetime
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from jinja2 import TemplateError from jinja2 import TemplateError
from loguru import logger from loguru import logger
@ -109,6 +111,82 @@ def _create_response(
return response return response
def _create_stream_chunk_ollama(
request_id: str,
generation: Optional[dict] = None,
model_name: Optional[str] = None,
is_usage_chunk: bool = False,
):
"""Create a chat completion stream chunk from the provided text."""
index = generation.get("index")
choices = []
usage_stats = None
if is_usage_chunk:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
usage_stats = UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
elif "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
index=index,
finish_reason=generation.get("finish_reason"),
)
# lets check if we have tool calls since we are at the end of the generation
if "tool_calls" in generation:
tool_calls = generation["tool_calls"]
message = ChatCompletionMessage(
tool_calls=postprocess_tool_call(tool_calls)
)
choice.delta = message
choices.append(choice)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
)
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), {})
top_logprobs = [
ChatCompletionLogprob(token=token, logprob=logprob)
for token, logprob in logprobs.items()
]
generated_token = next(iter(token_probs))
token_prob_response = ChatCompletionLogprob(
token=generated_token,
logprob=token_probs[generated_token],
top_logprobs=top_logprobs,
)
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
choice = ChatCompletionStreamChoice(
index=index,
delta=message,
logprobs=logprob_response,
)
ollama_bit = {
"model":model_name,
"created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z",
"message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none',
"content": choice.delta.content if hasattr(choice.delta, 'content') else 'none'},
"done_reason": choice.finish_reason,
"done": choice.finish_reason=="stop",
}
return ollama_bit
def _create_stream_chunk( def _create_stream_chunk(
request_id: str, request_id: str,
generation: Optional[dict] = None, generation: Optional[dict] = None,
@ -307,6 +385,101 @@ async def apply_chat_template(data: ChatCompletionRequest):
raise HTTPException(400, error_message) from exc raise HTTPException(400, error_message) from exc
async def stream_generate_chat_completion_ollama(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
):
"""Generator for the generation process."""
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
logger.info(f"Received chat completion streaming request {request.state.id}")
gen_params = data.to_gen_params()
for n in range(0, data.n):
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params
gen_task = asyncio.create_task(
_stream_collector(
n,
gen_queue,
prompt,
request.state.id,
abort_event,
**task_gen_params,
)
)
gen_tasks.append(gen_task)
# We need to keep track of the text generated so we can resume the tool calls
current_generation_text = ""
# Consumer loop
while True:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect(
f"Chat completion generation {request.state.id} cancelled by user."
)
generation = await gen_queue.get()
# lets only append the text if we need it for tool calls later
if data.tool_call_start and "text" in generation:
current_generation_text += generation["text"]
# check if we are running a tool model, and that we are at stop
if data.tool_call_start and "stop_str" in generation:
generations = await generate_tool_calls(
data,
[generation],
request,
current_generations=current_generation_text,
)
generation = generations[0] # We only have one generation in this case
# Stream collector will push an exception to the queue if it fails
if isinstance(generation, Exception):
raise generation
chunk = _create_stream_chunk_ollama(
request.state.id, generation, model_path.name
)
yield chunk
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
# Send a usage chunk
if data.stream_options and data.stream_options.include_usage:
usage_chunk = _create_stream_chunk_ollama(
request.state.id,
generation,
model_path.name,
is_usage_chunk=True,
)
yield usage_chunk
logger.info(
f"Finished chat completion streaming request {request.state.id}"
)
break
except CancelledError:
# Get out if the request gets disconnected
if not disconnect_task.done():
abort_event.set()
handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception:
yield get_generator_error(
"Chat completion aborted. Please check the server console."
)
async def stream_generate_chat_completion( async def stream_generate_chat_completion(
prompt: str, prompt: str,
embeddings: MultimodalEmbeddingWrapper, embeddings: MultimodalEmbeddingWrapper,

View file

@ -128,6 +128,7 @@ async def _stream_collector(
async def load_inline_model(model_name: str, request: Request): async def load_inline_model(model_name: str, request: Request):
"""Load a model from the data.model parameter""" """Load a model from the data.model parameter"""
model_name = model_name.split(":")[0]
# Return if the model container already exists and the model is fully loaded # Return if the model container already exists and the model is fully loaded
if ( if (
@ -175,6 +176,14 @@ async def load_inline_model(model_name: str, request: Request):
return return
if model.container is not None:
if model.container.model_dir.name != model_name:
logger.info(f"New model requested: {model_name}. Unloading current model.")
await model.unload_model()
elif model.container.model_loaded:
logger.info(f"Model {model_name} is already loaded.")
return
model_path = pathlib.Path(config.model.model_dir) model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / model_name model_path = model_path / model_name