summaryrefslogtreecommitdiffstats
path: root/testing/web-platform/tests/tools/third_party/websockets/src/websockets/sync/messages.py
blob: 67a22313ca163da162cedabc29bc082a396727f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
from __future__ import annotations

import codecs
import queue
import threading
from typing import Iterator, List, Optional, cast

from ..frames import Frame, Opcode
from ..typing import Data


__all__ = ["Assembler"]

UTF8Decoder = codecs.getincrementaldecoder("utf-8")


class Assembler:
    """
    Assemble messages from frames.

    """

    def __init__(self) -> None:
        # Serialize reads and writes -- except for reads via synchronization
        # primitives provided by the threading and queue modules.
        self.mutex = threading.Lock()

        # We create a latch with two events to ensure proper interleaving of
        # writing and reading messages.
        # put() sets this event to tell get() that a message can be fetched.
        self.message_complete = threading.Event()
        # get() sets this event to let put() that the message was fetched.
        self.message_fetched = threading.Event()

        # This flag prevents concurrent calls to get() by user code.
        self.get_in_progress = False
        # This flag prevents concurrent calls to put() by library code.
        self.put_in_progress = False

        # Decoder for text frames, None for binary frames.
        self.decoder: Optional[codecs.IncrementalDecoder] = None

        # Buffer of frames belonging to the same message.
        self.chunks: List[Data] = []

        # When switching from "buffering" to "streaming", we use a thread-safe
        # queue for transferring frames from the writing thread (library code)
        # to the reading thread (user code). We're buffering when chunks_queue
        # is None and streaming when it's a SimpleQueue. None is a sentinel
        # value marking the end of the stream, superseding message_complete.

        # Stream data from frames belonging to the same message.
        # Remove quotes around type when dropping Python < 3.9.
        self.chunks_queue: Optional["queue.SimpleQueue[Optional[Data]]"] = None

        # This flag marks the end of the stream.
        self.closed = False

    def get(self, timeout: Optional[float] = None) -> Data:
        """
        Read the next message.

        :meth:`get` returns a single :class:`str` or :class:`bytes`.

        If the message is fragmented, :meth:`get` waits until the last frame is
        received, then it reassembles the message and returns it. To receive
        messages frame by frame, use :meth:`get_iter` instead.

        Args:
            timeout: If a timeout is provided and elapses before a complete
                message is received, :meth:`get` raises :exc:`TimeoutError`.

        Raises:
            EOFError: If the stream of frames has ended.
            RuntimeError: If two threads run :meth:`get` or :meth:``get_iter`
                concurrently.

        """
        with self.mutex:
            if self.closed:
                raise EOFError("stream of frames ended")

            if self.get_in_progress:
                raise RuntimeError("get or get_iter is already running")

            self.get_in_progress = True

        # If the message_complete event isn't set yet, release the lock to
        # allow put() to run and eventually set it.
        # Locking with get_in_progress ensures only one thread can get here.
        completed = self.message_complete.wait(timeout)

        with self.mutex:
            self.get_in_progress = False

            # Waiting for a complete message timed out.
            if not completed:
                raise TimeoutError(f"timed out in {timeout:.1f}s")

            # get() was unblocked by close() rather than put().
            if self.closed:
                raise EOFError("stream of frames ended")

            assert self.message_complete.is_set()
            self.message_complete.clear()

            joiner: Data = b"" if self.decoder is None else ""
            # mypy cannot figure out that chunks have the proper type.
            message: Data = joiner.join(self.chunks)  # type: ignore

            assert not self.message_fetched.is_set()
            self.message_fetched.set()

            self.chunks = []
            assert self.chunks_queue is None

            return message

    def get_iter(self) -> Iterator[Data]:
        """
        Stream the next message.

        Iterating the return value of :meth:`get_iter` yields a :class:`str` or
        :class:`bytes` for each frame in the message.

        The iterator must be fully consumed before calling :meth:`get_iter` or
        :meth:`get` again. Else, :exc:`RuntimeError` is raised.

        This method only makes sense for fragmented messages. If messages aren't
        fragmented, use :meth:`get` instead.

        Raises:
            EOFError: If the stream of frames has ended.
            RuntimeError: If two threads run :meth:`get` or :meth:``get_iter`
                concurrently.

        """
        with self.mutex:
            if self.closed:
                raise EOFError("stream of frames ended")

            if self.get_in_progress:
                raise RuntimeError("get or get_iter is already running")

            chunks = self.chunks
            self.chunks = []
            self.chunks_queue = cast(
                # Remove quotes around type when dropping Python < 3.9.
                "queue.SimpleQueue[Optional[Data]]",
                queue.SimpleQueue(),
            )

            # Sending None in chunk_queue supersedes setting message_complete
            # when switching to "streaming". If message is already complete
            # when the switch happens, put() didn't send None, so we have to.
            if self.message_complete.is_set():
                self.chunks_queue.put(None)

            self.get_in_progress = True

        # Locking with get_in_progress ensures only one thread can get here.
        yield from chunks
        while True:
            chunk = self.chunks_queue.get()
            if chunk is None:
                break
            yield chunk

        with self.mutex:
            self.get_in_progress = False

            assert self.message_complete.is_set()
            self.message_complete.clear()

            # get_iter() was unblocked by close() rather than put().
            if self.closed:
                raise EOFError("stream of frames ended")

            assert not self.message_fetched.is_set()
            self.message_fetched.set()

            assert self.chunks == []
            self.chunks_queue = None

    def put(self, frame: Frame) -> None:
        """
        Add ``frame`` to the next message.

        When ``frame`` is the final frame in a message, :meth:`put` waits until
        the message is fetched, either by calling :meth:`get` or by fully
        consuming the return value of :meth:`get_iter`.

        :meth:`put` assumes that the stream of frames respects the protocol. If
        it doesn't, the behavior is undefined.

        Raises:
            EOFError: If the stream of frames has ended.
            RuntimeError: If two threads run :meth:`put` concurrently.

        """
        with self.mutex:
            if self.closed:
                raise EOFError("stream of frames ended")

            if self.put_in_progress:
                raise RuntimeError("put is already running")

            if frame.opcode is Opcode.TEXT:
                self.decoder = UTF8Decoder(errors="strict")
            elif frame.opcode is Opcode.BINARY:
                self.decoder = None
            elif frame.opcode is Opcode.CONT:
                pass
            else:
                # Ignore control frames.
                return

            data: Data
            if self.decoder is not None:
                data = self.decoder.decode(frame.data, frame.fin)
            else:
                data = frame.data

            if self.chunks_queue is None:
                self.chunks.append(data)
            else:
                self.chunks_queue.put(data)

            if not frame.fin:
                return

            # Message is complete. Wait until it's fetched to return.

            assert not self.message_complete.is_set()
            self.message_complete.set()

            if self.chunks_queue is not None:
                self.chunks_queue.put(None)

            assert not self.message_fetched.is_set()

            self.put_in_progress = True

        # Release the lock to allow get() to run and eventually set the event.
        self.message_fetched.wait()

        with self.mutex:
            self.put_in_progress = False

            assert self.message_fetched.is_set()
            self.message_fetched.clear()

            # put() was unblocked by close() rather than get() or get_iter().
            if self.closed:
                raise EOFError("stream of frames ended")

            self.decoder = None

    def close(self) -> None:
        """
        End the stream of frames.

        Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`,
        or :meth:`put` is safe. They will raise :exc:`EOFError`.

        """
        with self.mutex:
            if self.closed:
                return

            self.closed = True

            # Unblock get or get_iter.
            if self.get_in_progress:
                self.message_complete.set()
                if self.chunks_queue is not None:
                    self.chunks_queue.put(None)

            # Unblock put().
            if self.put_in_progress:
                self.message_fetched.set()