Replace streamlit with FastAPI

This commit is contained in:
Jakob Lechner 2025-08-02 03:25:54 +02:00
parent 82263557af
commit 91f4d70a77
7 changed files with 228 additions and 82 deletions

75
main.py Normal file
View file

@ -0,0 +1,75 @@
from fastapi import FastAPI, Request
from fastapi.responses import HTMLResponse, StreamingResponse, PlainTextResponse
from tracker import start_all_streams, get_metrics, get_latest_frame, STREAMS
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
import time
app = FastAPI()
@app.on_event("startup")
def startup_event():
start_all_streams()
@app.get("/", response_class=HTMLResponse)
def index(request: Request):
current_stream = request.query_params.get("stream")
stream_options = "\n".join(
[
f'<option value="{stream_id}" {"selected" if stream_id == current_stream else ""}>{stream_id}</option>'
for stream_id in STREAMS.keys()
]
)
return f"""
<html>
<body>
<h1>People Counter Dashboard</h1>
<form action="/" method="get">
<label>Select Stream:</label>
<select name="stream" onchange="this.form.submit()">
{stream_options}
</select>
</form>
<img src="/video_feed/{current_stream}" />
<p><a href="/metrics">Prometheus Metrics</a></p>
</body>
</html>
"""
@app.get("/metrics")
def metrics():
return PlainTextResponse(
generate_latest(get_metrics()), media_type=CONTENT_TYPE_LATEST
)
@app.get("/video_feed/{stream_id}")
def video_feed(stream_id: str):
boundary = "--boundarydonotcross"
def generate():
last_version = time.time()
while True:
frame, version = get_latest_frame(stream_id)
if frame is not None and version > last_version:
last_version = version
print(f"yielding new frame @{version}")
yield (
boundary.encode()
+ b"\r\n"
+ b"Content-Type: image/jpeg\r\n\r\n"
+ frame
+ b"\r\n"
)
else:
# wait for new frame
time.sleep(0.05)
return StreamingResponse(
generate(), media_type=f"multipart/x-mixed-replace; boundary={boundary}"
)

View file

@ -1,23 +1,28 @@
import threading
from prometheus_client import Counter, start_http_server
from prometheus_client import Counter, CollectorRegistry
_lock = threading.Lock()
_registry = None
_metrics = None
def get_metrics():
global _metrics
global _registry, _metrics
with _lock:
if _metrics is None:
start_http_server(9110)
if _registry is None:
_registry = CollectorRegistry()
_metrics = {
"people_in": Counter(
"people_in_count",
"Number of people who entered",
["stream"],
registry=_registry,
),
"people_out": Counter(
"people_out_count",
"Number of people who exited",
["stream"],
registry=_registry,
),
}
return _metrics
return _registry, _metrics

View file

@ -4,4 +4,5 @@ numpy
filterpy
scikit-image
supervision
streamlit
uvicorn
fastapi

View file

@ -1,69 +0,0 @@
import streamlit as st
import cv2
from ultralytics import YOLO
from utils.counter import Counter
from utils.zones import draw_count_line_horizontal, draw_count_line_vertical
import tempfile
import time
from metrics import get_metrics
metrics = get_metrics()
st.set_page_config(layout="wide", page_title="People Counter")
STREAM_URL = "http://192.168.11.76:8080?action=stream"
start_button = st.button("Start Counter")
FRAME_SKIP = 2 # Reduziert Verarbeitungslast
if start_button and STREAM_URL:
line_orientation = 'vertical'
line_position = 640
stframe = st.empty()
model = YOLO("yolo_weights/yolo11n.pt")
counter = Counter(line_orientation, metrics)
cap = cv2.VideoCapture(STREAM_URL)
frame_idx = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
st.warning("Kein Bild vom RTSP-Stream.")
break
frame_idx += 1
if frame_idx % FRAME_SKIP != 0:
continue
results = model.track(frame, persist=True, classes=0, tracker="bytetrack.yaml")
boxes = results[0].boxes
if boxes.id is not None:
for box, track_id in zip(boxes.xywh.cpu(), boxes.id.cpu()):
x, y, w, h = map(int, box)
cv2.rectangle(frame, (x-w//2, y-h//2), (x+w//2, y+h//2), (0, 255, 0), 2)
cv2.putText(frame, str(track_id), (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
tracks = [
type("Track", (), {"id": int(track_id), "bbox": box.numpy()})
for box, track_id in zip(boxes.xywh.cpu(), boxes.id.cpu())
]
counter.update(tracks, line_position)
if line_orientation == 'horizontal':
draw_count_line_horizontal(frame, line_position)
elif line_orientation == 'vertical':
draw_count_line_vertical(frame, line_position)
else:
raise NotImplementedError(f'Line orientation {line_orientation} is invalid!')
cv2.putText(frame, f"In: {counter.in_count}", (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
cv2.putText(frame, f"Out: {counter.out_count}", (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
stframe.image(frame, channels="RGB")
cap.release()

130
tracker.py Normal file
View file

@ -0,0 +1,130 @@
import cv2
import threading
import time
from ultralytics import YOLO
from utils.counter import Counter
from utils.zones import draw_count_line_vertical
from collections import defaultdict
from metrics import get_metrics
registry, metrics = get_metrics()
STREAMS = {
"Kasse 1": "http://192.168.11.76:8080?action=stream",
"Kasse 2": "http://192.168.11.230:8080?action=stream",
}
FRAME_SKIP = 2
model = YOLO("yolo_weights/yolo11n.pt")
latest_frames = {}
def process_stream(stream_id, url):
print(f"PROCESS STREAM {stream_id} from {url}")
line_orientation = "vertical"
counter = Counter(stream_id, line_orientation, metrics)
cap = cv2.VideoCapture(url)
frame_idx = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
time.sleep(1)
continue
height, width, _ = frame.shape
frame_idx += 1
if frame_idx % FRAME_SKIP != 0:
continue
try:
results = model.track(
frame, persist=True, classes=0, tracker="bytetrack.yaml"
)
except Exception as e:
print(e)
continue
if line_orientation == "horizontal":
line_position = int(height / 2)
elif line_orientation == "vertical":
line_position = int(width / 2)
else:
raise NotImplementedError(
f"Line orientation {line_orientation} is invalid!"
)
boxes = results[0].boxes
if boxes.id is not None:
for box, track_id in zip(boxes.xywh.cpu(), boxes.id.cpu()):
x, y, w, h = map(int, box)
cv2.rectangle(
frame,
(x - w // 2, y - h // 2),
(x + w // 2, y + h // 2),
(0, 255, 0),
2,
)
cv2.putText(
frame,
str(track_id),
(x, y - 10),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(0, 255, 0),
2,
)
tracks = [
type("Track", (), {"id": int(track_id), "bbox": box.numpy()})
for box, track_id in zip(boxes.xywh.cpu(), boxes.id.cpu())
]
counter.update(tracks, line_position)
if line_orientation == "horizontal":
draw_count_line_horizontal(frame, line_position)
elif line_orientation == "vertical":
draw_count_line_vertical(frame, line_position)
cv2.putText(
frame,
f"In: {counter.in_count}",
(10, 40),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 255),
2,
)
cv2.putText(
frame,
f"Out: {counter.out_count}",
(10, 80),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 255, 255),
2,
)
_, jpeg = cv2.imencode(".jpg", frame)
latest_frames[stream_id] = (jpeg.tobytes(), time.time())
print(f"[{stream_id}] Frame captured and stored")
cap.release()
def start_all_streams():
for stream_id, url in STREAMS.items():
threading.Thread(
target=process_stream, args=(stream_id, url), daemon=True
).start()
def get_metrics():
return registry
def get_latest_frame(stream_id):
return latest_frames.get(stream_id, (None, None))

View file

@ -1,23 +1,25 @@
class Counter:
def __init__(self, line_orientation, metrics):
def __init__(self, stream_id, line_orientation, metrics):
self.in_count = 0
self.out_count = 0
self.track_memory = {}
self.line_orientation = line_orientation
self.metrics = metrics
self.stream_id = stream_id
def update(self, tracks, line_position):
for track in tracks:
track_id = track.id
x, y, w, h = track.bbox
if self.line_orientation == 'horizontal':
if self.line_orientation == "horizontal":
center = int(y + h / 2)
elif self.line_orientation == 'vertical':
elif self.line_orientation == "vertical":
center = int(x + w / 2)
else:
raise NotImplementedError(f'Line orientation {self.line_orientation} is invalid!')
raise NotImplementedError(
f"Line orientation {self.line_orientation} is invalid!"
)
if track_id not in self.track_memory:
self.track_memory[track_id] = center
@ -28,7 +30,7 @@ class Counter:
if prev < line_position <= center:
self.in_count += 1
self.metrics["people_in"].inc()
self.metrics["people_in"].labels(stream=self.stream_id).inc()
elif prev > line_position >= center:
self.out_count += 1
self.metrics["people_out"].inc()
self.metrics["people_out"].labels(stream=self.stream_id).inc()

View file

@ -1,10 +1,12 @@
import cv2
def draw_count_line_horizontal(frame, y):
color = (0, 255, 255)
thickness = 2
cv2.line(frame, (0, y), (frame.shape[1], y), color, thickness)
def draw_count_line_vertical(frame, x):
color = (0, 255, 255)
thickness = 2