diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-07 09:22:09 +0000 |
commit | 43a97878ce14b72f0981164f87f2e35e14151312 (patch) | |
tree | 620249daf56c0258faa40cbdcf9cfba06de2a846 /third_party/prio | |
parent | Initial commit. (diff) | |
download | firefox-upstream.tar.xz firefox-upstream.zip |
Adding upstream version 110.0.1.upstream/110.0.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'third_party/prio')
28 files changed, 3989 insertions, 0 deletions
diff --git a/third_party/prio/README-mozilla b/third_party/prio/README-mozilla new file mode 100644 index 0000000000..f0090e3319 --- /dev/null +++ b/third_party/prio/README-mozilla @@ -0,0 +1,17 @@ +This directory contains the Prio source from the upstream repo: +https://github.com/mozilla/libprio + +Current version: 1.6 [commit 52643cefe6662b4099e16a40a057cb60651ab001] + +UPDATING: + +Our in-tree copy of Prio does not depend on any generated files from the +upstream build system. Therefore, it should be sufficient to simply overwrite +the in-tree files one the updated ones from upstream to perform updates. + +To simplify this, the in-tree copy can be updated by running + sh update.sh +from within the third_party/libprio directory. + +If the collection of source files changes, manual updates to moz.build may be +needed as we don't use the upstream makefiles. diff --git a/third_party/prio/include/mprio.h b/third_party/prio/include/mprio.h new file mode 100644 index 0000000000..7a53703e1d --- /dev/null +++ b/third_party/prio/include/mprio.h @@ -0,0 +1,304 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#ifndef __PRIO_H__ +#define __PRIO_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +#include <blapit.h> +#include <msgpack.h> +#include <pk11pub.h> +#include <seccomon.h> +#include <stdbool.h> +#include <stddef.h> + +/* Seed for a pseudo-random generator (PRG). */ +#define PRG_SEED_LENGTH AES_128_KEY_LENGTH +typedef unsigned char PrioPRGSeed[PRG_SEED_LENGTH]; + +/* Length of a raw curve25519 public key, in bytes. */ +#define CURVE25519_KEY_LEN 32 + +/* Length of a hex-encoded curve25519 public key, in bytes. */ +#define CURVE25519_KEY_LEN_HEX 64 + +/* + * Type for each of the two servers. + */ +typedef enum { PRIO_SERVER_A, PRIO_SERVER_B } PrioServerId; + +/* + * Opaque types + */ +typedef struct prio_config* PrioConfig; +typedef const struct prio_config* const_PrioConfig; + +typedef struct prio_server* PrioServer; +typedef const struct prio_server* const_PrioServer; + +typedef struct prio_verifier* PrioVerifier; +typedef const struct prio_verifier* const_PrioVerifier; + +typedef struct prio_packet_verify1* PrioPacketVerify1; +typedef const struct prio_packet_verify1* const_PrioPacketVerify1; + +typedef struct prio_packet_verify2* PrioPacketVerify2; +typedef const struct prio_packet_verify2* const_PrioPacketVerify2; + +typedef struct prio_total_share* PrioTotalShare; +typedef const struct prio_total_share* const_PrioTotalShare; + +typedef SECKEYPublicKey* PublicKey; +typedef const SECKEYPublicKey* const_PublicKey; + +typedef SECKEYPrivateKey* PrivateKey; +typedef const SECKEYPrivateKey* const_PrivateKey; + +/* + * Initialize and clear random number generator state. + * You must call Prio_init() before using the library. + * To avoid memory leaks, call Prio_clear() afterwards. + */ +SECStatus Prio_init(); +void Prio_clear(); + +/* + * PrioConfig holds the system parameters. The two relevant things determined + * by the config object are: + * (1) the number of data fields we are collecting, and + * (2) the modulus we use for modular arithmetic. + * The default configuration uses an 87-bit modulus. + * + * The value `nFields` must be in the range `0 < nFields <= max`, where `max` + * is the value returned by the function `PrioConfig_maxDataFields()` below. + * + * The `batch_id` field specifies which "batch" of aggregate statistics we are + * computing. For example, if the aggregate statistics are computed every 24 + * hours, the `batch_id` might be set to an encoding of the date. The clients + * and servers must all use the same `batch_id` for each run of the protocol. + * Each set of aggregate statistics should use a different `batch_id`. + * + * `PrioConfig_new` does not keep a pointer to the `batch_id` string that the + * caller passes in, so you may free the `batch_id` string as soon as + * `PrioConfig_new` returns. + */ +PrioConfig PrioConfig_new(int nFields, PublicKey serverA, PublicKey serverB, + const unsigned char* batchId, + unsigned int batchIdLen); +void PrioConfig_clear(PrioConfig cfg); +int PrioConfig_numDataFields(const_PrioConfig cfg); + +/* + * Return the maximum number of data fields that the implementation supports. + */ +int PrioConfig_maxDataFields(void); + +/* + * Create a PrioConfig object with no encryption keys. This routine is + * useful for testing, but PrioClient_encode() will always fail when used with + * this config. + */ +PrioConfig PrioConfig_newTest(int nFields); + +/* + * We use the PublicKey and PrivateKey objects for public-key encryption. Each + * Prio server has an associated public key, and the clients use these keys to + * encrypt messages to the servers. + */ +SECStatus Keypair_new(PrivateKey* pvtkey, PublicKey* pubkey); + +/* + * Import a new curve25519 public/private key from the raw bytes given. When + * importing a private key, you must pass in the corresponding public key as + * well. The byte arrays given as input should be of length + * `CURVE25519_KEY_LEN`. + * + * These functions will allocate a new `PublicKey`/`PrivateKey` object, which + * the caller must free using `PublicKey_clear`/`PrivateKey_clear`. + */ +SECStatus PublicKey_import(PublicKey* pk, const unsigned char* data, + unsigned int dataLen); +SECStatus PrivateKey_import(PrivateKey* sk, const unsigned char* privData, + unsigned int privDataLen, + const unsigned char* pubData, + unsigned int pubDataLen); + +/* + * Import a new curve25519 public/private key from a hex string that contains + * only the characters 0-9a-fA-F. + * + * The hex strings passed in must each be of length `CURVE25519_KEY_LEN_HEX`. + * These functions will allocate a new `PublicKey`/`PrivateKey` object, which + * the caller must free using `PublicKey_clear`/`PrivateKey_clear`. + */ +SECStatus PublicKey_import_hex(PublicKey* pk, const unsigned char* hexData, + unsigned int dataLen); +SECStatus PrivateKey_import_hex(PrivateKey* sk, + const unsigned char* privHexData, + unsigned int privDataLen, + const unsigned char* pubHexData, + unsigned int pubDataLen); + +/* + * Export a curve25519 key as a raw byte-array. + * + * The output buffer `data` must have length exactly `CURVE25519_KEY_LEN`. + */ +SECStatus PublicKey_export(const_PublicKey pk, unsigned char* data, + unsigned int dataLen); +SECStatus PrivateKey_export(PrivateKey sk, unsigned char* data, + unsigned int dataLen); + +/* + * Export a curve25519 key as a NULL-terminated hex string. + * + * The output buffer `data` must have length exactly `CURVE25519_KEY_LEN_HEX + + * 1`. + */ +SECStatus PublicKey_export_hex(const_PublicKey pk, unsigned char* data, + unsigned int dataLen); +SECStatus PrivateKey_export_hex(PrivateKey sk, unsigned char* data, + unsigned int dataLen); + +void PublicKey_clear(PublicKey pubkey); +void PrivateKey_clear(PrivateKey pvtkey); + +/* + * PrioPacketClient_encode + * + * Takes as input a pointer to an array (`data_in`) of boolean values + * whose length is equal to the number of data fields specified in + * the config. It then encodes the data for servers A and B into a + * string. + * + * NOTE: The caller must free() the strings `for_server_a` and + * `for_server_b` to avoid memory leaks. + */ +SECStatus PrioClient_encode(const_PrioConfig cfg, const bool* data_in, + unsigned char** forServerA, unsigned int* aLen, + unsigned char** forServerB, unsigned int* bLen); + +/* + * Generate a new PRG seed using the NSS global randomness source. + * Use this routine to initialize the secret that the two Prio servers + * share. + */ +SECStatus PrioPRGSeed_randomize(PrioPRGSeed* seed); + +/* + * The PrioServer object holds the state of the Prio servers. + * Pass in the _same_ secret PRGSeed when initializing the two servers. + * The PRGSeed must remain secret to the two servers. + */ +PrioServer PrioServer_new(const_PrioConfig cfg, PrioServerId serverIdx, + PrivateKey serverPriv, + const PrioPRGSeed serverSharedSecret); +void PrioServer_clear(PrioServer s); + +/* + * After receiving a client packet, each of the servers generate + * a PrioVerifier object that they use to check whether the client's + * encoded packet is well formed. + */ +PrioVerifier PrioVerifier_new(PrioServer s); +void PrioVerifier_clear(PrioVerifier v); + +/* + * Read in encrypted data from the client, decrypt it, and prepare to check the + * request for validity. + */ +SECStatus PrioVerifier_set_data(PrioVerifier v, unsigned char* data, + unsigned int dataLen); + +/* + * Generate the first packet that servers need to exchange to verify the + * client's submission. This should be sent over a TLS connection between the + * servers. + */ +PrioPacketVerify1 PrioPacketVerify1_new(void); +void PrioPacketVerify1_clear(PrioPacketVerify1 p1); + +SECStatus PrioPacketVerify1_set_data(PrioPacketVerify1 p1, + const_PrioVerifier v); + +SECStatus PrioPacketVerify1_write(const_PrioPacketVerify1 p, + msgpack_packer* pk); +SECStatus PrioPacketVerify1_read(PrioPacketVerify1 p, msgpack_unpacker* upk, + const_PrioConfig cfg); + +/* + * Generate the second packet that the servers need to exchange to verify the + * client's submission. The routine takes as input the PrioPacketVerify1 + * packets from both server A and server B. + * + * This should be sent over a TLS connection between the servers. + */ +PrioPacketVerify2 PrioPacketVerify2_new(void); +void PrioPacketVerify2_clear(PrioPacketVerify2 p); + +SECStatus PrioPacketVerify2_set_data(PrioPacketVerify2 p2, const_PrioVerifier v, + const_PrioPacketVerify1 p1A, + const_PrioPacketVerify1 p1B); + +SECStatus PrioPacketVerify2_write(const_PrioPacketVerify2 p, + msgpack_packer* pk); +SECStatus PrioPacketVerify2_read(PrioPacketVerify2 p, msgpack_unpacker* upk, + const_PrioConfig cfg); + +/* + * Use the PrioPacketVerify2s from both servers to check whether + * the client's submission is well formed. + */ +SECStatus PrioVerifier_isValid(const_PrioVerifier v, const_PrioPacketVerify2 pA, + const_PrioPacketVerify2 pB); + +/* + * Each of the two servers calls this routine to aggregate the data + * submission from a client that is included in the PrioVerifier object. + * + * IMPORTANT: This routine does *not* check the validity of the client's + * data packet. The servers must execute the verification checks + * above before aggregating any client data. + */ +SECStatus PrioServer_aggregate(PrioServer s, PrioVerifier v); + +/* + * After the servers have aggregated data packets from "enough" clients + * (this determines the anonymity set size), each server runs this routine + * to get a share of the aggregate statistics. + */ +PrioTotalShare PrioTotalShare_new(void); +void PrioTotalShare_clear(PrioTotalShare t); + +SECStatus PrioTotalShare_set_data(PrioTotalShare t, const_PrioServer s); + +SECStatus PrioTotalShare_write(const_PrioTotalShare t, msgpack_packer* pk); +SECStatus PrioTotalShare_read(PrioTotalShare t, msgpack_unpacker* upk, + const_PrioConfig cfg); + +/* + * Read the output data into an array of unsigned longs. You should + * be sure that each data value can fit into a single `unsigned long` + * and that the pointer `output` points to a buffer large enough to + * store one long per data field. + * + * This function returns failure if some final data value is too + * long to fit in an `unsigned long`. + */ +SECStatus PrioTotalShare_final(const_PrioConfig cfg, unsigned long long* output, + const_PrioTotalShare tA, + const_PrioTotalShare tB); + +#endif /* __PRIO_H__ */ + +#ifdef __cplusplus +} +#endif diff --git a/third_party/prio/moz.build b/third_party/prio/moz.build new file mode 100644 index 0000000000..0a6e3c74a2 --- /dev/null +++ b/third_party/prio/moz.build @@ -0,0 +1,49 @@ +# vim: set filetype=python: +# This Source Code Form is subject to the terms of the Mozilla Public +# License, v. 2.0. If a copy of the MPL was not distributed with this +# file, You can obtain one at http://mozilla.org/MPL/2.0/. + +LOCAL_INCLUDES += [ + '/security/nss/lib/freebl/mpi', + '/third_party/msgpack/include', + 'include', +] + +EXPORTS += [ + 'include/mprio.h', +] + +# We allow warnings for third-party code that can be updated from upstream. +AllowCompilerWarnings() + +NoVisibilityFlags() + +SOURCES += [ + '/security/nss/lib/freebl/mpi/montmulf.c', + '/security/nss/lib/freebl/mpi/mp_gf2m.c', + '/security/nss/lib/freebl/mpi/mpcpucache.c', + '/security/nss/lib/freebl/mpi/mpi.c', + '/security/nss/lib/freebl/mpi/mplogic.c', + '/security/nss/lib/freebl/mpi/mpmontg.c', + '/security/nss/lib/freebl/mpi/mpprime.c', +] + +SOURCES += [ + 'prio/client.c', + 'prio/config.c', + 'prio/encrypt.c', + 'prio/mparray.c', + 'prio/poly.c', + 'prio/prg.c', + 'prio/rand.c', + 'prio/serial.c', + 'prio/server.c', + 'prio/share.c', +] + +FINAL_LIBRARY = 'xul' + +# Use PKCS11 v2 struct definitions for now, otherwise NSS requires +# CK_GCM_PARAMS.ulIvBits to be set. This workaround is only required +# until NSS 3.52 RTM and upstream correctly initializes the field. +DEFINES['NSS_PKCS11_2_0_COMPAT'] = True diff --git a/third_party/prio/prio/SConscript b/third_party/prio/prio/SConscript new file mode 100644 index 0000000000..36e32e56e2 --- /dev/null +++ b/third_party/prio/prio/SConscript @@ -0,0 +1,26 @@ +import os + +Import('env') + +penv = env.Clone() + +src = [ + "client.c", + "config.c", + "encrypt.c", + "mparray.c", + "poly.c", + "rand.c", + "prg.c", + "server.c", + "serial.c", + "share.c", +] + +libs = [ + "msgpackc" +] + +penv.Append(LIBS = libs) +penv.StaticLibrary("mprio", src) + diff --git a/third_party/prio/prio/client.c b/third_party/prio/prio/client.c new file mode 100644 index 0000000000..2e83515a72 --- /dev/null +++ b/third_party/prio/prio/client.c @@ -0,0 +1,362 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. */ + +#include <mpi.h> +#include <mprio.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> + +#include "client.h" +#include "config.h" +#include "encrypt.h" +#include "poly.h" +#include "rand.h" +#include "serial.h" +#include "share.h" +#include "util.h" + +// Let the points of data_in be [x1, x2, x3, ... ]. +// We construct the polynomial f such that +// (a) f(0) = random, +// (b) f(i) = x_i for all i >= 1, +// (c) degree(f)+1 is a power of two. +// We then evaluate f at the 2N-th roots of unity +// and we return these evaluations as `evals_out` +// and we return f(0) as `const_term`. +static SECStatus +data_polynomial_evals(const_PrioConfig cfg, const_MPArray data_in, + MPArray evals_out, mp_int* const_term) +{ + SECStatus rv = SECSuccess; + const mp_int* mod = &cfg->modulus; + MPArray points_f = NULL; + MPArray poly_f = NULL; + + // Number of multiplication gates in the Valid() circuit. + const int mul_gates = cfg->num_data_fields; + + // Little n is the number of points on the polynomials. + // The constant term is randomized, so it's (mul_gates + 1). + const int n = mul_gates + 1; + + // Big N is n rounded up to a power of two. + const int N = next_power_of_two(n); + + P_CHECKA(points_f = MPArray_new(N)); + P_CHECKA(poly_f = MPArray_new(N)); + + // Set constant term f(0) to random + P_CHECKC(rand_int(&points_f->data[0], mod)); + MP_CHECKC(mp_copy(&points_f->data[0], const_term)); + + // Set other values of f(x) + for (int i = 1; i < n; i++) { + MP_CHECKC(mp_copy(&data_in->data[i - 1], &points_f->data[i])); + } + + // Interpolate through the Nth roots of unity + P_CHECKC(poly_fft(poly_f, points_f, cfg, true)); + + // Evaluate at all 2N-th roots of unity. + // To do so, first resize the eval arrays and fill upper + // values with zeros. + P_CHECKC(MPArray_resize(poly_f, 2 * N)); + P_CHECKC(MPArray_resize(evals_out, 2 * N)); + + // Evaluate at the 2N-th roots of unity + P_CHECKC(poly_fft(evals_out, poly_f, cfg, false)); + +cleanup: + MPArray_clear(points_f); + MPArray_clear(poly_f); + + return rv; +} + +static SECStatus +share_polynomials(const_PrioConfig cfg, const_MPArray data_in, + PrioPacketClient pA, PrioPacketClient pB, PRG prgB) +{ + SECStatus rv = SECSuccess; + const mp_int* mod = &cfg->modulus; + const_MPArray points_f = data_in; + + mp_int f0, g0; + MP_DIGITS(&f0) = NULL; + MP_DIGITS(&g0) = NULL; + + MPArray points_g = NULL; + MPArray evals_f_2N = NULL; + MPArray evals_g_2N = NULL; + + P_CHECKA(points_g = MPArray_dup(points_f)); + P_CHECKA(evals_f_2N = MPArray_new(0)); + P_CHECKA(evals_g_2N = MPArray_new(0)); + MP_CHECKC(mp_init(&f0)); + MP_CHECKC(mp_init(&g0)); + + for (int i = 0; i < points_f->len; i++) { + // For each input value x_i, we compute x_i * (x_i-1). + // f(i) = x_i + // g(i) = x_i - 1 + MP_CHECKC(mp_sub_d(&points_g->data[i], 1, &points_g->data[i])); + MP_CHECKC(mp_mod(&points_g->data[i], mod, &points_g->data[i])); + } + + P_CHECKC(data_polynomial_evals(cfg, points_f, evals_f_2N, &f0)); + P_CHECKC(data_polynomial_evals(cfg, points_g, evals_g_2N, &g0)); + + // The values f(0) and g(0) are set to random values. + // We must send to each server a share of the points + // f(0), g(0), and h(0) = f(0)*g(0) + P_CHECKC(share_int(cfg, &f0, &pA->f0_share, &pB->f0_share)); + P_CHECKC(share_int(cfg, &g0, &pA->g0_share, &pB->g0_share)); + + // Compute h(0) = f(0)*g(0). + MP_CHECKC(mp_mulmod(&f0, &g0, mod, &f0)); + // Give one share of h(0) to each server. + P_CHECKC(share_int(cfg, &f0, &pA->h0_share, &pB->h0_share)); + + // const int lenN = (evals_f_2N->len/2); + // P_CHECKC (MPArray_resize (pA->shares.A.h_points, lenN)); + + // We need to send to the servers the evaluations of + // f(r) * g(r) + // for all 2N-th roots of unity r that are not also + // N-th roots of unity. + // + // For each such root r, compute h(r) = f(r)*g(r) and + // send a share of this value to each server. + int j = 0; + for (int i = 1; i < evals_f_2N->len; i += 2) { + MP_CHECKC(mp_mulmod(&evals_f_2N->data[i], &evals_g_2N->data[i], mod, &f0)); + P_CHECKC(PRG_share_int(prgB, &pA->shares.A.h_points->data[j], &f0, cfg)); + j++; + } + +cleanup: + MPArray_clear(evals_f_2N); + MPArray_clear(evals_g_2N); + MPArray_clear(points_g); + mp_clear(&f0); + mp_clear(&g0); + return rv; +} + +PrioPacketClient +PrioPacketClient_new(const_PrioConfig cfg, PrioServerId for_server) +{ + SECStatus rv = SECSuccess; + const int data_len = cfg->num_data_fields; + PrioPacketClient p = NULL; + p = malloc(sizeof(*p)); + if (!p) + return NULL; + + p->for_server = for_server; + p->triple = NULL; + MP_DIGITS(&p->f0_share) = NULL; + MP_DIGITS(&p->g0_share) = NULL; + MP_DIGITS(&p->h0_share) = NULL; + + switch (p->for_server) { + case PRIO_SERVER_A: + p->shares.A.data_shares = NULL; + p->shares.A.h_points = NULL; + break; + case PRIO_SERVER_B: + memset(p->shares.B.seed, 0, PRG_SEED_LENGTH); + break; + default: + // Should never get here + rv = SECFailure; + goto cleanup; + } + + MP_CHECKC(mp_init(&p->f0_share)); + MP_CHECKC(mp_init(&p->g0_share)); + MP_CHECKC(mp_init(&p->h0_share)); + P_CHECKA(p->triple = BeaverTriple_new()); + + if (p->for_server == PRIO_SERVER_A) { + const int num_h_points = PrioConfig_hPoints(cfg); + P_CHECKA(p->shares.A.data_shares = MPArray_new(data_len)); + P_CHECKA(p->shares.A.h_points = MPArray_new(num_h_points)); + } + +cleanup: + if (rv != SECSuccess) { + PrioPacketClient_clear(p); + return NULL; + } + + return p; +} + +SECStatus +PrioPacketClient_set_data(const_PrioConfig cfg, const bool* data_in, + PrioPacketClient pA, PrioPacketClient pB) +{ + MPArray client_data = NULL; + PRG prgB = NULL; + SECStatus rv = SECSuccess; + const int data_len = cfg->num_data_fields; + + if (!data_in) + return SECFailure; + + P_CHECKC(PrioPRGSeed_randomize(&pB->shares.B.seed)); + P_CHECKA(prgB = PRG_new(pB->shares.B.seed)); + + P_CHECKC(BeaverTriple_set_rand(cfg, pA->triple, pB->triple)); + P_CHECKA(client_data = MPArray_new_bool(data_len, data_in)); + P_CHECKC(PRG_share_array(prgB, pA->shares.A.data_shares, client_data, cfg)); + P_CHECKC(share_polynomials(cfg, client_data, pA, pB, prgB)); + +cleanup: + MPArray_clear(client_data); + PRG_clear(prgB); + + return rv; +} + +void +PrioPacketClient_clear(PrioPacketClient p) +{ + if (p == NULL) + return; + + if (p->for_server == PRIO_SERVER_A) { + MPArray_clear(p->shares.A.h_points); + MPArray_clear(p->shares.A.data_shares); + } + + BeaverTriple_clear(p->triple); + mp_clear(&p->f0_share); + mp_clear(&p->g0_share); + mp_clear(&p->h0_share); + free(p); +} + +bool +PrioPacketClient_areEqual(const_PrioPacketClient p1, const_PrioPacketClient p2) +{ + if (!BeaverTriple_areEqual(p1->triple, p2->triple)) + return false; + if (mp_cmp(&p1->f0_share, &p2->f0_share)) + return false; + if (mp_cmp(&p1->g0_share, &p2->g0_share)) + return false; + if (mp_cmp(&p1->h0_share, &p2->h0_share)) + return false; + if (p1->for_server != p2->for_server) + return false; + + switch (p1->for_server) { + case PRIO_SERVER_A: + if (!MPArray_areEqual(p1->shares.A.data_shares, p2->shares.A.data_shares)) + return false; + if (!MPArray_areEqual(p1->shares.A.h_points, p2->shares.A.h_points)) + return false; + break; + case PRIO_SERVER_B: + if (memcmp(p1->shares.B.seed, p2->shares.B.seed, PRG_SEED_LENGTH)) + return false; + break; + default: + // Should never get here. + return false; + } + + return true; +} + +SECStatus +PrioClient_encode(const_PrioConfig cfg, const bool* data_in, + unsigned char** for_server_a, unsigned int* aLen, + unsigned char** for_server_b, unsigned int* bLen) +{ + SECStatus rv = SECSuccess; + PrioPacketClient pA = NULL; + PrioPacketClient pB = NULL; + *for_server_a = NULL; + *for_server_b = NULL; + + msgpack_sbuffer sbufA, sbufB; + msgpack_packer packerA, packerB; + + msgpack_sbuffer_init(&sbufA); + msgpack_sbuffer_init(&sbufB); + msgpack_packer_init(&packerA, &sbufA, msgpack_sbuffer_write); + msgpack_packer_init(&packerB, &sbufB, msgpack_sbuffer_write); + + P_CHECKA(pA = PrioPacketClient_new(cfg, PRIO_SERVER_A)); + P_CHECKA(pB = PrioPacketClient_new(cfg, PRIO_SERVER_B)); + + P_CHECKC(PrioPacketClient_set_data(cfg, data_in, pA, pB)); + P_CHECKC(serial_write_packet_client(&packerA, pA, cfg)); + P_CHECKC(serial_write_packet_client(&packerB, pB, cfg)); + + P_CHECKC(PublicKey_encryptSize(sbufA.size, aLen)); + P_CHECKC(PublicKey_encryptSize(sbufB.size, bLen)); + + P_CHECKA(*for_server_a = malloc(*aLen)); + P_CHECKA(*for_server_b = malloc(*bLen)); + + unsigned int writtenA; + unsigned int writtenB; + P_CHECKC(PublicKey_encrypt(cfg->server_a_pub, *for_server_a, &writtenA, *aLen, + (unsigned char*)sbufA.data, sbufA.size)); + P_CHECKC(PublicKey_encrypt(cfg->server_b_pub, *for_server_b, &writtenB, *bLen, + (unsigned char*)sbufB.data, sbufB.size)); + + P_CHECKCB(writtenA == *aLen); + P_CHECKCB(writtenB == *bLen); + +cleanup: + if (rv != SECSuccess) { + if (*for_server_a) + free(*for_server_a); + if (*for_server_b) + free(*for_server_b); + *for_server_a = NULL; + *for_server_b = NULL; + } + + PrioPacketClient_clear(pA); + PrioPacketClient_clear(pB); + msgpack_sbuffer_destroy(&sbufA); + msgpack_sbuffer_destroy(&sbufB); + + return rv; +} + +SECStatus +PrioPacketClient_decrypt(PrioPacketClient p, const_PrioConfig cfg, + PrivateKey server_priv, const unsigned char* data_in, + unsigned int data_len) +{ + SECStatus rv = SECSuccess; + msgpack_unpacker upk; + if (!msgpack_unpacker_init(&upk, data_len)) { + return SECFailure; + } + + // Decrypt the ciphertext into dec_buf + unsigned int bytes_decrypted; + P_CHECKC(PrivateKey_decrypt(server_priv, + (unsigned char*)msgpack_unpacker_buffer(&upk), + &bytes_decrypted, data_len, data_in, data_len)); + msgpack_unpacker_buffer_consumed(&upk, bytes_decrypted); + + P_CHECKC(serial_read_packet_client(&upk, p, cfg)); + +cleanup: + msgpack_unpacker_destroy(&upk); + return rv; +} diff --git a/third_party/prio/prio/client.h b/third_party/prio/prio/client.h new file mode 100644 index 0000000000..cedd1048d7 --- /dev/null +++ b/third_party/prio/prio/client.h @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __CLIENT_H__ +#define __CLIENT_H__ + +#include "mparray.h" +#include "prg.h" +#include "share.h" + +/* + * The PrioPacketClient object holds the encoded client data. + * The client sends one packet to server A and one packet to + * server B. The `for_server` parameter determines which server + * the packet is for. + */ +typedef struct prio_packet_client* PrioPacketClient; +typedef const struct prio_packet_client* const_PrioPacketClient; + +struct server_a_data +{ + // These values are only set for server A. + MPArray data_shares; + MPArray h_points; +}; + +struct server_b_data +{ + // This value is only used for server B. + // + // We use a pseudo-random generator to compress the secret-shared data + // values. See Appendix I of the Prio paper (the paragraph starting + // "Optimization: PRG secret sharing.") for details on this. + PrioPRGSeed seed; +}; + +/* + * The data that a Prio client sends to each server. + */ +struct prio_packet_client +{ + // TODO: Can also use a PRG to avoid need for sending Beaver triple shares. + // Since this optimization only saves ~30 bytes of communication, we haven't + // bothered implementing it yet. + BeaverTriple triple; + + mp_int f0_share, g0_share, h0_share; + PrioServerId for_server; + + union + { + struct server_a_data A; + struct server_b_data B; + } shares; +}; + +PrioPacketClient PrioPacketClient_new(const_PrioConfig cfg, + PrioServerId for_server); +void PrioPacketClient_clear(PrioPacketClient p); +SECStatus PrioPacketClient_set_data(const_PrioConfig cfg, const bool* data_in, + PrioPacketClient for_server_a, + PrioPacketClient for_server_b); + +SECStatus PrioPacketClient_decrypt(PrioPacketClient p, const_PrioConfig cfg, + PrivateKey server_priv, + const unsigned char* data_in, + unsigned int data_len); + +bool PrioPacketClient_areEqual(const_PrioPacketClient p1, + const_PrioPacketClient p2); + +#endif /* __CLIENT_H__ */ diff --git a/third_party/prio/prio/config.c b/third_party/prio/prio/config.c new file mode 100644 index 0000000000..46d67554d2 --- /dev/null +++ b/third_party/prio/prio/config.c @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <mprio.h> +#include <stdlib.h> + +#include "config.h" +#include "mparray.h" +#include "params.h" +#include "rand.h" +#include "util.h" + +int +PrioConfig_maxDataFields(void) +{ + const int n_roots = 1 << Generator2Order; + return (n_roots >> 1) - 1; +} + +PrioConfig +PrioConfig_new(int n_fields, PublicKey server_a, PublicKey server_b, + const unsigned char* batch_id, unsigned int batch_id_len) +{ + SECStatus rv = SECSuccess; + PrioConfig cfg = malloc(sizeof(*cfg)); + if (!cfg) + return NULL; + + cfg->batch_id = NULL; + cfg->batch_id_len = batch_id_len; + cfg->server_a_pub = server_a; + cfg->server_b_pub = server_b; + cfg->num_data_fields = n_fields; + cfg->n_roots = 1 << Generator2Order; + MP_DIGITS(&cfg->modulus) = NULL; + MP_DIGITS(&cfg->inv2) = NULL; + MP_DIGITS(&cfg->generator) = NULL; + + P_CHECKCB(cfg->n_roots > 1); + P_CHECKCB(cfg->num_data_fields <= PrioConfig_maxDataFields()); + + P_CHECKA(cfg->batch_id = malloc(batch_id_len)); + strncpy((char*)cfg->batch_id, (char*)batch_id, batch_id_len); + + MP_CHECKC(mp_init(&cfg->modulus)); + MP_CHECKC(mp_read_radix(&cfg->modulus, Modulus, 16)); + + MP_CHECKC(mp_init(&cfg->generator)); + MP_CHECKC(mp_read_radix(&cfg->generator, Generator, 16)); + + // Compute 2^{-1} modulo M + MP_CHECKC(mp_init(&cfg->inv2)); + mp_set(&cfg->inv2, 2); + MP_CHECKC(mp_invmod(&cfg->inv2, &cfg->modulus, &cfg->inv2)); + +cleanup: + if (rv != SECSuccess) { + PrioConfig_clear(cfg); + return NULL; + } + + return cfg; +} + +PrioConfig +PrioConfig_newTest(int nFields) +{ + return PrioConfig_new(nFields, NULL, NULL, (unsigned char*)"testBatch", 9); +} + +void +PrioConfig_clear(PrioConfig cfg) +{ + if (!cfg) + return; + if (cfg->batch_id) + free(cfg->batch_id); + mp_clear(&cfg->modulus); + mp_clear(&cfg->inv2); + mp_clear(&cfg->generator); + free(cfg); +} + +int +PrioConfig_numDataFields(const_PrioConfig cfg) +{ + return cfg->num_data_fields; +} + +SECStatus +Prio_init(void) +{ + return rand_init(); +} + +void +Prio_clear(void) +{ + rand_clear(); +} + +int +PrioConfig_hPoints(const_PrioConfig cfg) +{ + const int mul_gates = cfg->num_data_fields + 1; + const int N = next_power_of_two(mul_gates); + return N; +} diff --git a/third_party/prio/prio/config.h b/third_party/prio/prio/config.h new file mode 100644 index 0000000000..66ed95bddd --- /dev/null +++ b/third_party/prio/prio/config.h @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __CONFIG_H__ +#define __CONFIG_H__ + +#include <mpi.h> + +#include "mparray.h" + +struct prio_config +{ + int num_data_fields; + unsigned char* batch_id; + unsigned int batch_id_len; + + PublicKey server_a_pub; + PublicKey server_b_pub; + + mp_int modulus; + mp_int inv2; + + int n_roots; + mp_int generator; +}; + +int PrioConfig_hPoints(const_PrioConfig cfg); + +#endif /* __CONFIG_H__ */ diff --git a/third_party/prio/prio/debug.h b/third_party/prio/prio/debug.h new file mode 100644 index 0000000000..15f25c805c --- /dev/null +++ b/third_party/prio/prio/debug.h @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __DEBUG_H__ +#define __DEBUG_H__ + +#include <stdio.h> + +#ifdef DEBUG +#define PRIO_DEBUG(msg) \ + do { \ + fprintf(stderr, "Error: %s\n", msg); \ + } while (false); +#else +#define PRIO_DEBUG(msg) ; +#endif + +#endif /* __DEBUG_H__ */ diff --git a/third_party/prio/prio/encrypt.c b/third_party/prio/prio/encrypt.c new file mode 100644 index 0000000000..f4874c421b --- /dev/null +++ b/third_party/prio/prio/encrypt.c @@ -0,0 +1,566 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <keyhi.h> +#include <keythi.h> +#include <pk11pub.h> +#include <prerror.h> + +#include "encrypt.h" +#include "prio/rand.h" +#include "prio/util.h" + +// Use curve25519 +#define CURVE_OID_TAG SEC_OID_CURVE25519 + +// Use 96-bit IV +#define GCM_IV_LEN_BYTES 12 +// Use 128-bit auth tag +#define GCM_TAG_LEN_BYTES 16 + +#define PRIO_TAG "PrioPacket" +#define AAD_LEN (sizeof(PRIO_TAG) - 1 + CURVE25519_KEY_LEN + GCM_IV_LEN_BYTES) + +// For an example of NSS curve25519 import/export code, see: +// https://searchfox.org/nss/rev/cfd5fcba7efbfe116e2c08848075240ec3a92718/gtests/pk11_gtest/pk11_curve25519_unittest.cc#66 + +// The all-zeros curve25519 public key, as DER-encoded SPKI blob. +static const uint8_t curve25519_spki_zeros[] = { + 0x30, 0x39, 0x30, 0x14, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, + 0x01, 0x06, 0x09, 0x2b, 0x06, 0x01, 0x04, 0x01, 0xda, 0x47, 0x0f, 0x01, + 0x03, 0x21, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, +}; + +// The all-zeros curve25519 private key, as a PKCS#8 blob. +static const uint8_t curve25519_priv_zeros[] = { + 0x30, 0x67, 0x02, 0x01, 0x00, 0x30, 0x14, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, + 0x3d, 0x02, 0x01, 0x06, 0x09, 0x2b, 0x06, 0x01, 0x04, 0x01, 0xda, 0x47, 0x0f, + 0x01, 0x04, 0x4c, 0x30, 0x4a, 0x02, 0x01, 0x01, 0x04, 0x20, + + /* Byte index 36: 32 bytes of curve25519 private key. */ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + + /* misc type fields */ + 0xa1, 0x23, 0x03, 0x21, + + /* Byte index 73: 32 bytes of curve25519 public key. */ + 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 +}; + +// Index into `curve25519_priv_zeros` at which the private key begins. +static const size_t curve25519_priv_sk_offset = 36; +// Index into `curve25519_priv_zeros` at which the public key begins. +static const size_t curve25519_priv_pk_offset = 73; + +static SECStatus key_from_hex( + unsigned char key_out[CURVE25519_KEY_LEN], + const unsigned char hex_in[CURVE25519_KEY_LEN_HEX]); + +// Note that we do not use isxdigit because it is locale-dependent +// See: https://github.com/mozilla/libprio/issues/20 +static inline char +is_hex_digit(char c) +{ + return ('0' <= c && c <= '9') || ('a' <= c && c <= 'f') || + ('A' <= c && c <= 'F'); +} + +// Note that we do not use toupper because it is locale-dependent +// See: https://github.com/mozilla/libprio/issues/20 +static inline char +to_upper(char c) +{ + if (c >= 'a' && c <= 'z') { + return c - 0x20; + } else { + return c; + } +} + +static inline uint8_t +hex_to_int(char h) +{ + return (h > '9') ? to_upper(h) - 'A' + 10 : (h - '0'); +} + +static inline unsigned char +int_to_hex(uint8_t i) +{ + return (i > 0x09) ? ((i - 10) + 'A') : i + '0'; +} + +static SECStatus +derive_dh_secret(PK11SymKey** shared_secret, PrivateKey priv, PublicKey pub) +{ + if (priv == NULL) + return SECFailure; + if (pub == NULL) + return SECFailure; + if (shared_secret == NULL) + return SECFailure; + + SECStatus rv = SECSuccess; + *shared_secret = NULL; + + P_CHECKA(*shared_secret = PK11_PubDeriveWithKDF( + priv, pub, PR_FALSE, NULL, NULL, CKM_ECDH1_DERIVE, CKM_AES_GCM, + CKA_ENCRYPT | CKA_DECRYPT, 16, CKD_SHA256_KDF, NULL, NULL)); + +cleanup: + return rv; +} + +SECStatus +PublicKey_import(PublicKey* pk, const unsigned char* data, unsigned int dataLen) +{ + SECStatus rv = SECSuccess; + CERTSubjectPublicKeyInfo* pkinfo = NULL; + *pk = NULL; + unsigned char* key_bytes = NULL; + uint8_t* spki_data = NULL; + + if (dataLen != CURVE25519_KEY_LEN) + return SECFailure; + + P_CHECKA(key_bytes = calloc(dataLen, sizeof(unsigned char))); + memcpy(key_bytes, data, dataLen); + + const int spki_len = sizeof(curve25519_spki_zeros); + P_CHECKA(spki_data = calloc(spki_len, sizeof(uint8_t))); + + memcpy(spki_data, curve25519_spki_zeros, spki_len); + SECItem spki_item = { siBuffer, spki_data, spki_len }; + + // Import the all-zeros curve25519 public key. + P_CHECKA(pkinfo = SECKEY_DecodeDERSubjectPublicKeyInfo(&spki_item)); + P_CHECKA(*pk = SECKEY_ExtractPublicKey(pkinfo)); + + // Overwrite the all-zeros public key with the 32-byte curve25519 public key + // given as input. + memcpy((*pk)->u.ec.publicValue.data, data, CURVE25519_KEY_LEN); + +cleanup: + if (key_bytes) + free(key_bytes); + if (spki_data) + free(spki_data); + if (pkinfo) + SECKEY_DestroySubjectPublicKeyInfo(pkinfo); + + if (rv != SECSuccess) + PublicKey_clear(*pk); + return rv; +} + +SECStatus +PrivateKey_import(PrivateKey* sk, const unsigned char* sk_data, + unsigned int sk_data_len, const unsigned char* pk_data, + unsigned int pk_data_len) +{ + if (sk_data_len != CURVE25519_KEY_LEN || !sk_data) { + return SECFailure; + } + + if (pk_data_len != CURVE25519_KEY_LEN || !pk_data) { + return SECFailure; + } + + SECStatus rv = SECSuccess; + PK11SlotInfo* slot = NULL; + uint8_t* zero_priv_data = NULL; + *sk = NULL; + const int zero_priv_len = sizeof(curve25519_priv_zeros); + + P_CHECKA(slot = PK11_GetInternalSlot()); + + P_CHECKA(zero_priv_data = calloc(zero_priv_len, sizeof(uint8_t))); + SECItem zero_priv_item = { siBuffer, zero_priv_data, zero_priv_len }; + + // Copy the PKCS#8-encoded keypair into writable buffer. + memcpy(zero_priv_data, curve25519_priv_zeros, zero_priv_len); + // Copy private key into bytes beginning at index `curve25519_priv_sk_offset`. + memcpy(zero_priv_data + curve25519_priv_sk_offset, sk_data, sk_data_len); + // Copy private key into bytes beginning at index `curve25519_priv_pk_offset`. + memcpy(zero_priv_data + curve25519_priv_pk_offset, pk_data, pk_data_len); + + P_CHECKC(PK11_ImportDERPrivateKeyInfoAndReturnKey( + slot, &zero_priv_item, NULL, NULL, PR_FALSE, PR_FALSE, KU_ALL, sk, NULL)); + +cleanup: + if (slot) { + PK11_FreeSlot(slot); + } + if (zero_priv_data) { + free(zero_priv_data); + } + if (rv != SECSuccess) { + PrivateKey_clear(*sk); + } + return rv; +} + +SECStatus +PublicKey_import_hex(PublicKey* pk, const unsigned char* hexData, + unsigned int dataLen) +{ + unsigned char raw_bytes[CURVE25519_KEY_LEN]; + + if (dataLen != CURVE25519_KEY_LEN_HEX || !hexData) { + return SECFailure; + } + + if (key_from_hex(raw_bytes, hexData) != SECSuccess) { + return SECFailure; + } + + return PublicKey_import(pk, raw_bytes, CURVE25519_KEY_LEN); +} + +SECStatus +PrivateKey_import_hex(PrivateKey* sk, const unsigned char* privHexData, + unsigned int privDataLen, const unsigned char* pubHexData, + unsigned int pubDataLen) +{ + SECStatus rv = SECSuccess; + unsigned char raw_priv[CURVE25519_KEY_LEN]; + unsigned char raw_pub[CURVE25519_KEY_LEN]; + + if (privDataLen != CURVE25519_KEY_LEN_HEX || + pubDataLen != CURVE25519_KEY_LEN_HEX) { + return SECFailure; + } + + if (!privHexData || !pubHexData) { + return SECFailure; + } + + P_CHECK(key_from_hex(raw_priv, privHexData)); + P_CHECK(key_from_hex(raw_pub, pubHexData)); + + return PrivateKey_import(sk, raw_priv, CURVE25519_KEY_LEN, raw_pub, + CURVE25519_KEY_LEN); +} + +SECStatus +PublicKey_export(const_PublicKey pk, unsigned char* data, unsigned int dataLen) +{ + if (pk == NULL || dataLen != CURVE25519_KEY_LEN) { + return SECFailure; + } + + const SECItem* key = &pk->u.ec.publicValue; + if (key->len != CURVE25519_KEY_LEN) { + return SECFailure; + } + + memcpy(data, key->data, key->len); + return SECSuccess; +} + +SECStatus +PrivateKey_export(PrivateKey sk, unsigned char* data, unsigned int dataLen) +{ + if (sk == NULL || dataLen != CURVE25519_KEY_LEN) { + return SECFailure; + } + + SECStatus rv = SECSuccess; + SECItem item = { siBuffer, NULL, 0 }; + + P_CHECKC(PK11_ReadRawAttribute(PK11_TypePrivKey, sk, CKA_VALUE, &item)); + + // If the leading bytes of the key are '\0', then this string can be + // shorter than `CURVE25519_KEY_LEN` bytes. + memset(data, 0, CURVE25519_KEY_LEN); + P_CHECKCB(item.len <= CURVE25519_KEY_LEN); + + // Copy into the low-order bytes of the output. + const size_t leading_zeros = CURVE25519_KEY_LEN - item.len; + memcpy(data + leading_zeros, item.data, item.len); + +cleanup: + if (item.data != NULL) { + SECITEM_ZfreeItem(&item, PR_FALSE); + } + + return rv; +} + +static void +key_to_hex(const unsigned char key_in[CURVE25519_KEY_LEN], + unsigned char hex_out[(2 * CURVE25519_KEY_LEN) + 1]) +{ + const unsigned char* p = key_in; + for (unsigned int i = 0; i < CURVE25519_KEY_LEN; i++) { + unsigned char bytel = p[0] & 0x0f; + unsigned char byteu = (p[0] & 0xf0) >> 4; + hex_out[2 * i] = int_to_hex(byteu); + hex_out[2 * i + 1] = int_to_hex(bytel); + p++; + } + + hex_out[2 * CURVE25519_KEY_LEN] = '\0'; +} + +static SECStatus +key_from_hex(unsigned char key_out[CURVE25519_KEY_LEN], + const unsigned char hex_in[CURVE25519_KEY_LEN_HEX]) +{ + for (unsigned int i = 0; i < CURVE25519_KEY_LEN_HEX; i++) { + if (!is_hex_digit(hex_in[i])) + return SECFailure; + } + + const unsigned char* p = hex_in; + for (unsigned int i = 0; i < CURVE25519_KEY_LEN; i++) { + uint8_t d0 = hex_to_int(p[0]); + uint8_t d1 = hex_to_int(p[1]); + key_out[i] = (d0 << 4) | d1; + p += 2; + } + + return SECSuccess; +} + +SECStatus +PublicKey_export_hex(const_PublicKey pk, unsigned char* data, + unsigned int dataLen) +{ + if (dataLen != CURVE25519_KEY_LEN_HEX + 1) { + return SECFailure; + } + + unsigned char raw_data[CURVE25519_KEY_LEN]; + if (PublicKey_export(pk, raw_data, sizeof(raw_data)) != SECSuccess) { + return SECFailure; + } + + key_to_hex(raw_data, data); + return SECSuccess; +} + +SECStatus +PrivateKey_export_hex(PrivateKey sk, unsigned char* data, unsigned int dataLen) +{ + if (dataLen != CURVE25519_KEY_LEN_HEX + 1) { + return SECFailure; + } + + unsigned char raw_data[CURVE25519_KEY_LEN]; + if (PrivateKey_export(sk, raw_data, sizeof(raw_data)) != SECSuccess) { + return SECFailure; + } + + key_to_hex(raw_data, data); + return SECSuccess; +} + +SECStatus +Keypair_new(PrivateKey* pvtkey, PublicKey* pubkey) +{ + if (pvtkey == NULL) + return SECFailure; + if (pubkey == NULL) + return SECFailure; + + SECStatus rv = SECSuccess; + SECOidData* oid_data = NULL; + *pubkey = NULL; + *pvtkey = NULL; + + SECKEYECParams ecp; + ecp.data = NULL; + PK11SlotInfo* slot = NULL; + + P_CHECKA(oid_data = SECOID_FindOIDByTag(CURVE_OID_TAG)); + const int oid_struct_len = 2 + oid_data->oid.len; + + P_CHECKA(ecp.data = malloc(oid_struct_len)); + ecp.len = oid_struct_len; + + ecp.type = siDEROID; + + ecp.data[0] = SEC_ASN1_OBJECT_ID; + ecp.data[1] = oid_data->oid.len; + memcpy(&ecp.data[2], oid_data->oid.data, oid_data->oid.len); + + P_CHECKA(slot = PK11_GetInternalSlot()); + P_CHECKA(*pvtkey = PK11_GenerateKeyPair(slot, CKM_EC_KEY_PAIR_GEN, &ecp, + (SECKEYPublicKey**)pubkey, PR_FALSE, + PR_FALSE, NULL)); +cleanup: + if (slot) { + PK11_FreeSlot(slot); + } + if (ecp.data) { + free(ecp.data); + } + if (rv != SECSuccess) { + PublicKey_clear(*pubkey); + PrivateKey_clear(*pvtkey); + } + return rv; +} + +void +PublicKey_clear(PublicKey pubkey) +{ + if (pubkey) + SECKEY_DestroyPublicKey(pubkey); +} + +void +PrivateKey_clear(PrivateKey pvtkey) +{ + if (pvtkey) + SECKEY_DestroyPrivateKey(pvtkey); +} + +const SECItem* +PublicKey_toBytes(const_PublicKey pubkey) +{ + return &pubkey->u.ec.publicValue; +} + +SECStatus +PublicKey_encryptSize(unsigned int inputLen, unsigned int* outputLen) +{ + if (outputLen == NULL || inputLen >= MAX_ENCRYPT_LEN) + return SECFailure; + + // public key, IV, tag, and input + *outputLen = + CURVE25519_KEY_LEN + GCM_IV_LEN_BYTES + GCM_TAG_LEN_BYTES + inputLen; + return SECSuccess; +} + +static void +set_gcm_params(SECItem* paramItem, CK_GCM_PARAMS* param, unsigned char* nonce, + const_PublicKey pubkey, unsigned char* aadBuf) +{ + int offset = 0; + memcpy(aadBuf, PRIO_TAG, strlen(PRIO_TAG)); + offset += strlen(PRIO_TAG); + memcpy(aadBuf + offset, PublicKey_toBytes(pubkey)->data, CURVE25519_KEY_LEN); + offset += CURVE25519_KEY_LEN; + memcpy(aadBuf + offset, nonce, GCM_IV_LEN_BYTES); + + param->pIv = nonce; + param->ulIvLen = GCM_IV_LEN_BYTES; + param->pAAD = aadBuf; + param->ulAADLen = AAD_LEN; + param->ulTagBits = GCM_TAG_LEN_BYTES * 8; + + paramItem->type = siBuffer; + paramItem->data = (void*)param; + paramItem->len = sizeof(*param); +} + +SECStatus +PublicKey_encrypt(PublicKey pubkey, unsigned char* output, + unsigned int* outputLen, unsigned int maxOutputLen, + const unsigned char* input, unsigned int inputLen) +{ + if (pubkey == NULL) + return SECFailure; + + if (inputLen >= MAX_ENCRYPT_LEN) + return SECFailure; + + unsigned int needLen; + if (PublicKey_encryptSize(inputLen, &needLen) != SECSuccess) + return SECFailure; + + if (maxOutputLen < needLen) + return SECFailure; + + SECStatus rv = SECSuccess; + PublicKey eph_pub = NULL; + PrivateKey eph_priv = NULL; + PK11SymKey* aes_key = NULL; + + unsigned char nonce[GCM_IV_LEN_BYTES]; + unsigned char aadBuf[AAD_LEN]; + P_CHECKC(rand_bytes(nonce, GCM_IV_LEN_BYTES)); + + P_CHECKC(Keypair_new(&eph_priv, &eph_pub)); + P_CHECKC(derive_dh_secret(&aes_key, eph_priv, pubkey)); + + CK_GCM_PARAMS param; + SECItem paramItem; + set_gcm_params(¶mItem, ¶m, nonce, eph_pub, aadBuf); + + const SECItem* pk = PublicKey_toBytes(eph_pub); + P_CHECKCB(pk->len == CURVE25519_KEY_LEN); + memcpy(output, pk->data, pk->len); + memcpy(output + CURVE25519_KEY_LEN, param.pIv, param.ulIvLen); + + const int offset = CURVE25519_KEY_LEN + param.ulIvLen; + P_CHECKC(PK11_Encrypt(aes_key, CKM_AES_GCM, ¶mItem, output + offset, + outputLen, maxOutputLen - offset, input, inputLen)); + *outputLen = *outputLen + CURVE25519_KEY_LEN + GCM_IV_LEN_BYTES; + +cleanup: + PublicKey_clear(eph_pub); + PrivateKey_clear(eph_priv); + if (aes_key) + PK11_FreeSymKey(aes_key); + + return rv; +} + +SECStatus +PrivateKey_decrypt(PrivateKey privkey, unsigned char* output, + unsigned int* outputLen, unsigned int maxOutputLen, + const unsigned char* input, unsigned int inputLen) +{ + PK11SymKey* aes_key = NULL; + PublicKey eph_pub = NULL; + unsigned char aad_buf[AAD_LEN]; + + if (privkey == NULL) + return SECFailure; + + SECStatus rv = SECSuccess; + unsigned int headerLen; + if (PublicKey_encryptSize(0, &headerLen) != SECSuccess) + return SECFailure; + + if (inputLen < headerLen) + return SECFailure; + + const unsigned int msglen = inputLen - headerLen; + if (maxOutputLen < msglen || msglen >= MAX_ENCRYPT_LEN) + return SECFailure; + + P_CHECKC(PublicKey_import(&eph_pub, input, CURVE25519_KEY_LEN)); + unsigned char nonce[GCM_IV_LEN_BYTES]; + memcpy(nonce, input + CURVE25519_KEY_LEN, GCM_IV_LEN_BYTES); + + SECItem paramItem; + CK_GCM_PARAMS param; + set_gcm_params(¶mItem, ¶m, nonce, eph_pub, aad_buf); + + P_CHECKC(derive_dh_secret(&aes_key, privkey, eph_pub)); + + const int offset = CURVE25519_KEY_LEN + GCM_IV_LEN_BYTES; + P_CHECKC(PK11_Decrypt(aes_key, CKM_AES_GCM, ¶mItem, output, outputLen, + maxOutputLen, input + offset, inputLen - offset)); + +cleanup: + PublicKey_clear(eph_pub); + if (aes_key) + PK11_FreeSymKey(aes_key); + return rv; +} diff --git a/third_party/prio/prio/encrypt.h b/third_party/prio/prio/encrypt.h new file mode 100644 index 0000000000..bc4b1cf331 --- /dev/null +++ b/third_party/prio/prio/encrypt.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __ENCRYPT_H__ +#define __ENCRYPT_H__ + +#include <limits.h> +#include <mprio.h> + +/******* + * These functions attempt to implement CCA-secure public-key encryption using + * the NSS library. We use hashed-ElGamal encryption with Curve25519 as the + * underlying group and AES128-GCM as the bulk encryption mode of operation. + * + * I make no guarantees that I am using NSS correctly or that this encryption + * scheme is actually CCA secure. As far as I can tell, NSS does not provide + * any public-key hybrid encryption scheme out of the box, so I had to cook my + * own. If you want to be really safe, you should use the NaCl Box routines + * to implement these functions. + */ + +/* + * Messages encrypted using this library must be smaller than MAX_ENCRYPT_LEN. + * Enforcing this length limit helps avoid integer overflow. + */ +#define MAX_ENCRYPT_LEN (INT_MAX >> 3) + +/* + * Write the number of bytes needed to store a ciphertext that encrypts a + * plaintext message of length `inputLen` and authenticated data of length + * `adLen` into the variable pointed to by `outputLen`. If `inputLen` + * is too large (larger than `MAX_ENCRYPT_LEN`), this function returns + * an error. + */ +SECStatus PublicKey_encryptSize(unsigned int inputLen, unsigned int* outputLen); + +/* + * Generate a new keypair for public-key encryption. + */ +SECStatus Keypair_new(PrivateKey* pvtkey, PublicKey* pubkey); + +/* + * Encrypt an arbitrary bitstring to the specified public key. The buffer + * `output` should be large enough to store the ciphertext. Use the + * `PublicKey_encryptSize()` function above to figure out how large of a buffer + * you need. + * + * The value `inputLen` must be smaller than `MAX_ENCRYPT_LEN`. + */ +SECStatus PublicKey_encrypt(PublicKey pubkey, unsigned char* output, + unsigned int* outputLen, unsigned int maxOutputLen, + const unsigned char* input, unsigned int inputLen); + +/* + * Decrypt an arbitrary bitstring using the specified private key. The output + * buffer should be at least 16 bytes larger than the plaintext you expect. If + * `outputLen` >= `inputLen`, you should be safe. + */ +SECStatus PrivateKey_decrypt(PrivateKey privkey, unsigned char* output, + unsigned int* outputLen, unsigned int maxOutputLen, + const unsigned char* input, unsigned int inputLen); + +#endif /* __ENCRYPT_H__ */ diff --git a/third_party/prio/prio/mparray.c b/third_party/prio/prio/mparray.c new file mode 100644 index 0000000000..e7115457ab --- /dev/null +++ b/third_party/prio/prio/mparray.c @@ -0,0 +1,198 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <mprio.h> +#include <stdlib.h> + +#include "config.h" +#include "mparray.h" +#include "share.h" +#include "util.h" + +MPArray +MPArray_new(int len) +{ + SECStatus rv = SECSuccess; + MPArray arr = malloc(sizeof *arr); + if (!arr) + return NULL; + + arr->data = NULL; + arr->len = len; + + P_CHECKA(arr->data = calloc(len, sizeof(mp_int))); + + // Initialize these to NULL so that we can figure + // out which allocations failed (if any) + for (int i = 0; i < len; i++) { + MP_DIGITS(&arr->data[i]) = NULL; + } + + for (int i = 0; i < len; i++) { + MP_CHECKC(mp_init(&arr->data[i])); + } + +cleanup: + if (rv != SECSuccess) { + MPArray_clear(arr); + return NULL; + } + + return arr; +} + +MPArray +MPArray_new_bool(int len, const bool* data_in) +{ + MPArray arr = MPArray_new(len); + if (arr == NULL) + return NULL; + + for (int i = 0; i < len; i++) { + mp_set(&arr->data[i], data_in[i]); + } + + return arr; +} + +SECStatus +MPArray_resize(MPArray arr, int newlen) +{ + SECStatus rv = SECSuccess; + const int oldlen = arr->len; + + if (oldlen == newlen) + return rv; + + // TODO: Use realloc for this? + mp_int* newdata = calloc(newlen, sizeof(mp_int)); + if (newdata == NULL) + return SECFailure; + + for (int i = 0; i < newlen; i++) { + MP_DIGITS(&newdata[i]) = NULL; + } + + // Initialize new array + for (int i = 0; i < newlen; i++) { + MP_CHECKC(mp_init(&newdata[i])); + } + + // Copy old data into new array + for (int i = 0; i < newlen && i < oldlen; i++) { + MP_CHECKC(mp_copy(&arr->data[i], &newdata[i])); + } + + // Free old data + for (int i = 0; i < oldlen; i++) { + mp_clear(&arr->data[i]); + } + free(arr->data); + arr->data = newdata; + arr->len = newlen; + +cleanup: + if (rv != SECSuccess) { + for (int i = 0; i < newlen; i++) { + mp_clear(&newdata[i]); + } + free(newdata); + } + + return rv; +} + +MPArray +MPArray_dup(const_MPArray src) +{ + MPArray dst = MPArray_new(src->len); + if (!dst) + return NULL; + + SECStatus rv = MPArray_copy(dst, src); + if (rv == SECSuccess) { + return dst; + } else { + MPArray_clear(dst); + return NULL; + } +} + +SECStatus +MPArray_copy(MPArray dst, const_MPArray src) +{ + if (dst->len != src->len) + return SECFailure; + + for (int i = 0; i < src->len; i++) { + if (mp_copy(&src->data[i], &dst->data[i]) != MP_OKAY) { + return SECFailure; + } + } + + return SECSuccess; +} + +SECStatus +MPArray_set_share(MPArray arrA, MPArray arrB, const_MPArray src, + const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + if (arrA->len != src->len || arrB->len != src->len) + return SECFailure; + + const int len = src->len; + + for (int i = 0; i < len; i++) { + P_CHECK(share_int(cfg, &src->data[i], &arrA->data[i], &arrB->data[i])); + } + + return rv; +} + +void +MPArray_clear(MPArray arr) +{ + if (arr == NULL) + return; + + if (arr->data != NULL) { + for (int i = 0; i < arr->len; i++) { + mp_clear(&arr->data[i]); + } + free(arr->data); + } + free(arr); +} + +SECStatus +MPArray_addmod(MPArray dst, const_MPArray to_add, const mp_int* mod) +{ + if (dst->len != to_add->len) + return SECFailure; + + for (int i = 0; i < dst->len; i++) { + MP_CHECK(mp_addmod(&dst->data[i], &to_add->data[i], mod, &dst->data[i])); + } + + return SECSuccess; +} + +bool +MPArray_areEqual(const_MPArray arr1, const_MPArray arr2) +{ + if (arr1->len != arr2->len) + return false; + + for (int i = 0; i < arr1->len; i++) { + if (mp_cmp(&arr1->data[i], &arr2->data[i])) + return false; + } + + return true; +} diff --git a/third_party/prio/prio/mparray.h b/third_party/prio/prio/mparray.h new file mode 100644 index 0000000000..b268336de7 --- /dev/null +++ b/third_party/prio/prio/mparray.h @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __MPARRAY_H__ +#define __MPARRAY_H__ + +#include <mpi.h> +#include <mprio.h> + +struct mparray +{ + int len; + mp_int* data; +}; + +typedef struct mparray* MPArray; +typedef const struct mparray* const_MPArray; + +/* + * Initialize an array of `mp_int`s of the given length. + */ +MPArray MPArray_new(int len); +void MPArray_clear(MPArray arr); + +/* + * Copies secret sharing of data from src into arrays + * arrA and arrB. The lengths of the three input arrays + * must be identical. + */ +SECStatus MPArray_set_share(MPArray arrA, MPArray arrB, const_MPArray src, + const_PrioConfig cfg); + +/* + * Initializes array with 0/1 values specified in boolean array `data_in` + */ +MPArray MPArray_new_bool(int len, const bool* data_in); + +/* + * Expands or shrinks the MPArray to the desired size. If shrinking, + * will clear the values on the end of array. + */ +SECStatus MPArray_resize(MPArray arr, int newlen); + +/* + * Initializes dst and creates a duplicate of the array in src. + */ +MPArray MPArray_dup(const_MPArray src); + +/* + * Copies array from src to dst. Arrays must have the same length. + */ +SECStatus MPArray_copy(MPArray dst, const_MPArray src); + +/* For each index i into the array, set: + * dst[i] = dst[i] + to_add[i] (modulo mod) + */ +SECStatus MPArray_addmod(MPArray dst, const_MPArray to_add, const mp_int* mod); + +/* + * Return true iff the two arrays are equal in length + * and contents. This comparison is NOT constant time. + */ +bool MPArray_areEqual(const_MPArray arr1, const_MPArray arr2); + +#endif /* __MPARRAY_H__ */ diff --git a/third_party/prio/prio/params.h b/third_party/prio/prio/params.h new file mode 100644 index 0000000000..be814c67ab --- /dev/null +++ b/third_party/prio/prio/params.h @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __PARAMS_H__ +#define __PARAMS_H__ + +// A prime modulus p. +static const char Modulus[] = "8000000000000000080001"; + +// A generator g of a subgroup of Z*_p. +static const char Generator[] = "2597c14f48d5b65ed8dcca"; + +// The generator g generates a subgroup of +// order 2^Generator2Order in Z*_p. +static const int Generator2Order = 19; + +#endif /* __PARAMS_H__ */ diff --git a/third_party/prio/prio/poly.c b/third_party/prio/prio/poly.c new file mode 100644 index 0000000000..0b2f854548 --- /dev/null +++ b/third_party/prio/prio/poly.c @@ -0,0 +1,211 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <mprio.h> + +#include "config.h" +#include "poly.h" +#include "util.h" + +/* + * A nice exposition of the recursive FFT/DFT algorithm we implement + * is in the book: + * + * "Modern Computer Algebra" + * by Von zur Gathen and Gerhard. + * Cambridge University Press, 2013. + * + * They present this algorithm as Algorithm 8.14. + */ + +static SECStatus +fft_recurse(mp_int* out, const mp_int* mod, int n, const mp_int* roots, + const mp_int* ys, mp_int* tmp, mp_int* ySub, mp_int* rootsSub) +{ + if (n == 1) { + MP_CHECK(mp_copy(&ys[0], &out[0])); + return SECSuccess; + } + + // Recurse on the first half + for (int i = 0; i < n / 2; i++) { + MP_CHECK(mp_addmod(&ys[i], &ys[i + (n / 2)], mod, &ySub[i])); + MP_CHECK(mp_copy(&roots[2 * i], &rootsSub[i])); + } + + MP_CHECK(fft_recurse(tmp, mod, n / 2, rootsSub, ySub, &tmp[n / 2], + &ySub[n / 2], &rootsSub[n / 2])); + for (int i = 0; i < n / 2; i++) { + MP_CHECK(mp_copy(&tmp[i], &out[2 * i])); + } + + // Recurse on the second half + for (int i = 0; i < n / 2; i++) { + MP_CHECK(mp_submod(&ys[i], &ys[i + (n / 2)], mod, &ySub[i])); + MP_CHECK(mp_mulmod(&ySub[i], &roots[i], mod, &ySub[i])); + } + + MP_CHECK(fft_recurse(tmp, mod, n / 2, rootsSub, ySub, &tmp[n / 2], + &ySub[n / 2], &rootsSub[n / 2])); + for (int i = 0; i < n / 2; i++) { + MP_CHECK(mp_copy(&tmp[i], &out[2 * i + 1])); + } + + return SECSuccess; +} + +static SECStatus +fft_interpolate_raw(mp_int* out, const mp_int* ys, int nPoints, + const_MPArray roots, const mp_int* mod, bool invert) +{ + SECStatus rv = SECSuccess; + MPArray tmp = NULL; + MPArray ySub = NULL; + MPArray rootsSub = NULL; + + P_CHECKA(tmp = MPArray_new(nPoints)); + P_CHECKA(ySub = MPArray_new(nPoints)); + P_CHECKA(rootsSub = MPArray_new(nPoints)); + + mp_int n_inverse; + MP_DIGITS(&n_inverse) = NULL; + + MP_CHECKC(fft_recurse(out, mod, nPoints, roots->data, ys, tmp->data, + ySub->data, rootsSub->data)); + + if (invert) { + MP_CHECKC(mp_init(&n_inverse)); + + mp_set(&n_inverse, nPoints); + MP_CHECKC(mp_invmod(&n_inverse, mod, &n_inverse)); + for (int i = 0; i < nPoints; i++) { + MP_CHECKC(mp_mulmod(&out[i], &n_inverse, mod, &out[i])); + } + } + +cleanup: + MPArray_clear(tmp); + MPArray_clear(ySub); + MPArray_clear(rootsSub); + mp_clear(&n_inverse); + + return rv; +} + +/* + * The PrioConfig object has a list of N-th roots of unity for large N. + * This routine returns the n-th roots of unity for n < N, where n is + * a power of two. If the `invert` flag is set, it returns the inverses + * of the n-th roots of unity. + */ +SECStatus +poly_fft_get_roots(MPArray roots_out, int n_points, const_PrioConfig cfg, + bool invert) +{ + if (n_points < 1) { + return SECFailure; + } + + if (n_points != roots_out->len) { + return SECFailure; + } + + if (n_points > cfg->n_roots) { + return SECFailure; + } + + mp_set(&roots_out->data[0], 1); + if (n_points == 1) { + return SECSuccess; + } + + const int step_size = cfg->n_roots / n_points; + mp_int* gen = &roots_out->data[1]; + + MP_CHECK(mp_copy(&cfg->generator, gen)); + + if (invert) { + MP_CHECK(mp_invmod(gen, &cfg->modulus, gen)); + } + + // Compute g' = g^step_size + // Now, g' generates a subgroup of order n_points. + MP_CHECK(mp_exptmod_d(gen, step_size, &cfg->modulus, gen)); + + for (int i = 2; i < n_points; i++) { + // Compute g^i for all i in {0,..., n-1} + MP_CHECK(mp_mulmod(gen, &roots_out->data[i - 1], &cfg->modulus, + &roots_out->data[i])); + } + + return SECSuccess; +} + +SECStatus +poly_fft(MPArray points_out, const_MPArray points_in, const_PrioConfig cfg, + bool invert) +{ + SECStatus rv = SECSuccess; + const int n_points = points_in->len; + MPArray scaled_roots = NULL; + + if (points_out->len != points_in->len) + return SECFailure; + if (n_points > cfg->n_roots) + return SECFailure; + if (cfg->n_roots % n_points != 0) + return SECFailure; + + P_CHECKA(scaled_roots = MPArray_new(n_points)); + P_CHECKC(poly_fft_get_roots(scaled_roots, n_points, cfg, invert)); + + P_CHECKC(fft_interpolate_raw(points_out->data, points_in->data, n_points, + scaled_roots, &cfg->modulus, invert)); + +cleanup: + MPArray_clear(scaled_roots); + + return SECSuccess; +} + +SECStatus +poly_eval(mp_int* value, const_MPArray coeffs, const mp_int* eval_at, + const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + const int n = coeffs->len; + + // Use Horner's method to evaluate the polynomial at the point + // `eval_at` + MP_CHECK(mp_copy(&coeffs->data[n - 1], value)); + for (int i = n - 2; i >= 0; i--) { + MP_CHECK(mp_mulmod(value, eval_at, &cfg->modulus, value)); + MP_CHECK(mp_addmod(value, &coeffs->data[i], &cfg->modulus, value)); + } + + return rv; +} + +SECStatus +poly_interp_evaluate(mp_int* value, const_MPArray poly_points, + const mp_int* eval_at, const_PrioConfig cfg) +{ + SECStatus rv; + MPArray coeffs = NULL; + const int N = poly_points->len; + + P_CHECKA(coeffs = MPArray_new(N)); + + // Interpolate polynomial through roots of unity + P_CHECKC(poly_fft(coeffs, poly_points, cfg, true)) + P_CHECKC(poly_eval(value, coeffs, eval_at, cfg)); + +cleanup: + MPArray_clear(coeffs); + return rv; +} diff --git a/third_party/prio/prio/poly.h b/third_party/prio/prio/poly.h new file mode 100644 index 0000000000..c53e324622 --- /dev/null +++ b/third_party/prio/prio/poly.h @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef _FFT__H +#define _FFT__H + +#include <mpi.h> +#include <mprio.h> +#include <stdbool.h> + +#include "mparray.h" + +/* + * Compute the FFT or inverse FFT of the array in `points_in`. + * The length of the input and output arrays must be a multiple + * of two and must be no longer than the number of precomputed + * roots in the PrioConfig object passed in. + */ +SECStatus poly_fft(MPArray points_out, const_MPArray points_in, + const_PrioConfig cfg, bool invert); + +/* + * Get an array + * (r^0, r^1, r^2, ... ) + * where r is an n-th root of unity, for n a power of two + * less than cfg->n_roots. + * + * Do NOT mp_clear() the mp_ints stored in roots_out. + * These are owned by the PrioConfig object. + */ +SECStatus poly_fft_get_roots(MPArray roots_out, int n_points, + const_PrioConfig cfg, bool invert); + +/* + * Evaluate the polynomial specified by the coefficients + * at the point `eval_at` and return the result as `value`. + */ +SECStatus poly_eval(mp_int* value, const_MPArray coeffs, const mp_int* eval_at, + const_PrioConfig cfg); + +/* + * Interpolate the polynomial through the points + * (x_1, y_1), ..., (x_N, y_N), + * where x_i is an N-th root of unity and the y_i values are + * specified by `poly_points`. Evaluate the resulting polynomial + * at the point `eval_at`. Return the result as `value`. + */ +SECStatus poly_interp_evaluate(mp_int* value, const_MPArray poly_points, + const mp_int* eval_at, const_PrioConfig cfg); + +#endif diff --git a/third_party/prio/prio/prg.c b/third_party/prio/prio/prg.c new file mode 100644 index 0000000000..0fde1a2288 --- /dev/null +++ b/third_party/prio/prio/prg.c @@ -0,0 +1,188 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <blapit.h> +#include <mprio.h> +#include <pk11pub.h> +#include <string.h> + +#include "prg.h" +#include "rand.h" +#include "share.h" +#include "util.h" + +struct prg +{ + PK11SlotInfo* slot; + PK11SymKey* key; + PK11Context* ctx; +}; + +SECStatus +PrioPRGSeed_randomize(PrioPRGSeed* key) +{ + return rand_bytes((unsigned char*)key, PRG_SEED_LENGTH); +} + +PRG +PRG_new(const PrioPRGSeed key_in) +{ + PRG prg = malloc(sizeof(struct prg)); + if (!prg) + return NULL; + prg->slot = NULL; + prg->key = NULL; + prg->ctx = NULL; + + SECStatus rv = SECSuccess; + const CK_MECHANISM_TYPE cipher = CKM_AES_CTR; + + P_CHECKA(prg->slot = PK11_GetInternalSlot()); + + // Create a mutable copy of the key. + PrioPRGSeed key_mut; + memcpy(key_mut, key_in, PRG_SEED_LENGTH); + + SECItem keyItem = { siBuffer, key_mut, PRG_SEED_LENGTH }; + + // The IV can be all zeros since we only encrypt once with + // each AES key. + CK_AES_CTR_PARAMS param = { 128, {} }; + SECItem paramItem = { siBuffer, (void*)¶m, sizeof(CK_AES_CTR_PARAMS) }; + + P_CHECKA(prg->key = PK11_ImportSymKey(prg->slot, cipher, PK11_OriginUnwrap, + CKA_ENCRYPT, &keyItem, NULL)); + + P_CHECKA(prg->ctx = PK11_CreateContextBySymKey(cipher, CKA_ENCRYPT, prg->key, + ¶mItem)); + +cleanup: + if (rv != SECSuccess) { + PRG_clear(prg); + prg = NULL; + } + + return prg; +} + +void +PRG_clear(PRG prg) +{ + if (!prg) + return; + + if (prg->key) + PK11_FreeSymKey(prg->key); + if (prg->slot) + PK11_FreeSlot(prg->slot); + if (prg->ctx) + PK11_DestroyContext(prg->ctx, PR_TRUE); + + free(prg); +} + +static SECStatus +PRG_get_bytes_internal(void* prg_vp, unsigned char* bytes, size_t len) +{ + SECStatus rv = SECSuccess; + PRG prg = (PRG)prg_vp; + unsigned char* in = NULL; + + P_CHECKA(in = calloc(len, sizeof(unsigned char))); + + int outlen; + P_CHECKC(PK11_CipherOp(prg->ctx, bytes, &outlen, len, in, len)); + P_CHECKCB((size_t)outlen == len); + +cleanup: + if (in) + free(in); + + return rv; +} + +SECStatus +PRG_get_bytes(PRG prg, unsigned char* bytes, size_t len) +{ + return PRG_get_bytes_internal((void*)prg, bytes, len); +} + +SECStatus +PRG_get_int(PRG prg, mp_int* out, const mp_int* max) +{ + return rand_int_rng(out, max, &PRG_get_bytes_internal, (void*)prg); +} + +SECStatus +PRG_get_int_range(PRG prg, mp_int* out, const mp_int* lower, const mp_int* max) +{ + SECStatus rv; + mp_int width; + MP_DIGITS(&width) = NULL; + MP_CHECKC(mp_init(&width)); + + // Compute + // width = max - lower + MP_CHECKC(mp_sub(max, lower, &width)); + + // Get an integer x in the range [0, width) + P_CHECKC(PRG_get_int(prg, out, &width)); + + // Set + // out = lower + x + // which is in the range [lower, width+lower), + // which is [lower, max). + MP_CHECKC(mp_add(lower, out, out)); + +cleanup: + mp_clear(&width); + return rv; +} + +SECStatus +PRG_get_array(PRG prg, MPArray dst, const mp_int* mod) +{ + SECStatus rv; + for (int i = 0; i < dst->len; i++) { + P_CHECK(PRG_get_int(prg, &dst->data[i], mod)); + } + + return SECSuccess; +} + +SECStatus +PRG_share_int(PRG prgB, mp_int* shareA, const mp_int* src, const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + mp_int tmp; + MP_DIGITS(&tmp) = NULL; + + MP_CHECKC(mp_init(&tmp)); + P_CHECKC(PRG_get_int(prgB, &tmp, &cfg->modulus)); + MP_CHECKC(mp_submod(src, &tmp, &cfg->modulus, shareA)); + +cleanup: + mp_clear(&tmp); + return rv; +} + +SECStatus +PRG_share_array(PRG prgB, MPArray arrA, const_MPArray src, const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + if (arrA->len != src->len) + return SECFailure; + + const int len = src->len; + + for (int i = 0; i < len; i++) { + P_CHECK(PRG_share_int(prgB, &arrA->data[i], &src->data[i], cfg)); + } + + return rv; +} diff --git a/third_party/prio/prio/prg.h b/third_party/prio/prio/prg.h new file mode 100644 index 0000000000..f1a3b30b62 --- /dev/null +++ b/third_party/prio/prio/prg.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __PRG_H__ +#define __PRG_H__ + +#include <blapit.h> +#include <mpi.h> +#include <stdlib.h> + +#include "config.h" + +typedef struct prg* PRG; +typedef const struct prg* const_PRG; + +/* + * Initialize or destroy a pseudo-random generator. + */ +PRG PRG_new(const PrioPRGSeed key); +void PRG_clear(PRG prg); + +/* + * Produce the next bytes of output from the PRG. + */ +SECStatus PRG_get_bytes(PRG prg, unsigned char* bytes, size_t len); + +/* + * Use the PRG output to sample a big integer x in the range + * 0 <= x < max. + */ +SECStatus PRG_get_int(PRG prg, mp_int* out, const mp_int* max); + +/* + * Use the PRG output to sample a big integer x in the range + * lower <= x < max. + */ +SECStatus PRG_get_int_range(PRG prg, mp_int* out, const mp_int* lower, + const mp_int* max); + +/* + * Use secret sharing to split the int src into two shares. + * Use PRG to generate the value `shareB`. + * The mp_ints must be initialized. + */ +SECStatus PRG_share_int(PRG prg, mp_int* shareA, const mp_int* src, + const_PrioConfig cfg); + +/* + * Set each item in the array to a pseudorandom value in the range + * [0, mod), where the values are generated using the PRG. + */ +SECStatus PRG_get_array(PRG prg, MPArray arr, const mp_int* mod); + +/* + * Secret shares the array in `src` into `arrA` using randomness + * provided by `prgB`. The arrays `src` and `arrA` must be the same + * length. + */ +SECStatus PRG_share_array(PRG prgB, MPArray arrA, const_MPArray src, + const_PrioConfig cfg); + +#endif /* __PRG_H__ */ diff --git a/third_party/prio/prio/rand.c b/third_party/prio/prio/rand.c new file mode 100644 index 0000000000..7fa7d309b9 --- /dev/null +++ b/third_party/prio/prio/rand.c @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <limits.h> +#include <mprio.h> +#include <nss.h> +#include <pk11pub.h> +#include <prinit.h> + +#include "debug.h" +#include "rand.h" +#include "util.h" + +#define CHUNK_SIZE 8192 + +static NSSInitContext* prioGlobalContext = NULL; + +SECStatus +rand_init(void) +{ + if (prioGlobalContext) + return SECSuccess; + + prioGlobalContext = + NSS_InitContext("", "", "", "", NULL, + NSS_INIT_READONLY | NSS_INIT_NOCERTDB | NSS_INIT_NOMODDB | + NSS_INIT_FORCEOPEN | NSS_INIT_NOROOTINIT); + + return (prioGlobalContext != NULL) ? SECSuccess : SECFailure; +} + +static SECStatus +rand_bytes_internal(void* user_data, unsigned char* out, size_t n_bytes) +{ + // No pointer should ever be passed in. + if (user_data != NULL) + return SECFailure; + if (!NSS_IsInitialized()) { + PRIO_DEBUG("NSS not initialized. Call rand_init() first."); + return SECFailure; + } + + SECStatus rv = SECFailure; + + int to_go = n_bytes; + unsigned char* cp = out; + while (to_go) { + int to_gen = MIN(CHUNK_SIZE, to_go); + if ((rv = PK11_GenerateRandom(cp, to_gen)) != SECSuccess) { + PRIO_DEBUG("Error calling PK11_GenerateRandom"); + return SECFailure; + } + + cp += CHUNK_SIZE; + to_go -= to_gen; + } + + return rv; +} + +SECStatus +rand_bytes(unsigned char* out, size_t n_bytes) +{ + return rand_bytes_internal(NULL, out, n_bytes); +} + +SECStatus +rand_int(mp_int* out, const mp_int* max) +{ + return rand_int_rng(out, max, &rand_bytes_internal, NULL); +} + +SECStatus +rand_int_rng(mp_int* out, const mp_int* max, RandBytesFunc rng_func, + void* user_data) +{ + SECStatus rv = SECSuccess; + unsigned char* max_bytes = NULL; + unsigned char* buf = NULL; + + // Ensure max value is > 0 + if (mp_cmp_z(max) == 0) + return SECFailure; + + // Compute max-1, which tells us the largest + // value we will ever need to generate. + MP_CHECKC(mp_sub_d(max, 1, out)); + + const int nbytes = mp_unsigned_octet_size(out); + + // Figure out how many MSBs we need to get in the + // most-significant byte. + P_CHECKA(max_bytes = calloc(nbytes, sizeof(unsigned char))); + MP_CHECKC(mp_to_fixlen_octets(out, max_bytes, nbytes)); + const unsigned char mask = msb_mask(max_bytes[0]); + + // Buffer to store the pseudo-random bytes + P_CHECKA(buf = calloc(nbytes, sizeof(unsigned char))); + + do { + // Use rejection sampling to find a value strictly less than max. + P_CHECKC(rng_func(user_data, buf, nbytes)); + + // Mask off high-order bits that we will never need. + P_CHECKC(rng_func(user_data, &buf[0], 1)); + if (mask) + buf[0] &= mask; + + MP_CHECKC(mp_read_unsigned_octets(out, buf, nbytes)); + } while (mp_cmp(out, max) != -1); + +cleanup: + if (max_bytes) + free(max_bytes); + if (buf) + free(buf); + + return rv; +} + +void +rand_clear(void) +{ + if (prioGlobalContext) { + NSS_ShutdownContext(prioGlobalContext); +#ifdef DO_PR_CLEANUP + PR_Cleanup(); +#endif + } + + prioGlobalContext = NULL; +} diff --git a/third_party/prio/prio/rand.h b/third_party/prio/prio/rand.h new file mode 100644 index 0000000000..ec610a19dc --- /dev/null +++ b/third_party/prio/prio/rand.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __RAND_H__ +#define __RAND_H__ + +#include <mpi.h> +#include <seccomon.h> +#include <stdlib.h> + +/* + * Typedef for function pointer. A function pointer of type RandBytesFunc + * points to a function that fills the buffer `out` of with `len` random bytes. + */ +typedef SECStatus (*RandBytesFunc)(void* user_data, unsigned char* out, + size_t len); + +/* + * Initialize or cleanup the global random number generator + * state that NSS uses. + */ +SECStatus rand_init(void); +void rand_clear(void); + +/* + * Generate the specified number of random bytes using the + * NSS random number generator. + */ +SECStatus rand_bytes(unsigned char* out, size_t n_bytes); + +/* + * Generate a random number x such that + * 0 <= x < max + * using the NSS random number generator. + */ +SECStatus rand_int(mp_int* out, const mp_int* max); + +/* + * Generate a random number x such that + * 0 <= x < max + * using the specified randomness generator. + * + * The pointer user_data is passed to RandBytesFung `rng` as a first + * argument. + */ +SECStatus rand_int_rng(mp_int* out, const mp_int* max, RandBytesFunc rng, + void* user_data); + +#endif /* __RAND_H__ */ diff --git a/third_party/prio/prio/serial.c b/third_party/prio/prio/serial.c new file mode 100644 index 0000000000..9d36192b07 --- /dev/null +++ b/third_party/prio/prio/serial.c @@ -0,0 +1,458 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <mprio.h> +#include <msgpack.h> + +#include "client.h" +#include "serial.h" +#include "server.h" +#include "share.h" +#include "util.h" + +#define MSGPACK_OK 0 + +static SECStatus +serial_write_mp_int(msgpack_packer* pk, const mp_int* n) +{ + SECStatus rv = SECSuccess; + unsigned int n_size = mp_unsigned_octet_size(n); + unsigned char* data = NULL; + + P_CHECKA(data = calloc(n_size, sizeof(unsigned char))); + MP_CHECKC(mp_to_fixlen_octets(n, data, n_size)); + + P_CHECKC(msgpack_pack_str(pk, n_size)); + P_CHECKC(msgpack_pack_str_body(pk, data, n_size)); +cleanup: + if (data) + free(data); + return rv; +} + +static SECStatus +object_to_mp_int(msgpack_object* obj, mp_int* n, const mp_int* max) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(obj != NULL); + P_CHECKCB(obj->type == MSGPACK_OBJECT_STR); + P_CHECKCB(n != NULL); + + msgpack_object_str s = obj->via.str; + P_CHECKCB(s.ptr != NULL); + MP_CHECKC(mp_read_unsigned_octets(n, (unsigned char*)s.ptr, s.size)); + + P_CHECKCB(mp_cmp_z(n) >= 0); + P_CHECKCB(mp_cmp(n, max) < 0); + +cleanup: + return rv; +} + +static SECStatus +serial_read_mp_int(msgpack_unpacker* upk, mp_int* n, const mp_int* max) +{ + SECStatus rv = SECSuccess; + + msgpack_unpacked res; + msgpack_unpacked_init(&res); + + P_CHECKCB(upk != NULL); + P_CHECKCB(n != NULL); + P_CHECKCB(max != NULL); + + UP_CHECKC(msgpack_unpacker_next(upk, &res)); + + msgpack_object obj = res.data; + P_CHECKC(object_to_mp_int(&obj, n, max)); + +cleanup: + msgpack_unpacked_destroy(&res); + + return rv; +} + +static SECStatus +serial_read_int(msgpack_unpacker* upk, int* n) +{ + SECStatus rv = SECSuccess; + + msgpack_unpacked res; + msgpack_unpacked_init(&res); + + P_CHECKCB(upk != NULL); + P_CHECKCB(n != NULL); + + UP_CHECKC(msgpack_unpacker_next(upk, &res)); + + msgpack_object obj = res.data; + P_CHECKCB(obj.type == MSGPACK_OBJECT_POSITIVE_INTEGER); + + *n = obj.via.i64; + +cleanup: + msgpack_unpacked_destroy(&res); + + return rv; +} + +static SECStatus +serial_write_mp_array(msgpack_packer* pk, const_MPArray arr) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(pk != NULL); + P_CHECKCB(arr != NULL); + + P_CHECKC(msgpack_pack_array(pk, arr->len)); + for (int i = 0; i < arr->len; i++) { + P_CHECKC(serial_write_mp_int(pk, &arr->data[i])); + } + +cleanup: + return rv; +} + +static SECStatus +serial_read_mp_array(msgpack_unpacker* upk, MPArray arr, size_t len, + const mp_int* max) +{ + SECStatus rv = SECSuccess; + + msgpack_unpacked res; + msgpack_unpacked_init(&res); + + P_CHECKCB(upk != NULL); + P_CHECKCB(arr != NULL); + P_CHECKCB(max != NULL); + + UP_CHECKC(msgpack_unpacker_next(upk, &res)); + + msgpack_object obj = res.data; + P_CHECKCB(obj.type == MSGPACK_OBJECT_ARRAY); + + msgpack_object_array objarr = obj.via.array; + P_CHECKCB(objarr.size == len); + + P_CHECKC(MPArray_resize(arr, len)); + for (unsigned int i = 0; i < len; i++) { + P_CHECKC(object_to_mp_int(&objarr.ptr[i], &arr->data[i], max)); + } + +cleanup: + msgpack_unpacked_destroy(&res); + + return rv; +} + +static SECStatus +serial_write_beaver_triple(msgpack_packer* pk, const_BeaverTriple t) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(pk != NULL); + P_CHECKCB(t != NULL); + + P_CHECKC(serial_write_mp_int(pk, &t->a)); + P_CHECKC(serial_write_mp_int(pk, &t->b)); + P_CHECKC(serial_write_mp_int(pk, &t->c)); + +cleanup: + return rv; +} + +static SECStatus +serial_read_beaver_triple(msgpack_unpacker* pk, BeaverTriple t, + const mp_int* max) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(pk != NULL); + P_CHECKCB(t != NULL); + P_CHECKCB(max != NULL); + + P_CHECKC(serial_read_mp_int(pk, &t->a, max)); + P_CHECKC(serial_read_mp_int(pk, &t->b, max)); + P_CHECKC(serial_read_mp_int(pk, &t->c, max)); + +cleanup: + return rv; +} + +static SECStatus +serial_write_server_a_data(msgpack_packer* pk, const struct server_a_data* A) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(pk != NULL); + P_CHECKCB(A != NULL); + + P_CHECKC(serial_write_mp_array(pk, A->data_shares)); + P_CHECKC(serial_write_mp_array(pk, A->h_points)); +cleanup: + return rv; +} + +static SECStatus +serial_read_server_a_data(msgpack_unpacker* upk, struct server_a_data* A, + const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(upk != NULL); + P_CHECKCB(A != NULL); + + P_CHECKC(serial_read_mp_array(upk, A->data_shares, cfg->num_data_fields, + &cfg->modulus)); + P_CHECKC(serial_read_mp_array(upk, A->h_points, PrioConfig_hPoints(cfg), + &cfg->modulus)); + +cleanup: + return rv; +} + +static SECStatus +serial_write_prg_seed(msgpack_packer* pk, const PrioPRGSeed* seed) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(pk != NULL); + P_CHECKCB(seed != NULL); + + P_CHECKC(msgpack_pack_str(pk, PRG_SEED_LENGTH)); + P_CHECKC(msgpack_pack_str_body(pk, seed, PRG_SEED_LENGTH)); + +cleanup: + return rv; +} + +static SECStatus +serial_read_prg_seed(msgpack_unpacker* upk, PrioPRGSeed* seed) +{ + SECStatus rv = SECSuccess; + + msgpack_unpacked res; + msgpack_unpacked_init(&res); + + P_CHECKCB(upk != NULL); + P_CHECKCB(seed != NULL); + + UP_CHECKC(msgpack_unpacker_next(upk, &res)); + + msgpack_object obj = res.data; + P_CHECKCB(obj.type == MSGPACK_OBJECT_STR); + + msgpack_object_str s = obj.via.str; + P_CHECKCB(s.size == PRG_SEED_LENGTH); + memcpy(seed, s.ptr, PRG_SEED_LENGTH); + +cleanup: + msgpack_unpacked_destroy(&res); + + return rv; +} + +static SECStatus +serial_write_server_b_data(msgpack_packer* pk, const struct server_b_data* B) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(pk != NULL); + P_CHECKCB(B != NULL); + + rv = serial_write_prg_seed(pk, &B->seed); +cleanup: + return rv; +} + +static SECStatus +serial_read_server_b_data(msgpack_unpacker* upk, struct server_b_data* B) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(upk != NULL); + P_CHECKCB(B != NULL); + + rv = serial_read_prg_seed(upk, &B->seed); +cleanup: + return rv; +} + +SECStatus +serial_write_packet_client(msgpack_packer* pk, const_PrioPacketClient p, + const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(pk != NULL); + P_CHECKCB(p != NULL); + + P_CHECKC(msgpack_pack_str(pk, cfg->batch_id_len)); + P_CHECKC(msgpack_pack_str_body(pk, cfg->batch_id, cfg->batch_id_len)); + + P_CHECKC(serial_write_beaver_triple(pk, p->triple)); + + P_CHECKC(serial_write_mp_int(pk, &p->f0_share)); + P_CHECKC(serial_write_mp_int(pk, &p->g0_share)); + P_CHECKC(serial_write_mp_int(pk, &p->h0_share)); + + P_CHECKC(msgpack_pack_int(pk, p->for_server)); + + switch (p->for_server) { + case PRIO_SERVER_A: + P_CHECKC(serial_write_server_a_data(pk, &p->shares.A)); + break; + case PRIO_SERVER_B: + P_CHECKC(serial_write_server_b_data(pk, &p->shares.B)); + break; + default: + return SECFailure; + } + +cleanup: + return rv; +} + +SECStatus +serial_read_server_id(msgpack_unpacker* upk, PrioServerId* s) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(upk != NULL); + P_CHECKCB(s != NULL); + + int serv; + P_CHECKC(serial_read_int(upk, &serv)); + P_CHECKCB(serv == PRIO_SERVER_A || serv == PRIO_SERVER_B); + *s = serv; + +cleanup: + return rv; +} + +SECStatus +serial_read_packet_client(msgpack_unpacker* upk, PrioPacketClient p, + const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + + msgpack_unpacked res; + msgpack_unpacked_init(&res); + + P_CHECKCB(upk != NULL); + P_CHECKCB(p != NULL); + + UP_CHECKC(msgpack_unpacker_next(upk, &res)); + + msgpack_object obj = res.data; + P_CHECKCB(obj.type == MSGPACK_OBJECT_STR); + + msgpack_object_str s = obj.via.str; + P_CHECKCB(s.size == cfg->batch_id_len); + P_CHECKCB(!memcmp(s.ptr, (char*)cfg->batch_id, cfg->batch_id_len)); + + P_CHECKC(serial_read_beaver_triple(upk, p->triple, &cfg->modulus)); + + P_CHECKC(serial_read_mp_int(upk, &p->f0_share, &cfg->modulus)); + P_CHECKC(serial_read_mp_int(upk, &p->g0_share, &cfg->modulus)); + P_CHECKC(serial_read_mp_int(upk, &p->h0_share, &cfg->modulus)); + + PrioServerId remote_id; + P_CHECKC(serial_read_server_id(upk, &remote_id)); + P_CHECKCB(remote_id == p->for_server); + + switch (p->for_server) { + case PRIO_SERVER_A: + P_CHECKC(serial_read_server_a_data(upk, &p->shares.A, cfg)); + break; + case PRIO_SERVER_B: + P_CHECKC(serial_read_server_b_data(upk, &p->shares.B)); + break; + default: + rv = SECFailure; + goto cleanup; + } + +cleanup: + msgpack_unpacked_destroy(&res); + return rv; +} + +SECStatus +PrioPacketVerify1_write(const_PrioPacketVerify1 p, msgpack_packer* pk) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(pk != NULL); + P_CHECKCB(p != NULL); + + P_CHECKC(serial_write_mp_int(pk, &p->share_d)); + P_CHECKC(serial_write_mp_int(pk, &p->share_e)); + +cleanup: + return rv; +} + +SECStatus +PrioPacketVerify1_read(PrioPacketVerify1 p, msgpack_unpacker* upk, + const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(upk != NULL); + P_CHECKCB(p != NULL); + + P_CHECKC(serial_read_mp_int(upk, &p->share_d, &cfg->modulus)); + P_CHECKC(serial_read_mp_int(upk, &p->share_e, &cfg->modulus)); + +cleanup: + return rv; +} + +SECStatus +PrioPacketVerify2_write(const_PrioPacketVerify2 p, msgpack_packer* pk) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(pk != NULL); + P_CHECKCB(p != NULL); + + P_CHECKC(serial_write_mp_int(pk, &p->share_out)); + +cleanup: + return rv; +} + +SECStatus +PrioPacketVerify2_read(PrioPacketVerify2 p, msgpack_unpacker* upk, + const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(upk != NULL); + P_CHECKCB(p != NULL); + + P_CHECKC(serial_read_mp_int(upk, &p->share_out, &cfg->modulus)); + +cleanup: + return rv; +} + +SECStatus +PrioTotalShare_write(const_PrioTotalShare t, msgpack_packer* pk) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(t != NULL); + P_CHECKCB(pk != NULL); + P_CHECKC(msgpack_pack_int(pk, t->idx)); + P_CHECKC(serial_write_mp_array(pk, t->data_shares)); + +cleanup: + return rv; +} + +SECStatus +PrioTotalShare_read(PrioTotalShare t, msgpack_unpacker* upk, + const_PrioConfig cfg) +{ + SECStatus rv = SECSuccess; + P_CHECKCB(t != NULL); + P_CHECKCB(upk != NULL); + P_CHECKC(serial_read_server_id(upk, &t->idx)); + P_CHECKC(serial_read_mp_array(upk, t->data_shares, cfg->num_data_fields, + &cfg->modulus)); + +cleanup: + return rv; +} diff --git a/third_party/prio/prio/serial.h b/third_party/prio/prio/serial.h new file mode 100644 index 0000000000..8d69f2205c --- /dev/null +++ b/third_party/prio/prio/serial.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __SERIAL_H__ +#define __SERIAL_H__ + +#include <mprio.h> + +SECStatus serial_write_packet_client(msgpack_packer* pk, + const_PrioPacketClient p, + const_PrioConfig cfg); + +SECStatus serial_read_packet_client(msgpack_unpacker* upk, PrioPacketClient p, + const_PrioConfig cfg); + +#endif /* __SERIAL_H__ */ diff --git a/third_party/prio/prio/server.c b/third_party/prio/prio/server.c new file mode 100644 index 0000000000..f2eede18da --- /dev/null +++ b/third_party/prio/prio/server.c @@ -0,0 +1,509 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <mpi.h> +#include <mprio.h> +#include <stdio.h> +#include <stdlib.h> + +#include "client.h" +#include "mparray.h" +#include "poly.h" +#include "prg.h" +#include "server.h" +#include "util.h" + +/* In `PrioTotalShare_final`, we need to be able to store + * an `mp_digit` in an `unsigned long long`. + */ +#if (MP_DIGIT_MAX > ULLONG_MAX) +#error "Unsigned long long is not long enough to hold an MP digit" +#endif + +PrioServer +PrioServer_new(const_PrioConfig cfg, PrioServerId server_idx, + PrivateKey server_priv, const PrioPRGSeed seed) +{ + SECStatus rv = SECSuccess; + PrioServer s = malloc(sizeof(*s)); + if (!s) + return NULL; + s->cfg = cfg; + s->idx = server_idx; + s->priv_key = server_priv; + s->data_shares = NULL; + s->prg = NULL; + + P_CHECKA(s->data_shares = MPArray_new(s->cfg->num_data_fields)); + P_CHECKA(s->prg = PRG_new(seed)); + +cleanup: + if (rv != SECSuccess) { + PrioServer_clear(s); + return NULL; + } + + return s; +} + +void +PrioServer_clear(PrioServer s) +{ + if (!s) + return; + + PRG_clear(s->prg); + MPArray_clear(s->data_shares); + free(s); +} + +SECStatus +PrioServer_aggregate(PrioServer s, PrioVerifier v) +{ + MPArray arr = NULL; + switch (s->idx) { + case PRIO_SERVER_A: + arr = v->clientp->shares.A.data_shares; + break; + case PRIO_SERVER_B: + arr = v->data_sharesB; + break; + default: + // Should never get here + return SECFailure; + } + + return MPArray_addmod(s->data_shares, arr, &s->cfg->modulus); +} + +PrioTotalShare +PrioTotalShare_new(void) +{ + PrioTotalShare t = malloc(sizeof(*t)); + if (!t) + return NULL; + + t->data_shares = MPArray_new(0); + if (!t->data_shares) { + free(t); + return NULL; + } + + return t; +} + +void +PrioTotalShare_clear(PrioTotalShare t) +{ + if (!t) + return; + MPArray_clear(t->data_shares); + free(t); +} + +SECStatus +PrioTotalShare_set_data(PrioTotalShare t, const_PrioServer s) +{ + t->idx = s->idx; + SECStatus rv = SECSuccess; + + P_CHECK(MPArray_resize(t->data_shares, s->data_shares->len)); + P_CHECK(MPArray_copy(t->data_shares, s->data_shares)); + + return rv; +} + +SECStatus +PrioTotalShare_final(const_PrioConfig cfg, unsigned long long* output, + const_PrioTotalShare tA, const_PrioTotalShare tB) +{ + if (tA->data_shares->len != cfg->num_data_fields) + return SECFailure; + if (tA->data_shares->len != tB->data_shares->len) + return SECFailure; + if (tA->idx != PRIO_SERVER_A || tB->idx != PRIO_SERVER_B) + return SECFailure; + + SECStatus rv = SECSuccess; + + mp_int tmp; + MP_DIGITS(&tmp) = NULL; + MP_CHECKC(mp_init(&tmp)); + + for (int i = 0; i < cfg->num_data_fields; i++) { + MP_CHECKC(mp_addmod(&tA->data_shares->data[i], &tB->data_shares->data[i], + &cfg->modulus, &tmp)); + + if (MP_USED(&tmp) > 1) { + P_CHECKCB(false); + } + output[i] = MP_DIGIT(&tmp, 0); + } + +cleanup: + mp_clear(&tmp); + return rv; +} + +inline static mp_int* +get_data_share(const_PrioVerifier v, int i) +{ + switch (v->s->idx) { + case PRIO_SERVER_A: + return &v->clientp->shares.A.data_shares->data[i]; + case PRIO_SERVER_B: + return &v->data_sharesB->data[i]; + } + // Should never get here + return NULL; +} + +inline static mp_int* +get_h_share(const_PrioVerifier v, int i) +{ + switch (v->s->idx) { + case PRIO_SERVER_A: + return &v->clientp->shares.A.h_points->data[i]; + case PRIO_SERVER_B: + return &v->h_pointsB->data[i]; + } + // Should never get here + return NULL; +} + +/* + * Build shares of the polynomials f, g, and h used in the Prio verification + * routine and evalute these polynomials at a random point determined + * by the shared secret. Store the evaluations in the verifier object. + */ +static SECStatus +compute_shares(PrioVerifier v, const_PrioPacketClient p) +{ + SECStatus rv; + const int n = v->s->cfg->num_data_fields + 1; + const int N = next_power_of_two(n); + mp_int eval_at; + mp_int lower; + MP_DIGITS(&eval_at) = NULL; + MP_DIGITS(&lower) = NULL; + + MPArray points_f = NULL; + MPArray points_g = NULL; + MPArray points_h = NULL; + + MP_CHECKC(mp_init(&eval_at)); + MP_CHECKC(mp_init(&lower)); + P_CHECKA(points_f = MPArray_new(N)); + P_CHECKA(points_g = MPArray_new(N)); + P_CHECKA(points_h = MPArray_new(2 * N)); + + // Use PRG to generate random point. Per Appendix D.2 of full version of + // Prio paper, this value must be in the range + // [n+1, modulus). + mp_set(&lower, n + 1); + P_CHECKC(PRG_get_int_range(v->s->prg, &eval_at, &lower, &v->s->cfg->modulus)); + + // Reduce value into the field we're using. This + // doesn't yield exactly a uniformly random point, + // but for values this large, it will be close + // enough. + MP_CHECKC(mp_mod(&eval_at, &v->s->cfg->modulus, &eval_at)); + + // Client sends us the values of f(0) and g(0) + MP_CHECKC(mp_copy(&p->f0_share, &points_f->data[0])); + MP_CHECKC(mp_copy(&p->g0_share, &points_g->data[0])); + MP_CHECKC(mp_copy(&p->h0_share, &points_h->data[0])); + + for (int i = 1; i < n; i++) { + // [f](i) = i-th data share + const mp_int* data_i_minus_1 = get_data_share(v, i - 1); + MP_CHECKC(mp_copy(data_i_minus_1, &points_f->data[i])); + + // [g](i) = i-th data share minus 1 + // Only need to shift the share for 0-th server + MP_CHECKC(mp_copy(&points_f->data[i], &points_g->data[i])); + if (!v->s->idx) { + MP_CHECKC(mp_sub_d(&points_g->data[i], 1, &points_g->data[i])); + MP_CHECKC( + mp_mod(&points_g->data[i], &v->s->cfg->modulus, &points_g->data[i])); + } + } + + int j = 0; + for (int i = 1; i < 2 * N; i += 2) { + const mp_int* h_point_j = get_h_share(v, j++); + MP_CHECKC(mp_copy(h_point_j, &points_h->data[i])); + } + + P_CHECKC(poly_interp_evaluate(&v->share_fR, points_f, &eval_at, v->s->cfg)); + P_CHECKC(poly_interp_evaluate(&v->share_gR, points_g, &eval_at, v->s->cfg)); + P_CHECKC(poly_interp_evaluate(&v->share_hR, points_h, &eval_at, v->s->cfg)); + +cleanup: + MPArray_clear(points_f); + MPArray_clear(points_g); + MPArray_clear(points_h); + mp_clear(&eval_at); + mp_clear(&lower); + return rv; +} + +PrioVerifier +PrioVerifier_new(PrioServer s) +{ + SECStatus rv = SECSuccess; + PrioVerifier v = malloc(sizeof *v); + if (!v) + return NULL; + + v->s = s; + v->clientp = NULL; + v->data_sharesB = NULL; + v->h_pointsB = NULL; + + MP_DIGITS(&v->share_fR) = NULL; + MP_DIGITS(&v->share_gR) = NULL; + MP_DIGITS(&v->share_hR) = NULL; + + MP_CHECKC(mp_init(&v->share_fR)); + MP_CHECKC(mp_init(&v->share_gR)); + MP_CHECKC(mp_init(&v->share_hR)); + + P_CHECKA(v->clientp = PrioPacketClient_new(s->cfg, s->idx)); + + const int N = next_power_of_two(s->cfg->num_data_fields + 1); + if (v->s->idx == PRIO_SERVER_B) { + P_CHECKA(v->data_sharesB = MPArray_new(v->s->cfg->num_data_fields)); + P_CHECKA(v->h_pointsB = MPArray_new(N)); + } + +cleanup: + if (rv != SECSuccess) { + PrioVerifier_clear(v); + return NULL; + } + + return v; +} + +SECStatus +PrioVerifier_set_data(PrioVerifier v, unsigned char* data, + unsigned int data_len) +{ + SECStatus rv = SECSuccess; + PRG prgB = NULL; + P_CHECKC(PrioPacketClient_decrypt(v->clientp, v->s->cfg, v->s->priv_key, data, + data_len)); + + PrioPacketClient p = v->clientp; + if (p->for_server != v->s->idx) + return SECFailure; + + const int N = next_power_of_two(v->s->cfg->num_data_fields + 1); + if (v->s->idx == PRIO_SERVER_A) { + // Check that packet has the correct number of data fields + if (p->shares.A.data_shares->len != v->s->cfg->num_data_fields) + return SECFailure; + if (p->shares.A.h_points->len != N) + return SECFailure; + } + + if (v->s->idx == PRIO_SERVER_B) { + P_CHECKA(prgB = PRG_new(v->clientp->shares.B.seed)); + P_CHECKC(PRG_get_array(prgB, v->data_sharesB, &v->s->cfg->modulus)); + P_CHECKC(PRG_get_array(prgB, v->h_pointsB, &v->s->cfg->modulus)); + } + + // TODO: This can be done much faster by using the combined + // interpolate-and-evaluate optimization described in the + // Prio paper. + // + // Compute share of f(r), g(r), h(r) + P_CHECKC(compute_shares(v, p)); + +cleanup: + + PRG_clear(prgB); + return rv; +} + +void +PrioVerifier_clear(PrioVerifier v) +{ + if (v == NULL) + return; + PrioPacketClient_clear(v->clientp); + MPArray_clear(v->data_sharesB); + MPArray_clear(v->h_pointsB); + mp_clear(&v->share_fR); + mp_clear(&v->share_gR); + mp_clear(&v->share_hR); + free(v); +} + +PrioPacketVerify1 +PrioPacketVerify1_new(void) +{ + SECStatus rv = SECSuccess; + PrioPacketVerify1 p = malloc(sizeof *p); + if (!p) + return NULL; + + MP_DIGITS(&p->share_d) = NULL; + MP_DIGITS(&p->share_e) = NULL; + + MP_CHECKC(mp_init(&p->share_d)); + MP_CHECKC(mp_init(&p->share_e)); + +cleanup: + if (rv != SECSuccess) { + PrioPacketVerify1_clear(p); + return NULL; + } + + return p; +} + +void +PrioPacketVerify1_clear(PrioPacketVerify1 p) +{ + if (!p) + return; + mp_clear(&p->share_d); + mp_clear(&p->share_e); + free(p); +} + +SECStatus +PrioPacketVerify1_set_data(PrioPacketVerify1 p1, const_PrioVerifier v) +{ + // See the Prio paper for details on how this works. + // Appendix C descrives the MPC protocol used here. + + SECStatus rv = SECSuccess; + + // Compute corrections. + // [d] = [f(r)] - [a] + MP_CHECK(mp_sub(&v->share_fR, &v->clientp->triple->a, &p1->share_d)); + MP_CHECK(mp_mod(&p1->share_d, &v->s->cfg->modulus, &p1->share_d)); + + // [e] = [g(r)] - [b] + MP_CHECK(mp_sub(&v->share_gR, &v->clientp->triple->b, &p1->share_e)); + MP_CHECK(mp_mod(&p1->share_e, &v->s->cfg->modulus, &p1->share_e)); + + return rv; +} + +PrioPacketVerify2 +PrioPacketVerify2_new(void) +{ + SECStatus rv = SECSuccess; + PrioPacketVerify2 p = malloc(sizeof *p); + if (!p) + return NULL; + + MP_DIGITS(&p->share_out) = NULL; + MP_CHECKC(mp_init(&p->share_out)); + +cleanup: + if (rv != SECSuccess) { + PrioPacketVerify2_clear(p); + return NULL; + } + return p; +} + +void +PrioPacketVerify2_clear(PrioPacketVerify2 p) +{ + if (!p) + return; + mp_clear(&p->share_out); + free(p); +} + +SECStatus +PrioPacketVerify2_set_data(PrioPacketVerify2 p2, const_PrioVerifier v, + const_PrioPacketVerify1 p1A, + const_PrioPacketVerify1 p1B) +{ + SECStatus rv = SECSuccess; + + mp_int d, e, tmp; + MP_DIGITS(&d) = NULL; + MP_DIGITS(&e) = NULL; + MP_DIGITS(&tmp) = NULL; + + MP_CHECKC(mp_init(&d)); + MP_CHECKC(mp_init(&e)); + MP_CHECKC(mp_init(&tmp)); + + const mp_int* mod = &v->s->cfg->modulus; + + // Compute share of f(r)*g(r) + // [f(r)*g(r)] = [d*e/2] + d[b] + e[a] + [c] + + // Compute d + MP_CHECKC(mp_addmod(&p1A->share_d, &p1B->share_d, mod, &d)); + // Compute e + MP_CHECKC(mp_addmod(&p1A->share_e, &p1B->share_e, mod, &e)); + + // Compute d*e + MP_CHECKC(mp_mulmod(&d, &e, mod, &p2->share_out)); + // out = d*e/2 + MP_CHECKC(mp_mulmod(&p2->share_out, &v->s->cfg->inv2, mod, &p2->share_out)); + + // Compute d[b] + MP_CHECKC(mp_mulmod(&d, &v->clientp->triple->b, mod, &tmp)); + // out = d*e/2 + d[b] + MP_CHECKC(mp_addmod(&p2->share_out, &tmp, mod, &p2->share_out)); + + // Compute e[a] + MP_CHECKC(mp_mulmod(&e, &v->clientp->triple->a, mod, &tmp)); + // out = d*e/2 + d[b] + e[a] + MP_CHECKC(mp_addmod(&p2->share_out, &tmp, mod, &p2->share_out)); + + // out = d*e/2 + d[b] + e[a] + [c] + MP_CHECKC( + mp_addmod(&p2->share_out, &v->clientp->triple->c, mod, &p2->share_out)); + + // We want to compute f(r)*g(r) - h(r), + // so subtract off [h(r)]: + // out = d*e/2 + d[b] + e[a] + [c] - [h(r)] + MP_CHECKC(mp_sub(&p2->share_out, &v->share_hR, &p2->share_out)); + MP_CHECKC(mp_mod(&p2->share_out, mod, &p2->share_out)); + +cleanup: + mp_clear(&d); + mp_clear(&e); + mp_clear(&tmp); + return rv; +} + +int +PrioVerifier_isValid(const_PrioVerifier v, const_PrioPacketVerify2 pA, + const_PrioPacketVerify2 pB) +{ + SECStatus rv = SECSuccess; + mp_int res; + MP_DIGITS(&res) = NULL; + MP_CHECKC(mp_init(&res)); + + // Add up the shares of the output wire value and + // ensure that the sum is equal to zero, which indicates + // that + // f(r) * g(r) == h(r). + MP_CHECKC( + mp_addmod(&pA->share_out, &pB->share_out, &v->s->cfg->modulus, &res)); + + rv = (mp_cmp_d(&res, 0) == 0) ? SECSuccess : SECFailure; + +cleanup: + mp_clear(&res); + return rv; +} diff --git a/third_party/prio/prio/server.h b/third_party/prio/prio/server.h new file mode 100644 index 0000000000..ce549013ae --- /dev/null +++ b/third_party/prio/prio/server.h @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __SERVER_H__ +#define __SERVER_H__ + +#include "mparray.h" +#include "prg.h" +#include "share.h" + +struct prio_total_share +{ + PrioServerId idx; + MPArray data_shares; +}; + +struct prio_server +{ + const_PrioConfig cfg; + PrioServerId idx; + + // Sever's private decryption key + PrivateKey priv_key; + + // The accumulated data values from the clients. + MPArray data_shares; + + // PRG used to generate randomness for checking the client + // data packets. Both servers initialize this PRG with the + // same shared seed. + PRG prg; +}; + +struct prio_verifier +{ + PrioServer s; + + PrioPacketClient clientp; + MPArray data_sharesB; + MPArray h_pointsB; + + mp_int share_fR; + mp_int share_gR; + mp_int share_hR; + mp_int share_out; +}; + +struct prio_packet_verify1 +{ + mp_int share_d; + mp_int share_e; +}; + +struct prio_packet_verify2 +{ + mp_int share_out; +}; + +#endif /* __SERVER_H__ */ diff --git a/third_party/prio/prio/share.c b/third_party/prio/prio/share.c new file mode 100644 index 0000000000..6349e82329 --- /dev/null +++ b/third_party/prio/prio/share.c @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <mprio.h> + +#include "rand.h" +#include "share.h" +#include "util.h" + +SECStatus +share_int(const struct prio_config* cfg, const mp_int* src, mp_int* shareA, + mp_int* shareB) +{ + SECStatus rv; + P_CHECK(rand_int(shareA, &cfg->modulus)); + MP_CHECK(mp_submod(src, shareA, &cfg->modulus, shareB)); + + return rv; +} + +BeaverTriple +BeaverTriple_new(void) +{ + BeaverTriple triple = malloc(sizeof *triple); + if (!triple) + return NULL; + + MP_DIGITS(&triple->a) = NULL; + MP_DIGITS(&triple->b) = NULL; + MP_DIGITS(&triple->c) = NULL; + + SECStatus rv = SECSuccess; + MP_CHECKC(mp_init(&triple->a)); + MP_CHECKC(mp_init(&triple->b)); + MP_CHECKC(mp_init(&triple->c)); + +cleanup: + if (rv != SECSuccess) { + BeaverTriple_clear(triple); + return NULL; + } + return triple; +} + +void +BeaverTriple_clear(BeaverTriple triple) +{ + if (!triple) + return; + mp_clear(&triple->a); + mp_clear(&triple->b); + mp_clear(&triple->c); + free(triple); +} + +SECStatus +BeaverTriple_set_rand(const struct prio_config* cfg, + struct beaver_triple* triple_1, + struct beaver_triple* triple_2) +{ + SECStatus rv = SECSuccess; + + // TODO: Can shorten this code using share_int() + + // We need that + // (a1 + a2)(b1 + b2) = c1 + c2 (mod p) + P_CHECK(rand_int(&triple_1->a, &cfg->modulus)); + P_CHECK(rand_int(&triple_1->b, &cfg->modulus)); + P_CHECK(rand_int(&triple_2->a, &cfg->modulus)); + P_CHECK(rand_int(&triple_2->b, &cfg->modulus)); + + // We are trying to be a little clever here to avoid the use of temp + // variables. + + // c1 = a1 + a2 + MP_CHECK(mp_addmod(&triple_1->a, &triple_2->a, &cfg->modulus, &triple_1->c)); + + // c2 = b1 + b2 + MP_CHECK(mp_addmod(&triple_1->b, &triple_2->b, &cfg->modulus, &triple_2->c)); + + // c1 = c1 * c2 = (a1 + a2) (b1 + b2) + MP_CHECK(mp_mulmod(&triple_1->c, &triple_2->c, &cfg->modulus, &triple_1->c)); + + // Set c2 to random blinding value + MP_CHECK(rand_int(&triple_2->c, &cfg->modulus)); + + // c1 = c1 - c2 + MP_CHECK(mp_submod(&triple_1->c, &triple_2->c, &cfg->modulus, &triple_1->c)); + + // Now we should have random tuples satisfying: + // (a1 + a2) (b1 + b2) = c1 + c2 + + return rv; +} + +bool +BeaverTriple_areEqual(const_BeaverTriple t1, const_BeaverTriple t2) +{ + return (mp_cmp(&t1->a, &t2->a) == 0 && mp_cmp(&t1->b, &t2->b) == 0 && + mp_cmp(&t1->c, &t2->c) == 0); +} diff --git a/third_party/prio/prio/share.h b/third_party/prio/prio/share.h new file mode 100644 index 0000000000..2a8ad3b4c5 --- /dev/null +++ b/third_party/prio/prio/share.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __SHARE_H__ +#define __SHARE_H__ + +#include <mpi.h> + +#include "config.h" + +struct beaver_triple +{ + mp_int a; + mp_int b; + mp_int c; +}; + +typedef struct beaver_triple* BeaverTriple; +typedef const struct beaver_triple* const_BeaverTriple; + +/* + * Use secret sharing to split the int src into two shares. + * The mp_ints must be initialized. + */ +SECStatus share_int(const_PrioConfig cfg, const mp_int* src, mp_int* shareA, + mp_int* shareB); + +/* + * Prio uses Beaver triples to implement one step of the + * client data validation routine. A Beaver triple is just + * a sharing of random values a, b, c such that + * a * b = c + */ +BeaverTriple BeaverTriple_new(void); +void BeaverTriple_clear(BeaverTriple t); + +SECStatus BeaverTriple_set_rand(const_PrioConfig cfg, BeaverTriple triple_a, + BeaverTriple triple_b); + +bool BeaverTriple_areEqual(const_BeaverTriple t1, const_BeaverTriple t2); + +#endif /* __SHARE_H__ */ diff --git a/third_party/prio/prio/util.h b/third_party/prio/prio/util.h new file mode 100644 index 0000000000..a9a3a4626f --- /dev/null +++ b/third_party/prio/prio/util.h @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#ifndef __UTIL_H__ +#define __UTIL_H__ + +#include <mpi.h> +#include <mprio.h> + +// Minimum of two values +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +// Check a Prio error code and return failure if the call fails. +#define P_CHECK(s) \ + do { \ + if ((rv = (s)) != SECSuccess) \ + return rv; \ + } while (0); + +// Check an allocation that should not return NULL. If the allocation returns +// NULL, set the return value and jump to the cleanup label to free memory. +#define P_CHECKA(s) \ + do { \ + if ((s) == NULL) { \ + rv = SECFailure; \ + goto cleanup; \ + } \ + } while (0); + +// Check a Prio library call that should return SECSuccess. If it doesn't, +// jump to the cleanup label. +#define P_CHECKC(s) \ + do { \ + if ((rv = (s)) != SECSuccess) { \ + goto cleanup; \ + } \ + } while (0); + +// Check a boolean that should be true. If it not, +// jump to the cleanup label. +#define P_CHECKCB(s) \ + do { \ + if (!(s)) { \ + rv = SECFailure; \ + goto cleanup; \ + } \ + } while (0); + +// Check an MPI library call and return failure if it fails. +#define MP_CHECK(s) \ + do { \ + if ((s) != MP_OKAY) \ + return SECFailure; \ + } while (0); + +// Check a msgpack object unpacked correctly. If +// not, jump to the cleanup label. +#define UP_CHECKC(s) \ + do { \ + int r = (s); \ + if (r != MSGPACK_UNPACK_SUCCESS && r != MSGPACK_UNPACK_EXTRA_BYTES) { \ + rv = SECFailure; \ + goto cleanup; \ + } \ + } while (0); + +// Check an MPI library call. If it fails, set the return code and jump +// to the cleanup label. +#define MP_CHECKC(s) \ + do { \ + if ((s) != MP_OKAY) { \ + rv = SECFailure; \ + goto cleanup; \ + } \ + } while (0); + +static inline int +next_power_of_two(int val) +{ + int i = val; + int out = 0; + for (; i > 0; i >>= 1) { + out++; + } + + int pow = 1 << out; + return (pow > 1 && pow / 2 == val) ? val : pow; +} + +/* + * Return a mask that masks out all of the zero bits + */ +static inline unsigned char +msb_mask(unsigned char val) +{ + unsigned char mask; + for (mask = 0x00; (val & mask) != val; mask = (mask << 1) + 1) + ; + return mask; +} + +/* + * Specify that a parameter should be unused. + */ +#define UNUSED(x) (void)(x) + +#endif /* __UTIL_H__ */ diff --git a/third_party/prio/update.sh b/third_party/prio/update.sh new file mode 100644 index 0000000000..70fc327d24 --- /dev/null +++ b/third_party/prio/update.sh @@ -0,0 +1,30 @@ +#!/bin/sh + +# Script to update the mozilla in-tree copy of the libprio library. +# Run this within the /third_party/libprio directory of the source tree. + +MY_TEMP_DIR=`mktemp -d -t libprio_update.XXXXXX` || exit 1 + +COMMIT="52643cefe6662b4099e16a40a057cb60651ab001" + +git clone -n https://github.com/mozilla/libprio ${MY_TEMP_DIR}/libprio +git -C ${MY_TEMP_DIR}/libprio checkout ${COMMIT} + +FILES="include prio" +VERSION=$(git -C ${MY_TEMP_DIR}/libprio describe --tags) +perl -p -i -e "s/Current version: \S+ \[commit [0-9a-f]{40}\]/Current version: ${VERSION} [commit ${COMMIT}]/" README-mozilla + +for f in $FILES; do + rm -rf $f + mv ${MY_TEMP_DIR}/libprio/$f $f +done + +rm -rf ${MY_TEMP_DIR} + +hg revert -r . moz.build +hg addremove . + +echo "###" +echo "### Updated libprio to $COMMIT." +echo "### Remember to verify and commit the changes to source control!" +echo "###" |