diff options
Diffstat (limited to 'webserver.py')
| -rw-r--r-- | webserver.py | 214 |
1 files changed, 151 insertions, 63 deletions
diff --git a/webserver.py b/webserver.py index d1c0a1e..162badc 100644 --- a/webserver.py +++ b/webserver.py @@ -1,42 +1,43 @@ #!/usr/bin/env python - """ Module serving video from zmq to a webserver. """ - __author__ = "Franoosh Corporation" - import os -from collections import defaultdict +from collections import defaultdict, deque import json import logging +import zmq.asyncio import asyncio -from threading import Thread +from collections import defaultdict +from contextlib import asynccontextmanager import uvicorn +import zmq from fastapi import ( FastAPI, Request, HTTPException, WebSocket, + WebSocketDisconnect, templating, ) from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from helpers import CustomLoggingFormatter -import zmq CLIENTS_JSON_FILE = os.path.join(os.getcwd(), 'clients.json') +CLIENTS_DICT = defaultdict(list) # CLIENTS_DICT[client_id] = [camera_id1, camera_id2, ...] LOGFILE = 'webserver.log' -LOGLEVEL = logging.INFO +LOGLEVEL = logging.DEBUG HOST = "127.0.0.1" -ZMQPORT = "9979" -WSPORT = "8008" -ZMQ_BACKEND_ADDR = f"tcp://{HOST}:{ZMQPORT}" -WS_BACKEND_ADDR = f"tcp://{HOST}:{WSPORT}" +ZMQ_PORT = "9979" +WEB_PORT = "8008" +CTRL_BACKEND_ADDR = f"tcp://{HOST}:{ZMQ_PORT}" +WEB_BACKEND_ADDR = f"tcp://{HOST}:{WEB_PORT}" log_formatter = CustomLoggingFormatter() handler = logging.FileHandler(LOGFILE, encoding='utf-8', mode='a') @@ -50,88 +51,175 @@ logging.basicConfig( level=LOGLEVEL, ) - - -app = FastAPI() +# Track websocket connections by (client_id, camera_id): +ws_connections = defaultdict(dict) # ws_connections[client_id][camera_id] = websocket +ws_queues = defaultdict(dict) +ctrl_msg_que = asyncio.Queue() +# Create ZMQ context and socket: +zmq_context = zmq.asyncio.Context() +zmq_socket = zmq_context.socket(zmq.DEALER) +# Connect to ZMQ backend: +zmq_socket.connect(WEB_BACKEND_ADDR) + + +async def zmq_bridge(): + """Bridge between ZMQ backend and websocket clients.""" + while True: + try: + data = await zmq_socket.recv_multipart() + topic, frame_data = None, None + if len(data) == 2: + topic, frame_data = data + elif len(data) == 3: + topic, _, _ = data + else: + logger.warning("Received invalid ZMQ message: %r", data) + continue + if topic: + client_id, camera_id = topic.decode('utf-8').split(':', 1) + if frame_data: + # Add client and camera to CLIENTS_DICT if new: + if not camera_id in CLIENTS_DICT[client_id]: + CLIENTS_DICT[client_id].append(camera_id) + else: + # No frame data means a notification to remove camera: + try: + CLIENTS_DICT[client_id].remove(camera_id) + except ValueError: + pass + + queue = ws_queues.get(client_id, {}).get(camera_id) + if queue and frame_data: + if queue.full(): + _ = queue.get_nowait() # Discard oldest frame + await queue.put(frame_data) + if not ctrl_msg_que.empty(): + client_id, camera_id, command, args_list = await ctrl_msg_que.get() + zmq_socket.send_multipart([ + client_id.encode('utf-8'), + camera_id.encode('utf-8'), + command.encode('utf-8'), + ] + args_list) + logger.info("Sent control command '%s' with args: %r for '/ws/%s/%s' to backend.", command, args_list, client_id, camera_id) + except Exception as e: + logger.error("Error in ZMQ bridge: %r", e) + # TODO: Check if this loop can be optimized to avoid busy waiting. + # Alternative implementation using zmq Poller: + # poll = zmq.asyncio.Poller() + # poll.register(zmq_socket, zmq.POLLIN) + # while True: + # try: + # sockets = dict(await poll.poll()) + # if zmq_socket in sockets: + # topic, frame_data = await zmq_socket.recv_multipart() + # client_id, camera_id = topic.decode('utf-8').split(':', 1) + # set_clients(client_id, camera_id) + # queue = ws_queues.get(client_id, {}).get(camera_id) + # if queue: + # if queue.full(): + # _ = queue.get_nowait() # Discard oldest frame + # await queue.put(frame_data) + # if not ctrl_msg_que.empty(): + # client_id, camera_id, command, args_list = await ctrl_msg_que.get() + # zmq_socket.send_multipart([ + # client_id.encode('utf-8'), + # camera_id.encode('utf-8'), + # command.encode('utf-8'), + # ] + args_list) + # logger.info("Sent control command '%s' with args: %r for '/ws/%s/%s' to backend.", command, args_list, client_id, camera_id) + # except Exception as e: + # logger.error("Error in ZMQ bridge: %r", e) + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Create lifespan context for FastAPI app.""" + asyncio.create_task(zmq_bridge()) + yield + +app = FastAPI(lifespan=lifespan) app.mount("/static", StaticFiles(directory="static"), name="static") templates = templating.Jinja2Templates(directory='templates') -# Track websocket connections by (client_id, camera_id) -ws_connections = defaultdict(dict) # ws_connections[client_id][camera_id] = websocket - -# Set up a single ZMQ SUB socket for all websocket connections -zmq_context = zmq.Context() -zmq_socket = zmq_context.socket(zmq.SUB) -zmq_socket.bind(WS_BACKEND_ADDR) -zmq_socket.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all topics -poller = zmq.Poller() -poller.register(zmq_socket, zmq.POLLIN) - -def load_clients(): - try: - with open(CLIENTS_JSON_FILE) as f: - clients_dict = json.load(f) - except FileNotFoundError: - clients_dict = {} - return clients_dict @app.get("/") async def main_route(request: Request): - logger.error("DEBUG: main route visited") - clients = load_clients() + """Serve main page.""" + logger.debug("Main route visited") return templates.TemplateResponse( "main.html", { "request": request, - "clients": clients, + "clients": CLIENTS_DICT, } ) @app.get("/clients/{client_id}", response_class=HTMLResponse) async def client_route(request: Request, client_id: str): """Serve client page.""" - clients_dict = load_clients() - logger.debug("Checking client_id: '%s' in clients_dict: %r.", client_id, clients_dict) - if not client_id in clients_dict: - return HTTPException(status_code=404, detail="No such client ID.") + logger.debug("Checking client_id: '%s' in clients_dict: %r.", client_id, CLIENTS_DICT) + if not client_id in CLIENTS_DICT: + raise HTTPException(status_code=404, detail="No such client ID.") return templates.TemplateResponse( "client.html", { "request": request, "client_id": client_id, - "camera_ids": clients_dict[client_id], + "camera_ids": CLIENTS_DICT[client_id], }, ) - @app.websocket("/ws/{client_id}/{camera_id}") async def camera_route(websocket: WebSocket, client_id: str, camera_id: str): """Serve a particular camera page.""" logger.info("Accepting websocket connection for '/ws/%s/%s'.", client_id, camera_id) await websocket.accept() - if client_id not in ws_connections: - ws_connections[client_id] = {} - ws_connections[client_id][camera_id] = websocket - try: + ws_connections[client_id][camera_id] = {'ws': websocket} + ws_queues[client_id][camera_id] = asyncio.Queue(maxsize=10) + queue = ws_queues[client_id][camera_id] + + async def send_frames(): + while True: + frame_data = await queue.get() + await websocket.send_bytes(frame_data) + + async def receive_control(): while True: - # Wait for a frame for this client/camera - sockets = dict(poller.poll(1000)) - if zmq_socket in sockets: - msg = zmq_socket.recv_multipart() - if len(msg) == 3: - recv_client_id, recv_camera_id, content = msg - recv_client_id = recv_client_id.decode("utf-8") - recv_camera_id = recv_camera_id.decode("utf-8") - # Only send to the websocket for this client/camera - if recv_client_id == client_id and recv_camera_id == camera_id: - await websocket.send_bytes(content) - except Exception as exc: - logger.warning("Connection closed: %r", exc) + try: + data = await websocket.receive_text() + logger.info("Received control message from '/ws/%s/%s': %s", client_id, camera_id, data) + # Handle control messages from the client: + frontend_message = json.loads(data) + for command, args in frontend_message.items(): + args_list = [str(arg).encode('utf-8') for arg in args] + ctrl_msg_que.put_nowait((client_id, camera_id, command, args_list)) + # zmq_control_socket.send_multipart([ + # client_id.encode('utf-8'), + # camera_id.encode('utf-8'), + # command.encode('utf-8'), + # b'', + # ] + args_list) + # logger.info("Sent control command '%s' with args: %r for '/ws/%s/%s' to backend.", command, args_list, client_id, camera_id) + logger.info("Put control command '%s' with args: %r for '/ws/%s/%s' on queue to backend.", command, args_list, client_id, camera_id) + except json.JSONDecodeError: + logger.warning("Received invalid JSON from '/ws/%s/%s': %s", client_id, camera_id, data) + except WebSocketDisconnect: + logger.info("WebSocket disconnected for '/ws/%s/%s'.", client_id, camera_id) + break + except Exception as exc: + logger.warning("Error receiving control message: %r", exc) + send_task = asyncio.create_task(send_frames()) + receive_task = asyncio.create_task(receive_control()) + try: + # await asyncio.gather(send_task, receive_task) + await asyncio.wait( + [send_task, receive_task], + return_when=asyncio.FIRST_COMPLETED, + ) finally: - if client_id in ws_connections and camera_id in ws_connections[client_id]: - del ws_connections[client_id][camera_id] - await websocket.close() - + send_task.cancel() + receive_task.cancel() + ws_connections[client_id].pop(camera_id, None) + ws_queues[client_id].pop(camera_id, None) if __name__ == "__main__": uvicorn.run( @@ -139,4 +227,4 @@ if __name__ == "__main__": port=8007, host='127.0.0.1', log_level='info', - )
\ No newline at end of file + ) |
