diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5d783d1..8582b76 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -32,6 +32,8 @@ from typing import List, Optional, Union from ruamel.yaml import YAML +from common.health import HealthManager + from backends.exllamav2.grammar import ( ExLlamaV2Grammar, clear_grammar_func_cache, @@ -1373,6 +1375,8 @@ class ExllamaV2Container: ) asyncio.ensure_future(self.create_generator()) + await HealthManager.add_unhealthy_event(ex) + raise ex finally: # Log generation options to console diff --git a/common/health.py b/common/health.py new file mode 100644 index 0000000..4d21d6a --- /dev/null +++ b/common/health.py @@ -0,0 +1,42 @@ +import asyncio +from collections import deque +from datetime import datetime, timezone +from functools import partial +from pydantic import BaseModel, Field +from typing import Union + + +class UnhealthyEvent(BaseModel): + """Represents an error that makes the system unhealthy""" + + time: datetime = Field( + default_factory=partial(datetime.now, timezone.utc), + description="Time the error occurred in UTC time", + ) + description: str = Field("Unknown error", description="The error message") + + +class HealthManagerClass: + """Class to manage the health global state""" + + def __init__(self): + # limit the max stored errors to 100 to avoid a memory leak + self.issues: deque[UnhealthyEvent] = deque(maxlen=100) + self._lock = asyncio.Lock() + + async def add_unhealthy_event(self, error: Union[str, Exception]): + """Add a new unhealthy event""" + async with self._lock: + if isinstance(error, Exception): + error = f"{error.__class__.__name__}: {str(error)}" + self.issues.append(UnhealthyEvent(description=error)) + + async def is_service_healthy(self) -> tuple[bool, list[UnhealthyEvent]]: + """Check if the service is healthy""" + async with self._lock: + healthy = len(self.issues) == 0 + return healthy, list(self.issues) + + +# Create an instance of the global state manager +HealthManager = HealthManagerClass() diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 325fbad..2c60cd7 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -1,7 +1,7 @@ import asyncio import pathlib from sys import maxsize -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Request, Response from sse_starlette import EventSourceResponse from common import model, sampling @@ -12,6 +12,7 @@ from common.networking import handle_request_error, run_with_request_disconnect from common.tabby_config import config from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap +from common.health import HealthManager from endpoints.core.types.auth import AuthPermissionResponse from endpoints.core.types.download import DownloadRequest, DownloadResponse from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse @@ -22,6 +23,7 @@ from endpoints.core.types.model import ( ModelLoadRequest, ModelLoadResponse, ) +from endpoints.core.types.health import HealthCheckResponse from endpoints.core.types.sampler_overrides import ( SamplerOverrideListResponse, SamplerOverrideSwitchRequest, @@ -47,9 +49,16 @@ router = APIRouter() # Healthcheck endpoint @router.get("/health") -async def healthcheck(): +async def healthcheck(response: Response) -> HealthCheckResponse: """Get the current service health status""" - return {"status": "healthy"} + healthy, issues = await HealthManager.is_service_healthy() + + if not healthy: + response.status_code = 503 + + return HealthCheckResponse( + status="healthy" if healthy else "unhealthy", issues=issues + ) # Model list endpoint diff --git a/endpoints/core/types/health.py b/endpoints/core/types/health.py new file mode 100644 index 0000000..ad5fffe --- /dev/null +++ b/endpoints/core/types/health.py @@ -0,0 +1,15 @@ +from typing import Literal +from pydantic import BaseModel, Field + +from common.health import UnhealthyEvent + + +class HealthCheckResponse(BaseModel): + """System health status""" + + status: Literal["healthy", "unhealthy"] = Field( + "healthy", description="System health status" + ) + issues: list[UnhealthyEvent] = Field( + default_factory=list, description="List of issues" + )