summaryrefslogtreecommitdiffstats
path: root/src/etc/test-float-parse/runtests.py
blob: cc5e31a051fc25e533178159fd614311259129fd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
#!/usr/bin/env python3

"""
Testing dec2flt
===============
These are *really* extensive tests. Expect them to run for hours. Due to the
nature of the problem (the input is a string of arbitrary length), exhaustive
testing is not really possible. Instead, there are exhaustive tests for some
classes of inputs for which that is feasible and a bunch of deterministic and
random non-exhaustive tests for covering everything else.

The actual tests (generating decimal strings and feeding them to dec2flt) is
performed by a set of stand-along rust programs. This script compiles, runs,
and supervises them. The programs report the strings they generate and the
floating point numbers they converted those strings to, and this script
checks that the results are correct.

You can run specific tests rather than all of them by giving their names
(without .rs extension) as command line parameters.

Verification
------------
The tricky part is not generating those inputs but verifying the outputs.
Comparing with the result of Python's float() does not cut it because
(and this is apparently undocumented) although Python includes a version of
Martin Gay's code including the decimal-to-float part, it doesn't actually use
it for float() (only for round()) instead relying on the system scanf() which
is not necessarily completely accurate.

Instead, we take the input and compute the true value with bignum arithmetic
(as a fraction, using the ``fractions`` module).

Given an input string and the corresponding float computed via Rust, simply
decode the float into f * 2^k (for integers f, k) and the ULP.
We can now easily compute the error and check if it is within 0.5 ULP as it
should be. Zero and infinites are handled similarly:

- If the approximation is 0.0, the exact value should be *less or equal*
  half the smallest denormal float: the smallest denormal floating point
  number has an odd mantissa (00...001) and thus half of that is rounded
  to 00...00, i.e., zero.
- If the approximation is Inf, the exact value should be *greater or equal*
  to the largest finite float + 0.5 ULP: the largest finite float has an odd
  mantissa (11...11), so that plus half an ULP is rounded up to the nearest
  even number, which overflows.

Implementation details
----------------------
This directory contains a set of single-file Rust programs that perform
tests with a particular class of inputs. Each is compiled and run without
parameters, outputs (f64, f32, decimal) pairs to verify externally, and
in any case either exits gracefully or with a panic.

If a test binary writes *anything at all* to stderr or exits with an
exit code that's not 0, the test fails.
The output on stdout is treated as (f64, f32, decimal) record, encoded thusly:

- First, the bits of the f64 encoded as an ASCII hex string.
- Second, the bits of the f32 encoded as an ASCII hex string.
- Then the corresponding string input, in ASCII
- The record is terminated with a newline.

Incomplete records are an error. Not-a-Number bit patterns are invalid too.

The tests run serially but the validation for a single test is parallelized
with ``multiprocessing``. Each test is launched as a subprocess.
One thread supervises it: Accepts and enqueues records to validate, observe
stderr, and waits for the process to exit. A set of worker processes perform
the validation work for the outputs enqueued there. Another thread listens
for progress updates from the workers.

Known issues
------------
Some errors (e.g., NaN outputs) aren't handled very gracefully.
Also, if there is an exception or the process is interrupted (at least on
Windows) the worker processes are leaked and stick around forever.
They're only a few megabytes each, but still, this script should not be run
if you aren't prepared to manually kill a lot of orphaned processes.
"""
from __future__ import print_function
import sys
import os.path
import time
import struct
from fractions import Fraction
from collections import namedtuple
from subprocess import Popen, check_call, PIPE
from glob import glob
import multiprocessing
import threading
import ctypes
import binascii

try:  # Python 3
    import queue as Queue
except ImportError:  # Python 2
    import Queue

NUM_WORKERS = 2
UPDATE_EVERY_N = 50000
INF = namedtuple('INF', '')()
NEG_INF = namedtuple('NEG_INF', '')()
ZERO = namedtuple('ZERO', '')()
MAILBOX = None  # The queue for reporting errors to the main process.
STDOUT_LOCK = threading.Lock()
test_name = None
child_processes = []
exit_status = 0

def msg(*args):
    with STDOUT_LOCK:
        print("[" + test_name + "]", *args)
        sys.stdout.flush()


def write_errors():
    global exit_status
    f = open("errors.txt", 'w')
    have_seen_error = False
    while True:
        args = MAILBOX.get()
        if args is None:
            f.close()
            break
        print(*args, file=f)
        f.flush()
        if not have_seen_error:
            have_seen_error = True
            msg("Something is broken:", *args)
            msg("Future errors will be logged to errors.txt")
            exit_status = 101


def cargo():
    print("compiling tests")
    sys.stdout.flush()
    check_call(['cargo', 'build', '--release'])


def run(test):
    global test_name
    test_name = test

    t0 = time.perf_counter()
    msg("setting up supervisor")
    command = ['cargo', 'run', '--bin', test, '--release']
    proc = Popen(command, bufsize=1<<20 , stdin=PIPE, stdout=PIPE, stderr=PIPE)
    done = multiprocessing.Value(ctypes.c_bool)
    queue = multiprocessing.Queue(maxsize=5)#(maxsize=1024)
    workers = []
    for n in range(NUM_WORKERS):
        worker = multiprocessing.Process(name='Worker-' + str(n + 1),
                                         target=init_worker,
                                         args=[test, MAILBOX, queue, done])
        workers.append(worker)
        child_processes.append(worker)
    for worker in workers:
        worker.start()
    msg("running test")
    interact(proc, queue)
    with done.get_lock():
        done.value = True
    for worker in workers:
        worker.join()
    msg("python is done")
    assert queue.empty(), "did not validate everything"
    dt = time.perf_counter() - t0
    msg("took", round(dt, 3), "seconds")


def interact(proc, queue):
    n = 0
    while proc.poll() is None:
        line = proc.stdout.readline()
        if not line:
            continue
        assert line.endswith(b'\n'), "incomplete line: " + repr(line)
        queue.put(line)
        n += 1
        if n % UPDATE_EVERY_N == 0:
            msg("got", str(n // 1000) + "k", "records")
    msg("rust is done. exit code:", proc.returncode)
    rest, stderr = proc.communicate()
    if stderr:
        msg("rust stderr output:", stderr)
    for line in rest.split(b'\n'):
        if not line:
            continue
        queue.put(line)


def main():
    global MAILBOX
    files = glob('src/bin/*.rs')
    basenames = [os.path.basename(i) for i in files]
    all_tests = [os.path.splitext(f)[0] for f in basenames if not f.startswith('_')]
    args = sys.argv[1:]
    if args:
        tests = [test for test in all_tests if test in args]
    else:
        tests = all_tests
    if not tests:
        print("Error: No tests to run")
        sys.exit(1)
    # Compile first for quicker feedback
    cargo()
    # Set up mailbox once for all tests
    MAILBOX = multiprocessing.Queue()
    mailman = threading.Thread(target=write_errors)
    mailman.daemon = True
    mailman.start()
    for test in tests:
        run(test)
    MAILBOX.put(None)
    mailman.join()


# ---- Worker thread code ----


POW2 = { e: Fraction(2) ** e for e in range(-1100, 1100) }
HALF_ULP = { e: (Fraction(2) ** e)/2 for e in range(-1100, 1100) }
DONE_FLAG = None


def send_error_to_supervisor(*args):
    MAILBOX.put(args)


def init_worker(test, mailbox, queue, done):
    global test_name, MAILBOX, DONE_FLAG
    test_name = test
    MAILBOX = mailbox
    DONE_FLAG = done
    do_work(queue)


def is_done():
    with DONE_FLAG.get_lock():
        return DONE_FLAG.value


def do_work(queue):
    while True:
        try:
            line = queue.get(timeout=0.01)
        except Queue.Empty:
            if queue.empty() and is_done():
                return
            else:
                continue
        bin64, bin32, text = line.rstrip().split()
        validate(bin64, bin32, text.decode('utf-8'))


def decode_binary64(x):
    """
    Turn a IEEE 754 binary64 into (mantissa, exponent), except 0.0 and
    infinity (positive and negative), which return ZERO, INF, and NEG_INF
    respectively.
    """
    x = binascii.unhexlify(x)
    assert len(x) == 8, repr(x)
    [bits] = struct.unpack(b'>Q', x)
    if bits == 0:
        return ZERO
    exponent = (bits >> 52) & 0x7FF
    negative = bits >> 63
    low_bits = bits & 0xFFFFFFFFFFFFF
    if exponent == 0:
        mantissa = low_bits
        exponent += 1
        if mantissa == 0:
            return ZERO
    elif exponent == 0x7FF:
        assert low_bits == 0, "NaN"
        if negative:
            return NEG_INF
        else:
            return INF
    else:
        mantissa = low_bits | (1 << 52)
    exponent -= 1023 + 52
    if negative:
        mantissa = -mantissa
    return (mantissa, exponent)


def decode_binary32(x):
    """
    Turn a IEEE 754 binary32 into (mantissa, exponent), except 0.0 and
    infinity (positive and negative), which return ZERO, INF, and NEG_INF
    respectively.
    """
    x = binascii.unhexlify(x)
    assert len(x) == 4, repr(x)
    [bits] = struct.unpack(b'>I', x)
    if bits == 0:
        return ZERO
    exponent = (bits >> 23) & 0xFF
    negative = bits >> 31
    low_bits = bits & 0x7FFFFF
    if exponent == 0:
        mantissa = low_bits
        exponent += 1
        if mantissa == 0:
            return ZERO
    elif exponent == 0xFF:
        if negative:
            return NEG_INF
        else:
            return INF
    else:
        mantissa = low_bits | (1 << 23)
    exponent -= 127 + 23
    if negative:
        mantissa = -mantissa
    return (mantissa, exponent)


MIN_SUBNORMAL_DOUBLE = Fraction(2) ** -1074
MIN_SUBNORMAL_SINGLE = Fraction(2) ** -149  # XXX unsure
MAX_DOUBLE = (2 - Fraction(2) ** -52) * (2 ** 1023)
MAX_SINGLE = (2 - Fraction(2) ** -23) * (2 ** 127)
MAX_ULP_DOUBLE = 1023 - 52
MAX_ULP_SINGLE = 127 - 23
DOUBLE_ZERO_CUTOFF = MIN_SUBNORMAL_DOUBLE / 2
DOUBLE_INF_CUTOFF = MAX_DOUBLE + 2 ** (MAX_ULP_DOUBLE - 1)
SINGLE_ZERO_CUTOFF = MIN_SUBNORMAL_SINGLE / 2
SINGLE_INF_CUTOFF = MAX_SINGLE + 2 ** (MAX_ULP_SINGLE - 1)

def validate(bin64, bin32, text):
    try:
        double = decode_binary64(bin64)
    except AssertionError:
        print(bin64, bin32, text)
        raise
    single = decode_binary32(bin32)
    real = Fraction(text)

    if double is ZERO:
        if real > DOUBLE_ZERO_CUTOFF:
            record_special_error(text, "f64 zero")
    elif double is INF:
        if real < DOUBLE_INF_CUTOFF:
            record_special_error(text, "f64 inf")
    elif double is NEG_INF:
        if -real < DOUBLE_INF_CUTOFF:
            record_special_error(text, "f64 -inf")
    elif len(double) == 2:
        sig, k = double
        validate_normal(text, real, sig, k, "f64")
    else:
        assert 0, "didn't handle binary64"
    if single is ZERO:
        if real > SINGLE_ZERO_CUTOFF:
            record_special_error(text, "f32 zero")
    elif single is INF:
        if real < SINGLE_INF_CUTOFF:
            record_special_error(text, "f32 inf")
    elif single is NEG_INF:
        if -real < SINGLE_INF_CUTOFF:
            record_special_error(text, "f32 -inf")
    elif len(single) == 2:
        sig, k = single
        validate_normal(text, real, sig, k, "f32")
    else:
        assert 0, "didn't handle binary32"

def record_special_error(text, descr):
    send_error_to_supervisor(text.strip(), "wrongly rounded to", descr)


def validate_normal(text, real, sig, k, kind):
    approx = sig * POW2[k]
    error = abs(approx - real)
    if error > HALF_ULP[k]:
        record_normal_error(text, error, k, kind)


def record_normal_error(text, error, k, kind):
    one_ulp = HALF_ULP[k + 1]
    assert one_ulp == 2 * HALF_ULP[k]
    relative_error = error / one_ulp
    text = text.strip()
    try:
        err_repr = float(relative_error)
    except ValueError:
        err_repr = str(err_repr).replace('/', ' / ')
    send_error_to_supervisor(err_repr, "ULP error on", text, "(" + kind + ")")


if __name__ == '__main__':
    main()