aboutsummaryrefslogtreecommitdiff
path: root/webserver.py
blob: 162badc149a34476c6160d5b79c74c25e78da5dd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
#!/usr/bin/env python
"""
Module serving video from zmq to a webserver.
"""
__author__ = "Franoosh Corporation"

import os
from collections import defaultdict, deque
import json
import logging
import zmq.asyncio
import asyncio
from collections import defaultdict
from contextlib import asynccontextmanager
import uvicorn
import zmq
from fastapi import (
    FastAPI,
    Request,
    HTTPException,
    WebSocket,
    WebSocketDisconnect,
    templating,
)
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles

from helpers import CustomLoggingFormatter


CLIENTS_JSON_FILE = os.path.join(os.getcwd(), 'clients.json')
CLIENTS_DICT = defaultdict(list)  # CLIENTS_DICT[client_id] = [camera_id1, camera_id2, ...] 
LOGFILE = 'webserver.log'
LOGLEVEL = logging.DEBUG

HOST = "127.0.0.1"
ZMQ_PORT = "9979"
WEB_PORT = "8008"
CTRL_BACKEND_ADDR = f"tcp://{HOST}:{ZMQ_PORT}"
WEB_BACKEND_ADDR = f"tcp://{HOST}:{WEB_PORT}"

log_formatter = CustomLoggingFormatter()
handler = logging.FileHandler(LOGFILE, encoding='utf-8', mode='a')
handler.setFormatter(log_formatter)
logging.root.addHandler(handler)
logging.root.setLevel(LOGLEVEL)
logger = logging.getLogger(__name__)
logging.basicConfig(
    filename=LOGFILE,
    datefmt='%Y-%m-%d %I:%M:%S',
    level=LOGLEVEL,
)

# Track websocket connections by (client_id, camera_id):
ws_connections = defaultdict(dict)  # ws_connections[client_id][camera_id] = websocket
ws_queues = defaultdict(dict)
ctrl_msg_que = asyncio.Queue()
# Create ZMQ context and socket:
zmq_context = zmq.asyncio.Context()
zmq_socket = zmq_context.socket(zmq.DEALER)
# Connect to ZMQ backend:
zmq_socket.connect(WEB_BACKEND_ADDR)


async def zmq_bridge():
    """Bridge between ZMQ backend and websocket clients."""
    while True:
        try:
            data = await zmq_socket.recv_multipart()
            topic, frame_data = None, None
            if len(data) == 2:
                topic, frame_data = data
            elif len(data) == 3:
                topic, _, _ = data
            else:
                logger.warning("Received invalid ZMQ message: %r", data)
                continue
            if topic:
                client_id, camera_id = topic.decode('utf-8').split(':', 1)
                if frame_data:
                    # Add client and camera to CLIENTS_DICT if new:
                    if not camera_id in CLIENTS_DICT[client_id]:
                        CLIENTS_DICT[client_id].append(camera_id)
                else:
                    # No frame data means a notification to remove camera:
                    try:
                        CLIENTS_DICT[client_id].remove(camera_id)
                    except ValueError:
                        pass

            queue = ws_queues.get(client_id, {}).get(camera_id)
            if queue and frame_data:
                if queue.full():
                    _ = queue.get_nowait()  # Discard oldest frame
                await queue.put(frame_data)
            if not ctrl_msg_que.empty():
                client_id, camera_id, command, args_list = await ctrl_msg_que.get()
                zmq_socket.send_multipart([
                    client_id.encode('utf-8'),
                    camera_id.encode('utf-8'),
                    command.encode('utf-8'),
                ] +  args_list)
                logger.info("Sent control command '%s' with args: %r for '/ws/%s/%s' to backend.", command, args_list, client_id, camera_id)
        except Exception as e:
            logger.error("Error in ZMQ bridge: %r", e)
    # TODO: Check if this loop can be optimized to avoid busy waiting.
    # Alternative implementation using zmq Poller:
    # poll = zmq.asyncio.Poller()
    # poll.register(zmq_socket, zmq.POLLIN)
    # while True:
    #     try:
    #         sockets = dict(await poll.poll())
    #         if zmq_socket in sockets:
    #             topic, frame_data = await zmq_socket.recv_multipart()
    #             client_id, camera_id = topic.decode('utf-8').split(':', 1)
    #             set_clients(client_id, camera_id)
    #             queue = ws_queues.get(client_id, {}).get(camera_id)
    #             if queue:
    #                 if queue.full():
    #                     _ = queue.get_nowait()  # Discard oldest frame
    #                 await queue.put(frame_data)
    #         if not ctrl_msg_que.empty():
    #             client_id, camera_id, command, args_list = await ctrl_msg_que.get()
    #             zmq_socket.send_multipart([
    #                 client_id.encode('utf-8'),
    #                 camera_id.encode('utf-8'),
    #                 command.encode('utf-8'),
    #             ] +  args_list)
    #             logger.info("Sent control command '%s' with args: %r for '/ws/%s/%s' to backend.", command, args_list, client_id, camera_id)
    #     except Exception as e:
    #         logger.error("Error in ZMQ bridge: %r", e)

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Create lifespan context for FastAPI app."""
    asyncio.create_task(zmq_bridge())
    yield

app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = templating.Jinja2Templates(directory='templates')


@app.get("/")
async def main_route(request: Request):
    """Serve main page."""
    logger.debug("Main route visited")
    return templates.TemplateResponse(
        "main.html",
        {
            "request": request,
            "clients": CLIENTS_DICT,
         }
    )

@app.get("/clients/{client_id}", response_class=HTMLResponse)
async def client_route(request: Request, client_id: str):
    """Serve client page."""
    logger.debug("Checking client_id: '%s' in clients_dict: %r.", client_id, CLIENTS_DICT)
    if not client_id in CLIENTS_DICT:
        raise HTTPException(status_code=404, detail="No such client ID.")
    return templates.TemplateResponse(
        "client.html",
        {
            "request": request,
            "client_id": client_id,
            "camera_ids": CLIENTS_DICT[client_id],
        },
    )

@app.websocket("/ws/{client_id}/{camera_id}")
async def camera_route(websocket: WebSocket, client_id: str, camera_id: str):
    """Serve a particular camera page."""
    logger.info("Accepting websocket connection for '/ws/%s/%s'.", client_id, camera_id)
    await websocket.accept()
    ws_connections[client_id][camera_id] = {'ws': websocket}
    ws_queues[client_id][camera_id] = asyncio.Queue(maxsize=10)
    queue = ws_queues[client_id][camera_id]
    
    async def send_frames():
        while True:
            frame_data = await queue.get()
            await websocket.send_bytes(frame_data)
    
    async def receive_control():
        while True:
            try:
                data = await websocket.receive_text()
                logger.info("Received control message from '/ws/%s/%s': %s", client_id, camera_id, data)
                # Handle control messages from the client:
                frontend_message = json.loads(data)
                for command, args in frontend_message.items():
                    args_list = [str(arg).encode('utf-8') for arg in args]
                    ctrl_msg_que.put_nowait((client_id, camera_id, command, args_list))
                    # zmq_control_socket.send_multipart([
                    #     client_id.encode('utf-8'),
                    #     camera_id.encode('utf-8'),
                    #     command.encode('utf-8'),
                    #     b'',
                    # ] + args_list)
                    # logger.info("Sent control command '%s' with args: %r for '/ws/%s/%s' to backend.", command, args_list, client_id, camera_id)
                    logger.info("Put control command '%s' with args: %r for '/ws/%s/%s' on queue to backend.", command, args_list, client_id, camera_id)
            except json.JSONDecodeError:
                logger.warning("Received invalid JSON from '/ws/%s/%s': %s", client_id, camera_id, data)
            except WebSocketDisconnect:
                logger.info("WebSocket disconnected for '/ws/%s/%s'.", client_id, camera_id)
                break
            except Exception as exc:
                logger.warning("Error receiving control message: %r", exc)
    send_task = asyncio.create_task(send_frames())
    receive_task = asyncio.create_task(receive_control())
    try:
        # await asyncio.gather(send_task, receive_task)
        await asyncio.wait(
            [send_task, receive_task],
            return_when=asyncio.FIRST_COMPLETED,
        )
    finally:
        send_task.cancel()
        receive_task.cancel()
        ws_connections[client_id].pop(camera_id, None)
        ws_queues[client_id].pop(camera_id, None)

if __name__ == "__main__":
    uvicorn.run(
        app,
        port=8007,
        host='127.0.0.1',
        log_level='info',
    )