summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/websockets/handlers/msg_channel_wsh.py
blob: 7a66646f2bcbc2860f6f3aff1bf74ba93cb1db9e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
#!/usr/bin/python
import json
import logging
import urllib
import threading
import traceback
from queue import Empty

from mod_pywebsocket import stream, msgutil
from wptserve import stash as stashmod

logger = logging.getLogger()

address, authkey = stashmod.load_env_config()
stash = stashmod.Stash("msg_channel", address=address, authkey=authkey)

# Backend for websocket based channels.
#
# Each socket connection has a uuid identifying the channel and a
# direction which is either "read" or "write".  There can be only 1
# "read" channel per uuid, but multiple "write" channels
# (i.e. multiple producer, single consumer).
#
# The websocket connection URL contains the uuid and the direction as
# named query parameters.
#
# Channels are backed by a queue which is stored in the stash (one
# queue per uuid).
#
# The representation of a queue in the stash is a tuple (queue,
# has_reader, writer_count).  The first field is the queue itself, the
# latter are effectively reference counts for reader channels (which
# is zero or one, represented by a bool) and writer channels.  Once
# both counts drop to zero the queue can be deleted.
#
# Entries on the queue itself are formed of (command, data) pairs. The
# command can be either "close", signalling the socket is closing and
# the reference count on the channel should be decremented, or
# "message", which indicates a message.


def log(uuid, msg, level="debug"):
    msg = f"{uuid}: {msg}"
    getattr(logger, level)(msg)


def web_socket_do_extra_handshake(request):
    return


def web_socket_transfer_data(request):
    """Handle opening a websocket connection."""

    uuid, direction = parse_request(request)
    log(uuid, f"Got web_socket_transfer_data {direction}")

    # Get or create the relevant queue from the stash and update the refcount
    with stash.lock:
        value = stash.take(uuid)
        if value is None:
            queue = stash.get_queue()
            if direction == "read":
                has_reader = True
                writer_count = 0
            else:
                has_reader = False
                writer_count = 1
        else:
            queue, has_reader, writer_count = value
            if direction == "read":
                if has_reader:
                    raise ValueError("Tried to start multiple readers for the same queue")
                has_reader = True
            else:
                writer_count += 1

        stash.put(uuid, (queue, has_reader, writer_count))

    if direction == "read":
        run_read(request, uuid, queue)
    elif direction == "write":
        run_write(request, uuid, queue)

    log(uuid, f"transfer_data loop exited {direction}")
    close_channel(uuid, direction)


def web_socket_passive_closing_handshake(request):
    """Handle a client initiated close.

    When the client closes a reader, put a message in the message
    queue indicating the close. For a writer we don't need special
    handling here because receive_message in run_read will return an
    empty message in this case, so that loop will exit on its own.
    """
    uuid, direction = parse_request(request)
    log(uuid, f"Got web_socket_passive_closing_handshake {direction}")

    if direction == "read":
        with stash.lock:
            data = stash.take(uuid)
            stash.put(uuid, data)
        if data is not None:
            queue = data[0]
            queue.put(("close", None))

    return request.ws_close_code, request.ws_close_reason


def parse_request(request):
    query = request.unparsed_uri.split('?')[1]
    GET = dict(urllib.parse.parse_qsl(query))
    uuid = GET["uuid"]
    direction = GET["direction"]
    return uuid, direction


def wait_for_close(request, uuid, queue):
    """Listen for messages on the socket for a read connection to a channel."""
    closed = False
    while not closed:
        try:
            msg = request.ws_stream.receive_message()
            if msg is None:
                break
            try:
                cmd, data = json.loads(msg)
            except ValueError:
                cmd = None
            if cmd == "close":
                closed = True
                log(uuid, "Got client initiated close")
            else:
                log(uuid, f"Unexpected message on read socket {msg}", "warning")
        except Exception:
            if not (request.server_terminated or request.client_terminated):
                log(uuid, f"Got exception in wait_for_close\n{traceback.format_exc()}")
            closed = True

    if not request.server_terminated:
        queue.put(("close", None))


def run_read(request, uuid, queue):
    """Main loop for a read-type connection.

    This mostly just listens on the queue for new messages of the
    form (message, data). Supported messages are:
     message - Send `data` on the WebSocket
     close - Close the reader queue

    In addition there's a thread that listens for messages on the
    socket itself. Typically this socket shouldn't recieve any
    messages, but it can recieve an explicit "close" message,
    indicating the socket should be disconnected.
    """

    close_thread = threading.Thread(target=wait_for_close, args=(request, uuid, queue), daemon=True)
    close_thread.start()

    while True:
        try:
            data = queue.get(True, 1)
        except Empty:
            if request.server_terminated or request.client_terminated:
                break
        else:
            cmd, body = data
            log(uuid, f"queue.get ({cmd}, {body})")
            if cmd == "close":
                break
            if cmd == "message":
                msgutil.send_message(request, json.dumps(body))
            else:
                log(uuid, f"Unknown queue command {cmd}", level="warning")


def run_write(request, uuid, queue):
    """Main loop for a write-type connection.

    Messages coming over the socket have the format (command, data).
    The recognised commands are:
     message - Send the message `data` over the channel.
     disconnectReader - Close the reader connection for this channel.
     delete - Force-delete the entire channel and the underlying queue.
    """
    while True:
        msg = request.ws_stream.receive_message()
        if msg is None:
            break
        cmd, body = json.loads(msg)
        if cmd == "disconnectReader":
            queue.put(("close", None))
        elif cmd == "message":
            log(uuid, f"queue.put ({cmd}, {body})")
            queue.put((cmd, body))
        elif cmd == "delete":
            close_channel(uuid, None)


def close_channel(uuid, direction):
    """Update the channel state in the stash when closing a connection

    This updates the stash entry, including refcounts, once a
    connection to a channel is closed.

    Params:
    uuid - the UUID of the channel being closed.
    direction - "read" if a read connection was closed, "write" if a
                write connection was closed, None to remove the
                underlying queue from the stash entirely.

    """
    log(uuid, f"Got close_channel {direction}")
    with stash.lock:
        data = stash.take(uuid)
        if data is None:
            log(uuid, "Message queue already deleted")
            return
        if direction is None:
            # Return without replacing the channel in the stash
            log(uuid, "Force deleting message queue")
            return
        queue, has_reader, writer_count = data
        if direction == "read":
            has_reader = False
        else:
            writer_count -= 1

        if has_reader or writer_count > 0 or not queue.empty():
            log(uuid, f"Updating refcount {has_reader}, {writer_count}")
            stash.put(uuid, (queue, has_reader, writer_count))
        else:
            log(uuid, "Deleting message queue")