diff --git a/.webui_secret_key b/.webui_secret_key new file mode 100644 index 0000000..145a017 --- /dev/null +++ b/.webui_secret_key @@ -0,0 +1 @@ +XkPmjle3dN2r0iZ3 \ No newline at end of file diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 8f4e7a4..8947723 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,4 +1,8 @@ import asyncio +import json +from fastapi import Response +from datetime import datetime +from fastapi.responses import StreamingResponse from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -13,11 +17,13 @@ from endpoints.OAI.types.chat_completion import ( ChatCompletionRequest, ChatCompletionResponse, ) +import os from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( apply_chat_template, generate_chat_completion, stream_generate_chat_completion, + stream_generate_chat_completion_ollama, ) from endpoints.OAI.utils.completion import ( generate_completion, @@ -165,3 +171,129 @@ async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsRes ) return response + +from pydantic import BaseModel, Field +from typing import List, Optional +import hashlib + +class ModelItem(BaseModel): + model: str + name: str + digest: str + urls: List[int] + +class ModelListResponse(BaseModel): + object: str = Field("list", description="Type of the response object.") + models: List[ModelItem] + +async def fetch_models(): + models_dir = "models" + models = [] + # Iterate over the files in the models directory + if os.path.exists(models_dir): + for model in os.listdir(models_dir): + model_path = os.path.join(models_dir, model) + if os.path.isdir(model_path): # Assuming each model is in its own directory + digest = hashlib.md5(model.encode()).hexdigest() + models.append({ + "model":f"{model}:latest", + "name":f"{model}:latest", + "digest":digest, + "urls":[0] + }) + else: + print(f"Models directory {models_dir} does not exist.") + return ModelListResponse(models=models) + +@router.get( + "/ollama/api/version", + dependencies=[Depends(check_api_key)] +) +async def dummy2(request: Request): + return {"version": "1.0"} +@router.get( + "/api/version", + dependencies=[Depends(check_api_key)] +) +async def dummy(request: Request): + return {"version": "1.0"} + +# Models endpoint +@router.get( + "/api/tags", + dependencies=[Depends(check_api_key)] +) +async def get_all_models(request: Request) -> ModelListResponse: + print(f"Processing request for models {request.state.id}") + + response = await run_with_request_disconnect( + request, + asyncio.create_task(fetch_models()), + disconnect_message=f"All models fetched", + ) + + return response + + + +@router.post( + "/api/chat", + dependencies=[Depends(check_api_key)], +) +async def chat_completion_request_ollama( + request: Request, data: ChatCompletionRequest +): + """ + Generates a chat completion from a prompt. + + If stream = true, this returns an SSE stream. + """ + + if data.model: + await load_inline_model(data.model, request) + else: + await check_model_container() + + if model.container.prompt_template is None: + error_message = handle_request_error( + "Chat completions are disabled because a prompt template is not set.", + exc_info=False, + ).error.message + + raise HTTPException(422, error_message) + + model_path = model.container.model_dir + + if isinstance(data.messages, str): + prompt = data.messages + else: + prompt = await format_prompt_with_template(data) + + # Set an empty JSON schema if the request wants a JSON response + if data.response_format.type == "json": + data.json_schema = {"type": "object"} + + disable_request_streaming = config.developer.disable_request_streaming + + async def stream_response(request: Request): + async for chunk in stream_generate_chat_completion_ollama(prompt, data, request, model_path): + yield json.dumps(chunk).encode('utf-8') + b'\n' + + + + if data.stream and not disable_request_streaming: + return StreamingResponse(stream_response(request), media_type="application/x-ndjson") + + + else: + generate_task = asyncio.create_task( + generate_chat_completion(prompt, data, request, model_path) + ) + + response = await run_with_request_disconnect( + request, + generate_task, + disconnect_message=f"Chat completion {request.state.id} cancelled by user.", + ) + return response + diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 5252314..95f8a6b 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -98,3 +98,20 @@ class ChatCompletionStreamChunk(BaseModel): model: str object: str = "chat.completion.chunk" usage: Optional[UsageStats] = None + +class ChatCompletionResponseOllama(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") + choices: List[ChatCompletionRespChoice] + created: int = Field(default_factory=lambda: int(time())) + model: str + object: str = "chat.completion" + usage: Optional[UsageStats] = None + + +class ChatCompletionStreamChunkOllama(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") + choices: List[ChatCompletionStreamChoice] + created: int = Field(default_factory=lambda: int(time())) + model: str + object: str = "chat.completion.chunk" + usage: Optional[UsageStats] = None diff --git a/endpoints/OAI/types/common.py b/endpoints/OAI/types/common.py index 16ef2ed..0f765de 100644 --- a/endpoints/OAI/types/common.py +++ b/endpoints/OAI/types/common.py @@ -25,7 +25,7 @@ class CompletionResponseFormat(BaseModel): class ChatCompletionStreamOptions(BaseModel): include_usage: Optional[bool] = False - + class CommonCompletionRequest(BaseSamplerRequest): """Represents a common completion request.""" diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 4a6c210..efc7b34 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -4,6 +4,8 @@ import asyncio import pathlib from asyncio import CancelledError from typing import List, Optional +import json +from datetime import datetime from fastapi import HTTPException, Request from jinja2 import TemplateError from loguru import logger @@ -109,6 +111,82 @@ def _create_response( return response +def _create_stream_chunk_ollama( + request_id: str, + generation: Optional[dict] = None, + model_name: Optional[str] = None, + is_usage_chunk: bool = False, +): + """Create a chat completion stream chunk from the provided text.""" + + index = generation.get("index") + choices = [] + usage_stats = None + + if is_usage_chunk: + prompt_tokens = unwrap(generation.get("prompt_tokens"), 0) + completion_tokens = unwrap(generation.get("generated_tokens"), 0) + + usage_stats = UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + elif "finish_reason" in generation: + choice = ChatCompletionStreamChoice( + index=index, + finish_reason=generation.get("finish_reason"), + ) + + # lets check if we have tool calls since we are at the end of the generation + if "tool_calls" in generation: + tool_calls = generation["tool_calls"] + message = ChatCompletionMessage( + tool_calls=postprocess_tool_call(tool_calls) + ) + choice.delta = message + + choices.append(choice) + + else: + message = ChatCompletionMessage( + role="assistant", content=unwrap(generation.get("text"), "") + ) + + logprob_response = None + + token_probs = unwrap(generation.get("token_probs"), {}) + if token_probs: + logprobs = unwrap(generation.get("logprobs"), {}) + top_logprobs = [ + ChatCompletionLogprob(token=token, logprob=logprob) + for token, logprob in logprobs.items() + ] + + generated_token = next(iter(token_probs)) + token_prob_response = ChatCompletionLogprob( + token=generated_token, + logprob=token_probs[generated_token], + top_logprobs=top_logprobs, + ) + + logprob_response = ChatCompletionLogprobs(content=[token_prob_response]) + + choice = ChatCompletionStreamChoice( + index=index, + delta=message, + logprobs=logprob_response, + ) + ollama_bit = { + "model":model_name, + "created_at": datetime.utcnow().isoformat(timespec='microseconds') + "Z", + "message": {"role":choice.delta.role if hasattr(choice.delta, 'role') else 'none', + "content": choice.delta.content if hasattr(choice.delta, 'content') else 'none'}, + "done_reason": choice.finish_reason, + "done": choice.finish_reason=="stop", + } + return ollama_bit + def _create_stream_chunk( request_id: str, generation: Optional[dict] = None, @@ -307,6 +385,101 @@ async def apply_chat_template(data: ChatCompletionRequest): raise HTTPException(400, error_message) from exc +async def stream_generate_chat_completion_ollama( + prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path +): + """Generator for the generation process.""" + abort_event = asyncio.Event() + gen_queue = asyncio.Queue() + gen_tasks: List[asyncio.Task] = [] + disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + + try: + logger.info(f"Received chat completion streaming request {request.state.id}") + + gen_params = data.to_gen_params() + + for n in range(0, data.n): + if n > 0: + task_gen_params = deepcopy(gen_params) + else: + task_gen_params = gen_params + + gen_task = asyncio.create_task( + _stream_collector( + n, + gen_queue, + prompt, + request.state.id, + abort_event, + **task_gen_params, + ) + ) + + gen_tasks.append(gen_task) + + # We need to keep track of the text generated so we can resume the tool calls + current_generation_text = "" + + # Consumer loop + while True: + if disconnect_task.done(): + abort_event.set() + handle_request_disconnect( + f"Chat completion generation {request.state.id} cancelled by user." + ) + + generation = await gen_queue.get() + # lets only append the text if we need it for tool calls later + if data.tool_call_start and "text" in generation: + current_generation_text += generation["text"] + + # check if we are running a tool model, and that we are at stop + if data.tool_call_start and "stop_str" in generation: + generations = await generate_tool_calls( + data, + [generation], + request, + current_generations=current_generation_text, + ) + generation = generations[0] # We only have one generation in this case + + # Stream collector will push an exception to the queue if it fails + if isinstance(generation, Exception): + raise generation + + chunk = _create_stream_chunk_ollama( + request.state.id, generation, model_path.name + ) + + yield chunk + + # Check if all tasks are completed + if all(task.done() for task in gen_tasks) and gen_queue.empty(): + # Send a usage chunk + if data.stream_options and data.stream_options.include_usage: + usage_chunk = _create_stream_chunk_ollama( + request.state.id, + generation, + model_path.name, + is_usage_chunk=True, + ) + yield usage_chunk + + logger.info( + f"Finished chat completion streaming request {request.state.id}" + ) + break + except CancelledError: + # Get out if the request gets disconnected + + if not disconnect_task.done(): + abort_event.set() + handle_request_disconnect("Chat completion generation cancelled by user.") + except Exception: + yield get_generator_error( + "Chat completion aborted. Please check the server console." + ) async def stream_generate_chat_completion( prompt: str, embeddings: MultimodalEmbeddingWrapper, diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index ca51c9c..98712d9 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -128,6 +128,7 @@ async def _stream_collector( async def load_inline_model(model_name: str, request: Request): """Load a model from the data.model parameter""" + model_name = model_name.split(":")[0] # Return if the model container already exists and the model is fully loaded if ( @@ -175,6 +176,14 @@ async def load_inline_model(model_name: str, request: Request): return + if model.container is not None: + if model.container.model_dir.name != model_name: + logger.info(f"New model requested: {model_name}. Unloading current model.") + await model.unload_model() + elif model.container.model_loaded: + logger.info(f"Model {model_name} is already loaded.") + return + model_path = pathlib.Path(config.model.model_dir) model_path = model_path / model_name