Add health check monitoring for EXL2 errors (#206)

* Add health check monitoring for EXL2 errors

* Health: Format and change status code

A status code of 503 makes more sense to use.
---------
This commit is contained in:
TerminalMan 2024-09-23 02:40:36 +01:00 committed by GitHub
parent e0ffa90865
commit 2cda890deb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 73 additions and 3 deletions

View file

@ -32,6 +32,8 @@ from typing import List, Optional, Union
from ruamel.yaml import YAML from ruamel.yaml import YAML
from common.health import HealthManager
from backends.exllamav2.grammar import ( from backends.exllamav2.grammar import (
ExLlamaV2Grammar, ExLlamaV2Grammar,
clear_grammar_func_cache, clear_grammar_func_cache,
@ -1373,6 +1375,8 @@ class ExllamaV2Container:
) )
asyncio.ensure_future(self.create_generator()) asyncio.ensure_future(self.create_generator())
await HealthManager.add_unhealthy_event(ex)
raise ex raise ex
finally: finally:
# Log generation options to console # Log generation options to console

42
common/health.py Normal file
View file

@ -0,0 +1,42 @@
import asyncio
from collections import deque
from datetime import datetime, timezone
from functools import partial
from pydantic import BaseModel, Field
from typing import Union
class UnhealthyEvent(BaseModel):
"""Represents an error that makes the system unhealthy"""
time: datetime = Field(
default_factory=partial(datetime.now, timezone.utc),
description="Time the error occurred in UTC time",
)
description: str = Field("Unknown error", description="The error message")
class HealthManagerClass:
"""Class to manage the health global state"""
def __init__(self):
# limit the max stored errors to 100 to avoid a memory leak
self.issues: deque[UnhealthyEvent] = deque(maxlen=100)
self._lock = asyncio.Lock()
async def add_unhealthy_event(self, error: Union[str, Exception]):
"""Add a new unhealthy event"""
async with self._lock:
if isinstance(error, Exception):
error = f"{error.__class__.__name__}: {str(error)}"
self.issues.append(UnhealthyEvent(description=error))
async def is_service_healthy(self) -> tuple[bool, list[UnhealthyEvent]]:
"""Check if the service is healthy"""
async with self._lock:
healthy = len(self.issues) == 0
return healthy, list(self.issues)
# Create an instance of the global state manager
HealthManager = HealthManagerClass()

View file

@ -1,7 +1,7 @@
import asyncio import asyncio
import pathlib import pathlib
from sys import maxsize from sys import maxsize
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from common import model, sampling from common import model, sampling
@ -12,6 +12,7 @@ from common.networking import handle_request_error, run_with_request_disconnect
from common.tabby_config import config from common.tabby_config import config
from common.templating import PromptTemplate, get_all_templates from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap from common.utils import unwrap
from common.health import HealthManager
from endpoints.core.types.auth import AuthPermissionResponse from endpoints.core.types.auth import AuthPermissionResponse
from endpoints.core.types.download import DownloadRequest, DownloadResponse from endpoints.core.types.download import DownloadRequest, DownloadResponse
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
@ -22,6 +23,7 @@ from endpoints.core.types.model import (
ModelLoadRequest, ModelLoadRequest,
ModelLoadResponse, ModelLoadResponse,
) )
from endpoints.core.types.health import HealthCheckResponse
from endpoints.core.types.sampler_overrides import ( from endpoints.core.types.sampler_overrides import (
SamplerOverrideListResponse, SamplerOverrideListResponse,
SamplerOverrideSwitchRequest, SamplerOverrideSwitchRequest,
@ -47,9 +49,16 @@ router = APIRouter()
# Healthcheck endpoint # Healthcheck endpoint
@router.get("/health") @router.get("/health")
async def healthcheck(): async def healthcheck(response: Response) -> HealthCheckResponse:
"""Get the current service health status""" """Get the current service health status"""
return {"status": "healthy"} healthy, issues = await HealthManager.is_service_healthy()
if not healthy:
response.status_code = 503
return HealthCheckResponse(
status="healthy" if healthy else "unhealthy", issues=issues
)
# Model list endpoint # Model list endpoint

View file

@ -0,0 +1,15 @@
from typing import Literal
from pydantic import BaseModel, Field
from common.health import UnhealthyEvent
class HealthCheckResponse(BaseModel):
"""System health status"""
status: Literal["healthy", "unhealthy"] = Field(
"healthy", description="System health status"
)
issues: list[UnhealthyEvent] = Field(
default_factory=list, description="List of issues"
)