// SPDX-License-Identifier: MPL-2.0

use crate::{
    codec::{Encode, ParameterizedDecode},
    flp::Type,
    vdaf::{
        prg::Prg,
        prio3::{Prio3, Prio3InputShare, Prio3PrepareShare},
        Aggregator, PrepareTransition,
    },
};
use serde::{Deserialize, Serialize};
use std::{convert::TryInto, fmt::Debug};

#[derive(Debug, Deserialize, Serialize)]
struct TEncoded(#[serde(with = "hex")] Vec<u8>);

impl AsRef<[u8]> for TEncoded {
    fn as_ref(&self) -> &[u8] {
        &self.0
    }
}

#[derive(Deserialize, Serialize)]
struct TPrio3Prep<M> {
    measurement: M,
    #[serde(with = "hex")]
    nonce: Vec<u8>,
    input_shares: Vec<TEncoded>,
    prep_shares: Vec<Vec<TEncoded>>,
    prep_messages: Vec<TEncoded>,
    out_shares: Vec<Vec<M>>,
}

#[derive(Deserialize, Serialize)]
struct TPrio3<M> {
    verify_key: TEncoded,
    prep: Vec<TPrio3Prep<M>>,
}

macro_rules! err {
    (
        $test_num:ident,
        $error:expr,
        $msg:expr
    ) => {
        panic!("test #{} failed: {} err: {}", $test_num, $msg, $error)
    };
}

// TODO Generalize this method to work with any VDAF. To do so we would need to add
// `test_vec_setup()` and `test_vec_shard()` to traits. (There may be a less invasive alternative.)
fn check_prep_test_vec<M, T, P, const L: usize>(
    prio3: &Prio3<T, P, L>,
    verify_key: &[u8; L],
    test_num: usize,
    t: &TPrio3Prep<M>,
) where
    T: Type<Measurement = M>,
    P: Prg<L>,
    M: From<<T as Type>::Field> + Debug + PartialEq,
{
    let input_shares = prio3
        .test_vec_shard(&t.measurement)
        .expect("failed to generate input shares");

    assert_eq!(2, t.input_shares.len(), "#{}", test_num);
    for (agg_id, want) in t.input_shares.iter().enumerate() {
        assert_eq!(
            input_shares[agg_id],
            Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), want.as_ref())
                .unwrap_or_else(|e| err!(test_num, e, "decode test vector (input share)")),
            "#{}",
            test_num
        );
        assert_eq!(
            input_shares[agg_id].get_encoded(),
            want.as_ref(),
            "#{}",
            test_num
        )
    }

    let mut states = Vec::new();
    let mut prep_shares = Vec::new();
    for (agg_id, input_share) in input_shares.iter().enumerate() {
        let (state, prep_share) = prio3
            .prepare_init(verify_key, agg_id, &(), &t.nonce, &(), input_share)
            .unwrap_or_else(|e| err!(test_num, e, "prep state init"));
        states.push(state);
        prep_shares.push(prep_share);
    }

    assert_eq!(1, t.prep_shares.len(), "#{}", test_num);
    for (i, want) in t.prep_shares[0].iter().enumerate() {
        assert_eq!(
            prep_shares[i],
            Prio3PrepareShare::get_decoded_with_param(&states[i], want.as_ref())
                .unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")),
            "#{}",
            test_num
        );
        assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{}", test_num);
    }

    let inbound = prio3
        .prepare_preprocess(prep_shares)
        .unwrap_or_else(|e| err!(test_num, e, "prep preprocess"));
    assert_eq!(t.prep_messages.len(), 1);
    assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref());

    let mut out_shares = Vec::new();
    for state in states.iter_mut() {
        match prio3.prepare_step(state.clone(), inbound.clone()).unwrap() {
            PrepareTransition::Finish(out_share) => {
                out_shares.push(out_share);
            }
            _ => panic!("unexpected transition"),
        }
    }

    for (got, want) in out_shares.iter().zip(t.out_shares.iter()) {
        let got: Vec<M> = got.as_ref().iter().map(|x| M::from(*x)).collect();
        assert_eq!(&got, want);
    }
}

#[test]
fn test_vec_prio3_count() {
    let t: TPrio3<u64> =
        serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Count_0.json")).unwrap();
    let prio3 = Prio3::new_aes128_count(2).unwrap();
    let verify_key = t.verify_key.as_ref().try_into().unwrap();

    for (test_num, p) in t.prep.iter().enumerate() {
        check_prep_test_vec(&prio3, &verify_key, test_num, p);
    }
}

#[test]
fn test_vec_prio3_sum() {
    let t: TPrio3<u128> =
        serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Sum_0.json")).unwrap();
    let prio3 = Prio3::new_aes128_sum(2, 8).unwrap();
    let verify_key = t.verify_key.as_ref().try_into().unwrap();

    for (test_num, p) in t.prep.iter().enumerate() {
        check_prep_test_vec(&prio3, &verify_key, test_num, p);
    }
}

#[test]
fn test_vec_prio3_histogram() {
    let t: TPrio3<u128> =
        serde_json::from_str(include_str!("test_vec/03/Prio3Aes128Histogram_0.json")).unwrap();
    let prio3 = Prio3::new_aes128_histogram(2, &[1, 10, 100]).unwrap();
    let verify_key = t.verify_key.as_ref().try_into().unwrap();

    for (test_num, p) in t.prep.iter().enumerate() {
        check_prep_test_vec(&prio3, &verify_key, test_num, p);
    }
}