diff options
Diffstat (limited to 'third_party/prio/prio/server.c')
-rw-r--r-- | third_party/prio/prio/server.c | 509 |
1 files changed, 509 insertions, 0 deletions
diff --git a/third_party/prio/prio/server.c b/third_party/prio/prio/server.c new file mode 100644 index 0000000000..f2eede18da --- /dev/null +++ b/third_party/prio/prio/server.c @@ -0,0 +1,509 @@ +/* + * Copyright (c) 2018, Henry Corrigan-Gibbs + * + * This Source Code Form is subject to the terms of the Mozilla Public + * License, v. 2.0. If a copy of the MPL was not distributed with this + * file, You can obtain one at http://mozilla.org/MPL/2.0/. + */ + +#include <mpi.h> +#include <mprio.h> +#include <stdio.h> +#include <stdlib.h> + +#include "client.h" +#include "mparray.h" +#include "poly.h" +#include "prg.h" +#include "server.h" +#include "util.h" + +/* In `PrioTotalShare_final`, we need to be able to store + * an `mp_digit` in an `unsigned long long`. + */ +#if (MP_DIGIT_MAX > ULLONG_MAX) +#error "Unsigned long long is not long enough to hold an MP digit" +#endif + +PrioServer +PrioServer_new(const_PrioConfig cfg, PrioServerId server_idx, + PrivateKey server_priv, const PrioPRGSeed seed) +{ + SECStatus rv = SECSuccess; + PrioServer s = malloc(sizeof(*s)); + if (!s) + return NULL; + s->cfg = cfg; + s->idx = server_idx; + s->priv_key = server_priv; + s->data_shares = NULL; + s->prg = NULL; + + P_CHECKA(s->data_shares = MPArray_new(s->cfg->num_data_fields)); + P_CHECKA(s->prg = PRG_new(seed)); + +cleanup: + if (rv != SECSuccess) { + PrioServer_clear(s); + return NULL; + } + + return s; +} + +void +PrioServer_clear(PrioServer s) +{ + if (!s) + return; + + PRG_clear(s->prg); + MPArray_clear(s->data_shares); + free(s); +} + +SECStatus +PrioServer_aggregate(PrioServer s, PrioVerifier v) +{ + MPArray arr = NULL; + switch (s->idx) { + case PRIO_SERVER_A: + arr = v->clientp->shares.A.data_shares; + break; + case PRIO_SERVER_B: + arr = v->data_sharesB; + break; + default: + // Should never get here + return SECFailure; + } + + return MPArray_addmod(s->data_shares, arr, &s->cfg->modulus); +} + +PrioTotalShare +PrioTotalShare_new(void) +{ + PrioTotalShare t = malloc(sizeof(*t)); + if (!t) + return NULL; + + t->data_shares = MPArray_new(0); + if (!t->data_shares) { + free(t); + return NULL; + } + + return t; +} + +void +PrioTotalShare_clear(PrioTotalShare t) +{ + if (!t) + return; + MPArray_clear(t->data_shares); + free(t); +} + +SECStatus +PrioTotalShare_set_data(PrioTotalShare t, const_PrioServer s) +{ + t->idx = s->idx; + SECStatus rv = SECSuccess; + + P_CHECK(MPArray_resize(t->data_shares, s->data_shares->len)); + P_CHECK(MPArray_copy(t->data_shares, s->data_shares)); + + return rv; +} + +SECStatus +PrioTotalShare_final(const_PrioConfig cfg, unsigned long long* output, + const_PrioTotalShare tA, const_PrioTotalShare tB) +{ + if (tA->data_shares->len != cfg->num_data_fields) + return SECFailure; + if (tA->data_shares->len != tB->data_shares->len) + return SECFailure; + if (tA->idx != PRIO_SERVER_A || tB->idx != PRIO_SERVER_B) + return SECFailure; + + SECStatus rv = SECSuccess; + + mp_int tmp; + MP_DIGITS(&tmp) = NULL; + MP_CHECKC(mp_init(&tmp)); + + for (int i = 0; i < cfg->num_data_fields; i++) { + MP_CHECKC(mp_addmod(&tA->data_shares->data[i], &tB->data_shares->data[i], + &cfg->modulus, &tmp)); + + if (MP_USED(&tmp) > 1) { + P_CHECKCB(false); + } + output[i] = MP_DIGIT(&tmp, 0); + } + +cleanup: + mp_clear(&tmp); + return rv; +} + +inline static mp_int* +get_data_share(const_PrioVerifier v, int i) +{ + switch (v->s->idx) { + case PRIO_SERVER_A: + return &v->clientp->shares.A.data_shares->data[i]; + case PRIO_SERVER_B: + return &v->data_sharesB->data[i]; + } + // Should never get here + return NULL; +} + +inline static mp_int* +get_h_share(const_PrioVerifier v, int i) +{ + switch (v->s->idx) { + case PRIO_SERVER_A: + return &v->clientp->shares.A.h_points->data[i]; + case PRIO_SERVER_B: + return &v->h_pointsB->data[i]; + } + // Should never get here + return NULL; +} + +/* + * Build shares of the polynomials f, g, and h used in the Prio verification + * routine and evalute these polynomials at a random point determined + * by the shared secret. Store the evaluations in the verifier object. + */ +static SECStatus +compute_shares(PrioVerifier v, const_PrioPacketClient p) +{ + SECStatus rv; + const int n = v->s->cfg->num_data_fields + 1; + const int N = next_power_of_two(n); + mp_int eval_at; + mp_int lower; + MP_DIGITS(&eval_at) = NULL; + MP_DIGITS(&lower) = NULL; + + MPArray points_f = NULL; + MPArray points_g = NULL; + MPArray points_h = NULL; + + MP_CHECKC(mp_init(&eval_at)); + MP_CHECKC(mp_init(&lower)); + P_CHECKA(points_f = MPArray_new(N)); + P_CHECKA(points_g = MPArray_new(N)); + P_CHECKA(points_h = MPArray_new(2 * N)); + + // Use PRG to generate random point. Per Appendix D.2 of full version of + // Prio paper, this value must be in the range + // [n+1, modulus). + mp_set(&lower, n + 1); + P_CHECKC(PRG_get_int_range(v->s->prg, &eval_at, &lower, &v->s->cfg->modulus)); + + // Reduce value into the field we're using. This + // doesn't yield exactly a uniformly random point, + // but for values this large, it will be close + // enough. + MP_CHECKC(mp_mod(&eval_at, &v->s->cfg->modulus, &eval_at)); + + // Client sends us the values of f(0) and g(0) + MP_CHECKC(mp_copy(&p->f0_share, &points_f->data[0])); + MP_CHECKC(mp_copy(&p->g0_share, &points_g->data[0])); + MP_CHECKC(mp_copy(&p->h0_share, &points_h->data[0])); + + for (int i = 1; i < n; i++) { + // [f](i) = i-th data share + const mp_int* data_i_minus_1 = get_data_share(v, i - 1); + MP_CHECKC(mp_copy(data_i_minus_1, &points_f->data[i])); + + // [g](i) = i-th data share minus 1 + // Only need to shift the share for 0-th server + MP_CHECKC(mp_copy(&points_f->data[i], &points_g->data[i])); + if (!v->s->idx) { + MP_CHECKC(mp_sub_d(&points_g->data[i], 1, &points_g->data[i])); + MP_CHECKC( + mp_mod(&points_g->data[i], &v->s->cfg->modulus, &points_g->data[i])); + } + } + + int j = 0; + for (int i = 1; i < 2 * N; i += 2) { + const mp_int* h_point_j = get_h_share(v, j++); + MP_CHECKC(mp_copy(h_point_j, &points_h->data[i])); + } + + P_CHECKC(poly_interp_evaluate(&v->share_fR, points_f, &eval_at, v->s->cfg)); + P_CHECKC(poly_interp_evaluate(&v->share_gR, points_g, &eval_at, v->s->cfg)); + P_CHECKC(poly_interp_evaluate(&v->share_hR, points_h, &eval_at, v->s->cfg)); + +cleanup: + MPArray_clear(points_f); + MPArray_clear(points_g); + MPArray_clear(points_h); + mp_clear(&eval_at); + mp_clear(&lower); + return rv; +} + +PrioVerifier +PrioVerifier_new(PrioServer s) +{ + SECStatus rv = SECSuccess; + PrioVerifier v = malloc(sizeof *v); + if (!v) + return NULL; + + v->s = s; + v->clientp = NULL; + v->data_sharesB = NULL; + v->h_pointsB = NULL; + + MP_DIGITS(&v->share_fR) = NULL; + MP_DIGITS(&v->share_gR) = NULL; + MP_DIGITS(&v->share_hR) = NULL; + + MP_CHECKC(mp_init(&v->share_fR)); + MP_CHECKC(mp_init(&v->share_gR)); + MP_CHECKC(mp_init(&v->share_hR)); + + P_CHECKA(v->clientp = PrioPacketClient_new(s->cfg, s->idx)); + + const int N = next_power_of_two(s->cfg->num_data_fields + 1); + if (v->s->idx == PRIO_SERVER_B) { + P_CHECKA(v->data_sharesB = MPArray_new(v->s->cfg->num_data_fields)); + P_CHECKA(v->h_pointsB = MPArray_new(N)); + } + +cleanup: + if (rv != SECSuccess) { + PrioVerifier_clear(v); + return NULL; + } + + return v; +} + +SECStatus +PrioVerifier_set_data(PrioVerifier v, unsigned char* data, + unsigned int data_len) +{ + SECStatus rv = SECSuccess; + PRG prgB = NULL; + P_CHECKC(PrioPacketClient_decrypt(v->clientp, v->s->cfg, v->s->priv_key, data, + data_len)); + + PrioPacketClient p = v->clientp; + if (p->for_server != v->s->idx) + return SECFailure; + + const int N = next_power_of_two(v->s->cfg->num_data_fields + 1); + if (v->s->idx == PRIO_SERVER_A) { + // Check that packet has the correct number of data fields + if (p->shares.A.data_shares->len != v->s->cfg->num_data_fields) + return SECFailure; + if (p->shares.A.h_points->len != N) + return SECFailure; + } + + if (v->s->idx == PRIO_SERVER_B) { + P_CHECKA(prgB = PRG_new(v->clientp->shares.B.seed)); + P_CHECKC(PRG_get_array(prgB, v->data_sharesB, &v->s->cfg->modulus)); + P_CHECKC(PRG_get_array(prgB, v->h_pointsB, &v->s->cfg->modulus)); + } + + // TODO: This can be done much faster by using the combined + // interpolate-and-evaluate optimization described in the + // Prio paper. + // + // Compute share of f(r), g(r), h(r) + P_CHECKC(compute_shares(v, p)); + +cleanup: + + PRG_clear(prgB); + return rv; +} + +void +PrioVerifier_clear(PrioVerifier v) +{ + if (v == NULL) + return; + PrioPacketClient_clear(v->clientp); + MPArray_clear(v->data_sharesB); + MPArray_clear(v->h_pointsB); + mp_clear(&v->share_fR); + mp_clear(&v->share_gR); + mp_clear(&v->share_hR); + free(v); +} + +PrioPacketVerify1 +PrioPacketVerify1_new(void) +{ + SECStatus rv = SECSuccess; + PrioPacketVerify1 p = malloc(sizeof *p); + if (!p) + return NULL; + + MP_DIGITS(&p->share_d) = NULL; + MP_DIGITS(&p->share_e) = NULL; + + MP_CHECKC(mp_init(&p->share_d)); + MP_CHECKC(mp_init(&p->share_e)); + +cleanup: + if (rv != SECSuccess) { + PrioPacketVerify1_clear(p); + return NULL; + } + + return p; +} + +void +PrioPacketVerify1_clear(PrioPacketVerify1 p) +{ + if (!p) + return; + mp_clear(&p->share_d); + mp_clear(&p->share_e); + free(p); +} + +SECStatus +PrioPacketVerify1_set_data(PrioPacketVerify1 p1, const_PrioVerifier v) +{ + // See the Prio paper for details on how this works. + // Appendix C descrives the MPC protocol used here. + + SECStatus rv = SECSuccess; + + // Compute corrections. + // [d] = [f(r)] - [a] + MP_CHECK(mp_sub(&v->share_fR, &v->clientp->triple->a, &p1->share_d)); + MP_CHECK(mp_mod(&p1->share_d, &v->s->cfg->modulus, &p1->share_d)); + + // [e] = [g(r)] - [b] + MP_CHECK(mp_sub(&v->share_gR, &v->clientp->triple->b, &p1->share_e)); + MP_CHECK(mp_mod(&p1->share_e, &v->s->cfg->modulus, &p1->share_e)); + + return rv; +} + +PrioPacketVerify2 +PrioPacketVerify2_new(void) +{ + SECStatus rv = SECSuccess; + PrioPacketVerify2 p = malloc(sizeof *p); + if (!p) + return NULL; + + MP_DIGITS(&p->share_out) = NULL; + MP_CHECKC(mp_init(&p->share_out)); + +cleanup: + if (rv != SECSuccess) { + PrioPacketVerify2_clear(p); + return NULL; + } + return p; +} + +void +PrioPacketVerify2_clear(PrioPacketVerify2 p) +{ + if (!p) + return; + mp_clear(&p->share_out); + free(p); +} + +SECStatus +PrioPacketVerify2_set_data(PrioPacketVerify2 p2, const_PrioVerifier v, + const_PrioPacketVerify1 p1A, + const_PrioPacketVerify1 p1B) +{ + SECStatus rv = SECSuccess; + + mp_int d, e, tmp; + MP_DIGITS(&d) = NULL; + MP_DIGITS(&e) = NULL; + MP_DIGITS(&tmp) = NULL; + + MP_CHECKC(mp_init(&d)); + MP_CHECKC(mp_init(&e)); + MP_CHECKC(mp_init(&tmp)); + + const mp_int* mod = &v->s->cfg->modulus; + + // Compute share of f(r)*g(r) + // [f(r)*g(r)] = [d*e/2] + d[b] + e[a] + [c] + + // Compute d + MP_CHECKC(mp_addmod(&p1A->share_d, &p1B->share_d, mod, &d)); + // Compute e + MP_CHECKC(mp_addmod(&p1A->share_e, &p1B->share_e, mod, &e)); + + // Compute d*e + MP_CHECKC(mp_mulmod(&d, &e, mod, &p2->share_out)); + // out = d*e/2 + MP_CHECKC(mp_mulmod(&p2->share_out, &v->s->cfg->inv2, mod, &p2->share_out)); + + // Compute d[b] + MP_CHECKC(mp_mulmod(&d, &v->clientp->triple->b, mod, &tmp)); + // out = d*e/2 + d[b] + MP_CHECKC(mp_addmod(&p2->share_out, &tmp, mod, &p2->share_out)); + + // Compute e[a] + MP_CHECKC(mp_mulmod(&e, &v->clientp->triple->a, mod, &tmp)); + // out = d*e/2 + d[b] + e[a] + MP_CHECKC(mp_addmod(&p2->share_out, &tmp, mod, &p2->share_out)); + + // out = d*e/2 + d[b] + e[a] + [c] + MP_CHECKC( + mp_addmod(&p2->share_out, &v->clientp->triple->c, mod, &p2->share_out)); + + // We want to compute f(r)*g(r) - h(r), + // so subtract off [h(r)]: + // out = d*e/2 + d[b] + e[a] + [c] - [h(r)] + MP_CHECKC(mp_sub(&p2->share_out, &v->share_hR, &p2->share_out)); + MP_CHECKC(mp_mod(&p2->share_out, mod, &p2->share_out)); + +cleanup: + mp_clear(&d); + mp_clear(&e); + mp_clear(&tmp); + return rv; +} + +int +PrioVerifier_isValid(const_PrioVerifier v, const_PrioPacketVerify2 pA, + const_PrioPacketVerify2 pB) +{ + SECStatus rv = SECSuccess; + mp_int res; + MP_DIGITS(&res) = NULL; + MP_CHECKC(mp_init(&res)); + + // Add up the shares of the output wire value and + // ensure that the sum is equal to zero, which indicates + // that + // f(r) * g(r) == h(r). + MP_CHECKC( + mp_addmod(&pA->share_out, &pB->share_out, &v->s->cfg->modulus, &res)); + + rv = (mp_cmp_d(&res, 0) == 0) ? SECSuccess : SECFailure; + +cleanup: + mp_clear(&res); + return rv; +} |