diff options
Diffstat (limited to 'third_party/rust/prio')
59 files changed, 23048 insertions, 0 deletions
diff --git a/third_party/rust/prio/.cargo-checksum.json b/third_party/rust/prio/.cargo-checksum.json new file mode 100644 index 0000000000..2d8ae36c4a --- /dev/null +++ b/third_party/rust/prio/.cargo-checksum.json @@ -0,0 +1 @@ +{"files":{"Cargo.toml":"9ae63800fb2fb7ad25b5206e1c6aedba07ebeadce625fa3d3c214391c754c434","LICENSE":"5f5a5db8d4baa0eea0ff2d32a5a86c7a899a3343f1496f4477f42e2d651cc6dc","README.md":"8b36e999c5b65c83fcc77d85e78e690438c08afbb8292640cfc7dbc5164dcfcd","benches/cycle_counts.rs":"40d2ce27e6df99b9d7f4913b39d29a7e60c8efa3718218d125f2691b559b0b77","benches/speed_tests.rs":"2ef52920b677cb0292adf3d5f6bf06f57e678b510550acc99e51b33e8a1bdb44","documentation/releases.md":"14cfe917c88b69d557badc683b887c734254810402c7e19c9a45d815637480a9","src/benchmarked.rs":"a1ca26229967c55ba0d000ed5a594ba42ba2a27af0b1f5e37168f00cb62ad562","src/codec.rs":"2f5763612d19659c7f3f806b94795276c032da7a8e0b41687b355e32b99929da","src/dp.rs":"683de654a0a4bfd51741fdb03f17a3e2a15d11633857623dae3a5df199cfccb6","src/dp/distributions.rs":"8cedb5d44aa8bed9635d8791ec0f05a42fabff95f63d59f50bf9c7764d0c68bd","src/fft.rs":"5900d31263cf16eec9d462eb75458c2582d4c502312d478f87652317ec9f1e9f","src/field.rs":"6e29dcad7ce386e6a0f549214a6a14335118e177e8a2a9d6b1eafbf1fb865ac1","src/field/field255.rs":"6472c11354a1a75d733436607739e69051d5ee15a3f2477bee22d3dd8c80b686","src/flp.rs":"e961b9000e1d332a98975bfbfead93f26af56333338f6102ae09c0a5cdb87987","src/flp/gadgets.rs":"ddf641b0c06348b2a74bd579f994ce8f3ba387f4fa5ec3ca5e3a1260237a1615","src/flp/types.rs":"cb57d1f71b73d19893bfc5e86077b8bdaf66f1cc62b6a49d78ac15737580f45d","src/flp/types/fixedpoint_l2.rs":"22c96b2d0f3bf57155a6acddeee9b6164c3521e99cd12f6b2dd5c206543765d6","src/flp/types/fixedpoint_l2/compatible_float.rs":"8bcfc8ccb1e3ef32fdabbbac40dd7ffb9af10ab9ed10d9cd8c02cd55faa8814f","src/fp.rs":"a70f2961bd9df6a6b9ad683924ed27e4d44fc67408681d7b68fba2c4594edba4","src/idpf.rs":"6a81395351d8beca8f38774f883af23e161067f32ea9f9c177a135b00b139b3c","src/lib.rs":"bb5a25d5d1b2b0f916acd66e27bbcd16022965110f46292edf512973191a9497","src/polynomial.rs":"8e91a56de1a9eaf7eb1f01857e1e1834f768525cadc3bbe02f4f4ff44deeaef8","src/prng.rs":"bcf0125f0fe82c17fca417f068d49c5e08e0e7159a2d5e8a156bb52702d070cc","src/topology/mod.rs":"09a6d07c598d7bb52164bb5f82ab5ceddba5a192454828a789516d621e85c557","src/topology/ping_pong.rs":"89b4b4968d2c3b322ba8ca72101bb026426f5c4f10efb4a7895cb0c4a1488d03","src/vdaf.rs":"fe507b7644288e4912ac3c561a77c390cb85005b889fdcb471392a5886683607","src/vdaf/dummy.rs":"386e7e163625a1c53a1a95dd47c14324c206872080dbe76e90ec5d93f90d39ae","src/vdaf/poplar1.rs":"78af3fe04ef038ec263073522b1c93e0e14a1ab08d6ef1136dd4aa46cb19f667","src/vdaf/prio2.rs":"97d425fcb0a8a15f600ee202e4da7be68ee04e329e8e9f4d50cecd7557213662","src/vdaf/prio2/client.rs":"90c529400c3be77461b02d5cd78c52e98626e83e149b677da711097c9b17e806","src/vdaf/prio2/server.rs":"a5ebc32d92849d2a565888e18c63eac36d2e854f2e51c98ebc04983837465235","src/vdaf/prio2/test_vector.rs":"507cb6d05e40270786eb2c123866a18f7ee586e7c022161bf8fd47cb472507a9","src/vdaf/prio3.rs":"0c60f16188604013bd894e9dddfc8263f65d9b882f7e6831d43c14e05b389af9","src/vdaf/prio3_test.rs":"e2c7e1fd6f20ea3c886a2bfb5a63262f41c9a419e0e419d0b808f585b55b730e","src/vdaf/test_vec/07/IdpfPoplar_0.json":"f4f23df90774d7ac74c0bd984e25db7c1a2c53705adf30a693f388922f2c5d38","src/vdaf/test_vec/07/Poplar1_0.json":"8a5a44e85c5c08bf54a656a8692cad962d08bf9ef49e69f04ca4646de3cc1a40","src/vdaf/test_vec/07/Poplar1_1.json":"1caff76c31ce637baca4adc62f63b4d0d887a860bba8f59ffdef5409f5a0988e","src/vdaf/test_vec/07/Poplar1_2.json":"439eb1ea543929b127c6f77b1109f2455f649a9a84b0335d3dd5177e75fe521f","src/vdaf/test_vec/07/Poplar1_3.json":"2864fe4000934fa4035dcb5914da804ddbd3e7e125c459efc5f400b41b5b6b55","src/vdaf/test_vec/07/Prio3Count_0.json":"33e41769b1e11376276dbfcc6670100673ea3ed887baaaf5c37c6f389167c26c","src/vdaf/test_vec/07/Prio3Count_1.json":"483f53318116847bd43d4b6040ef2f00436f643c05eae153144285e8a032990c","src/vdaf/test_vec/07/Prio3Histogram_0.json":"f71a736c57eb0811b34e2806dc7c6005b67fc62fffb0779a263f1030a09ec747","src/vdaf/test_vec/07/Prio3Histogram_1.json":"3d67c8547fe69baa6ba8d90c3fd307003f9ebe04984c39343dcd8e794a6bf5d8","src/vdaf/test_vec/07/Prio3SumVec_0.json":"9d0ff391d954073ccbb798d382ba8701131687410e6d754c3bce36e9c8f650df","src/vdaf/test_vec/07/Prio3SumVec_1.json":"83f1abe06fc76a9f628be112efc8168243ce4713b878fe7d6ffe2b5b17a87382","src/vdaf/test_vec/07/Prio3Sum_0.json":"732093776a144bf9a8e71ae3179ae0fd74e5db3787b31603dceeb6e557d9d814","src/vdaf/test_vec/07/Prio3Sum_1.json":"a4250943b8a33d4d19652bd8a7312059ea92e3f8b1d541439b64bff996bc1bf4","src/vdaf/test_vec/07/XofFixedKeyAes128.json":"ff40fc42eec15a483bd02e798ddab6c38d81a335f1fe5703290ced88d7ceae26","src/vdaf/test_vec/07/XofShake128.json":"e68b4769f4fb9be9a8d9d5bf823636ed97c6e1f68bcd794c7576ba8b42eeba9a","src/vdaf/test_vec/prio2/fieldpriov2.json":"7ba82fcf068dfd5b04fc7fd01ebe5626ea9740be9cd6aa631525f23b80bcd027","src/vdaf/xof.rs":"175a4b0077bf295aeee4f42c70b55684079f4ef8c1b580b90ac0af8914f05fc9","tests/discrete_gauss.rs":"ddae145ac52ff8cd2b2134266733fee6fd9378bfb68c943a28e5a0a314a1ceb7","tests/test_vectors/discrete_gauss_100.json":"0f056accac870bf79539c82492f7b1f30d69835fe780bb83e0a15b0f0b318bc3","tests/test_vectors/discrete_gauss_2.342.json":"2ce62090c800c786b02ad409c4de18a94858105cd4859d7f595fbdf7ab79e8e1","tests/test_vectors/discrete_gauss_3.json":"b088841eef4cba2b287f04d3707a7efd33c11a2d40a71c8d953e8231150dbe6e","tests/test_vectors/discrete_gauss_41293847.json":"5896a922e313ce15d1353f0bdbf1f55413cc6a8116531cfc1fd9f3f7000e9002","tests/test_vectors/discrete_gauss_9.json":"4ae1c1195d752b6db5b3be47e7e662c2070aac1146e5b9ce9398c0c93cec21d2","tests/test_vectors/discrete_gauss_9999999999999999999999.json":"d65ec163c7bcbd8091c748e659eba460a9c56b37bc7612fd4a96394e58480b23"},"package":"b3163d19b7d8bc08c7ab6b74510f5e048c0937509d14c28b8919d2baf8cb9387"}
\ No newline at end of file diff --git a/third_party/rust/prio/Cargo.toml b/third_party/rust/prio/Cargo.toml new file mode 100644 index 0000000000..6e7bcd7e6e --- /dev/null +++ b/third_party/rust/prio/Cargo.toml @@ -0,0 +1,212 @@ +# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO +# +# When uploading crates to the registry Cargo will automatically +# "normalize" Cargo.toml files for maximal compatibility +# with all versions of Cargo and also rewrite `path` dependencies +# to registry (e.g., crates.io) dependencies. +# +# If you are reading this file be aware that the original Cargo.toml +# will likely look very different (and much more reasonable). +# See Cargo.toml.orig for the original contents. + +[package] +edition = "2021" +rust-version = "1.64" +name = "prio" +version = "0.15.3" +authors = [ + "Josh Aas <jaas@kflag.net>", + "Tim Geoghegan <timg@letsencrypt.org>", + "Christopher Patton <cpatton@cloudflare.com", + "Karl Tarbe <tarbe@apple.com>", +] +exclude = ["/supply-chain"] +description = "Implementation of the Prio aggregation system core: https://crypto.stanford.edu/prio/" +readme = "README.md" +license = "MPL-2.0" +repository = "https://github.com/divviup/libprio-rs" +resolver = "2" + +[package.metadata.cargo-all-features] +skip_optional_dependencies = true + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = [ + "--cfg", + "docsrs", +] + +[lib] +bench = false + +[[test]] +name = "discrete_gauss" +path = "tests/discrete_gauss.rs" +required-features = ["experimental"] + +[[bench]] +name = "speed_tests" +harness = false + +[[bench]] +name = "cycle_counts" +harness = false + +[dependencies.aes] +version = "0.8.3" +optional = true + +[dependencies.bitvec] +version = "1.0.1" +optional = true + +[dependencies.byteorder] +version = "1.4.3" + +[dependencies.ctr] +version = "0.9.2" +optional = true + +[dependencies.fiat-crypto] +version = "0.2.1" +optional = true + +[dependencies.fixed] +version = "1.23" +optional = true + +[dependencies.getrandom] +version = "0.2.10" +features = ["std"] + +[dependencies.hmac] +version = "0.12.1" +optional = true + +[dependencies.num-bigint] +version = "0.4.4" +features = [ + "rand", + "serde", +] +optional = true + +[dependencies.num-integer] +version = "0.1.45" +optional = true + +[dependencies.num-iter] +version = "0.1.43" +optional = true + +[dependencies.num-rational] +version = "0.4.1" +features = ["serde"] +optional = true + +[dependencies.num-traits] +version = "0.2.16" +optional = true + +[dependencies.rand] +version = "0.8" +optional = true + +[dependencies.rand_core] +version = "0.6.4" + +[dependencies.rayon] +version = "1.8.0" +optional = true + +[dependencies.serde] +version = "1.0" +features = ["derive"] + +[dependencies.sha2] +version = "0.10.7" +optional = true + +[dependencies.sha3] +version = "0.10.8" + +[dependencies.subtle] +version = "2.5.0" + +[dependencies.thiserror] +version = "1.0" + +[dev-dependencies.assert_matches] +version = "1.5.0" + +[dev-dependencies.base64] +version = "0.21.4" + +[dev-dependencies.cfg-if] +version = "1.0.0" + +[dev-dependencies.criterion] +version = "0.5" + +[dev-dependencies.fixed-macro] +version = "1.2.0" + +[dev-dependencies.hex] +version = "0.4.3" +features = ["serde"] + +[dev-dependencies.hex-literal] +version = "0.4.1" + +[dev-dependencies.iai] +version = "0.1" + +[dev-dependencies.itertools] +version = "0.11.0" + +[dev-dependencies.modinverse] +version = "0.1.0" + +[dev-dependencies.num-bigint] +version = "0.4.4" + +[dev-dependencies.once_cell] +version = "1.18.0" + +[dev-dependencies.rand] +version = "0.8" + +[dev-dependencies.serde_json] +version = "1.0" + +[dev-dependencies.statrs] +version = "0.16.0" + +[dev-dependencies.zipf] +version = "7.0.1" + +[features] +crypto-dependencies = [ + "aes", + "ctr", +] +default = ["crypto-dependencies"] +experimental = [ + "bitvec", + "fiat-crypto", + "fixed", + "num-bigint", + "num-rational", + "num-traits", + "num-integer", + "num-iter", + "rand", +] +multithreaded = ["rayon"] +prio2 = [ + "crypto-dependencies", + "hmac", + "sha2", +] +test-util = ["rand"] diff --git a/third_party/rust/prio/LICENSE b/third_party/rust/prio/LICENSE new file mode 100644 index 0000000000..0e880dfe44 --- /dev/null +++ b/third_party/rust/prio/LICENSE @@ -0,0 +1,375 @@ +Copyright 2021 ISRG, except where otherwise noted. All rights reserved. + +Mozilla Public License Version 2.0 +================================== + +1. Definitions +-------------- + +1.1. "Contributor" + means each individual or legal entity that creates, contributes to + the creation of, or owns Covered Software. + +1.2. "Contributor Version" + means the combination of the Contributions of others (if any) used + by a Contributor and that particular Contributor's Contribution. + +1.3. "Contribution" + means Covered Software of a particular Contributor. + +1.4. "Covered Software" + means Source Code Form to which the initial Contributor has attached + the notice in Exhibit A, the Executable Form of such Source Code + Form, and Modifications of such Source Code Form, in each case + including portions thereof. + +1.5. "Incompatible With Secondary Licenses" + means + + (a) that the initial Contributor has attached the notice described + in Exhibit B to the Covered Software; or + + (b) that the Covered Software was made available under the terms of + version 1.1 or earlier of the License, but not also under the + terms of a Secondary License. + +1.6. "Executable Form" + means any form of the work other than Source Code Form. + +1.7. "Larger Work" + means a work that combines Covered Software with other material, in + a separate file or files, that is not Covered Software. + +1.8. "License" + means this document. + +1.9. "Licensable" + means having the right to grant, to the maximum extent possible, + whether at the time of the initial grant or subsequently, any and + all of the rights conveyed by this License. + +1.10. "Modifications" + means any of the following: + + (a) any file in Source Code Form that results from an addition to, + deletion from, or modification of the contents of Covered + Software; or + + (b) any new file in Source Code Form that contains any Covered + Software. + +1.11. "Patent Claims" of a Contributor + means any patent claim(s), including without limitation, method, + process, and apparatus claims, in any patent Licensable by such + Contributor that would be infringed, but for the grant of the + License, by the making, using, selling, offering for sale, having + made, import, or transfer of either its Contributions or its + Contributor Version. + +1.12. "Secondary License" + means either the GNU General Public License, Version 2.0, the GNU + Lesser General Public License, Version 2.1, the GNU Affero General + Public License, Version 3.0, or any later versions of those + licenses. + +1.13. "Source Code Form" + means the form of the work preferred for making modifications. + +1.14. "You" (or "Your") + means an individual or a legal entity exercising rights under this + License. For legal entities, "You" includes any entity that + controls, is controlled by, or is under common control with You. For + purposes of this definition, "control" means (a) the power, direct + or indirect, to cause the direction or management of such entity, + whether by contract or otherwise, or (b) ownership of more than + fifty percent (50%) of the outstanding shares or beneficial + ownership of such entity. + +2. License Grants and Conditions +-------------------------------- + +2.1. Grants + +Each Contributor hereby grants You a world-wide, royalty-free, +non-exclusive license: + +(a) under intellectual property rights (other than patent or trademark) + Licensable by such Contributor to use, reproduce, make available, + modify, display, perform, distribute, and otherwise exploit its + Contributions, either on an unmodified basis, with Modifications, or + as part of a Larger Work; and + +(b) under Patent Claims of such Contributor to make, use, sell, offer + for sale, have made, import, and otherwise transfer either its + Contributions or its Contributor Version. + +2.2. Effective Date + +The licenses granted in Section 2.1 with respect to any Contribution +become effective for each Contribution on the date the Contributor first +distributes such Contribution. + +2.3. Limitations on Grant Scope + +The licenses granted in this Section 2 are the only rights granted under +this License. No additional rights or licenses will be implied from the +distribution or licensing of Covered Software under this License. +Notwithstanding Section 2.1(b) above, no patent license is granted by a +Contributor: + +(a) for any code that a Contributor has removed from Covered Software; + or + +(b) for infringements caused by: (i) Your and any other third party's + modifications of Covered Software, or (ii) the combination of its + Contributions with other software (except as part of its Contributor + Version); or + +(c) under Patent Claims infringed by Covered Software in the absence of + its Contributions. + +This License does not grant any rights in the trademarks, service marks, +or logos of any Contributor (except as may be necessary to comply with +the notice requirements in Section 3.4). + +2.4. Subsequent Licenses + +No Contributor makes additional grants as a result of Your choice to +distribute the Covered Software under a subsequent version of this +License (see Section 10.2) or under the terms of a Secondary License (if +permitted under the terms of Section 3.3). + +2.5. Representation + +Each Contributor represents that the Contributor believes its +Contributions are its original creation(s) or it has sufficient rights +to grant the rights to its Contributions conveyed by this License. + +2.6. Fair Use + +This License is not intended to limit any rights You have under +applicable copyright doctrines of fair use, fair dealing, or other +equivalents. + +2.7. Conditions + +Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted +in Section 2.1. + +3. Responsibilities +------------------- + +3.1. Distribution of Source Form + +All distribution of Covered Software in Source Code Form, including any +Modifications that You create or to which You contribute, must be under +the terms of this License. You must inform recipients that the Source +Code Form of the Covered Software is governed by the terms of this +License, and how they can obtain a copy of this License. You may not +attempt to alter or restrict the recipients' rights in the Source Code +Form. + +3.2. Distribution of Executable Form + +If You distribute Covered Software in Executable Form then: + +(a) such Covered Software must also be made available in Source Code + Form, as described in Section 3.1, and You must inform recipients of + the Executable Form how they can obtain a copy of such Source Code + Form by reasonable means in a timely manner, at a charge no more + than the cost of distribution to the recipient; and + +(b) You may distribute such Executable Form under the terms of this + License, or sublicense it under different terms, provided that the + license for the Executable Form does not attempt to limit or alter + the recipients' rights in the Source Code Form under this License. + +3.3. Distribution of a Larger Work + +You may create and distribute a Larger Work under terms of Your choice, +provided that You also comply with the requirements of this License for +the Covered Software. If the Larger Work is a combination of Covered +Software with a work governed by one or more Secondary Licenses, and the +Covered Software is not Incompatible With Secondary Licenses, this +License permits You to additionally distribute such Covered Software +under the terms of such Secondary License(s), so that the recipient of +the Larger Work may, at their option, further distribute the Covered +Software under the terms of either this License or such Secondary +License(s). + +3.4. Notices + +You may not remove or alter the substance of any license notices +(including copyright notices, patent notices, disclaimers of warranty, +or limitations of liability) contained within the Source Code Form of +the Covered Software, except that You may alter any license notices to +the extent required to remedy known factual inaccuracies. + +3.5. Application of Additional Terms + +You may choose to offer, and to charge a fee for, warranty, support, +indemnity or liability obligations to one or more recipients of Covered +Software. However, You may do so only on Your own behalf, and not on +behalf of any Contributor. You must make it absolutely clear that any +such warranty, support, indemnity, or liability obligation is offered by +You alone, and You hereby agree to indemnify every Contributor for any +liability incurred by such Contributor as a result of warranty, support, +indemnity or liability terms You offer. You may include additional +disclaimers of warranty and limitations of liability specific to any +jurisdiction. + +4. Inability to Comply Due to Statute or Regulation +--------------------------------------------------- + +If it is impossible for You to comply with any of the terms of this +License with respect to some or all of the Covered Software due to +statute, judicial order, or regulation then You must: (a) comply with +the terms of this License to the maximum extent possible; and (b) +describe the limitations and the code they affect. Such description must +be placed in a text file included with all distributions of the Covered +Software under this License. Except to the extent prohibited by statute +or regulation, such description must be sufficiently detailed for a +recipient of ordinary skill to be able to understand it. + +5. Termination +-------------- + +5.1. The rights granted under this License will terminate automatically +if You fail to comply with any of its terms. However, if You become +compliant, then the rights granted under this License from a particular +Contributor are reinstated (a) provisionally, unless and until such +Contributor explicitly and finally terminates Your grants, and (b) on an +ongoing basis, if such Contributor fails to notify You of the +non-compliance by some reasonable means prior to 60 days after You have +come back into compliance. Moreover, Your grants from a particular +Contributor are reinstated on an ongoing basis if such Contributor +notifies You of the non-compliance by some reasonable means, this is the +first time You have received notice of non-compliance with this License +from such Contributor, and You become compliant prior to 30 days after +Your receipt of the notice. + +5.2. If You initiate litigation against any entity by asserting a patent +infringement claim (excluding declaratory judgment actions, +counter-claims, and cross-claims) alleging that a Contributor Version +directly or indirectly infringes any patent, then the rights granted to +You by any and all Contributors for the Covered Software under Section +2.1 of this License shall terminate. + +5.3. In the event of termination under Sections 5.1 or 5.2 above, all +end user license agreements (excluding distributors and resellers) which +have been validly granted by You or Your distributors under this License +prior to termination shall survive termination. + +************************************************************************ +* * +* 6. Disclaimer of Warranty * +* ------------------------- * +* * +* Covered Software is provided under this License on an "as is" * +* basis, without warranty of any kind, either expressed, implied, or * +* statutory, including, without limitation, warranties that the * +* Covered Software is free of defects, merchantable, fit for a * +* particular purpose or non-infringing. The entire risk as to the * +* quality and performance of the Covered Software is with You. * +* Should any Covered Software prove defective in any respect, You * +* (not any Contributor) assume the cost of any necessary servicing, * +* repair, or correction. This disclaimer of warranty constitutes an * +* essential part of this License. No use of any Covered Software is * +* authorized under this License except under this disclaimer. * +* * +************************************************************************ + +************************************************************************ +* * +* 7. Limitation of Liability * +* -------------------------- * +* * +* Under no circumstances and under no legal theory, whether tort * +* (including negligence), contract, or otherwise, shall any * +* Contributor, or anyone who distributes Covered Software as * +* permitted above, be liable to You for any direct, indirect, * +* special, incidental, or consequential damages of any character * +* including, without limitation, damages for lost profits, loss of * +* goodwill, work stoppage, computer failure or malfunction, or any * +* and all other commercial damages or losses, even if such party * +* shall have been informed of the possibility of such damages. This * +* limitation of liability shall not apply to liability for death or * +* personal injury resulting from such party's negligence to the * +* extent applicable law prohibits such limitation. Some * +* jurisdictions do not allow the exclusion or limitation of * +* incidental or consequential damages, so this exclusion and * +* limitation may not apply to You. * +* * +************************************************************************ + +8. Litigation +------------- + +Any litigation relating to this License may be brought only in the +courts of a jurisdiction where the defendant maintains its principal +place of business and such litigation shall be governed by laws of that +jurisdiction, without reference to its conflict-of-law provisions. +Nothing in this Section shall prevent a party's ability to bring +cross-claims or counter-claims. + +9. Miscellaneous +---------------- + +This License represents the complete agreement concerning the subject +matter hereof. If any provision of this License is held to be +unenforceable, such provision shall be reformed only to the extent +necessary to make it enforceable. Any law or regulation which provides +that the language of a contract shall be construed against the drafter +shall not be used to construe this License against a Contributor. + +10. Versions of the License +--------------------------- + +10.1. New Versions + +Mozilla Foundation is the license steward. Except as provided in Section +10.3, no one other than the license steward has the right to modify or +publish new versions of this License. Each version will be given a +distinguishing version number. + +10.2. Effect of New Versions + +You may distribute the Covered Software under the terms of the version +of the License under which You originally received the Covered Software, +or under the terms of any subsequent version published by the license +steward. + +10.3. Modified Versions + +If you create software not governed by this License, and you want to +create a new license for such software, you may create and use a +modified version of this License if you rename the license and remove +any references to the name of the license steward (except to note that +such modified license differs from this License). + +10.4. Distributing Source Code Form that is Incompatible With Secondary +Licenses + +If You choose to distribute Source Code Form that is Incompatible With +Secondary Licenses under the terms of this version of the License, the +notice described in Exhibit B of this License must be attached. + +Exhibit A - Source Code Form License Notice +------------------------------------------- + + This Source Code Form is subject to the terms of the Mozilla Public + License, v. 2.0. If a copy of the MPL was not distributed with this + file, You can obtain one at http://mozilla.org/MPL/2.0/. + +If it is not possible or desirable to put the notice in a particular +file, then You may include the notice in a location (such as a LICENSE +file in a relevant directory) where a recipient would be likely to look +for such a notice. + +You may add additional accurate notices of copyright ownership. + +Exhibit B - "Incompatible With Secondary Licenses" Notice +--------------------------------------------------------- + + This Source Code Form is "Incompatible With Secondary Licenses", as + defined by the Mozilla Public License, v. 2.0. diff --git a/third_party/rust/prio/README.md b/third_party/rust/prio/README.md new file mode 100644 index 0000000000..b81a0150c2 --- /dev/null +++ b/third_party/rust/prio/README.md @@ -0,0 +1,76 @@ +# libprio-rs +[![Build Status]][actions] [![Latest Version]][crates.io] [![Docs badge]][docs.rs] + +[Build Status]: https://github.com/divviup/libprio-rs/workflows/ci-build/badge.svg +[actions]: https://github.com/divviup/libprio-rs/actions?query=branch%3Amain +[Latest Version]: https://img.shields.io/crates/v/prio.svg +[crates.io]: https://crates.io/crates/prio +[Docs badge]: https://img.shields.io/badge/docs.rs-rustdoc-green +[docs.rs]: https://docs.rs/prio/ + +Pure Rust implementation of [Prio](https://crypto.stanford.edu/prio/), a system for Private, Robust, +and Scalable Computation of Aggregate Statistics. + +## Exposure Notifications Private Analytics + +This crate was used in the [Exposure Notifications Private Analytics][enpa] +system. This is referred to in various places as Prio v2. See +[`prio-server`][prio-server] or the [ENPA whitepaper][enpa-whitepaper] for more +details. + +## Verifiable Distributed Aggregation Function + +This crate also implements a [Verifiable Distributed Aggregation Function +(VDAF)][vdaf] called "Prio3", implemented in the `vdaf` module, allowing Prio to +be used in the [Distributed Aggregation Protocol][dap] protocol being developed +in the PPM working group at the IETF. This support is still evolving along with +the DAP and VDAF specifications. + +### Draft versions and release branches + +The `main` branch is under continuous development and will usually be partway between VDAF drafts. +libprio uses stable release branches to maintain implementations of different VDAF draft versions. +Crate `prio` version `x.y.z` is released from a corresponding `release/x.y` branch. We try to +maintain [Rust SemVer][semver] compatibility, meaning that API breaks only happen on minor version +increases (e.g., 0.10 to 0.11). + +| Crate version | Git branch | VDAF draft version | DAP draft version | Conforms to specification? | Status | +| ----- | ---------- | ------------- | ------------- | --------------------- | ------ | +| 0.8 | `release/0.8` | [`draft-irtf-cfrg-vdaf-01`][vdaf-01] | [`draft-ietf-ppm-dap-01`][dap-01] | Yes | Unmaintained as of March 28, 2023 | +| 0.9 | `release/0.9` | [`draft-irtf-cfrg-vdaf-03`][vdaf-03] | [`draft-ietf-ppm-dap-02`][dap-02] and [`draft-ietf-ppm-dap-03`][dap-03] | Yes | Unmaintained as of September 22, 2022 | +| 0.10 | `release/0.10` | [`draft-irtf-cfrg-vdaf-03`][vdaf-03] | [`draft-ietf-ppm-dap-02`][dap-02] and [`draft-ietf-ppm-dap-03`][dap-03] | Yes | Supported | +| 0.11 | `release/0.11` | [`draft-irtf-cfrg-vdaf-04`][vdaf-04] | N/A | Yes | Unmaintained | +| 0.12 | `release/0.12` | [`draft-irtf-cfrg-vdaf-05`][vdaf-05] | [`draft-ietf-ppm-dap-04`][dap-04] | Yes | Supported | +| 0.13 | `release/0.13` | [`draft-irtf-cfrg-vdaf-06`][vdaf-06] | [`draft-ietf-ppm-dap-05`][dap-05] | Yes | Unmaintained | +| 0.14 | `release/0.14` | [`draft-irtf-cfrg-vdaf-06`][vdaf-06] | [`draft-ietf-ppm-dap-05`][dap-05] | Yes | Unmaintained | +| 0.15 | `main` | [`draft-irtf-cfrg-vdaf-07`][vdaf-07] | [`draft-ietf-ppm-dap-06`][dap-06] | Yes | Supported | + +[vdaf-01]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/01/ +[vdaf-03]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/03/ +[vdaf-04]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/04/ +[vdaf-05]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/05/ +[vdaf-06]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/06/ +[vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ +[dap-01]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/01/ +[dap-02]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/02/ +[dap-03]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/03/ +[dap-04]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/04/ +[dap-05]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/05/ +[dap-06]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/06/ +[enpa]: https://www.abetterinternet.org/post/prio-services-for-covid-en/ +[enpa-whitepaper]: https://covid19-static.cdn-apple.com/applications/covid19/current/static/contact-tracing/pdf/ENPA_White_Paper.pdf +[prio-server]: https://github.com/divviup/prio-server +[vdaf]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/ +[dap]: https://datatracker.ietf.org/doc/draft-ietf-ppm-dap/ +[semver]: https://doc.rust-lang.org/cargo/reference/semver.html + +## Cargo Features + +This crate defines the following feature flags: + +|Name|Default feature?|Description| +|---|---|---| +|`crypto-dependencies`|Yes|Enables dependencies on various RustCrypto crates, and uses them to implement `XofShake128` to support VDAFs.| +|`experimental`|No|Certain experimental APIs are guarded by this feature. They may undergo breaking changes in future patch releases, as an exception to semantic versioning.| +|`multithreaded`|No|Enables certain Prio3 VDAF implementations that use `rayon` for parallelization of gadget evaluations.| +|`prio2`|No|Enables the Prio v2 API, and a VDAF based on the Prio2 system.| diff --git a/third_party/rust/prio/benches/cycle_counts.rs b/third_party/rust/prio/benches/cycle_counts.rs new file mode 100644 index 0000000000..43a4ccdad0 --- /dev/null +++ b/third_party/rust/prio/benches/cycle_counts.rs @@ -0,0 +1,341 @@ +#![cfg_attr(windows, allow(dead_code))] + +use cfg_if::cfg_if; +use iai::black_box; +#[cfg(feature = "experimental")] +use prio::{ + codec::{Decode, Encode, ParameterizedDecode}, + field::{Field255, FieldElement}, + idpf::{Idpf, IdpfInput, IdpfPublicShare, RingBufferCache}, + vdaf::{poplar1::Poplar1IdpfValue, xof::Seed}, +}; +#[cfg(feature = "prio2")] +use prio::{ + field::FieldPrio2, + vdaf::{ + prio2::{Prio2, Prio2PrepareShare}, + Aggregator, Share, + }, +}; +use prio::{ + field::{random_vector, Field128, Field64}, + vdaf::{ + prio3::{Prio3, Prio3InputShare}, + Client, + }, +}; + +fn prng(size: usize) -> Vec<Field128> { + random_vector(size).unwrap() +} + +fn prng_16() -> Vec<Field128> { + prng(16) +} + +fn prng_256() -> Vec<Field128> { + prng(256) +} + +fn prng_1024() -> Vec<Field128> { + prng(1024) +} + +fn prng_4096() -> Vec<Field128> { + prng(4096) +} + +#[cfg(feature = "prio2")] +fn prio2_client(size: usize) -> Vec<Share<FieldPrio2, 32>> { + let prio2 = Prio2::new(size).unwrap(); + let input = vec![0u32; size]; + let nonce = [0; 16]; + prio2.shard(&black_box(input), &black_box(nonce)).unwrap().1 +} + +#[cfg(feature = "prio2")] +fn prio2_client_10() -> Vec<Share<FieldPrio2, 32>> { + prio2_client(10) +} + +#[cfg(feature = "prio2")] +fn prio2_client_100() -> Vec<Share<FieldPrio2, 32>> { + prio2_client(100) +} + +#[cfg(feature = "prio2")] +fn prio2_client_1000() -> Vec<Share<FieldPrio2, 32>> { + prio2_client(1000) +} + +#[cfg(feature = "prio2")] +fn prio2_shard_and_prepare(size: usize) -> Prio2PrepareShare { + let prio2 = Prio2::new(size).unwrap(); + let input = vec![0u32; size]; + let nonce = [0; 16]; + let (public_share, input_shares) = prio2.shard(&black_box(input), &black_box(nonce)).unwrap(); + prio2 + .prepare_init(&[0; 32], 0, &(), &nonce, &public_share, &input_shares[0]) + .unwrap() + .1 +} + +#[cfg(feature = "prio2")] +fn prio2_shard_and_prepare_10() -> Prio2PrepareShare { + prio2_shard_and_prepare(10) +} + +#[cfg(feature = "prio2")] +fn prio2_shard_and_prepare_100() -> Prio2PrepareShare { + prio2_shard_and_prepare(100) +} + +#[cfg(feature = "prio2")] +fn prio2_shard_and_prepare_1000() -> Prio2PrepareShare { + prio2_shard_and_prepare(1000) +} + +fn prio3_client_count() -> Vec<Prio3InputShare<Field64, 16>> { + let prio3 = Prio3::new_count(2).unwrap(); + let measurement = 1; + let nonce = [0; 16]; + prio3 + .shard(&black_box(measurement), &black_box(nonce)) + .unwrap() + .1 +} + +fn prio3_client_histogram_10() -> Vec<Prio3InputShare<Field128, 16>> { + let prio3 = Prio3::new_histogram(2, 10, 3).unwrap(); + let measurement = 9; + let nonce = [0; 16]; + prio3 + .shard(&black_box(measurement), &black_box(nonce)) + .unwrap() + .1 +} + +fn prio3_client_sum_32() -> Vec<Prio3InputShare<Field128, 16>> { + let prio3 = Prio3::new_sum(2, 16).unwrap(); + let measurement = 1337; + let nonce = [0; 16]; + prio3 + .shard(&black_box(measurement), &black_box(nonce)) + .unwrap() + .1 +} + +fn prio3_client_count_vec_1000() -> Vec<Prio3InputShare<Field128, 16>> { + let len = 1000; + let prio3 = Prio3::new_sum_vec(2, 1, len, 31).unwrap(); + let measurement = vec![0; len]; + let nonce = [0; 16]; + prio3 + .shard(&black_box(measurement), &black_box(nonce)) + .unwrap() + .1 +} + +#[cfg(feature = "multithreaded")] +fn prio3_client_count_vec_multithreaded_1000() -> Vec<Prio3InputShare<Field128, 16>> { + let len = 1000; + let prio3 = Prio3::new_sum_vec_multithreaded(2, 1, len, 31).unwrap(); + let measurement = vec![0; len]; + let nonce = [0; 16]; + prio3 + .shard(&black_box(measurement), &black_box(nonce)) + .unwrap() + .1 +} + +#[cfg(feature = "experimental")] +fn idpf_poplar_gen( + input: &IdpfInput, + inner_values: Vec<Poplar1IdpfValue<Field64>>, + leaf_value: Poplar1IdpfValue<Field255>, +) { + let idpf = Idpf::new((), ()); + idpf.gen(input, inner_values, leaf_value, &[0; 16]).unwrap(); +} + +#[cfg(feature = "experimental")] +fn idpf_poplar_gen_8() { + let input = IdpfInput::from_bytes(b"A"); + let one = Field64::one(); + idpf_poplar_gen( + &input, + vec![Poplar1IdpfValue::new([one, one]); 7], + Poplar1IdpfValue::new([Field255::one(), Field255::one()]), + ); +} + +#[cfg(feature = "experimental")] +fn idpf_poplar_gen_128() { + let input = IdpfInput::from_bytes(b"AAAAAAAAAAAAAAAA"); + let one = Field64::one(); + idpf_poplar_gen( + &input, + vec![Poplar1IdpfValue::new([one, one]); 127], + Poplar1IdpfValue::new([Field255::one(), Field255::one()]), + ); +} + +#[cfg(feature = "experimental")] +fn idpf_poplar_gen_2048() { + let input = IdpfInput::from_bytes(&[0x41; 256]); + let one = Field64::one(); + idpf_poplar_gen( + &input, + vec![Poplar1IdpfValue::new([one, one]); 2047], + Poplar1IdpfValue::new([Field255::one(), Field255::one()]), + ); +} + +#[cfg(feature = "experimental")] +fn idpf_poplar_eval( + input: &IdpfInput, + public_share: &IdpfPublicShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>, + key: &Seed<16>, +) { + let mut cache = RingBufferCache::new(1); + let idpf = Idpf::new((), ()); + idpf.eval(0, public_share, key, input, &[0; 16], &mut cache) + .unwrap(); +} + +#[cfg(feature = "experimental")] +fn idpf_poplar_eval_8() { + let input = IdpfInput::from_bytes(b"A"); + let public_share = IdpfPublicShare::get_decoded_with_param(&8, &[0x7f; 306]).unwrap(); + let key = Seed::get_decoded(&[0xff; 16]).unwrap(); + idpf_poplar_eval(&input, &public_share, &key); +} + +#[cfg(feature = "experimental")] +fn idpf_poplar_eval_128() { + let input = IdpfInput::from_bytes(b"AAAAAAAAAAAAAAAA"); + let public_share = IdpfPublicShare::get_decoded_with_param(&128, &[0x7f; 4176]).unwrap(); + let key = Seed::get_decoded(&[0xff; 16]).unwrap(); + idpf_poplar_eval(&input, &public_share, &key); +} + +#[cfg(feature = "experimental")] +fn idpf_poplar_eval_2048() { + let input = IdpfInput::from_bytes(&[0x41; 256]); + let public_share = IdpfPublicShare::get_decoded_with_param(&2048, &[0x7f; 66096]).unwrap(); + let key = Seed::get_decoded(&[0xff; 16]).unwrap(); + idpf_poplar_eval(&input, &public_share, &key); +} + +#[cfg(feature = "experimental")] +fn idpf_codec() { + let data = hex::decode(concat!( + "9a", + "0000000000000000000000000000000000000000000000", + "01eb3a1bd6b5fa4a4500000000000000000000000000000000", + "ffffffff0000000022522c3fd5a33cac00000000000000000000000000000000", + "ffffffff0000000069f41eee46542b6900000000000000000000000000000000", + "00000000000000000000000000000000000000000000000000000000000000", + "017d1fd6df94280145a0dcc933ceb706e9219d50e7c4f92fd8ca9a0ffb7d819646", + )) + .unwrap(); + let bits = 4; + let public_share = IdpfPublicShare::<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>::get_decoded_with_param(&bits, &data).unwrap(); + let encoded = public_share.get_encoded(); + let _ = black_box(encoded.len()); +} + +macro_rules! main_base { + ( $( $func_name:ident ),* $(,)* ) => { + iai::main!( + prng_16, + prng_256, + prng_1024, + prng_4096, + prio3_client_count, + prio3_client_histogram_10, + prio3_client_sum_32, + prio3_client_count_vec_1000, + $( $func_name, )* + ); + }; +} + +#[cfg(feature = "prio2")] +macro_rules! main_add_prio2 { + ( $( $func_name:ident ),* $(,)* ) => { + main_base!( + prio2_client_10, + prio2_client_100, + prio2_client_1000, + prio2_shard_and_prepare_10, + prio2_shard_and_prepare_100, + prio2_shard_and_prepare_1000, + $( $func_name, )* + ); + }; +} + +#[cfg(not(feature = "prio2"))] +macro_rules! main_add_prio2 { + ( $( $func_name:ident ),* $(,)* ) => { + main_base!( + $( $func_name, )* + ); + }; +} + +#[cfg(feature = "multithreaded")] +macro_rules! main_add_multithreaded { + ( $( $func_name:ident ),* $(,)* ) => { + main_add_prio2!( + prio3_client_count_vec_multithreaded_1000, + $( $func_name, )* + ); + }; +} + +#[cfg(not(feature = "multithreaded"))] +macro_rules! main_add_multithreaded { + ( $( $func_name:ident ),* $(,)* ) => { + main_add_prio2!( + $( $func_name, )* + ); + }; +} + +#[cfg(feature = "experimental")] +macro_rules! main_add_experimental { + ( $( $func_name:ident ),* $(,)* ) => { + main_add_multithreaded!( + idpf_codec, + idpf_poplar_gen_8, + idpf_poplar_gen_128, + idpf_poplar_gen_2048, + idpf_poplar_eval_8, + idpf_poplar_eval_128, + idpf_poplar_eval_2048, + $( $func_name, )* + ); + }; +} + +#[cfg(not(feature = "experimental"))] +macro_rules! main_add_experimental { + ( $( $func_name:ident ),* $(,)* ) => { + main_add_multithreaded!( + $( $func_name, )* + ); + }; +} + +cfg_if! { + if #[cfg(windows)] { + fn main() { + eprintln!("Cycle count benchmarks are not supported on Windows."); + } + } + else { + main_add_experimental!(); + } +} diff --git a/third_party/rust/prio/benches/speed_tests.rs b/third_party/rust/prio/benches/speed_tests.rs new file mode 100644 index 0000000000..66458b1ada --- /dev/null +++ b/third_party/rust/prio/benches/speed_tests.rs @@ -0,0 +1,876 @@ +// SPDX-License-Identifier: MPL-2.0 + +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +#[cfg(feature = "experimental")] +use criterion::{BatchSize, Throughput}; +#[cfg(feature = "experimental")] +use fixed::types::{I1F15, I1F31}; +#[cfg(feature = "experimental")] +use fixed_macro::fixed; +#[cfg(feature = "experimental")] +use num_bigint::BigUint; +#[cfg(feature = "experimental")] +use num_rational::Ratio; +#[cfg(feature = "experimental")] +use num_traits::ToPrimitive; +#[cfg(feature = "experimental")] +use prio::dp::distributions::DiscreteGaussian; +#[cfg(feature = "prio2")] +use prio::vdaf::prio2::Prio2; +use prio::{ + benchmarked::*, + field::{random_vector, Field128 as F, FieldElement}, + flp::gadgets::Mul, + vdaf::{prio3::Prio3, Aggregator, Client}, +}; +#[cfg(feature = "experimental")] +use prio::{ + field::{Field255, Field64}, + flp::types::fixedpoint_l2::FixedPointBoundedL2VecSum, + idpf::{Idpf, IdpfInput, RingBufferCache}, + vdaf::poplar1::{Poplar1, Poplar1AggregationParam, Poplar1IdpfValue}, +}; +#[cfg(feature = "experimental")] +use rand::prelude::*; +#[cfg(feature = "experimental")] +use std::iter; +use std::time::Duration; +#[cfg(feature = "experimental")] +use zipf::ZipfDistribution; + +/// Seed for generation of random benchmark inputs. +/// +/// A fixed RNG seed is used to generate inputs in order to minimize run-to-run variability. The +/// seed value may be freely changed to get a different set of inputs. +#[cfg(feature = "experimental")] +const RNG_SEED: u64 = 0; + +/// Speed test for generating a seed and deriving a pseudorandom sequence of field elements. +fn prng(c: &mut Criterion) { + let mut group = c.benchmark_group("rand"); + let test_sizes = [16, 256, 1024, 4096]; + for size in test_sizes { + group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, size| { + b.iter(|| random_vector::<F>(*size)) + }); + } + group.finish(); +} + +/// Speed test for generating samples from the discrete gaussian distribution using different +/// standard deviations. +#[cfg(feature = "experimental")] +pub fn dp_noise(c: &mut Criterion) { + let mut group = c.benchmark_group("dp_noise"); + let mut rng = StdRng::seed_from_u64(RNG_SEED); + + let test_stds = [ + Ratio::<BigUint>::from_integer(BigUint::from(u128::MAX)).pow(2), + Ratio::<BigUint>::from_integer(BigUint::from(u64::MAX)), + Ratio::<BigUint>::from_integer(BigUint::from(u32::MAX)), + Ratio::<BigUint>::from_integer(BigUint::from(5u8)), + Ratio::<BigUint>::new(BigUint::from(10000u32), BigUint::from(23u32)), + ]; + for std in test_stds { + let sampler = DiscreteGaussian::new(std.clone()).unwrap(); + group.bench_function( + BenchmarkId::new("discrete_gaussian", std.to_f64().unwrap_or(f64::INFINITY)), + |b| b.iter(|| sampler.sample(&mut rng)), + ); + } + group.finish(); +} + +/// The asymptotic cost of polynomial multiplication is `O(n log n)` using FFT and `O(n^2)` using +/// the naive method. This benchmark demonstrates that the latter has better concrete performance +/// for small polynomials. The result is used to pick the `FFT_THRESHOLD` constant in +/// `src/flp/gadgets.rs`. +fn poly_mul(c: &mut Criterion) { + let test_sizes = [1_usize, 30, 60, 90, 120, 150]; + + let mut group = c.benchmark_group("poly_mul"); + for size in test_sizes { + group.bench_with_input(BenchmarkId::new("fft", size), &size, |b, size| { + let m = (size + 1).next_power_of_two(); + let mut g: Mul<F> = Mul::new(*size); + let mut outp = vec![F::zero(); 2 * m]; + let inp = vec![random_vector(m).unwrap(); 2]; + + b.iter(|| { + benchmarked_gadget_mul_call_poly_fft(&mut g, &mut outp, &inp).unwrap(); + }) + }); + + group.bench_with_input(BenchmarkId::new("direct", size), &size, |b, size| { + let m = (size + 1).next_power_of_two(); + let mut g: Mul<F> = Mul::new(*size); + let mut outp = vec![F::zero(); 2 * m]; + let inp = vec![random_vector(m).unwrap(); 2]; + + b.iter(|| { + benchmarked_gadget_mul_call_poly_direct(&mut g, &mut outp, &inp).unwrap(); + }) + }); + } + group.finish(); +} + +/// Benchmark prio2. +#[cfg(feature = "prio2")] +fn prio2(c: &mut Criterion) { + let mut group = c.benchmark_group("prio2_shard"); + for input_length in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::from_parameter(input_length), + &input_length, + |b, input_length| { + let vdaf = Prio2::new(*input_length).unwrap(); + let measurement = (0..u32::try_from(*input_length).unwrap()) + .map(|i| i & 1) + .collect::<Vec<_>>(); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }, + ); + } + group.finish(); + + let mut group = c.benchmark_group("prio2_prepare_init"); + for input_length in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::from_parameter(input_length), + &input_length, + |b, input_length| { + let vdaf = Prio2::new(*input_length).unwrap(); + let measurement = (0..u32::try_from(*input_length).unwrap()) + .map(|i| i & 1) + .collect::<Vec<_>>(); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 32]); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) + .unwrap(); + }); + }, + ); + } + group.finish(); +} + +/// Benchmark prio3. +fn prio3(c: &mut Criterion) { + let num_shares = 2; + + c.bench_function("prio3count_shard", |b| { + let vdaf = Prio3::new_count(num_shares).unwrap(); + let measurement = black_box(1); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }); + + c.bench_function("prio3count_prepare_init", |b| { + let vdaf = Prio3::new_count(num_shares).unwrap(); + let measurement = black_box(1); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) + .unwrap() + }); + }); + + let mut group = c.benchmark_group("prio3sum_shard"); + for bits in [8, 32] { + group.bench_with_input(BenchmarkId::from_parameter(bits), &bits, |b, bits| { + let vdaf = Prio3::new_sum(num_shares, *bits).unwrap(); + let measurement = (1 << bits) - 1; + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }); + } + group.finish(); + + let mut group = c.benchmark_group("prio3sum_prepare_init"); + for bits in [8, 32] { + group.bench_with_input(BenchmarkId::from_parameter(bits), &bits, |b, bits| { + let vdaf = Prio3::new_sum(num_shares, *bits).unwrap(); + let measurement = (1 << bits) - 1; + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) + .unwrap() + }); + }); + } + group.finish(); + + let mut group = c.benchmark_group("prio3sumvec_shard"); + for (input_length, chunk_length) in [(10, 3), (100, 10), (1_000, 31)] { + group.bench_with_input( + BenchmarkId::new("serial", input_length), + &(input_length, chunk_length), + |b, (input_length, chunk_length)| { + let vdaf = Prio3::new_sum_vec(num_shares, 1, *input_length, *chunk_length).unwrap(); + let measurement = (0..u128::try_from(*input_length).unwrap()) + .map(|i| i & 1) + .collect::<Vec<_>>(); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }, + ); + } + + #[cfg(feature = "multithreaded")] + { + for (input_length, chunk_length) in [(10, 3), (100, 10), (1_000, 31)] { + group.bench_with_input( + BenchmarkId::new("parallel", input_length), + &(input_length, chunk_length), + |b, (input_length, chunk_length)| { + let vdaf = Prio3::new_sum_vec_multithreaded( + num_shares, + 1, + *input_length, + *chunk_length, + ) + .unwrap(); + let measurement = (0..u128::try_from(*input_length).unwrap()) + .map(|i| i & 1) + .collect::<Vec<_>>(); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }, + ); + } + } + group.finish(); + + let mut group = c.benchmark_group("prio3sumvec_prepare_init"); + for (input_length, chunk_length) in [(10, 3), (100, 10), (1_000, 31)] { + group.bench_with_input( + BenchmarkId::new("serial", input_length), + &(input_length, chunk_length), + |b, (input_length, chunk_length)| { + let vdaf = Prio3::new_sum_vec(num_shares, 1, *input_length, *chunk_length).unwrap(); + let measurement = (0..u128::try_from(*input_length).unwrap()) + .map(|i| i & 1) + .collect::<Vec<_>>(); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) + .unwrap() + }); + }, + ); + } + + #[cfg(feature = "multithreaded")] + { + for (input_length, chunk_length) in [(10, 3), (100, 10), (1_000, 31)] { + group.bench_with_input( + BenchmarkId::new("parallel", input_length), + &(input_length, chunk_length), + |b, (input_length, chunk_length)| { + let vdaf = Prio3::new_sum_vec_multithreaded( + num_shares, + 1, + *input_length, + *chunk_length, + ) + .unwrap(); + let measurement = (0..u128::try_from(*input_length).unwrap()) + .map(|i| i & 1) + .collect::<Vec<_>>(); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init( + &verify_key, + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() + }); + }, + ); + } + } + group.finish(); + + let mut group = c.benchmark_group("prio3histogram_shard"); + for (input_length, chunk_length) in [ + (10, 3), + (100, 10), + (1_000, 31), + (10_000, 100), + (100_000, 316), + ] { + if input_length >= 100_000 { + group.measurement_time(Duration::from_secs(15)); + } + group.bench_with_input( + BenchmarkId::new("serial", input_length), + &(input_length, chunk_length), + |b, (input_length, chunk_length)| { + let vdaf = Prio3::new_histogram(num_shares, *input_length, *chunk_length).unwrap(); + let measurement = black_box(0); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }, + ); + } + + #[cfg(feature = "multithreaded")] + { + for (input_length, chunk_length) in [ + (10, 3), + (100, 10), + (1_000, 31), + (10_000, 100), + (100_000, 316), + ] { + if input_length >= 100_000 { + group.measurement_time(Duration::from_secs(15)); + } + group.bench_with_input( + BenchmarkId::new("parallel", input_length), + &(input_length, chunk_length), + |b, (input_length, chunk_length)| { + let vdaf = Prio3::new_histogram_multithreaded( + num_shares, + *input_length, + *chunk_length, + ) + .unwrap(); + let measurement = black_box(0); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }, + ); + } + } + group.finish(); + + let mut group = c.benchmark_group("prio3histogram_prepare_init"); + for (input_length, chunk_length) in [ + (10, 3), + (100, 10), + (1_000, 31), + (10_000, 100), + (100_000, 316), + ] { + if input_length >= 100_000 { + group.measurement_time(Duration::from_secs(15)); + } + group.bench_with_input( + BenchmarkId::new("serial", input_length), + &(input_length, chunk_length), + |b, (input_length, chunk_length)| { + let vdaf = Prio3::new_histogram(num_shares, *input_length, *chunk_length).unwrap(); + let measurement = black_box(0); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init(&verify_key, 0, &(), &nonce, &public_share, &input_shares[0]) + .unwrap() + }); + }, + ); + } + + #[cfg(feature = "multithreaded")] + { + for (input_length, chunk_length) in [ + (10, 3), + (100, 10), + (1_000, 31), + (10_000, 100), + (100_000, 316), + ] { + if input_length >= 100_000 { + group.measurement_time(Duration::from_secs(15)); + } + group.bench_with_input( + BenchmarkId::new("parallel", input_length), + &(input_length, chunk_length), + |b, (input_length, chunk_length)| { + let vdaf = Prio3::new_histogram_multithreaded( + num_shares, + *input_length, + *chunk_length, + ) + .unwrap(); + let measurement = black_box(0); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init( + &verify_key, + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() + }); + }, + ); + } + } + group.finish(); + + #[cfg(feature = "experimental")] + { + let mut group = c.benchmark_group("prio3fixedpointboundedl2vecsum_i1f15_shard"); + for dimension in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::new("serial", dimension), + &dimension, + |b, dimension| { + let vdaf: Prio3<FixedPointBoundedL2VecSum<I1F15, _, _>, _, 16> = + Prio3::new_fixedpoint_boundedl2_vec_sum(num_shares, *dimension).unwrap(); + let mut measurement = vec![fixed!(0: I1F15); *dimension]; + measurement[0] = fixed!(0.5: I1F15); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }, + ); + } + + #[cfg(feature = "multithreaded")] + { + for dimension in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::new("parallel", dimension), + &dimension, + |b, dimension| { + let vdaf: Prio3<FixedPointBoundedL2VecSum<I1F15, _, _>, _, 16> = + Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded( + num_shares, *dimension, + ) + .unwrap(); + let mut measurement = vec![fixed!(0: I1F15); *dimension]; + measurement[0] = fixed!(0.5: I1F15); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }, + ); + } + } + group.finish(); + + let mut group = c.benchmark_group("prio3fixedpointboundedl2vecsum_i1f15_prepare_init"); + for dimension in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::new("series", dimension), + &dimension, + |b, dimension| { + let vdaf: Prio3<FixedPointBoundedL2VecSum<I1F15, _, _>, _, 16> = + Prio3::new_fixedpoint_boundedl2_vec_sum(num_shares, *dimension).unwrap(); + let mut measurement = vec![fixed!(0: I1F15); *dimension]; + measurement[0] = fixed!(0.5: I1F15); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init( + &verify_key, + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() + }); + }, + ); + } + + #[cfg(feature = "multithreaded")] + { + for dimension in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::new("parallel", dimension), + &dimension, + |b, dimension| { + let vdaf: Prio3<FixedPointBoundedL2VecSum<I1F15, _, _>, _, 16> = + Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded( + num_shares, *dimension, + ) + .unwrap(); + let mut measurement = vec![fixed!(0: I1F15); *dimension]; + measurement[0] = fixed!(0.5: I1F15); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = + vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init( + &verify_key, + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() + }); + }, + ); + } + } + group.finish(); + + let mut group = c.benchmark_group("prio3fixedpointboundedl2vecsum_i1f31_shard"); + for dimension in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::new("serial", dimension), + &dimension, + |b, dimension| { + let vdaf: Prio3<FixedPointBoundedL2VecSum<I1F31, _, _>, _, 16> = + Prio3::new_fixedpoint_boundedl2_vec_sum(num_shares, *dimension).unwrap(); + let mut measurement = vec![fixed!(0: I1F31); *dimension]; + measurement[0] = fixed!(0.5: I1F31); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }, + ); + } + + #[cfg(feature = "multithreaded")] + { + for dimension in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::new("parallel", dimension), + &dimension, + |b, dimension| { + let vdaf: Prio3<FixedPointBoundedL2VecSum<I1F31, _, _>, _, 16> = + Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded( + num_shares, *dimension, + ) + .unwrap(); + let mut measurement = vec![fixed!(0: I1F31); *dimension]; + measurement[0] = fixed!(0.5: I1F31); + let nonce = black_box([0u8; 16]); + b.iter(|| vdaf.shard(&measurement, &nonce).unwrap()); + }, + ); + } + } + group.finish(); + + let mut group = c.benchmark_group("prio3fixedpointboundedl2vecsum_i1f31_prepare_init"); + for dimension in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::new("series", dimension), + &dimension, + |b, dimension| { + let vdaf: Prio3<FixedPointBoundedL2VecSum<I1F31, _, _>, _, 16> = + Prio3::new_fixedpoint_boundedl2_vec_sum(num_shares, *dimension).unwrap(); + let mut measurement = vec![fixed!(0: I1F31); *dimension]; + measurement[0] = fixed!(0.5: I1F31); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init( + &verify_key, + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() + }); + }, + ); + } + + #[cfg(feature = "multithreaded")] + { + for dimension in [10, 100, 1_000] { + group.bench_with_input( + BenchmarkId::new("parallel", dimension), + &dimension, + |b, dimension| { + let vdaf: Prio3<FixedPointBoundedL2VecSum<I1F31, _, _>, _, 16> = + Prio3::new_fixedpoint_boundedl2_vec_sum_multithreaded( + num_shares, *dimension, + ) + .unwrap(); + let mut measurement = vec![fixed!(0: I1F31); *dimension]; + measurement[0] = fixed!(0.5: I1F31); + let nonce = black_box([0u8; 16]); + let verify_key = black_box([0u8; 16]); + let (public_share, input_shares) = + vdaf.shard(&measurement, &nonce).unwrap(); + b.iter(|| { + vdaf.prepare_init( + &verify_key, + 0, + &(), + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap() + }); + }, + ); + } + } + group.finish(); + } +} + +/// Benchmark IdpfPoplar performance. +#[cfg(feature = "experimental")] +fn idpf(c: &mut Criterion) { + let test_sizes = [8usize, 8 * 16, 8 * 256]; + + let mut group = c.benchmark_group("idpf_gen"); + for size in test_sizes.iter() { + group.throughput(Throughput::Bytes(*size as u64 / 8)); + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let bits = iter::repeat_with(random).take(size).collect::<Vec<bool>>(); + let input = IdpfInput::from_bools(&bits); + + let inner_values = random_vector::<Field64>(size - 1) + .unwrap() + .into_iter() + .map(|random_element| Poplar1IdpfValue::new([Field64::one(), random_element])) + .collect::<Vec<_>>(); + let leaf_value = Poplar1IdpfValue::new([Field255::one(), random_vector(1).unwrap()[0]]); + + let idpf = Idpf::new((), ()); + b.iter(|| { + idpf.gen(&input, inner_values.clone(), leaf_value, &[0; 16]) + .unwrap(); + }); + }); + } + group.finish(); + + let mut group = c.benchmark_group("idpf_eval"); + for size in test_sizes.iter() { + group.throughput(Throughput::Bytes(*size as u64 / 8)); + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let bits = iter::repeat_with(random).take(size).collect::<Vec<bool>>(); + let input = IdpfInput::from_bools(&bits); + + let inner_values = random_vector::<Field64>(size - 1) + .unwrap() + .into_iter() + .map(|random_element| Poplar1IdpfValue::new([Field64::one(), random_element])) + .collect::<Vec<_>>(); + let leaf_value = Poplar1IdpfValue::new([Field255::one(), random_vector(1).unwrap()[0]]); + + let idpf = Idpf::new((), ()); + let (public_share, keys) = idpf + .gen(&input, inner_values, leaf_value, &[0; 16]) + .unwrap(); + + b.iter(|| { + // This is an aggressively small cache, to minimize its impact on the benchmark. + // In this synthetic benchmark, we are only checking one candidate prefix per level + // (typically there are many candidate prefixes per level) so the cache hit rate + // will be unaffected. + let mut cache = RingBufferCache::new(1); + + for prefix_length in 1..=size { + let prefix = input[..prefix_length].to_owned().into(); + idpf.eval(0, &public_share, &keys[0], &prefix, &[0; 16], &mut cache) + .unwrap(); + } + }); + }); + } + group.finish(); +} + +/// Benchmark Poplar1. +#[cfg(feature = "experimental")] +fn poplar1(c: &mut Criterion) { + let test_sizes = [16_usize, 128, 256]; + + let mut group = c.benchmark_group("poplar1_shard"); + for size in test_sizes.iter() { + group.throughput(Throughput::Bytes(*size as u64 / 8)); + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let vdaf = Poplar1::new_shake128(size); + let mut rng = StdRng::seed_from_u64(RNG_SEED); + let nonce = rng.gen::<[u8; 16]>(); + + b.iter_batched( + || { + let bits = iter::repeat_with(|| rng.gen()) + .take(size) + .collect::<Vec<bool>>(); + IdpfInput::from_bools(&bits) + }, + |measurement| { + vdaf.shard(&measurement, &nonce).unwrap(); + }, + BatchSize::SmallInput, + ); + }); + } + group.finish(); + + let mut group = c.benchmark_group("poplar1_prepare_init"); + for size in test_sizes.iter() { + group.measurement_time(Duration::from_secs(30)); // slower benchmark + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let vdaf = Poplar1::new_shake128(size); + let mut rng = StdRng::seed_from_u64(RNG_SEED); + + b.iter_batched( + || { + let verify_key: [u8; 16] = rng.gen(); + let nonce: [u8; 16] = rng.gen(); + + // Parameters are chosen to match Chris Wood's experimental setup: + // https://github.com/chris-wood/heavy-hitter-comparison + let (measurements, prefix_tree) = poplar1_generate_zipf_distributed_batch( + &mut rng, // rng + size, // bits + 10, // threshold + 1000, // number of measurements + 128, // Zipf support + 1.03, // Zipf exponent + ); + + // We are benchmarking preparation of a single report. For this test, it doesn't matter + // which measurement we generate a report for, so pick the first measurement + // arbitrarily. + let (public_share, input_shares) = + vdaf.shard(&measurements[0], &nonce).unwrap(); + + // For the aggregation paramter, we use the candidate prefixes from the prefix tree + // for the sampled measurements. Run preparation for the last step, which ought to + // represent the worst-case performance. + let agg_param = + Poplar1AggregationParam::try_from_prefixes(prefix_tree[size - 1].clone()) + .unwrap(); + + ( + verify_key, + nonce, + agg_param, + public_share, + input_shares.into_iter().next().unwrap(), + ) + }, + |(verify_key, nonce, agg_param, public_share, input_share)| { + vdaf.prepare_init( + &verify_key, + 0, + &agg_param, + &nonce, + &public_share, + &input_share, + ) + .unwrap(); + }, + BatchSize::SmallInput, + ); + }); + } + group.finish(); +} + +/// Generate a set of Poplar1 measurements with the given bit length `bits`. They are sampled +/// according to the Zipf distribution with parameters `zipf_support` and `zipf_exponent`. Return +/// the measurements, along with the prefix tree for the desired threshold. +/// +/// The prefix tree consists of a sequence of candidate prefixes for each level. For a given level, +/// the candidate prefixes are computed from the hit counts of the prefixes at the previous level: +/// For any prefix `p` whose hit count is at least the desired threshold, add `p || 0` and `p || 1` +/// to the list. +#[cfg(feature = "experimental")] +fn poplar1_generate_zipf_distributed_batch( + rng: &mut impl Rng, + bits: usize, + threshold: usize, + measurement_count: usize, + zipf_support: usize, + zipf_exponent: f64, +) -> (Vec<IdpfInput>, Vec<Vec<IdpfInput>>) { + // Generate random inputs. + let mut inputs = Vec::with_capacity(zipf_support); + for _ in 0..zipf_support { + let bools: Vec<bool> = (0..bits).map(|_| rng.gen()).collect(); + inputs.push(IdpfInput::from_bools(&bools)); + } + + // Sample a number of inputs according to the Zipf distribution. + let mut samples = Vec::with_capacity(measurement_count); + let zipf = ZipfDistribution::new(zipf_support, zipf_exponent).unwrap(); + for _ in 0..measurement_count { + samples.push(inputs[zipf.sample(rng) - 1].clone()); + } + + // Compute the prefix tree for the desired threshold. + let mut prefix_tree = Vec::with_capacity(bits); + prefix_tree.push(vec![ + IdpfInput::from_bools(&[false]), + IdpfInput::from_bools(&[true]), + ]); + + for level in 0..bits - 1 { + // Compute the hit count of each prefix from the previous level. + let mut hit_counts = vec![0; prefix_tree[level].len()]; + for (hit_count, prefix) in hit_counts.iter_mut().zip(prefix_tree[level].iter()) { + for sample in samples.iter() { + let mut is_prefix = true; + for j in 0..prefix.len() { + if prefix[j] != sample[j] { + is_prefix = false; + break; + } + } + if is_prefix { + *hit_count += 1; + } + } + } + + // Compute the next set of candidate prefixes. + let mut next_prefixes = Vec::new(); + for (hit_count, prefix) in hit_counts.iter().zip(prefix_tree[level].iter()) { + if *hit_count >= threshold { + next_prefixes.push(prefix.clone_with_suffix(&[false])); + next_prefixes.push(prefix.clone_with_suffix(&[true])); + } + } + prefix_tree.push(next_prefixes); + } + + (samples, prefix_tree) +} + +#[cfg(all(feature = "prio2", feature = "experimental"))] +criterion_group!(benches, poplar1, prio3, prio2, poly_mul, prng, idpf, dp_noise); +#[cfg(all(not(feature = "prio2"), feature = "experimental"))] +criterion_group!(benches, poplar1, prio3, poly_mul, prng, idpf, dp_noise); +#[cfg(all(feature = "prio2", not(feature = "experimental")))] +criterion_group!(benches, prio3, prio2, prng, poly_mul); +#[cfg(all(not(feature = "prio2"), not(feature = "experimental")))] +criterion_group!(benches, prio3, prng, poly_mul); + +criterion_main!(benches); diff --git a/third_party/rust/prio/documentation/releases.md b/third_party/rust/prio/documentation/releases.md new file mode 100644 index 0000000000..db0b4c246f --- /dev/null +++ b/third_party/rust/prio/documentation/releases.md @@ -0,0 +1,13 @@ +# Releases + +We use a GitHub Action to publish a crate named `prio` to [crates.io](https://crates.io). To cut a +release and publish: + +- Bump the version number in `Cargo.toml` to e.g. `1.2.3` and merge that change to `main` +- Tag that commit on main as `v1.2.3`, either in `git` or in [GitHub's releases UI][releases]. +- Publish a release in [GitHub's releases UI][releases]. + +Publishing the release will automatically publish the updated [`prio` crate][crate]. + +[releases]: https://github.com/divviup/libprio-rs/releases/new +[crate]: https://crates.io/crates/prio diff --git a/third_party/rust/prio/src/benchmarked.rs b/third_party/rust/prio/src/benchmarked.rs new file mode 100644 index 0000000000..1882de91e7 --- /dev/null +++ b/third_party/rust/prio/src/benchmarked.rs @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MPL-2.0 + +#![doc(hidden)] + +//! This module provides wrappers around internal components of this crate that we want to +//! benchmark, but which we don't want to expose in the public API. + +use crate::fft::discrete_fourier_transform; +use crate::field::FftFriendlyFieldElement; +use crate::flp::gadgets::Mul; +use crate::flp::FlpError; +use crate::polynomial::{poly_fft, PolyAuxMemory}; + +/// Sets `outp` to the Discrete Fourier Transform (DFT) using an iterative FFT algorithm. +pub fn benchmarked_iterative_fft<F: FftFriendlyFieldElement>(outp: &mut [F], inp: &[F]) { + discrete_fourier_transform(outp, inp, inp.len()).unwrap(); +} + +/// Sets `outp` to the Discrete Fourier Transform (DFT) using a recursive FFT algorithm. +pub fn benchmarked_recursive_fft<F: FftFriendlyFieldElement>(outp: &mut [F], inp: &[F]) { + let mut mem = PolyAuxMemory::new(inp.len() / 2); + poly_fft( + outp, + inp, + &mem.roots_2n, + inp.len(), + false, + &mut mem.fft_memory, + ) +} + +/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function +/// uses FFT for multiplication. +pub fn benchmarked_gadget_mul_call_poly_fft<F: FftFriendlyFieldElement>( + g: &mut Mul<F>, + outp: &mut [F], + inp: &[Vec<F>], +) -> Result<(), FlpError> { + g.call_poly_fft(outp, inp) +} + +/// Sets `outp` to `inp[0] * inp[1]`, where `inp[0]` and `inp[1]` are polynomials. This function +/// does the multiplication directly. +pub fn benchmarked_gadget_mul_call_poly_direct<F: FftFriendlyFieldElement>( + g: &mut Mul<F>, + outp: &mut [F], + inp: &[Vec<F>], +) -> Result<(), FlpError> { + g.call_poly_direct(outp, inp) +} diff --git a/third_party/rust/prio/src/codec.rs b/third_party/rust/prio/src/codec.rs new file mode 100644 index 0000000000..71f4f8ce5f --- /dev/null +++ b/third_party/rust/prio/src/codec.rs @@ -0,0 +1,734 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Support for encoding and decoding messages to or from the TLS wire encoding, as specified in +//! [RFC 8446, Section 3][1]. +//! +//! The [`Encode`], [`Decode`], [`ParameterizedEncode`] and [`ParameterizedDecode`] traits can be +//! implemented on values that need to be encoded or decoded. Utility functions are provided to +//! encode or decode sequences of values. +//! +//! [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3 + +use byteorder::{BigEndian, ReadBytesExt}; +use std::{ + convert::TryInto, + error::Error, + io::{Cursor, Read}, + mem::size_of, + num::TryFromIntError, +}; + +/// An error that occurred during decoding. +#[derive(Debug, thiserror::Error)] +pub enum CodecError { + /// An I/O error. + #[error("I/O error")] + Io(#[from] std::io::Error), + + /// Extra data remained in the input after decoding a value. + #[error("{0} bytes left in buffer after decoding value")] + BytesLeftOver(usize), + + /// The length prefix of an encoded vector exceeds the amount of remaining input. + #[error("length prefix of encoded vector overflows buffer: {0}")] + LengthPrefixTooBig(usize), + + /// Custom errors from [`Decode`] implementations. + #[error("other error: {0}")] + Other(#[source] Box<dyn Error + 'static + Send + Sync>), + + /// An invalid value was decoded. + #[error("unexpected value")] + UnexpectedValue, +} + +/// Describes how to decode an object from a byte sequence. +pub trait Decode: Sized { + /// Read and decode an encoded object from `bytes`. On success, the decoded value is returned + /// and `bytes` is advanced by the encoded size of the value. On failure, an error is returned + /// and no further attempt to read from `bytes` should be made. + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError>; + + /// Convenience method to get a decoded value. Returns an error if [`Self::decode`] fails, or if + /// there are any bytes left in `bytes` after decoding a value. + fn get_decoded(bytes: &[u8]) -> Result<Self, CodecError> { + Self::get_decoded_with_param(&(), bytes) + } +} + +/// Describes how to decode an object from a byte sequence and a decoding parameter that provides +/// additional context. +pub trait ParameterizedDecode<P>: Sized { + /// Read and decode an encoded object from `bytes`. `decoding_parameter` provides details of the + /// wire encoding such as lengths of different portions of the message. On success, the decoded + /// value is returned and `bytes` is advanced by the encoded size of the value. On failure, an + /// error is returned and no further attempt to read from `bytes` should be made. + fn decode_with_param( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError>; + + /// Convenience method to get a decoded value. Returns an error if [`Self::decode_with_param`] + /// fails, or if there are any bytes left in `bytes` after decoding a value. + fn get_decoded_with_param(decoding_parameter: &P, bytes: &[u8]) -> Result<Self, CodecError> { + let mut cursor = Cursor::new(bytes); + let decoded = Self::decode_with_param(decoding_parameter, &mut cursor)?; + if cursor.position() as usize != bytes.len() { + return Err(CodecError::BytesLeftOver( + bytes.len() - cursor.position() as usize, + )); + } + + Ok(decoded) + } +} + +/// Provide a blanket implementation so that any [`Decode`] can be used as a +/// `ParameterizedDecode<T>` for any `T`. +impl<D: Decode + ?Sized, T> ParameterizedDecode<T> for D { + fn decode_with_param( + _decoding_parameter: &T, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + Self::decode(bytes) + } +} + +/// Describes how to encode objects into a byte sequence. +pub trait Encode { + /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. + fn encode(&self, bytes: &mut Vec<u8>); + + /// Convenience method to encode a value into a new `Vec<u8>`. + fn get_encoded(&self) -> Vec<u8> { + self.get_encoded_with_param(&()) + } + + /// Returns an optional hint indicating how many bytes will be required to encode this value, or + /// `None` by default. + fn encoded_len(&self) -> Option<usize> { + None + } +} + +/// Describes how to encode objects into a byte sequence. +pub trait ParameterizedEncode<P> { + /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. + /// `encoding_parameter` provides details of the wire encoding, used to control how the value + /// is encoded. + fn encode_with_param(&self, encoding_parameter: &P, bytes: &mut Vec<u8>); + + /// Convenience method to encode a value into a new `Vec<u8>`. + fn get_encoded_with_param(&self, encoding_parameter: &P) -> Vec<u8> { + let mut ret = if let Some(length) = self.encoded_len_with_param(encoding_parameter) { + Vec::with_capacity(length) + } else { + Vec::new() + }; + self.encode_with_param(encoding_parameter, &mut ret); + ret + } + + /// Returns an optional hint indicating how many bytes will be required to encode this value, or + /// `None` by default. + fn encoded_len_with_param(&self, _encoding_parameter: &P) -> Option<usize> { + None + } +} + +/// Provide a blanket implementation so that any [`Encode`] can be used as a +/// `ParameterizedEncode<T>` for any `T`. +impl<E: Encode + ?Sized, T> ParameterizedEncode<T> for E { + fn encode_with_param(&self, _encoding_parameter: &T, bytes: &mut Vec<u8>) { + self.encode(bytes) + } + + fn encoded_len_with_param(&self, _encoding_parameter: &T) -> Option<usize> { + <Self as Encode>::encoded_len(self) + } +} + +impl Decode for () { + fn decode(_bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(()) + } +} + +impl Encode for () { + fn encode(&self, _bytes: &mut Vec<u8>) {} + + fn encoded_len(&self) -> Option<usize> { + Some(0) + } +} + +impl Decode for u8 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let mut value = [0u8; size_of::<u8>()]; + bytes.read_exact(&mut value)?; + Ok(value[0]) + } +} + +impl Encode for u8 { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.push(*self); + } + + fn encoded_len(&self) -> Option<usize> { + Some(1) + } +} + +impl Decode for u16 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(bytes.read_u16::<BigEndian>()?) + } +} + +impl Encode for u16 { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&u16::to_be_bytes(*self)); + } + + fn encoded_len(&self) -> Option<usize> { + Some(2) + } +} + +/// 24 bit integer, per +/// [RFC 8443, section 3.3](https://datatracker.ietf.org/doc/html/rfc8446#section-3.3) +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct U24(pub u32); + +impl Decode for U24 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(U24(bytes.read_u24::<BigEndian>()?)) + } +} + +impl Encode for U24 { + fn encode(&self, bytes: &mut Vec<u8>) { + // Encode lower three bytes of the u32 as u24 + bytes.extend_from_slice(&u32::to_be_bytes(self.0)[1..]); + } + + fn encoded_len(&self) -> Option<usize> { + Some(3) + } +} + +impl Decode for u32 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(bytes.read_u32::<BigEndian>()?) + } +} + +impl Encode for u32 { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&u32::to_be_bytes(*self)); + } + + fn encoded_len(&self) -> Option<usize> { + Some(4) + } +} + +impl Decode for u64 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(bytes.read_u64::<BigEndian>()?) + } +} + +impl Encode for u64 { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&u64::to_be_bytes(*self)); + } + + fn encoded_len(&self) -> Option<usize> { + Some(8) + } +} + +/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn encode_u8_items<P, E: ParameterizedEncode<P>>( + bytes: &mut Vec<u8>, + encoding_parameter: &P, + items: &[E], +) { + // Reserve space to later write length + let len_offset = bytes.len(); + bytes.push(0); + + for item in items { + item.encode_with_param(encoding_parameter, bytes); + } + + let len = bytes.len() - len_offset - 1; + assert!(len <= usize::from(u8::MAX)); + bytes[len_offset] = len as u8; +} + +/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of +/// maximum length `0xff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn decode_u8_items<P, D: ParameterizedDecode<P>>( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + // Read one byte to get length of opaque byte vector + let length = usize::from(u8::decode(bytes)?); + + decode_items(length, decoding_parameter, bytes) +} + +/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of `0xffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn encode_u16_items<P, E: ParameterizedEncode<P>>( + bytes: &mut Vec<u8>, + encoding_parameter: &P, + items: &[E], +) { + // Reserve space to later write length + let len_offset = bytes.len(); + 0u16.encode(bytes); + + for item in items { + item.encode_with_param(encoding_parameter, bytes); + } + + let len = bytes.len() - len_offset - 2; + assert!(len <= usize::from(u16::MAX)); + for (offset, byte) in u16::to_be_bytes(len as u16).iter().enumerate() { + bytes[len_offset + offset] = *byte; + } +} + +/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of +/// maximum length `0xffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn decode_u16_items<P, D: ParameterizedDecode<P>>( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + // Read two bytes to get length of opaque byte vector + let length = usize::from(u16::decode(bytes)?); + + decode_items(length, decoding_parameter, bytes) +} + +/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of +/// `0xffffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn encode_u24_items<P, E: ParameterizedEncode<P>>( + bytes: &mut Vec<u8>, + encoding_parameter: &P, + items: &[E], +) { + // Reserve space to later write length + let len_offset = bytes.len(); + U24(0).encode(bytes); + + for item in items { + item.encode_with_param(encoding_parameter, bytes); + } + + let len = bytes.len() - len_offset - 3; + assert!(len <= 0xffffff); + for (offset, byte) in u32::to_be_bytes(len as u32)[1..].iter().enumerate() { + bytes[len_offset + offset] = *byte; + } +} + +/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of +/// maximum length `0xffffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn decode_u24_items<P, D: ParameterizedDecode<P>>( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + // Read three bytes to get length of opaque byte vector + let length = U24::decode(bytes)?.0 as usize; + + decode_items(length, decoding_parameter, bytes) +} + +/// Encode `items` into `bytes` as a [variable-length vector][1] with a maximum length of +/// `0xffffffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn encode_u32_items<P, E: ParameterizedEncode<P>>( + bytes: &mut Vec<u8>, + encoding_parameter: &P, + items: &[E], +) { + // Reserve space to later write length + let len_offset = bytes.len(); + 0u32.encode(bytes); + + for item in items { + item.encode_with_param(encoding_parameter, bytes); + } + + let len = bytes.len() - len_offset - 4; + let len: u32 = len.try_into().expect("Length too large"); + for (offset, byte) in len.to_be_bytes().iter().enumerate() { + bytes[len_offset + offset] = *byte; + } +} + +/// Decode `bytes` into a vector of `D` values, treating `bytes` as a [variable-length vector][1] of +/// maximum length `0xffffffff`. +/// +/// [1]: https://datatracker.ietf.org/doc/html/rfc8446#section-3.4 +pub fn decode_u32_items<P, D: ParameterizedDecode<P>>( + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + // Read four bytes to get length of opaque byte vector. + let len: usize = u32::decode(bytes)? + .try_into() + .map_err(|err: TryFromIntError| CodecError::Other(err.into()))?; + + decode_items(len, decoding_parameter, bytes) +} + +/// Decode the next `length` bytes from `bytes` into as many instances of `D` as possible. +fn decode_items<P, D: ParameterizedDecode<P>>( + length: usize, + decoding_parameter: &P, + bytes: &mut Cursor<&[u8]>, +) -> Result<Vec<D>, CodecError> { + let mut decoded = Vec::new(); + let initial_position = bytes.position() as usize; + + // Create cursor over specified portion of provided cursor to ensure we can't read past length. + let inner = bytes.get_ref(); + + // Make sure encoded length doesn't overflow usize or go past the end of provided byte buffer. + let (items_end, overflowed) = initial_position.overflowing_add(length); + if overflowed || items_end > inner.len() { + return Err(CodecError::LengthPrefixTooBig(length)); + } + + let mut sub = Cursor::new(&bytes.get_ref()[initial_position..items_end]); + + while sub.position() < length as u64 { + decoded.push(D::decode_with_param(decoding_parameter, &mut sub)?); + } + + // Advance outer cursor by the amount read in the inner cursor + bytes.set_position(initial_position as u64 + sub.position()); + + Ok(decoded) +} + +#[cfg(test)] +mod tests { + + use super::*; + use assert_matches::assert_matches; + + #[test] + fn encode_nothing() { + let mut bytes = vec![]; + ().encode(&mut bytes); + assert_eq!(bytes.len(), 0); + } + + #[test] + fn roundtrip_u8() { + let value = 100u8; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 1); + + let decoded = u8::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[test] + fn roundtrip_u16() { + let value = 1000u16; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 2); + // Check endianness of encoding + assert_eq!(bytes, vec![3, 232]); + + let decoded = u16::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[test] + fn roundtrip_u24() { + let value = U24(1_000_000u32); + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 3); + // Check endianness of encoding + assert_eq!(bytes, vec![15, 66, 64]); + + let decoded = U24::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[test] + fn roundtrip_u32() { + let value = 134_217_728u32; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 4); + // Check endianness of encoding + assert_eq!(bytes, vec![8, 0, 0, 0]); + + let decoded = u32::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[test] + fn roundtrip_u64() { + let value = 137_438_953_472u64; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), 8); + // Check endianness of encoding + assert_eq!(bytes, vec![0, 0, 0, 32, 0, 0, 0, 0]); + + let decoded = u64::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + #[derive(Debug, Eq, PartialEq)] + struct TestMessage { + field_u8: u8, + field_u16: u16, + field_u24: U24, + field_u32: u32, + field_u64: u64, + } + + impl Encode for TestMessage { + fn encode(&self, bytes: &mut Vec<u8>) { + self.field_u8.encode(bytes); + self.field_u16.encode(bytes); + self.field_u24.encode(bytes); + self.field_u32.encode(bytes); + self.field_u64.encode(bytes); + } + + fn encoded_len(&self) -> Option<usize> { + Some( + self.field_u8.encoded_len()? + + self.field_u16.encoded_len()? + + self.field_u24.encoded_len()? + + self.field_u32.encoded_len()? + + self.field_u64.encoded_len()?, + ) + } + } + + impl Decode for TestMessage { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let field_u8 = u8::decode(bytes)?; + let field_u16 = u16::decode(bytes)?; + let field_u24 = U24::decode(bytes)?; + let field_u32 = u32::decode(bytes)?; + let field_u64 = u64::decode(bytes)?; + + Ok(TestMessage { + field_u8, + field_u16, + field_u24, + field_u32, + field_u64, + }) + } + } + + impl TestMessage { + fn encoded_length() -> usize { + // u8 field + 1 + + // u16 field + 2 + + // u24 field + 3 + + // u32 field + 4 + + // u64 field + 8 + } + } + + #[test] + fn roundtrip_message() { + let value = TestMessage { + field_u8: 0, + field_u16: 300, + field_u24: U24(1_000_000), + field_u32: 134_217_728, + field_u64: 137_438_953_472, + }; + + let mut bytes = vec![]; + value.encode(&mut bytes); + assert_eq!(bytes.len(), TestMessage::encoded_length()); + assert_eq!(value.encoded_len().unwrap(), TestMessage::encoded_length()); + + let decoded = TestMessage::decode(&mut Cursor::new(&bytes)).unwrap(); + assert_eq!(value, decoded); + } + + fn messages_vec() -> Vec<TestMessage> { + vec![ + TestMessage { + field_u8: 0, + field_u16: 300, + field_u24: U24(1_000_000), + field_u32: 134_217_728, + field_u64: 137_438_953_472, + }, + TestMessage { + field_u8: 0, + field_u16: 300, + field_u24: U24(1_000_000), + field_u32: 134_217_728, + field_u64: 137_438_953_472, + }, + TestMessage { + field_u8: 0, + field_u16: 300, + field_u24: U24(1_000_000), + field_u32: 134_217_728, + field_u64: 137_438_953_472, + }, + ] + } + + #[test] + fn roundtrip_variable_length_u8() { + let values = messages_vec(); + let mut bytes = vec![]; + encode_u8_items(&mut bytes, &(), &values); + + assert_eq!( + bytes.len(), + // Length of opaque vector + 1 + + // 3 TestMessage values + 3 * TestMessage::encoded_length() + ); + + let decoded = decode_u8_items(&(), &mut Cursor::new(&bytes)).unwrap(); + assert_eq!(values, decoded); + } + + #[test] + fn roundtrip_variable_length_u16() { + let values = messages_vec(); + let mut bytes = vec![]; + encode_u16_items(&mut bytes, &(), &values); + + assert_eq!( + bytes.len(), + // Length of opaque vector + 2 + + // 3 TestMessage values + 3 * TestMessage::encoded_length() + ); + + // Check endianness of encoded length + assert_eq!(bytes[0..2], [0, 3 * TestMessage::encoded_length() as u8]); + + let decoded = decode_u16_items(&(), &mut Cursor::new(&bytes)).unwrap(); + assert_eq!(values, decoded); + } + + #[test] + fn roundtrip_variable_length_u24() { + let values = messages_vec(); + let mut bytes = vec![]; + encode_u24_items(&mut bytes, &(), &values); + + assert_eq!( + bytes.len(), + // Length of opaque vector + 3 + + // 3 TestMessage values + 3 * TestMessage::encoded_length() + ); + + // Check endianness of encoded length + assert_eq!(bytes[0..3], [0, 0, 3 * TestMessage::encoded_length() as u8]); + + let decoded = decode_u24_items(&(), &mut Cursor::new(&bytes)).unwrap(); + assert_eq!(values, decoded); + } + + #[test] + fn roundtrip_variable_length_u32() { + let values = messages_vec(); + let mut bytes = Vec::new(); + encode_u32_items(&mut bytes, &(), &values); + + assert_eq!(bytes.len(), 4 + 3 * TestMessage::encoded_length()); + + // Check endianness of encoded length. + assert_eq!( + bytes[0..4], + [0, 0, 0, 3 * TestMessage::encoded_length() as u8] + ); + + let decoded = decode_u32_items(&(), &mut Cursor::new(&bytes)).unwrap(); + assert_eq!(values, decoded); + } + + #[test] + fn decode_items_overflow() { + let encoded = vec![1u8]; + + let mut cursor = Cursor::new(encoded.as_slice()); + cursor.set_position(1); + + assert_matches!( + decode_items::<(), u8>(usize::MAX, &(), &mut cursor).unwrap_err(), + CodecError::LengthPrefixTooBig(usize::MAX) + ); + } + + #[test] + fn decode_items_too_big() { + let encoded = vec![1u8]; + + let mut cursor = Cursor::new(encoded.as_slice()); + cursor.set_position(1); + + assert_matches!( + decode_items::<(), u8>(2, &(), &mut cursor).unwrap_err(), + CodecError::LengthPrefixTooBig(2) + ); + } + + #[test] + fn length_hint_correctness() { + assert_eq!(().encoded_len().unwrap(), ().get_encoded().len()); + assert_eq!(0u8.encoded_len().unwrap(), 0u8.get_encoded().len()); + assert_eq!(0u16.encoded_len().unwrap(), 0u16.get_encoded().len()); + assert_eq!(U24(0).encoded_len().unwrap(), U24(0).get_encoded().len()); + assert_eq!(0u32.encoded_len().unwrap(), 0u32.get_encoded().len()); + assert_eq!(0u64.encoded_len().unwrap(), 0u64.get_encoded().len()); + } +} diff --git a/third_party/rust/prio/src/dp.rs b/third_party/rust/prio/src/dp.rs new file mode 100644 index 0000000000..506676dbb9 --- /dev/null +++ b/third_party/rust/prio/src/dp.rs @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Differential privacy (DP) primitives. +//! +//! There are three main traits defined in this module: +//! +//! - `DifferentialPrivacyBudget`: Implementors should be types of DP-budgets, +//! i.e., methods to measure the amount of privacy provided by DP-mechanisms. +//! Examples: zCDP, ApproximateDP (Epsilon-Delta), PureDP +//! +//! - `DifferentialPrivacyDistribution`: Distribution from which noise is sampled. +//! Examples: DiscreteGaussian, DiscreteLaplace +//! +//! - `DifferentialPrivacyStrategy`: This is a combination of choices for budget and distribution. +//! Examples: zCDP-DiscreteGaussian, EpsilonDelta-DiscreteGaussian +//! +use num_bigint::{BigInt, BigUint, TryFromBigIntError}; +use num_rational::{BigRational, Ratio}; +use serde::{Deserialize, Serialize}; + +/// Errors propagated by methods in this module. +#[derive(Debug, thiserror::Error)] +pub enum DpError { + /// Tried to use an invalid float as privacy parameter. + #[error( + "DP error: input value was not a valid privacy parameter. \ + It should to be a non-negative, finite float." + )] + InvalidFloat, + + /// Tried to construct a rational number with zero denominator. + #[error("DP error: input denominator was zero.")] + ZeroDenominator, + + /// Tried to convert BigInt into something incompatible. + #[error("DP error: {0}")] + BigIntConversion(#[from] TryFromBigIntError<BigInt>), +} + +/// Positive arbitrary precision rational number to represent DP and noise distribution parameters in +/// protocol messages and manipulate them without rounding errors. +#[derive(Clone, Debug)] +pub struct Rational(Ratio<BigUint>); + +impl Rational { + /// Construct a [`Rational`] number from numerator `n` and denominator `d`. Errors if denominator is zero. + pub fn from_unsigned<T>(n: T, d: T) -> Result<Self, DpError> + where + T: Into<u128>, + { + // we don't want to expose BigUint in the public api, hence the Into<u128> bound + let d = d.into(); + if d == 0 { + Err(DpError::ZeroDenominator) + } else { + Ok(Rational(Ratio::<BigUint>::new(n.into().into(), d.into()))) + } + } +} + +impl TryFrom<f32> for Rational { + type Error = DpError; + /// Constructs a `Rational` from a given `f32` value. + /// + /// The special float values (NaN, positive and negative infinity) result in + /// an error. All other values are represented exactly, without rounding errors. + fn try_from(value: f32) -> Result<Self, DpError> { + match BigRational::from_float(value) { + Some(y) => Ok(Rational(Ratio::<BigUint>::new( + y.numer().clone().try_into()?, + y.denom().clone().try_into()?, + ))), + None => Err(DpError::InvalidFloat)?, + } + } +} + +/// Marker trait for differential privacy budgets (regardless of the specific accounting method). +pub trait DifferentialPrivacyBudget {} + +/// Marker trait for differential privacy scalar noise distributions. +pub trait DifferentialPrivacyDistribution {} + +/// Zero-concentrated differential privacy (ZCDP) budget as defined in [[BS16]]. +/// +/// [BS16]: https://arxiv.org/pdf/1605.02065.pdf +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Ord, PartialOrd)] +pub struct ZCdpBudget { + epsilon: Ratio<BigUint>, +} + +impl ZCdpBudget { + /// Create a budget for parameter `epsilon`, using the notation from [[CKS20]] where `rho = (epsilon**2)/2` + /// for a `rho`-ZCDP budget. + /// + /// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf + pub fn new(epsilon: Rational) -> Self { + Self { epsilon: epsilon.0 } + } +} + +impl DifferentialPrivacyBudget for ZCdpBudget {} + +/// Strategy to make aggregate results differentially private, e.g. by adding noise from a specific +/// type of distribution instantiated with a given DP budget. +pub trait DifferentialPrivacyStrategy { + /// The type of the DP budget, i.e. the variant of differential privacy that can be obtained + /// by using this strategy. + type Budget: DifferentialPrivacyBudget; + + /// The distribution type this strategy will use to generate the noise. + type Distribution: DifferentialPrivacyDistribution; + + /// The type the sensitivity used for privacy analysis has. + type Sensitivity; + + /// Create a strategy from a differential privacy budget. The distribution created with + /// `create_distribution` should provide the amount of privacy specified here. + fn from_budget(b: Self::Budget) -> Self; + + /// Create a new distribution parametrized s.t. adding samples to the result of a function + /// with sensitivity `s` will yield differential privacy of the DP variant given in the + /// `Budget` type. Can error upon invalid parameters. + fn create_distribution(&self, s: Self::Sensitivity) -> Result<Self::Distribution, DpError>; +} + +pub mod distributions; diff --git a/third_party/rust/prio/src/dp/distributions.rs b/third_party/rust/prio/src/dp/distributions.rs new file mode 100644 index 0000000000..ba0270df9c --- /dev/null +++ b/third_party/rust/prio/src/dp/distributions.rs @@ -0,0 +1,607 @@ +// Copyright (c) 2023 ISRG +// SPDX-License-Identifier: MPL-2.0 +// +// This file contains code covered by the following copyright and permission notice +// and has been modified by ISRG and collaborators. +// +// Copyright (c) 2022 President and Fellows of Harvard College +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// This file incorporates work covered by the following copyright and +// permission notice: +// +// Copyright 2020 Thomas Steinke +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The following code is adapted from the opendp implementation to reduce dependencies: +// https://github.com/opendp/opendp/blob/main/rust/src/traits/samplers/cks20 + +//! Implementation of a sampler from the Discrete Gaussian Distribution. +//! +//! Follows +//! Clément Canonne, Gautam Kamath, Thomas Steinke. The Discrete Gaussian for Differential Privacy. 2020. +//! <https://arxiv.org/pdf/2004.00010.pdf> + +use num_bigint::{BigInt, BigUint, UniformBigUint}; +use num_integer::Integer; +use num_iter::range_inclusive; +use num_rational::Ratio; +use num_traits::{One, Zero}; +use rand::{distributions::uniform::UniformSampler, distributions::Distribution, Rng}; +use serde::{Deserialize, Serialize}; + +use super::{ + DifferentialPrivacyBudget, DifferentialPrivacyDistribution, DifferentialPrivacyStrategy, + DpError, ZCdpBudget, +}; + +/// Sample from the Bernoulli(gamma) distribution, where $gamma /leq 1$. +/// +/// `sample_bernoulli(gamma, rng)` returns numbers distributed as $Bernoulli(gamma)$. +/// using the given random number generator for base randomness. The procedure is as described +/// on page 30 of [[CKS20]]. +/// +/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf +fn sample_bernoulli<R: Rng + ?Sized>(gamma: &Ratio<BigUint>, rng: &mut R) -> bool { + let d = gamma.denom(); + assert!(!d.is_zero()); + assert!(gamma <= &Ratio::<BigUint>::one()); + + // sample uniform biguint in {1,...,d} + // uses the implementation of rand::Uniform for num_bigint::BigUint + let s = UniformBigUint::sample_single_inclusive(BigUint::one(), d, rng); + + s <= *gamma.numer() +} + +/// Sample from the Bernoulli(exp(-gamma)) distribution where `gamma` is in `[0,1]`. +/// +/// `sample_bernoulli_exp1(gamma, rng)` returns numbers distributed as $Bernoulli(exp(-gamma))$, +/// using the given random number generator for base randomness. Follows Algorithm 1 of [[CKS20]], +/// splitting the branches into two non-recursive functions. This is the `gamma in [0,1]` branch. +/// +/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf +fn sample_bernoulli_exp1<R: Rng + ?Sized>(gamma: &Ratio<BigUint>, rng: &mut R) -> bool { + assert!(!gamma.denom().is_zero()); + assert!(gamma <= &Ratio::<BigUint>::one()); + + let mut k = BigUint::one(); + loop { + if sample_bernoulli(&(gamma / k.clone()), rng) { + k += 1u8; + } else { + return k.is_odd(); + } + } +} + +/// Sample from the Bernoulli(exp(-gamma)) distribution. +/// +/// `sample_bernoulli_exp(gamma, rng)` returns numbers distributed as $Bernoulli(exp(-gamma))$, +/// using the given random number generator for base randomness. Follows Algorithm 1 of [[CKS20]], +/// splitting the branches into two non-recursive functions. This is the `gamma > 1` branch. +/// +/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf +fn sample_bernoulli_exp<R: Rng + ?Sized>(gamma: &Ratio<BigUint>, rng: &mut R) -> bool { + assert!(!gamma.denom().is_zero()); + for _ in range_inclusive(BigUint::one(), gamma.floor().to_integer()) { + if !sample_bernoulli_exp1(&Ratio::<BigUint>::one(), rng) { + return false; + } + } + sample_bernoulli_exp1(&(gamma - gamma.floor()), rng) +} + +/// Sample from the geometric distribution with parameter 1 - exp(-gamma). +/// +/// `sample_geometric_exp(gamma, rng)` returns numbers distributed according to +/// $Geometric(1 - exp(-gamma))$, using the given random number generator for base randomness. +/// The code follows all but the last three lines of Algorithm 2 in [[CKS20]]. +/// +/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf +fn sample_geometric_exp<R: Rng + ?Sized>(gamma: &Ratio<BigUint>, rng: &mut R) -> BigUint { + let (s, t) = (gamma.numer(), gamma.denom()); + assert!(!t.is_zero()); + if gamma.is_zero() { + return BigUint::zero(); + } + + // sampler for uniform biguint in {0...t-1} + // uses the implementation of rand::Uniform for num_bigint::BigUint + let usampler = UniformBigUint::new(BigUint::zero(), t); + let mut u = usampler.sample(rng); + + while !sample_bernoulli_exp1(&Ratio::<BigUint>::new(u.clone(), t.clone()), rng) { + u = usampler.sample(rng); + } + + let mut v = BigUint::zero(); + loop { + if sample_bernoulli_exp1(&Ratio::<BigUint>::one(), rng) { + v += 1u8; + } else { + break; + } + } + + // we do integer division, so the following term equals floor((u + t*v)/s) + (u + t * v) / s +} + +/// Sample from the discrete Laplace distribution. +/// +/// `sample_discrete_laplace(scale, rng)` returns numbers distributed according to +/// $\mathcal{L}_\mathbb{Z}(0, scale)$, using the given random number generator for base randomness. +/// This follows Algorithm 2 of [[CKS20]], using a subfunction for geometric sampling. +/// +/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf +fn sample_discrete_laplace<R: Rng + ?Sized>(scale: &Ratio<BigUint>, rng: &mut R) -> BigInt { + let (s, t) = (scale.numer(), scale.denom()); + assert!(!t.is_zero()); + if s.is_zero() { + return BigInt::zero(); + } + + loop { + let negative = sample_bernoulli(&Ratio::<BigUint>::new(BigUint::one(), 2u8.into()), rng); + let y: BigInt = sample_geometric_exp(&scale.recip(), rng).into(); + if negative && y.is_zero() { + continue; + } else { + return if negative { -y } else { y }; + } + } +} + +/// Sample from the discrete Gaussian distribution. +/// +/// `sample_discrete_gaussian(sigma, rng)` returns `BigInt` numbers distributed as +/// $\mathcal{N}_\mathbb{Z}(0, sigma^2)$, using the given random number generator for base +/// randomness. Follows Algorithm 3 from [[CKS20]]. +/// +/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf +fn sample_discrete_gaussian<R: Rng + ?Sized>(sigma: &Ratio<BigUint>, rng: &mut R) -> BigInt { + assert!(!sigma.denom().is_zero()); + if sigma.is_zero() { + return 0.into(); + } + let t = sigma.floor() + BigUint::one(); + + // no need to compute these parts of the probability term every iteration + let summand = sigma.pow(2) / t.clone(); + // compute probability of accepting the laplace sample y + let prob = |term: Ratio<BigUint>| term.pow(2) * (sigma.pow(2) * BigUint::from(2u8)).recip(); + + loop { + let y = sample_discrete_laplace(&t, rng); + + // absolute value without type conversion + let y_abs: Ratio<BigUint> = BigUint::new(y.to_u32_digits().1).into(); + + // unsigned subtraction-followed-by-square + let prob: Ratio<BigUint> = if y_abs < summand { + prob(summand.clone() - y_abs) + } else { + prob(y_abs - summand.clone()) + }; + + if sample_bernoulli_exp(&prob, rng) { + return y; + } + } +} + +/// Samples `BigInt` numbers according to the discrete Gaussian distribution with mean zero. +/// The distribution is defined over the integers, represented by arbitrary-precision integers. +/// The sampling procedure follows [[CKS20]]. +/// +/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf +#[derive(Clone, Debug)] +pub struct DiscreteGaussian { + /// The standard deviation of the distribution. + std: Ratio<BigUint>, +} + +impl DiscreteGaussian { + /// Create a new sampler from the Discrete Gaussian Distribution with the given + /// standard deviation and mean zero. Errors if the input has denominator zero. + pub fn new(std: Ratio<BigUint>) -> Result<DiscreteGaussian, DpError> { + if std.denom().is_zero() { + return Err(DpError::ZeroDenominator); + } + Ok(DiscreteGaussian { std }) + } +} + +impl Distribution<BigInt> for DiscreteGaussian { + fn sample<R>(&self, rng: &mut R) -> BigInt + where + R: Rng + ?Sized, + { + sample_discrete_gaussian(&self.std, rng) + } +} + +impl DifferentialPrivacyDistribution for DiscreteGaussian {} + +/// A DP strategy using the discrete gaussian distribution. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Ord, PartialOrd)] +pub struct DiscreteGaussianDpStrategy<B> +where + B: DifferentialPrivacyBudget, +{ + budget: B, +} + +/// A DP strategy using the discrete gaussian distribution providing zero-concentrated DP. +pub type ZCdpDiscreteGaussian = DiscreteGaussianDpStrategy<ZCdpBudget>; + +impl DifferentialPrivacyStrategy for DiscreteGaussianDpStrategy<ZCdpBudget> { + type Budget = ZCdpBudget; + type Distribution = DiscreteGaussian; + type Sensitivity = Ratio<BigUint>; + + fn from_budget(budget: ZCdpBudget) -> DiscreteGaussianDpStrategy<ZCdpBudget> { + DiscreteGaussianDpStrategy { budget } + } + + /// Create a new sampler from the Discrete Gaussian Distribution with a standard + /// deviation calibrated to provide `1/2 epsilon^2` zero-concentrated differential + /// privacy when added to the result of an integer-valued function with sensitivity + /// `sensitivity`, following Theorem 4 from [[CKS20]] + /// + /// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf + fn create_distribution( + &self, + sensitivity: Ratio<BigUint>, + ) -> Result<DiscreteGaussian, DpError> { + DiscreteGaussian::new(sensitivity / self.budget.epsilon.clone()) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::dp::Rational; + use crate::vdaf::xof::SeedStreamSha3; + + use num_bigint::{BigUint, Sign, ToBigInt, ToBigUint}; + use num_traits::{One, Signed, ToPrimitive}; + use rand::{distributions::Distribution, SeedableRng}; + use statrs::distribution::{ChiSquared, ContinuousCDF, Normal}; + use std::collections::HashMap; + + #[test] + fn test_discrete_gaussian() { + let sampler = + DiscreteGaussian::new(Ratio::<BigUint>::from_integer(BigUint::from(5u8))).unwrap(); + + // check samples are consistent + let mut rng = SeedStreamSha3::from_seed([0u8; 16]); + let samples: Vec<i8> = (0..10) + .map(|_| i8::try_from(sampler.sample(&mut rng)).unwrap()) + .collect(); + let samples1: Vec<i8> = (0..10) + .map(|_| i8::try_from(sampler.sample(&mut rng)).unwrap()) + .collect(); + assert_eq!(samples, vec![-3, -11, -3, 5, 1, 5, 2, 2, 1, 18]); + assert_eq!(samples1, vec![4, -4, -5, -2, 0, -5, -3, 1, 1, -2]); + } + + #[test] + /// Make sure that the distribution created by `create_distribution` + /// of `ZCdpDicreteGaussian` is the same one as manually creating one + /// by using the constructor of `DiscreteGaussian` directly. + fn test_zcdp_discrete_gaussian() { + // sample from a manually created distribution + let sampler1 = + DiscreteGaussian::new(Ratio::<BigUint>::from_integer(BigUint::from(4u8))).unwrap(); + let mut rng = SeedStreamSha3::from_seed([0u8; 16]); + let samples1: Vec<i8> = (0..10) + .map(|_| i8::try_from(sampler1.sample(&mut rng)).unwrap()) + .collect(); + + // sample from the distribution created by the `zcdp` strategy + let zcdp = ZCdpDiscreteGaussian { + budget: ZCdpBudget::new(Rational::try_from(0.25).unwrap()), + }; + let sampler2 = zcdp + .create_distribution(Ratio::<BigUint>::from_integer(1u8.into())) + .unwrap(); + let mut rng2 = SeedStreamSha3::from_seed([0u8; 16]); + let samples2: Vec<i8> = (0..10) + .map(|_| i8::try_from(sampler2.sample(&mut rng2)).unwrap()) + .collect(); + + assert_eq!(samples2, samples1); + } + + pub fn test_mean<FS: FnMut() -> BigInt>( + mut sampler: FS, + hyp_mean: f64, + hyp_var: f64, + alpha: f64, + n: u32, + ) -> bool { + // we test if the mean from our sampler is within the given error margin assuimng its + // normally distributed with mean hyp_mean and variance sqrt(hyp_var/n) + // this assumption is from the central limit theorem + + // inverse cdf (quantile function) is F s.t. P[X<=F(p)]=p for X ~ N(0,1) + // (i.e. X from the standard normal distribution) + let probit = |p| Normal::new(0.0, 1.0).unwrap().inverse_cdf(p); + + // x such that the probability of a N(0,1) variable attaining + // a value outside of (-x, x) is alpha + let z_stat = probit(alpha / 2.).abs(); + + // confidence interval for the mean + let abs_p_tol = Ratio::<BigInt>::from_float(z_stat * (hyp_var / n as f64).sqrt()).unwrap(); + + // take n samples from the distribution, compute empirical mean + let emp_mean = Ratio::<BigInt>::new((0..n).map(|_| sampler()).sum::<BigInt>(), n.into()); + + (emp_mean - Ratio::<BigInt>::from_float(hyp_mean).unwrap()).abs() < abs_p_tol + } + + fn histogram( + d: &Vec<BigInt>, + bin_bounds: &[Option<(BigInt, BigInt)>], + smallest: BigInt, + largest: BigInt, + ) -> HashMap<Option<(BigInt, BigInt)>, u64> { + // a binned histogram of the samples in `d` + // used for chi_square test + + fn insert<T>(hist: &mut HashMap<T, u64>, key: &T, val: u64) + where + T: Eq + std::hash::Hash + Clone, + { + *hist.entry(key.clone()).or_default() += val; + } + + // regular histogram + let mut hist = HashMap::<BigInt, u64>::new(); + //binned histogram + let mut bin_hist = HashMap::<Option<(BigInt, BigInt)>, u64>::new(); + + for val in d { + // throw outliers with bound bins + if val < &smallest || val > &largest { + insert(&mut bin_hist, &None, 1); + } else { + insert(&mut hist, val, 1); + } + } + // sort values into their bins + for (a, b) in bin_bounds.iter().flatten() { + for i in range_inclusive(a.clone(), b.clone()) { + if let Some(count) = hist.get(&i) { + insert(&mut bin_hist, &Some((a.clone(), b.clone())), *count); + } + } + } + bin_hist + } + + fn discrete_gauss_cdf_approx( + sigma: &BigUint, + bin_bounds: &[Option<(BigInt, BigInt)>], + ) -> HashMap<Option<(BigInt, BigInt)>, f64> { + // approximate bin probabilties from theoretical distribution + // formula is eq. (1) on page 3 of [[CKS20]] + // + // [CKS20]: https://arxiv.org/pdf/2004.00010.pdf + let sigma = BigInt::from_biguint(Sign::Plus, sigma.clone()); + let exp_sum = |lower: &BigInt, upper: &BigInt| { + range_inclusive(lower.clone(), upper.clone()) + .map(|x: BigInt| { + f64::exp( + Ratio::<BigInt>::new(-(x.pow(2)), 2 * sigma.pow(2)) + .to_f64() + .unwrap(), + ) + }) + .sum::<f64>() + }; + // denominator is approximate up to 10 times the variance + // outside of that probabilities should be very small + // so the error will be negligible for the test + let denom = exp_sum(&(-10i8 * sigma.pow(2)), &(10i8 * sigma.pow(2))); + + // compute probabilities for each bin + let mut cdf = HashMap::new(); + let mut p_outside = 1.0; // probability of not landing inside bin boundaries + for (a, b) in bin_bounds.iter().flatten() { + let entry = exp_sum(a, b) / denom; + assert!(!entry.is_zero() && entry.is_finite()); + cdf.insert(Some((a.clone(), b.clone())), entry); + p_outside -= entry; + } + cdf.insert(None, p_outside); + cdf + } + + fn chi_square(sigma: &BigUint, n_bins: usize, alpha: f64) -> bool { + // perform pearsons chi-squared test on the discrete gaussian sampler + + let sigma_signed = BigInt::from_biguint(Sign::Plus, sigma.clone()); + + // cut off at 3 times the std. and collect all outliers in a seperate bin + let global_bound = 3u8 * sigma_signed; + + // bounds of bins + let lower_bounds = range_inclusive(-global_bound.clone(), global_bound.clone()).step_by( + ((2u8 * global_bound.clone()) / BigInt::from(n_bins)) + .try_into() + .unwrap(), + ); + let mut bin_bounds: Vec<Option<(BigInt, BigInt)>> = std::iter::zip( + lower_bounds.clone().take(n_bins), + lower_bounds.map(|x: BigInt| x - 1u8).skip(1), + ) + .map(Some) + .collect(); + bin_bounds.push(None); // bin for outliers + + // approximate bin probabilities + let cdf = discrete_gauss_cdf_approx(sigma, &bin_bounds); + + // chi2 stat wants at least 5 expected entries per bin + // so we choose n_samples in a way that gives us that + let n_samples = cdf + .values() + .map(|val| f64::ceil(5.0 / *val) as u32) + .max() + .unwrap(); + + // collect that number of samples + let mut rng = SeedStreamSha3::from_seed([0u8; 16]); + let samples: Vec<BigInt> = (1..n_samples) + .map(|_| { + sample_discrete_gaussian(&Ratio::<BigUint>::from_integer(sigma.clone()), &mut rng) + }) + .collect(); + + // make a histogram from the samples + let hist = histogram(&samples, &bin_bounds, -global_bound.clone(), global_bound); + + // compute pearsons chi-squared test statistic + let stat: f64 = bin_bounds + .iter() + .map(|key| { + let expected = cdf.get(&(key.clone())).unwrap() * n_samples as f64; + if let Some(val) = hist.get(&(key.clone())) { + (*val as f64 - expected).powf(2.) / expected + } else { + 0.0 + } + }) + .sum::<f64>(); + + let chi2 = ChiSquared::new((cdf.len() - 1) as f64).unwrap(); + // the probability of observing X >= stat for X ~ chi-squared + // (the "p-value") + let p = 1.0 - chi2.cdf(stat); + + p > alpha + } + + #[test] + fn empirical_test_gauss() { + [100, 2000, 20000].iter().for_each(|p| { + let mut rng = SeedStreamSha3::from_seed([0u8; 16]); + let sampler = || { + sample_discrete_gaussian( + &Ratio::<BigUint>::from_integer((*p).to_biguint().unwrap()), + &mut rng, + ) + }; + let mean = 0.0; + let var = (p * p) as f64; + assert!( + test_mean(sampler, mean, var, 0.00001, 1000), + "Empirical evaluation of discrete Gaussian({:?}) sampler mean failed.", + p + ); + }); + // we only do chi square for std 100 because it's expensive + assert!(chi_square(&(100u8.to_biguint().unwrap()), 10, 0.05)); + } + + #[test] + fn empirical_test_bernoulli_mean() { + [2u8, 5u8, 7u8, 9u8].iter().for_each(|p| { + let mut rng = SeedStreamSha3::from_seed([0u8; 16]); + let sampler = || { + if sample_bernoulli( + &Ratio::<BigUint>::new(BigUint::one(), (*p).into()), + &mut rng, + ) { + BigInt::one() + } else { + BigInt::zero() + } + }; + let mean = 1. / (*p as f64); + let var = mean * (1. - mean); + assert!( + test_mean(sampler, mean, var, 0.00001, 1000), + "Empirical evaluation of the Bernoulli(1/{:?}) distribution mean failed", + p + ); + }) + } + + #[test] + fn empirical_test_geometric_mean() { + [2u8, 5u8, 7u8, 9u8].iter().for_each(|p| { + let mut rng = SeedStreamSha3::from_seed([0u8; 16]); + let sampler = || { + sample_geometric_exp( + &Ratio::<BigUint>::new(BigUint::one(), (*p).into()), + &mut rng, + ) + .to_bigint() + .unwrap() + }; + let p_prob = 1. - f64::exp(-(1. / *p as f64)); + let mean = (1. - p_prob) / p_prob; + let var = (1. - p_prob) / p_prob.powi(2); + assert!( + test_mean(sampler, mean, var, 0.0001, 1000), + "Empirical evaluation of the Geometric(1-exp(-1/{:?})) distribution mean failed", + p + ); + }) + } + + #[test] + fn empirical_test_laplace_mean() { + [2u8, 5u8, 7u8, 9u8].iter().for_each(|p| { + let mut rng = SeedStreamSha3::from_seed([0u8; 16]); + let sampler = || { + sample_discrete_laplace( + &Ratio::<BigUint>::new(BigUint::one(), (*p).into()), + &mut rng, + ) + }; + let mean = 0.0; + let var = (1. / *p as f64).powi(2); + assert!( + test_mean(sampler, mean, var, 0.0001, 1000), + "Empirical evaluation of the Laplace(0,1/{:?}) distribution mean failed", + p + ); + }) + } +} diff --git a/third_party/rust/prio/src/fft.rs b/third_party/rust/prio/src/fft.rs new file mode 100644 index 0000000000..cac59a89ea --- /dev/null +++ b/third_party/rust/prio/src/fft.rs @@ -0,0 +1,222 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! This module implements an iterative FFT algorithm for computing the (inverse) Discrete Fourier +//! Transform (DFT) over a slice of field elements. + +use crate::field::FftFriendlyFieldElement; +use crate::fp::{log2, MAX_ROOTS}; + +use std::convert::TryFrom; + +/// An error returned by an FFT operation. +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +pub enum FftError { + /// The output is too small. + #[error("output slice is smaller than specified size")] + OutputTooSmall, + /// The specified size is too large. + #[error("size is larger than than maximum permitted")] + SizeTooLarge, + /// The specified size is not a power of 2. + #[error("size is not a power of 2")] + SizeInvalid, +} + +/// Sets `outp` to the DFT of `inp`. +/// +/// Interpreting the input as the coefficients of a polynomial, the output is equal to the input +/// evaluated at points `p^0, p^1, ... p^(size-1)`, where `p` is the `2^size`-th principal root of +/// unity. +#[allow(clippy::many_single_char_names)] +pub fn discrete_fourier_transform<F: FftFriendlyFieldElement>( + outp: &mut [F], + inp: &[F], + size: usize, +) -> Result<(), FftError> { + let d = usize::try_from(log2(size as u128)).map_err(|_| FftError::SizeTooLarge)?; + + if size > outp.len() { + return Err(FftError::OutputTooSmall); + } + + if size > 1 << MAX_ROOTS { + return Err(FftError::SizeTooLarge); + } + + if size != 1 << d { + return Err(FftError::SizeInvalid); + } + + for (i, outp_val) in outp[..size].iter_mut().enumerate() { + let j = bitrev(d, i); + *outp_val = if j < inp.len() { inp[j] } else { F::zero() }; + } + + let mut w: F; + for l in 1..d + 1 { + w = F::one(); + let r = F::root(l).unwrap(); + let y = 1 << (l - 1); + let chunk = (size / y) >> 1; + + // unrolling first iteration of i-loop. + for j in 0..chunk { + let x = j << l; + let u = outp[x]; + let v = outp[x + y]; + outp[x] = u + v; + outp[x + y] = u - v; + } + + for i in 1..y { + w *= r; + for j in 0..chunk { + let x = (j << l) + i; + let u = outp[x]; + let v = w * outp[x + y]; + outp[x] = u + v; + outp[x + y] = u - v; + } + } + } + + Ok(()) +} + +/// Sets `outp` to the inverse of the DFT of `inp`. +#[cfg(test)] +pub(crate) fn discrete_fourier_transform_inv<F: FftFriendlyFieldElement>( + outp: &mut [F], + inp: &[F], + size: usize, +) -> Result<(), FftError> { + let size_inv = F::from(F::Integer::try_from(size).unwrap()).inv(); + discrete_fourier_transform(outp, inp, size)?; + discrete_fourier_transform_inv_finish(outp, size, size_inv); + Ok(()) +} + +/// An intermediate step in the computation of the inverse DFT. Exposing this function allows us to +/// amortize the cost the modular inverse across multiple inverse DFT operations. +pub(crate) fn discrete_fourier_transform_inv_finish<F: FftFriendlyFieldElement>( + outp: &mut [F], + size: usize, + size_inv: F, +) { + let mut tmp: F; + outp[0] *= size_inv; + outp[size >> 1] *= size_inv; + for i in 1..size >> 1 { + tmp = outp[i] * size_inv; + outp[i] = outp[size - i] * size_inv; + outp[size - i] = tmp; + } +} + +// bitrev returns the first d bits of x in reverse order. (Thanks, OEIS! https://oeis.org/A030109) +fn bitrev(d: usize, x: usize) -> usize { + let mut y = 0; + for i in 0..d { + y += ((x >> i) & 1) << (d - i); + } + y >> 1 +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::{random_vector, split_vector, Field128, Field64, FieldElement, FieldPrio2}; + use crate::polynomial::{poly_fft, PolyAuxMemory}; + + fn discrete_fourier_transform_then_inv_test<F: FftFriendlyFieldElement>() -> Result<(), FftError> + { + let test_sizes = [1, 2, 4, 8, 16, 256, 1024, 2048]; + + for size in test_sizes.iter() { + let mut tmp = vec![F::zero(); *size]; + let mut got = vec![F::zero(); *size]; + let want = random_vector(*size).unwrap(); + + discrete_fourier_transform(&mut tmp, &want, want.len())?; + discrete_fourier_transform_inv(&mut got, &tmp, tmp.len())?; + assert_eq!(got, want); + } + + Ok(()) + } + + #[test] + fn test_priov2_field32() { + discrete_fourier_transform_then_inv_test::<FieldPrio2>().expect("unexpected error"); + } + + #[test] + fn test_field64() { + discrete_fourier_transform_then_inv_test::<Field64>().expect("unexpected error"); + } + + #[test] + fn test_field128() { + discrete_fourier_transform_then_inv_test::<Field128>().expect("unexpected error"); + } + + #[test] + fn test_recursive_fft() { + let size = 128; + let mut mem = PolyAuxMemory::new(size / 2); + + let inp = random_vector(size).unwrap(); + let mut want = vec![FieldPrio2::zero(); size]; + let mut got = vec![FieldPrio2::zero(); size]; + + discrete_fourier_transform::<FieldPrio2>(&mut want, &inp, inp.len()).unwrap(); + + poly_fft( + &mut got, + &inp, + &mem.roots_2n, + size, + false, + &mut mem.fft_memory, + ); + + assert_eq!(got, want); + } + + // This test demonstrates a consequence of \[BBG+19, Fact 4.4\]: interpolating a polynomial + // over secret shares and summing up the coefficients is equivalent to interpolating a + // polynomial over the plaintext data. + #[test] + fn test_fft_linearity() { + let len = 16; + let num_shares = 3; + let x: Vec<Field64> = random_vector(len).unwrap(); + let mut x_shares = split_vector(&x, num_shares).unwrap(); + + // Just for fun, let's do something different with a subset of the inputs. For the first + // share, every odd element is set to the plaintext value. For all shares but the first, + // every odd element is set to 0. + for (i, x_val) in x.iter().enumerate() { + if i % 2 != 0 { + x_shares[0][i] = *x_val; + for x_share in x_shares[1..num_shares].iter_mut() { + x_share[i] = Field64::zero(); + } + } + } + + let mut got = vec![Field64::zero(); len]; + let mut buf = vec![Field64::zero(); len]; + for share in x_shares { + discrete_fourier_transform_inv(&mut buf, &share, len).unwrap(); + for i in 0..len { + got[i] += buf[i]; + } + } + + let mut want = vec![Field64::zero(); len]; + discrete_fourier_transform_inv(&mut want, &x, len).unwrap(); + + assert_eq!(got, want); + } +} diff --git a/third_party/rust/prio/src/field.rs b/third_party/rust/prio/src/field.rs new file mode 100644 index 0000000000..fb931de2d3 --- /dev/null +++ b/third_party/rust/prio/src/field.rs @@ -0,0 +1,1190 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Finite field arithmetic. +//! +//! Basic field arithmetic is captured in the [`FieldElement`] trait. Fields used in Prio implement +//! [`FftFriendlyFieldElement`], and have an associated element called the "generator" that +//! generates a multiplicative subgroup of order `2^n` for some `n`. + +#[cfg(feature = "crypto-dependencies")] +use crate::prng::{Prng, PrngError}; +use crate::{ + codec::{CodecError, Decode, Encode}, + fp::{FP128, FP32, FP64}, +}; +use serde::{ + de::{DeserializeOwned, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::{ + cmp::min, + convert::{TryFrom, TryInto}, + fmt::{self, Debug, Display, Formatter}, + hash::{Hash, Hasher}, + io::{Cursor, Read}, + marker::PhantomData, + ops::{ + Add, AddAssign, BitAnd, ControlFlow, Div, DivAssign, Mul, MulAssign, Neg, Shl, Shr, Sub, + SubAssign, + }, +}; +use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq}; + +#[cfg(feature = "experimental")] +mod field255; + +#[cfg(feature = "experimental")] +pub use field255::Field255; + +/// Possible errors from finite field operations. +#[derive(Debug, thiserror::Error)] +pub enum FieldError { + /// Input sizes do not match. + #[error("input sizes do not match")] + InputSizeMismatch, + /// Returned when decoding a [`FieldElement`] from a too-short byte string. + #[error("short read from bytes")] + ShortRead, + /// Returned when decoding a [`FieldElement`] from a byte string that encodes an integer greater + /// than or equal to the field modulus. + #[error("read from byte slice exceeds modulus")] + ModulusOverflow, + /// Error while performing I/O. + #[error("I/O error")] + Io(#[from] std::io::Error), + /// Error encoding or decoding a field. + #[error("Codec error")] + Codec(#[from] CodecError), + /// Error converting to [`FieldElementWithInteger::Integer`]. + #[error("Integer TryFrom error")] + IntegerTryFrom, +} + +/// Objects with this trait represent an element of `GF(p)` for some prime `p`. +pub trait FieldElement: + Sized + + Debug + + Copy + + PartialEq + + Eq + + ConstantTimeEq + + ConditionallySelectable + + ConditionallyNegatable + + Add<Output = Self> + + AddAssign + + Sub<Output = Self> + + SubAssign + + Mul<Output = Self> + + MulAssign + + Div<Output = Self> + + DivAssign + + Neg<Output = Self> + + Display + + for<'a> TryFrom<&'a [u8], Error = FieldError> + // NOTE Ideally we would require `Into<[u8; Self::ENCODED_SIZE]>` instead of `Into<Vec<u8>>`, + // since the former avoids a heap allocation and can easily be converted into Vec<u8>, but that + // isn't possible yet[1]. However we can provide the impl on FieldElement implementations. + // [1]: https://github.com/rust-lang/rust/issues/60551 + + Into<Vec<u8>> + + Serialize + + DeserializeOwned + + Encode + + Decode + + 'static // NOTE This bound is needed for downcasting a `dyn Gadget<F>>` to a concrete type. +{ + /// Size in bytes of an encoded field element. + const ENCODED_SIZE: usize; + + /// Modular inversion, i.e., `self^-1 (mod p)`. If `self` is 0, then the output is undefined. + fn inv(&self) -> Self; + + /// Interprets the next [`Self::ENCODED_SIZE`] bytes from the input slice as an element of the + /// field. The `m` most significant bits are cleared, where `m` is equal to the length of + /// [`Self::Integer`] in bits minus the length of the modulus in bits. + /// + /// # Errors + /// + /// An error is returned if the provided slice is too small to encode a field element or if the + /// result encodes an integer larger than or equal to the field modulus. + /// + /// # Warnings + /// + /// This function should only be used within [`prng::Prng`] to convert a random byte string into + /// a field element. Use [`Self::decode`] to deserialize field elements. Use + /// [`field::rand`] or [`prng::Prng`] to randomly generate field elements. + #[doc(hidden)] + fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError>; + + /// Returns the additive identity. + fn zero() -> Self; + + /// Returns the multiplicative identity. + fn one() -> Self; + + /// Convert a slice of field elements into a vector of bytes. + /// + /// # Notes + /// + /// Ideally we would implement `From<&[F: FieldElement]> for Vec<u8>` or the corresponding + /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this + /// impossible. + fn slice_into_byte_vec(values: &[Self]) -> Vec<u8> { + let mut vec = Vec::with_capacity(values.len() * Self::ENCODED_SIZE); + encode_fieldvec(values, &mut vec); + vec + } + + /// Convert a slice of bytes into a vector of field elements. The slice is interpreted as a + /// sequence of [`Self::ENCODED_SIZE`]-byte sequences. + /// + /// # Errors + /// + /// Returns an error if the length of the provided byte slice is not a multiple of the size of a + /// field element, or if any of the values in the byte slice are invalid encodings of a field + /// element, because the encoded integer is larger than or equal to the field modulus. + /// + /// # Notes + /// + /// Ideally we would implement `From<&[u8]> for Vec<F: FieldElement>` or the corresponding + /// `Into`, but the orphan rule and the stdlib's blanket implementations of `Into` make this + /// impossible. + fn byte_slice_into_vec(bytes: &[u8]) -> Result<Vec<Self>, FieldError> { + if bytes.len() % Self::ENCODED_SIZE != 0 { + return Err(FieldError::ShortRead); + } + let mut vec = Vec::with_capacity(bytes.len() / Self::ENCODED_SIZE); + for chunk in bytes.chunks_exact(Self::ENCODED_SIZE) { + vec.push(Self::get_decoded(chunk)?); + } + Ok(vec) + } +} + +/// Extension trait for field elements that can be converted back and forth to an integer type. +/// +/// The `Integer` associated type is an integer (primitive or otherwise) that supports various +/// arithmetic operations. The order of the field is guaranteed to fit inside the range of the +/// integer type. This trait also defines methods on field elements, `pow` and `modulus`, that make +/// use of the associated integer type. +pub trait FieldElementWithInteger: FieldElement + From<Self::Integer> { + /// The error returned if converting `usize` to an `Integer` fails. + type IntegerTryFromError: std::error::Error; + + /// The error returned if converting an `Integer` to a `u64` fails. + type TryIntoU64Error: std::error::Error; + + /// The integer representation of a field element. + type Integer: Copy + + Debug + + Eq + + Ord + + BitAnd<Output = Self::Integer> + + Div<Output = Self::Integer> + + Shl<usize, Output = Self::Integer> + + Shr<usize, Output = Self::Integer> + + Add<Output = Self::Integer> + + Sub<Output = Self::Integer> + + From<Self> + + TryFrom<usize, Error = Self::IntegerTryFromError> + + TryInto<u64, Error = Self::TryIntoU64Error>; + + /// Modular exponentation, i.e., `self^exp (mod p)`. + fn pow(&self, exp: Self::Integer) -> Self; + + /// Returns the prime modulus `p`. + fn modulus() -> Self::Integer; +} + +/// Methods common to all `FieldElementWithInteger` implementations that are private to the crate. +pub(crate) trait FieldElementWithIntegerExt: FieldElementWithInteger { + /// Encode `input` as bitvector of elements of `Self`. Output is written into the `output` slice. + /// If `output.len()` is smaller than the number of bits required to respresent `input`, + /// an error is returned. + /// + /// # Arguments + /// + /// * `input` - The field element to encode + /// * `output` - The slice to write the encoded bits into. Least signicant bit comes first + fn fill_with_bitvector_representation( + input: &Self::Integer, + output: &mut [Self], + ) -> Result<(), FieldError> { + // Create a mutable copy of `input`. In each iteration of the following loop we take the + // least significant bit, and shift input to the right by one bit. + let mut i = *input; + + let one = Self::Integer::from(Self::one()); + for bit in output.iter_mut() { + let w = Self::from(i & one); + *bit = w; + i = i >> 1; + } + + // If `i` is still not zero, this means that it cannot be encoded by `bits` bits. + if i != Self::Integer::from(Self::zero()) { + return Err(FieldError::InputSizeMismatch); + } + + Ok(()) + } + + /// Encode `input` as `bits`-bit vector of elements of `Self` if it's small enough + /// to be represented with that many bits. + /// + /// # Arguments + /// + /// * `input` - The field element to encode + /// * `bits` - The number of bits to use for the encoding + fn encode_into_bitvector_representation( + input: &Self::Integer, + bits: usize, + ) -> Result<Vec<Self>, FieldError> { + let mut result = vec![Self::zero(); bits]; + Self::fill_with_bitvector_representation(input, &mut result)?; + Ok(result) + } + + /// Decode the bitvector-represented value `input` into a simple representation as a single + /// field element. + /// + /// # Errors + /// + /// This function errors if `2^input.len() - 1` does not fit into the field `Self`. + fn decode_from_bitvector_representation(input: &[Self]) -> Result<Self, FieldError> { + let fi_one = Self::Integer::from(Self::one()); + + if !Self::valid_integer_bitlength(input.len()) { + return Err(FieldError::ModulusOverflow); + } + + let mut decoded = Self::zero(); + for (l, bit) in input.iter().enumerate() { + let w = fi_one << l; + decoded += Self::from(w) * *bit; + } + Ok(decoded) + } + + /// Interpret `i` as [`Self::Integer`] if it's representable in that type and smaller than the + /// field modulus. + fn valid_integer_try_from<N>(i: N) -> Result<Self::Integer, FieldError> + where + Self::Integer: TryFrom<N>, + { + let i_int = Self::Integer::try_from(i).map_err(|_| FieldError::IntegerTryFrom)?; + if Self::modulus() <= i_int { + return Err(FieldError::ModulusOverflow); + } + Ok(i_int) + } + + /// Check if the largest number representable with `bits` bits (i.e. 2^bits - 1) is + /// representable in this field. + fn valid_integer_bitlength(bits: usize) -> bool { + if bits >= 8 * Self::ENCODED_SIZE { + return false; + } + if Self::modulus() >> bits != Self::Integer::from(Self::zero()) { + return true; + } + false + } +} + +impl<F: FieldElementWithInteger> FieldElementWithIntegerExt for F {} + +/// Methods common to all `FieldElement` implementations that are private to the crate. +pub(crate) trait FieldElementExt: FieldElement { + /// Try to interpret a slice of [`Self::ENCODED_SIZE`] random bytes as an element in the field. If + /// the input represents an integer greater than or equal to the field modulus, then + /// [`ControlFlow::Continue`] is returned instead, to indicate that an enclosing rejection sampling + /// loop should try again with different random bytes. + /// + /// # Panics + /// + /// Panics if `bytes` is not of length [`Self::ENCODED_SIZE`]. + fn from_random_rejection(bytes: &[u8]) -> ControlFlow<Self, ()> { + match Self::try_from_random(bytes) { + Ok(x) => ControlFlow::Break(x), + Err(FieldError::ModulusOverflow) => ControlFlow::Continue(()), + Err(err) => panic!("unexpected error: {err}"), + } + } +} + +impl<F: FieldElement> FieldElementExt for F {} + +/// serde Visitor implementation used to generically deserialize `FieldElement` +/// values from byte arrays. +pub(crate) struct FieldElementVisitor<F: FieldElement> { + pub(crate) phantom: PhantomData<F>, +} + +impl<'de, F: FieldElement> Visitor<'de> for FieldElementVisitor<F> { + type Value = F; + + fn expecting(&self, formatter: &mut Formatter) -> fmt::Result { + formatter.write_fmt(format_args!("an array of {} bytes", F::ENCODED_SIZE)) + } + + fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E> + where + E: serde::de::Error, + { + Self::Value::try_from(v).map_err(E::custom) + } + + fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error> + where + A: serde::de::SeqAccess<'de>, + { + let mut bytes = vec![]; + while let Some(byte) = seq.next_element()? { + bytes.push(byte); + } + + self.visit_bytes(&bytes) + } +} + +/// Objects with this trait represent an element of `GF(p)`, where `p` is some prime and the +/// field's multiplicative group has a subgroup with an order that is a power of 2, and at least +/// `2^20`. +pub trait FftFriendlyFieldElement: FieldElementWithInteger { + /// Returns the size of the multiplicative subgroup generated by + /// [`FftFriendlyFieldElement::generator`]. + fn generator_order() -> Self::Integer; + + /// Returns the generator of the multiplicative subgroup of size + /// [`FftFriendlyFieldElement::generator_order`]. + fn generator() -> Self; + + /// Returns the `2^l`-th principal root of unity for any `l <= 20`. Note that the `2^0`-th + /// prinicpal root of unity is `1` by definition. + fn root(l: usize) -> Option<Self>; +} + +macro_rules! make_field { + ( + $(#[$meta:meta])* + $elem:ident, $int:ident, $fp:ident, $encoding_size:literal, + ) => { + $(#[$meta])* + /// + /// This structure represents a field element in a prime order field. The concrete + /// representation of the element is via the Montgomery domain. For an element `n` in + /// `GF(p)`, we store `n * R^-1 mod p` (where `R` is a given power of two). This + /// representation enables using a more efficient (and branchless) multiplication algorithm, + /// at the expense of having to convert elements between their Montgomery domain + /// representation and natural representation. For calculations with many multiplications or + /// exponentiations, this is worthwhile. + /// + /// As an invariant, this integer representing the field element in the Montgomery domain + /// must be less than the field modulus, `p`. + #[derive(Clone, Copy, PartialOrd, Ord, Default)] + pub struct $elem(u128); + + impl $elem { + /// Attempts to instantiate an `$elem` from the first `Self::ENCODED_SIZE` bytes in the + /// provided slice. The decoded value will be bitwise-ANDed with `mask` before reducing + /// it using the field modulus. + /// + /// # Errors + /// + /// An error is returned if the provided slice is not long enough to encode a field + /// element or if the decoded value is greater than the field prime. + /// + /// # Notes + /// + /// We cannot use `u128::from_le_bytes` or `u128::from_be_bytes` because those functions + /// expect inputs to be exactly 16 bytes long. Our encoding of most field elements is + /// more compact. + fn try_from_bytes(bytes: &[u8], mask: u128) -> Result<Self, FieldError> { + if Self::ENCODED_SIZE > bytes.len() { + return Err(FieldError::ShortRead); + } + + let mut int = 0; + for i in 0..Self::ENCODED_SIZE { + int |= (bytes[i] as u128) << (i << 3); + } + + int &= mask; + + if int >= $fp.p { + return Err(FieldError::ModulusOverflow); + } + // FieldParameters::montgomery() will return a value that has been fully reduced + // mod p, satisfying the invariant on Self. + Ok(Self($fp.montgomery(int))) + } + } + + impl PartialEq for $elem { + fn eq(&self, rhs: &Self) -> bool { + // The fields included in this comparison MUST match the fields + // used in Hash::hash + // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq + + // Check the invariant that the integer representation is fully reduced. + debug_assert!(self.0 < $fp.p); + debug_assert!(rhs.0 < $fp.p); + + self.0 == rhs.0 + } + } + + impl ConstantTimeEq for $elem { + fn ct_eq(&self, rhs: &Self) -> Choice { + self.0.ct_eq(&rhs.0) + } + } + + impl ConditionallySelectable for $elem { + fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self { + Self(u128::conditional_select(&a.0, &b.0, choice)) + } + } + + impl Hash for $elem { + fn hash<H: Hasher>(&self, state: &mut H) { + // The fields included in this hash MUST match the fields used + // in PartialEq::eq + // https://doc.rust-lang.org/std/hash/trait.Hash.html#hash-and-eq + + // Check the invariant that the integer representation is fully reduced. + debug_assert!(self.0 < $fp.p); + + self.0.hash(state); + } + } + + impl Eq for $elem {} + + impl Add for $elem { + type Output = $elem; + fn add(self, rhs: Self) -> Self { + // FieldParameters::add() returns a value that has been fully reduced + // mod p, satisfying the invariant on Self. + Self($fp.add(self.0, rhs.0)) + } + } + + impl Add for &$elem { + type Output = $elem; + fn add(self, rhs: Self) -> $elem { + *self + *rhs + } + } + + impl AddAssign for $elem { + fn add_assign(&mut self, rhs: Self) { + *self = *self + rhs; + } + } + + impl Sub for $elem { + type Output = $elem; + fn sub(self, rhs: Self) -> Self { + // We know that self.0 and rhs.0 are both less than p, thus FieldParameters::sub() + // returns a value less than p, satisfying the invariant on Self. + Self($fp.sub(self.0, rhs.0)) + } + } + + impl Sub for &$elem { + type Output = $elem; + fn sub(self, rhs: Self) -> $elem { + *self - *rhs + } + } + + impl SubAssign for $elem { + fn sub_assign(&mut self, rhs: Self) { + *self = *self - rhs; + } + } + + impl Mul for $elem { + type Output = $elem; + fn mul(self, rhs: Self) -> Self { + // FieldParameters::mul() always returns a value less than p, so the invariant on + // Self is satisfied. + Self($fp.mul(self.0, rhs.0)) + } + } + + impl Mul for &$elem { + type Output = $elem; + fn mul(self, rhs: Self) -> $elem { + *self * *rhs + } + } + + impl MulAssign for $elem { + fn mul_assign(&mut self, rhs: Self) { + *self = *self * rhs; + } + } + + impl Div for $elem { + type Output = $elem; + #[allow(clippy::suspicious_arithmetic_impl)] + fn div(self, rhs: Self) -> Self { + self * rhs.inv() + } + } + + impl Div for &$elem { + type Output = $elem; + fn div(self, rhs: Self) -> $elem { + *self / *rhs + } + } + + impl DivAssign for $elem { + fn div_assign(&mut self, rhs: Self) { + *self = *self / rhs; + } + } + + impl Neg for $elem { + type Output = $elem; + fn neg(self) -> Self { + // FieldParameters::neg() will return a value less than p because self.0 is less + // than p, and neg() dispatches to sub(). + Self($fp.neg(self.0)) + } + } + + impl Neg for &$elem { + type Output = $elem; + fn neg(self) -> $elem { + -(*self) + } + } + + impl From<$int> for $elem { + fn from(x: $int) -> Self { + // FieldParameters::montgomery() will return a value that has been fully reduced + // mod p, satisfying the invariant on Self. + Self($fp.montgomery(u128::try_from(x).unwrap())) + } + } + + impl From<$elem> for $int { + fn from(x: $elem) -> Self { + $int::try_from($fp.residue(x.0)).unwrap() + } + } + + impl PartialEq<$int> for $elem { + fn eq(&self, rhs: &$int) -> bool { + $fp.residue(self.0) == u128::try_from(*rhs).unwrap() + } + } + + impl<'a> TryFrom<&'a [u8]> for $elem { + type Error = FieldError; + + fn try_from(bytes: &[u8]) -> Result<Self, FieldError> { + Self::try_from_bytes(bytes, u128::MAX) + } + } + + impl From<$elem> for [u8; $elem::ENCODED_SIZE] { + fn from(elem: $elem) -> Self { + let int = $fp.residue(elem.0); + let mut slice = [0; $elem::ENCODED_SIZE]; + for i in 0..$elem::ENCODED_SIZE { + slice[i] = ((int >> (i << 3)) & 0xff) as u8; + } + slice + } + } + + impl From<$elem> for Vec<u8> { + fn from(elem: $elem) -> Self { + <[u8; $elem::ENCODED_SIZE]>::from(elem).to_vec() + } + } + + impl Display for $elem { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "{}", $fp.residue(self.0)) + } + } + + impl Debug for $elem { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", $fp.residue(self.0)) + } + } + + // We provide custom [`serde::Serialize`] and [`serde::Deserialize`] implementations because + // the derived implementations would represent `FieldElement` values as the backing `u128`, + // which is not what we want because (1) we can be more efficient in all cases and (2) in + // some circumstances, [some serializers don't support `u128`](https://github.com/serde-rs/json/issues/625). + impl Serialize for $elem { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + let bytes: [u8; $elem::ENCODED_SIZE] = (*self).into(); + serializer.serialize_bytes(&bytes) + } + } + + impl<'de> Deserialize<'de> for $elem { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<$elem, D::Error> { + deserializer.deserialize_bytes(FieldElementVisitor { phantom: PhantomData }) + } + } + + impl Encode for $elem { + fn encode(&self, bytes: &mut Vec<u8>) { + let slice = <[u8; $elem::ENCODED_SIZE]>::from(*self); + bytes.extend_from_slice(&slice); + } + + fn encoded_len(&self) -> Option<usize> { + Some(Self::ENCODED_SIZE) + } + } + + impl Decode for $elem { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let mut value = [0u8; $elem::ENCODED_SIZE]; + bytes.read_exact(&mut value)?; + $elem::try_from_bytes(&value, u128::MAX).map_err(|e| { + CodecError::Other(Box::new(e) as Box<dyn std::error::Error + 'static + Send + Sync>) + }) + } + } + + impl FieldElement for $elem { + const ENCODED_SIZE: usize = $encoding_size; + fn inv(&self) -> Self { + // FieldParameters::inv() ultimately relies on mul(), and will always return a + // value less than p. + Self($fp.inv(self.0)) + } + + fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError> { + $elem::try_from_bytes(bytes, $fp.bit_mask) + } + + fn zero() -> Self { + Self(0) + } + + fn one() -> Self { + Self($fp.roots[0]) + } + } + + impl FieldElementWithInteger for $elem { + type Integer = $int; + type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error; + type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error; + + fn pow(&self, exp: Self::Integer) -> Self { + // FieldParameters::pow() relies on mul(), and will always return a value less + // than p. + Self($fp.pow(self.0, u128::try_from(exp).unwrap())) + } + + fn modulus() -> Self::Integer { + $fp.p as $int + } + } + + impl FftFriendlyFieldElement for $elem { + fn generator() -> Self { + Self($fp.g) + } + + fn generator_order() -> Self::Integer { + 1 << (Self::Integer::try_from($fp.num_roots).unwrap()) + } + + fn root(l: usize) -> Option<Self> { + if l < min($fp.roots.len(), $fp.num_roots+1) { + Some(Self($fp.roots[l])) + } else { + None + } + } + } + }; +} + +make_field!( + /// Same as Field32, but encoded in little endian for compatibility with Prio v2. + FieldPrio2, + u32, + FP32, + 4, +); + +make_field!( + /// `GF(18446744069414584321)`, a 64-bit field. + Field64, + u64, + FP64, + 8, +); + +make_field!( + /// `GF(340282366920938462946865773367900766209)`, a 128-bit field. + Field128, + u128, + FP128, + 16, +); + +/// Merge two vectors of fields by summing other_vector into accumulator. +/// +/// # Errors +/// +/// Fails if the two vectors do not have the same length. +pub(crate) fn merge_vector<F: FieldElement>( + accumulator: &mut [F], + other_vector: &[F], +) -> Result<(), FieldError> { + if accumulator.len() != other_vector.len() { + return Err(FieldError::InputSizeMismatch); + } + for (a, o) in accumulator.iter_mut().zip(other_vector.iter()) { + *a += *o; + } + + Ok(()) +} + +/// Outputs an additive secret sharing of the input. +#[cfg(all(feature = "crypto-dependencies", test))] +pub(crate) fn split_vector<F: FieldElement>( + inp: &[F], + num_shares: usize, +) -> Result<Vec<Vec<F>>, PrngError> { + if num_shares == 0 { + return Ok(vec![]); + } + + let mut outp = Vec::with_capacity(num_shares); + outp.push(inp.to_vec()); + + for _ in 1..num_shares { + let share: Vec<F> = random_vector(inp.len())?; + for (x, y) in outp[0].iter_mut().zip(&share) { + *x -= *y; + } + outp.push(share); + } + + Ok(outp) +} + +/// Generate a vector of uniformly distributed random field elements. +#[cfg(feature = "crypto-dependencies")] +#[cfg_attr(docsrs, doc(cfg(feature = "crypto-dependencies")))] +pub fn random_vector<F: FieldElement>(len: usize) -> Result<Vec<F>, PrngError> { + Ok(Prng::new()?.take(len).collect()) +} + +/// `encode_fieldvec` serializes a type that is equivalent to a vector of field elements. +#[inline(always)] +pub(crate) fn encode_fieldvec<F: FieldElement, T: AsRef<[F]>>(val: T, bytes: &mut Vec<u8>) { + for elem in val.as_ref() { + elem.encode(bytes); + } +} + +/// `decode_fieldvec` deserializes some number of field elements from a cursor, and advances the +/// cursor's position. +pub(crate) fn decode_fieldvec<F: FieldElement>( + count: usize, + input: &mut Cursor<&[u8]>, +) -> Result<Vec<F>, CodecError> { + let mut vec = Vec::with_capacity(count); + let mut buffer = [0u8; 64]; + assert!( + buffer.len() >= F::ENCODED_SIZE, + "field is too big for buffer" + ); + for _ in 0..count { + input.read_exact(&mut buffer[..F::ENCODED_SIZE])?; + vec.push( + F::try_from(&buffer[..F::ENCODED_SIZE]).map_err(|e| CodecError::Other(Box::new(e)))?, + ); + } + Ok(vec) +} + +#[cfg(test)] +pub(crate) mod test_utils { + use super::{FieldElement, FieldElementWithInteger}; + use crate::{codec::CodecError, field::FieldError, prng::Prng}; + use assert_matches::assert_matches; + use std::{ + collections::hash_map::DefaultHasher, + convert::{TryFrom, TryInto}, + fmt::Debug, + hash::{Hash, Hasher}, + io::Cursor, + ops::{Add, BitAnd, Div, Shl, Shr, Sub}, + }; + + /// A test-only copy of `FieldElementWithInteger`. + /// + /// This trait is only used in tests, and it is implemented on some fields that do not have + /// `FieldElementWithInteger` implementations. This separate trait is used in order to avoid + /// affecting trait resolution with conditional compilation. Additionally, this trait only + /// requires the `Integer` associated type satisfy `Clone`, not `Copy`, so that it may be used + /// with arbitrary precision integer implementations. + pub(crate) trait TestFieldElementWithInteger: + FieldElement + From<Self::Integer> + { + type IntegerTryFromError: std::error::Error; + type TryIntoU64Error: std::error::Error; + type Integer: Clone + + Debug + + Eq + + Ord + + BitAnd<Output = Self::Integer> + + Div<Output = Self::Integer> + + Shl<usize, Output = Self::Integer> + + Shr<usize, Output = Self::Integer> + + Add<Output = Self::Integer> + + Sub<Output = Self::Integer> + + From<Self> + + TryFrom<usize, Error = Self::IntegerTryFromError> + + TryInto<u64, Error = Self::TryIntoU64Error>; + + fn pow(&self, exp: Self::Integer) -> Self; + + fn modulus() -> Self::Integer; + } + + impl<F> TestFieldElementWithInteger for F + where + F: FieldElementWithInteger, + { + type IntegerTryFromError = <F as FieldElementWithInteger>::IntegerTryFromError; + type TryIntoU64Error = <F as FieldElementWithInteger>::TryIntoU64Error; + type Integer = <F as FieldElementWithInteger>::Integer; + + fn pow(&self, exp: Self::Integer) -> Self { + <F as FieldElementWithInteger>::pow(self, exp) + } + + fn modulus() -> Self::Integer { + <F as FieldElementWithInteger>::modulus() + } + } + + pub(crate) fn field_element_test_common<F: TestFieldElementWithInteger>() { + let mut prng: Prng<F, _> = Prng::new().unwrap(); + let int_modulus = F::modulus(); + let int_one = F::Integer::try_from(1).unwrap(); + let zero = F::zero(); + let one = F::one(); + let two = F::from(F::Integer::try_from(2).unwrap()); + let four = F::from(F::Integer::try_from(4).unwrap()); + + // add + assert_eq!(F::from(int_modulus.clone() - int_one.clone()) + one, zero); + assert_eq!(one + one, two); + assert_eq!(two + F::from(int_modulus.clone()), two); + + // add w/ assignment + let mut a = prng.get(); + let b = prng.get(); + let c = a + b; + a += b; + assert_eq!(a, c); + + // sub + assert_eq!(zero - one, F::from(int_modulus.clone() - int_one.clone())); + #[allow(clippy::eq_op)] + { + assert_eq!(one - one, zero); + } + assert_eq!(one + (-one), zero); + assert_eq!(two - F::from(int_modulus.clone()), two); + assert_eq!(one - F::from(int_modulus.clone() - int_one.clone()), two); + + // sub w/ assignment + let mut a = prng.get(); + let b = prng.get(); + let c = a - b; + a -= b; + assert_eq!(a, c); + + // add + sub + for _ in 0..100 { + let f = prng.get(); + let g = prng.get(); + assert_eq!(f + g - f - g, zero); + assert_eq!(f + g - g, f); + assert_eq!(f + g - f, g); + } + + // mul + assert_eq!(two * two, four); + assert_eq!(two * one, two); + assert_eq!(two * zero, zero); + assert_eq!(one * F::from(int_modulus.clone()), zero); + + // mul w/ assignment + let mut a = prng.get(); + let b = prng.get(); + let c = a * b; + a *= b; + assert_eq!(a, c); + + // integer conversion + assert_eq!(F::Integer::from(zero), F::Integer::try_from(0).unwrap()); + assert_eq!(F::Integer::from(one), F::Integer::try_from(1).unwrap()); + assert_eq!(F::Integer::from(two), F::Integer::try_from(2).unwrap()); + assert_eq!(F::Integer::from(four), F::Integer::try_from(4).unwrap()); + + // serialization + let test_inputs = vec![ + zero, + one, + prng.get(), + F::from(int_modulus.clone() - int_one.clone()), + ]; + for want in test_inputs.iter() { + let mut bytes = vec![]; + want.encode(&mut bytes); + + assert_eq!(bytes.len(), F::ENCODED_SIZE); + assert_eq!(want.encoded_len().unwrap(), F::ENCODED_SIZE); + + let got = F::get_decoded(&bytes).unwrap(); + assert_eq!(got, *want); + } + + let serialized_vec = F::slice_into_byte_vec(&test_inputs); + let deserialized = F::byte_slice_into_vec(&serialized_vec).unwrap(); + assert_eq!(deserialized, test_inputs); + + let test_input = prng.get(); + let json = serde_json::to_string(&test_input).unwrap(); + let deserialized = serde_json::from_str::<F>(&json).unwrap(); + assert_eq!(deserialized, test_input); + + let value = serde_json::from_str::<serde_json::Value>(&json).unwrap(); + let array = value.as_array().unwrap(); + for element in array { + element.as_u64().unwrap(); + } + + let err = F::byte_slice_into_vec(&[0]).unwrap_err(); + assert_matches!(err, FieldError::ShortRead); + + let err = F::byte_slice_into_vec(&vec![0xffu8; F::ENCODED_SIZE]).unwrap_err(); + assert_matches!(err, FieldError::Codec(CodecError::Other(err)) => { + assert_matches!(err.downcast_ref::<FieldError>(), Some(FieldError::ModulusOverflow)); + }); + + let insufficient = vec![0u8; F::ENCODED_SIZE - 1]; + let err = F::try_from(insufficient.as_ref()).unwrap_err(); + assert_matches!(err, FieldError::ShortRead); + let err = F::decode(&mut Cursor::new(&insufficient)).unwrap_err(); + assert_matches!(err, CodecError::Io(_)); + + let err = F::decode(&mut Cursor::new(&vec![0xffu8; F::ENCODED_SIZE])).unwrap_err(); + assert_matches!(err, CodecError::Other(err) => { + assert_matches!(err.downcast_ref::<FieldError>(), Some(FieldError::ModulusOverflow)); + }); + + // equality and hash: Generate many elements, confirm they are not equal, and confirm + // various products that should be equal have the same hash. Three is chosen as a generator + // here because it happens to generate fairly large subgroups of (Z/pZ)* for all four + // primes. + let three = F::from(F::Integer::try_from(3).unwrap()); + let mut powers_of_three = Vec::with_capacity(500); + let mut power = one; + for _ in 0..500 { + powers_of_three.push(power); + power *= three; + } + // Check all these elements are mutually not equal. + for i in 0..powers_of_three.len() { + let first = &powers_of_three[i]; + for second in &powers_of_three[0..i] { + assert_ne!(first, second); + } + } + + // Construct an element from a number that needs to be reduced, and test comparisons on it, + // confirming that it is reduced correctly. + let p = F::from(int_modulus.clone()); + assert_eq!(p, zero); + let p_plus_one = F::from(int_modulus + int_one); + assert_eq!(p_plus_one, one); + } + + pub(super) fn hash_helper<H: Hash>(input: H) -> u64 { + let mut hasher = DefaultHasher::new(); + input.hash(&mut hasher); + hasher.finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::test_utils::{field_element_test_common, hash_helper}; + use crate::fp::MAX_ROOTS; + use crate::prng::Prng; + use assert_matches::assert_matches; + + #[test] + fn test_accumulate() { + let mut lhs = vec![FieldPrio2(1); 10]; + let rhs = vec![FieldPrio2(2); 10]; + + merge_vector(&mut lhs, &rhs).unwrap(); + + lhs.iter().for_each(|f| assert_eq!(*f, FieldPrio2(3))); + rhs.iter().for_each(|f| assert_eq!(*f, FieldPrio2(2))); + + let wrong_len = vec![FieldPrio2::zero(); 9]; + let result = merge_vector(&mut lhs, &wrong_len); + assert_matches!(result, Err(FieldError::InputSizeMismatch)); + } + + fn field_element_test<F: FftFriendlyFieldElement + Hash>() { + field_element_test_common::<F>(); + + let mut prng: Prng<F, _> = Prng::new().unwrap(); + let int_modulus = F::modulus(); + let int_one = F::Integer::try_from(1).unwrap(); + let zero = F::zero(); + let one = F::one(); + let two = F::from(F::Integer::try_from(2).unwrap()); + let four = F::from(F::Integer::try_from(4).unwrap()); + + // div + assert_eq!(four / two, two); + #[allow(clippy::eq_op)] + { + assert_eq!(two / two, one); + } + assert_eq!(zero / two, zero); + assert_eq!(two / zero, zero); // Undefined behavior + assert_eq!(zero.inv(), zero); // Undefined behavior + + // div w/ assignment + let mut a = prng.get(); + let b = prng.get(); + let c = a / b; + a /= b; + assert_eq!(a, c); + assert_eq!(hash_helper(a), hash_helper(c)); + + // mul + div + for _ in 0..100 { + let f = prng.get(); + if f == zero { + continue; + } + assert_eq!(f * f.inv(), one); + assert_eq!(f.inv() * f, one); + } + + // pow + assert_eq!(two.pow(F::Integer::try_from(0).unwrap()), one); + assert_eq!(two.pow(int_one), two); + assert_eq!(two.pow(F::Integer::try_from(2).unwrap()), four); + assert_eq!(two.pow(int_modulus - int_one), one); + assert_eq!(two.pow(int_modulus), two); + + // roots + let mut int_order = F::generator_order(); + for l in 0..MAX_ROOTS + 1 { + assert_eq!( + F::generator().pow(int_order), + F::root(l).unwrap(), + "failure for F::root({l})" + ); + int_order = int_order >> 1; + } + + // formatting + assert_eq!(format!("{zero}"), "0"); + assert_eq!(format!("{one}"), "1"); + assert_eq!(format!("{zero:?}"), "0"); + assert_eq!(format!("{one:?}"), "1"); + + let three = F::from(F::Integer::try_from(3).unwrap()); + let mut powers_of_three = Vec::with_capacity(500); + let mut power = one; + for _ in 0..500 { + powers_of_three.push(power); + power *= three; + } + + // Check that 3^i is the same whether it's calculated with pow() or repeated + // multiplication, with both equality and hash equality. + for (i, power) in powers_of_three.iter().enumerate() { + let result = three.pow(F::Integer::try_from(i).unwrap()); + assert_eq!(result, *power); + let hash1 = hash_helper(power); + let hash2 = hash_helper(result); + assert_eq!(hash1, hash2); + } + + // Check that 3^n = (3^i)*(3^(n-i)), via both equality and hash equality. + let expected_product = powers_of_three[powers_of_three.len() - 1]; + let expected_hash = hash_helper(expected_product); + for i in 0..powers_of_three.len() { + let a = powers_of_three[i]; + let b = powers_of_three[powers_of_three.len() - 1 - i]; + let product = a * b; + assert_eq!(product, expected_product); + assert_eq!(hash_helper(product), expected_hash); + } + } + + #[test] + fn test_field_prio2() { + field_element_test::<FieldPrio2>(); + } + + #[test] + fn test_field64() { + field_element_test::<Field64>(); + } + + #[test] + fn test_field128() { + field_element_test::<Field128>(); + } + + #[test] + fn test_encode_into_bitvector() { + let zero = Field128::zero(); + let one = Field128::one(); + let zero_enc = Field128::encode_into_bitvector_representation(&0, 4).unwrap(); + let one_enc = Field128::encode_into_bitvector_representation(&1, 4).unwrap(); + let fifteen_enc = Field128::encode_into_bitvector_representation(&15, 4).unwrap(); + assert_eq!(zero_enc, [zero; 4]); + assert_eq!(one_enc, [one, zero, zero, zero]); + assert_eq!(fifteen_enc, [one; 4]); + Field128::encode_into_bitvector_representation(&16, 4).unwrap_err(); + } + + #[test] + fn test_fill_bitvector() { + let zero = Field128::zero(); + let one = Field128::one(); + let mut output: Vec<Field128> = vec![zero; 6]; + Field128::fill_with_bitvector_representation(&9, &mut output[1..5]).unwrap(); + assert_eq!(output, [zero, one, zero, zero, one, zero]); + Field128::fill_with_bitvector_representation(&16, &mut output[1..5]).unwrap_err(); + } +} diff --git a/third_party/rust/prio/src/field/field255.rs b/third_party/rust/prio/src/field/field255.rs new file mode 100644 index 0000000000..fd06a6334a --- /dev/null +++ b/third_party/rust/prio/src/field/field255.rs @@ -0,0 +1,543 @@ +// Copyright (c) 2023 ISRG +// SPDX-License-Identifier: MPL-2.0 + +//! Finite field arithmetic for `GF(2^255 - 19)`. + +use crate::{ + codec::{CodecError, Decode, Encode}, + field::{FieldElement, FieldElementVisitor, FieldError}, +}; +use fiat_crypto::curve25519_64::{ + fiat_25519_add, fiat_25519_carry, fiat_25519_carry_mul, fiat_25519_from_bytes, + fiat_25519_loose_field_element, fiat_25519_opp, fiat_25519_relax, fiat_25519_selectznz, + fiat_25519_sub, fiat_25519_tight_field_element, fiat_25519_to_bytes, +}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::{ + convert::TryFrom, + fmt::{self, Debug, Display, Formatter}, + io::{Cursor, Read}, + marker::PhantomData, + mem::size_of, + ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; +use subtle::{ + Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeGreater, ConstantTimeLess, +}; + +// `python3 -c "print(', '.join(hex(x) for x in (2**255-19).to_bytes(32, 'little')))"` +const MODULUS_LITTLE_ENDIAN: [u8; 32] = [ + 0xed, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, +]; + +/// `GF(2^255 - 19)`, a 255-bit field. +#[derive(Clone, Copy)] +#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))] +pub struct Field255(fiat_25519_tight_field_element); + +impl Field255 { + /// Attempts to instantiate a `Field255` from the first `Self::ENCODED_SIZE` bytes in the + /// provided slice. + /// + /// # Errors + /// + /// An error is returned if the provided slice is not long enough to encode a field element or + /// if the decoded value is greater than the field prime. + fn try_from_bytes(bytes: &[u8], mask_top_bit: bool) -> Result<Self, FieldError> { + if Self::ENCODED_SIZE > bytes.len() { + return Err(FieldError::ShortRead); + } + + let mut value = [0u8; Self::ENCODED_SIZE]; + value.copy_from_slice(&bytes[..Self::ENCODED_SIZE]); + + if mask_top_bit { + value[31] &= 0b0111_1111; + } + + // Walk through the bytes of the provided value from most significant to least, + // and identify whether the first byte that differs from the field's modulus is less than + // the corresponding byte or greater than the corresponding byte, in constant time. (Or + // whether the provided value is equal to the field modulus.) + let mut less_than_modulus = Choice::from(0u8); + let mut greater_than_modulus = Choice::from(0u8); + for (value_byte, modulus_byte) in value.iter().rev().zip(MODULUS_LITTLE_ENDIAN.iter().rev()) + { + less_than_modulus |= value_byte.ct_lt(modulus_byte) & !greater_than_modulus; + greater_than_modulus |= value_byte.ct_gt(modulus_byte) & !less_than_modulus; + } + + if bool::from(!less_than_modulus) { + return Err(FieldError::ModulusOverflow); + } + + let mut output = fiat_25519_tight_field_element([0; 5]); + fiat_25519_from_bytes(&mut output, &value); + + Ok(Field255(output)) + } +} + +impl ConstantTimeEq for Field255 { + fn ct_eq(&self, rhs: &Self) -> Choice { + // The internal representation used by fiat-crypto is not 1-1 with the field, so it is + // necessary to compare field elements via their canonical encodings. + + let mut self_encoded = [0; 32]; + fiat_25519_to_bytes(&mut self_encoded, &self.0); + let mut rhs_encoded = [0; 32]; + fiat_25519_to_bytes(&mut rhs_encoded, &rhs.0); + + self_encoded.ct_eq(&rhs_encoded) + } +} + +impl ConditionallySelectable for Field255 { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + let mut output = [0; 5]; + fiat_25519_selectznz(&mut output, choice.unwrap_u8(), &(a.0).0, &(b.0).0); + Field255(fiat_25519_tight_field_element(output)) + } +} + +impl PartialEq for Field255 { + fn eq(&self, rhs: &Self) -> bool { + self.ct_eq(rhs).into() + } +} + +impl Eq for Field255 {} + +impl Add for Field255 { + type Output = Field255; + + fn add(self, rhs: Self) -> Field255 { + let mut loose_output = fiat_25519_loose_field_element([0; 5]); + fiat_25519_add(&mut loose_output, &self.0, &rhs.0); + let mut output = fiat_25519_tight_field_element([0; 5]); + fiat_25519_carry(&mut output, &loose_output); + Field255(output) + } +} + +impl AddAssign for Field255 { + fn add_assign(&mut self, rhs: Self) { + let mut loose_output = fiat_25519_loose_field_element([0; 5]); + fiat_25519_add(&mut loose_output, &self.0, &rhs.0); + fiat_25519_carry(&mut self.0, &loose_output); + } +} + +impl Sub for Field255 { + type Output = Field255; + + fn sub(self, rhs: Self) -> Field255 { + let mut loose_output = fiat_25519_loose_field_element([0; 5]); + fiat_25519_sub(&mut loose_output, &self.0, &rhs.0); + let mut output = fiat_25519_tight_field_element([0; 5]); + fiat_25519_carry(&mut output, &loose_output); + Field255(output) + } +} + +impl SubAssign for Field255 { + fn sub_assign(&mut self, rhs: Self) { + let mut loose_output = fiat_25519_loose_field_element([0; 5]); + fiat_25519_sub(&mut loose_output, &self.0, &rhs.0); + fiat_25519_carry(&mut self.0, &loose_output); + } +} + +impl Mul for Field255 { + type Output = Field255; + + fn mul(self, rhs: Self) -> Field255 { + let mut self_relaxed = fiat_25519_loose_field_element([0; 5]); + fiat_25519_relax(&mut self_relaxed, &self.0); + let mut rhs_relaxed = fiat_25519_loose_field_element([0; 5]); + fiat_25519_relax(&mut rhs_relaxed, &rhs.0); + let mut output = fiat_25519_tight_field_element([0; 5]); + fiat_25519_carry_mul(&mut output, &self_relaxed, &rhs_relaxed); + Field255(output) + } +} + +impl MulAssign for Field255 { + fn mul_assign(&mut self, rhs: Self) { + let mut self_relaxed = fiat_25519_loose_field_element([0; 5]); + fiat_25519_relax(&mut self_relaxed, &self.0); + let mut rhs_relaxed = fiat_25519_loose_field_element([0; 5]); + fiat_25519_relax(&mut rhs_relaxed, &rhs.0); + fiat_25519_carry_mul(&mut self.0, &self_relaxed, &rhs_relaxed); + } +} + +impl Div for Field255 { + type Output = Field255; + + fn div(self, _rhs: Self) -> Self::Output { + unimplemented!("Div is not implemented for Field255 because it's not needed yet") + } +} + +impl DivAssign for Field255 { + fn div_assign(&mut self, _rhs: Self) { + unimplemented!("DivAssign is not implemented for Field255 because it's not needed yet") + } +} + +impl Neg for Field255 { + type Output = Field255; + + fn neg(self) -> Field255 { + -&self + } +} + +impl<'a> Neg for &'a Field255 { + type Output = Field255; + + fn neg(self) -> Field255 { + let mut loose_output = fiat_25519_loose_field_element([0; 5]); + fiat_25519_opp(&mut loose_output, &self.0); + let mut output = fiat_25519_tight_field_element([0; 5]); + fiat_25519_carry(&mut output, &loose_output); + Field255(output) + } +} + +impl From<u64> for Field255 { + fn from(value: u64) -> Self { + let input_bytes = value.to_le_bytes(); + let mut field_bytes = [0u8; Self::ENCODED_SIZE]; + field_bytes[..input_bytes.len()].copy_from_slice(&input_bytes); + Self::try_from_bytes(&field_bytes, false).unwrap() + } +} + +impl<'a> TryFrom<&'a [u8]> for Field255 { + type Error = FieldError; + + fn try_from(bytes: &[u8]) -> Result<Self, FieldError> { + Self::try_from_bytes(bytes, false) + } +} + +impl From<Field255> for [u8; Field255::ENCODED_SIZE] { + fn from(element: Field255) -> Self { + let mut array = [0; Field255::ENCODED_SIZE]; + fiat_25519_to_bytes(&mut array, &element.0); + array + } +} + +impl From<Field255> for Vec<u8> { + fn from(elem: Field255) -> Vec<u8> { + <[u8; Field255::ENCODED_SIZE]>::from(elem).to_vec() + } +} + +impl Display for Field255 { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + let encoded: [u8; Self::ENCODED_SIZE] = (*self).into(); + write!(f, "0x")?; + for byte in encoded { + write!(f, "{byte:02x}")?; + } + Ok(()) + } +} + +impl Debug for Field255 { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + <Self as Display>::fmt(self, f) + } +} + +impl Serialize for Field255 { + fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { + let bytes: [u8; Self::ENCODED_SIZE] = (*self).into(); + serializer.serialize_bytes(&bytes) + } +} + +impl<'de> Deserialize<'de> for Field255 { + fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Field255, D::Error> { + deserializer.deserialize_bytes(FieldElementVisitor { + phantom: PhantomData, + }) + } +} + +impl Encode for Field255 { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&<[u8; Self::ENCODED_SIZE]>::from(*self)); + } + + fn encoded_len(&self) -> Option<usize> { + Some(Self::ENCODED_SIZE) + } +} + +impl Decode for Field255 { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let mut value = [0u8; Self::ENCODED_SIZE]; + bytes.read_exact(&mut value)?; + Field255::try_from_bytes(&value, false).map_err(|e| { + CodecError::Other(Box::new(e) as Box<dyn std::error::Error + 'static + Send + Sync>) + }) + } +} + +impl FieldElement for Field255 { + const ENCODED_SIZE: usize = 32; + + fn inv(&self) -> Self { + unimplemented!("Field255::inv() is not implemented because it's not needed yet") + } + + fn try_from_random(bytes: &[u8]) -> Result<Self, FieldError> { + Field255::try_from_bytes(bytes, true) + } + + fn zero() -> Self { + Field255(fiat_25519_tight_field_element([0, 0, 0, 0, 0])) + } + + fn one() -> Self { + Field255(fiat_25519_tight_field_element([1, 0, 0, 0, 0])) + } +} + +impl Default for Field255 { + fn default() -> Self { + Field255::zero() + } +} + +impl TryFrom<Field255> for u64 { + type Error = FieldError; + + fn try_from(elem: Field255) -> Result<u64, FieldError> { + const PREFIX_LEN: usize = size_of::<u64>(); + let mut le_bytes = [0; 32]; + + fiat_25519_to_bytes(&mut le_bytes, &elem.0); + if !bool::from(le_bytes[PREFIX_LEN..].ct_eq(&[0_u8; 32 - PREFIX_LEN])) { + return Err(FieldError::IntegerTryFrom); + } + + Ok(u64::from_le_bytes( + le_bytes[..PREFIX_LEN].try_into().unwrap(), + )) + } +} + +#[cfg(test)] +mod tests { + use super::{Field255, MODULUS_LITTLE_ENDIAN}; + use crate::{ + codec::Encode, + field::{ + test_utils::{field_element_test_common, TestFieldElementWithInteger}, + FieldElement, FieldError, + }, + }; + use assert_matches::assert_matches; + use fiat_crypto::curve25519_64::{ + fiat_25519_from_bytes, fiat_25519_tight_field_element, fiat_25519_to_bytes, + }; + use num_bigint::BigUint; + use once_cell::sync::Lazy; + use std::convert::{TryFrom, TryInto}; + + static MODULUS: Lazy<BigUint> = Lazy::new(|| BigUint::from_bytes_le(&MODULUS_LITTLE_ENDIAN)); + + impl From<BigUint> for Field255 { + fn from(value: BigUint) -> Self { + let le_bytes_vec = (value % &*MODULUS).to_bytes_le(); + + let mut le_bytes_array = [0u8; 32]; + le_bytes_array[..le_bytes_vec.len()].copy_from_slice(&le_bytes_vec); + + let mut output = fiat_25519_tight_field_element([0; 5]); + fiat_25519_from_bytes(&mut output, &le_bytes_array); + Field255(output) + } + } + + impl From<Field255> for BigUint { + fn from(value: Field255) -> Self { + let mut le_bytes = [0u8; 32]; + fiat_25519_to_bytes(&mut le_bytes, &value.0); + BigUint::from_bytes_le(&le_bytes) + } + } + + impl TestFieldElementWithInteger for Field255 { + type Integer = BigUint; + type IntegerTryFromError = <Self::Integer as TryFrom<usize>>::Error; + type TryIntoU64Error = <Self::Integer as TryInto<u64>>::Error; + + fn pow(&self, _exp: Self::Integer) -> Self { + unimplemented!("Field255::pow() is not implemented because it's not needed yet") + } + + fn modulus() -> Self::Integer { + MODULUS.clone() + } + } + + #[test] + fn check_modulus() { + let modulus = Field255::modulus(); + let element = Field255::from(modulus); + // Note that these two objects represent the same field element, they encode to the same + // canonical value (32 zero bytes), but they do not have the same internal representation. + assert_eq!(element, Field255::zero()); + } + + #[test] + fn check_identities() { + let zero_bytes: [u8; 32] = Field255::zero().into(); + assert_eq!(zero_bytes, [0; 32]); + let one_bytes: [u8; 32] = Field255::one().into(); + assert_eq!( + one_bytes, + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ] + ); + } + + #[test] + fn encode_endianness() { + let mut one_encoded = Vec::new(); + Field255::one().encode(&mut one_encoded); + assert_eq!( + one_encoded, + [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ] + ); + } + + #[test] + fn test_field255() { + field_element_test_common::<Field255>(); + } + + #[test] + fn try_from_bytes() { + assert_matches!( + Field255::try_from_bytes( + &[ + 0xed, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + ], + false, + ), + Err(FieldError::ModulusOverflow) + ); + assert_matches!( + Field255::try_from_bytes( + &[ + 0xee, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe, + 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + ], + false, + ), + Ok(_) + ); + assert_matches!( + Field255::try_from_bytes( + &[ + 0xec, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f, + ], + true, + ), + Ok(element) => assert_eq!(element + Field255::one(), Field255::zero()) + ); + assert_matches!( + Field255::try_from_bytes( + &[ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, + ], + false + ), + Err(FieldError::ModulusOverflow) + ); + assert_matches!( + Field255::try_from_bytes( + &[ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, + ], + true + ), + Ok(element) => assert_eq!(element, Field255::zero()) + ); + } + + #[test] + fn u64_conversion() { + assert_eq!(Field255::from(0u64), Field255::zero()); + assert_eq!(Field255::from(1u64), Field255::one()); + + let max_bytes: [u8; 32] = Field255::from(u64::MAX).into(); + assert_eq!( + max_bytes, + [ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00 + ] + ); + + let want: u64 = 0xffffffffffffffff; + assert_eq!(u64::try_from(Field255::from(want)).unwrap(), want); + + let want: u64 = 0x7000000000000001; + assert_eq!(u64::try_from(Field255::from(want)).unwrap(), want); + + let want: u64 = 0x1234123412341234; + assert_eq!(u64::try_from(Field255::from(want)).unwrap(), want); + + assert!(u64::try_from(Field255::try_from_bytes(&[1; 32], false).unwrap()).is_err()); + assert!(u64::try_from(Field255::try_from_bytes(&[2; 32], false).unwrap()).is_err()); + } + + #[test] + fn formatting() { + assert_eq!( + format!("{}", Field255::zero()), + "0x0000000000000000000000000000000000000000000000000000000000000000" + ); + assert_eq!( + format!("{}", Field255::one()), + "0x0100000000000000000000000000000000000000000000000000000000000000" + ); + assert_eq!( + format!("{}", -Field255::one()), + "0xecffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff7f" + ); + assert_eq!( + format!("{:?}", Field255::zero()), + "0x0000000000000000000000000000000000000000000000000000000000000000" + ); + assert_eq!( + format!("{:?}", Field255::one()), + "0x0100000000000000000000000000000000000000000000000000000000000000" + ); + } +} diff --git a/third_party/rust/prio/src/flp.rs b/third_party/rust/prio/src/flp.rs new file mode 100644 index 0000000000..1912ebab14 --- /dev/null +++ b/third_party/rust/prio/src/flp.rs @@ -0,0 +1,1059 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of the generic Fully Linear Proof (FLP) system specified in +//! [[draft-irtf-cfrg-vdaf-07]]. This is the main building block of [`Prio3`](crate::vdaf::prio3). +//! +//! The FLP is derived for any implementation of the [`Type`] trait. Such an implementation +//! specifies a validity circuit that defines the set of valid measurements, as well as the finite +//! field in which the validity circuit is evaluated. It also determines how raw measurements are +//! encoded as inputs to the validity circuit, and how aggregates are decoded from sums of +//! measurements. +//! +//! # Overview +//! +//! The proof system is comprised of three algorithms. The first, `prove`, is run by the prover in +//! order to generate a proof of a statement's validity. The second and third, `query` and +//! `decide`, are run by the verifier in order to check the proof. The proof asserts that the input +//! is an element of a language recognized by the arithmetic circuit. If an input is _not_ valid, +//! then the verification step will fail with high probability: +//! +//! ``` +//! use prio::flp::types::Count; +//! use prio::flp::Type; +//! use prio::field::{random_vector, FieldElement, Field64}; +//! +//! // The prover chooses a measurement. +//! let count = Count::new(); +//! let input: Vec<Field64> = count.encode_measurement(&0).unwrap(); +//! +//! // The prover and verifier agree on "joint randomness" used to generate and +//! // check the proof. The application needs to ensure that the prover +//! // "commits" to the input before this point. In Prio3, the joint +//! // randomness is derived from additive shares of the input. +//! let joint_rand = random_vector(count.joint_rand_len()).unwrap(); +//! +//! // The prover generates the proof. +//! let prove_rand = random_vector(count.prove_rand_len()).unwrap(); +//! let proof = count.prove(&input, &prove_rand, &joint_rand).unwrap(); +//! +//! // The verifier checks the proof. In the first step, the verifier "queries" +//! // the input and proof, getting the "verifier message" in response. It then +//! // inspects the verifier to decide if the input is valid. +//! let query_rand = random_vector(count.query_rand_len()).unwrap(); +//! let verifier = count.query(&input, &proof, &query_rand, &joint_rand, 1).unwrap(); +//! assert!(count.decide(&verifier).unwrap()); +//! ``` +//! +//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ + +#[cfg(feature = "experimental")] +use crate::dp::DifferentialPrivacyStrategy; +use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish, FftError}; +use crate::field::{FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldError}; +use crate::fp::log2; +use crate::polynomial::poly_eval; +use std::any::Any; +use std::convert::TryFrom; +use std::fmt::Debug; + +pub mod gadgets; +pub mod types; + +/// Errors propagated by methods in this module. +#[derive(Debug, thiserror::Error)] +pub enum FlpError { + /// Calling [`Type::prove`] returned an error. + #[error("prove error: {0}")] + Prove(String), + + /// Calling [`Type::query`] returned an error. + #[error("query error: {0}")] + Query(String), + + /// Calling [`Type::decide`] returned an error. + #[error("decide error: {0}")] + Decide(String), + + /// Calling a gadget returned an error. + #[error("gadget error: {0}")] + Gadget(String), + + /// Calling the validity circuit returned an error. + #[error("validity circuit error: {0}")] + Valid(String), + + /// Calling [`Type::encode_measurement`] returned an error. + #[error("value error: {0}")] + Encode(String), + + /// Calling [`Type::decode_result`] returned an error. + #[error("value error: {0}")] + Decode(String), + + /// Calling [`Type::truncate`] returned an error. + #[error("truncate error: {0}")] + Truncate(String), + + /// Generic invalid parameter. This may be returned when an FLP type cannot be constructed. + #[error("invalid paramter: {0}")] + InvalidParameter(String), + + /// Returned if an FFT operation propagates an error. + #[error("FFT error: {0}")] + Fft(#[from] FftError), + + /// Returned if a field operation encountered an error. + #[error("Field error: {0}")] + Field(#[from] FieldError), + + #[cfg(feature = "experimental")] + /// An error happened during noising. + #[error("differential privacy error: {0}")] + DifferentialPrivacy(#[from] crate::dp::DpError), + + /// Unit test error. + #[cfg(test)] + #[error("test failed: {0}")] + Test(String), +} + +/// A type. Implementations of this trait specify how a particular kind of measurement is encoded +/// as a vector of field elements and how validity of the encoded measurement is determined. +/// Validity is determined via an arithmetic circuit evaluated over the encoded measurement. +pub trait Type: Sized + Eq + Clone + Debug { + /// The Prio3 VDAF identifier corresponding to this type. + const ID: u32; + + /// The type of raw measurement to be encoded. + type Measurement: Clone + Debug; + + /// The type of aggregate result for this type. + type AggregateResult: Clone + Debug; + + /// The finite field used for this type. + type Field: FftFriendlyFieldElement; + + /// Encodes a measurement as a vector of [`Self::input_len`] field elements. + fn encode_measurement( + &self, + measurement: &Self::Measurement, + ) -> Result<Vec<Self::Field>, FlpError>; + + /// Decode an aggregate result. + fn decode_result( + &self, + data: &[Self::Field], + num_measurements: usize, + ) -> Result<Self::AggregateResult, FlpError>; + + /// Returns the sequence of gadgets associated with the validity circuit. + /// + /// # Notes + /// + /// The construction of [[BBCG+19], Theorem 4.3] uses a single gadget rather than many. The + /// idea to generalize the proof system to allow multiple gadgets is discussed briefly in + /// [[BBCG+19], Remark 4.5], but no construction is given. The construction implemented here + /// requires security analysis. + /// + /// [BBCG+19]: https://ia.cr/2019/188 + fn gadget(&self) -> Vec<Box<dyn Gadget<Self::Field>>>; + + /// Evaluates the validity circuit on an input and returns the output. + /// + /// # Parameters + /// + /// * `gadgets` is the sequence of gadgets, presumably output by [`Self::gadget`]. + /// * `input` is the input to be validated. + /// * `joint_rand` is the joint randomness shared by the prover and verifier. + /// * `num_shares` is the number of input shares. + /// + /// # Example usage + /// + /// Applications typically do not call this method directly. It is used internally by + /// [`Self::prove`] and [`Self::query`] to generate and verify the proof respectively. + /// + /// ``` + /// use prio::flp::types::Count; + /// use prio::flp::Type; + /// use prio::field::{random_vector, FieldElement, Field64}; + /// + /// let count = Count::new(); + /// let input: Vec<Field64> = count.encode_measurement(&1).unwrap(); + /// let joint_rand = random_vector(count.joint_rand_len()).unwrap(); + /// let v = count.valid(&mut count.gadget(), &input, &joint_rand, 1).unwrap(); + /// assert_eq!(v, Field64::zero()); + /// ``` + fn valid( + &self, + gadgets: &mut Vec<Box<dyn Gadget<Self::Field>>>, + input: &[Self::Field], + joint_rand: &[Self::Field], + num_shares: usize, + ) -> Result<Self::Field, FlpError>; + + /// Constructs an aggregatable output from an encoded input. Calling this method is only safe + /// once `input` has been validated. + fn truncate(&self, input: Vec<Self::Field>) -> Result<Vec<Self::Field>, FlpError>; + + /// The length in field elements of the encoded input returned by [`Self::encode_measurement`]. + fn input_len(&self) -> usize; + + /// The length in field elements of the proof generated for this type. + fn proof_len(&self) -> usize; + + /// The length in field elements of the verifier message constructed by [`Self::query`]. + fn verifier_len(&self) -> usize; + + /// The length of the truncated output (i.e., the output of [`Type::truncate`]). + fn output_len(&self) -> usize; + + /// The length of the joint random input. + fn joint_rand_len(&self) -> usize; + + /// The length in field elements of the random input consumed by the prover to generate a + /// proof. This is the same as the sum of the arity of each gadget in the validity circuit. + fn prove_rand_len(&self) -> usize; + + /// The length in field elements of the random input consumed by the verifier to make queries + /// against inputs and proofs. This is the same as the number of gadgets in the validity + /// circuit. + fn query_rand_len(&self) -> usize; + + /// Generate a proof of an input's validity. The return value is a sequence of + /// [`Self::proof_len`] field elements. + /// + /// # Parameters + /// + /// * `input` is the input. + /// * `prove_rand` is the prover' randomness. + /// * `joint_rand` is the randomness shared by the prover and verifier. + fn prove( + &self, + input: &[Self::Field], + prove_rand: &[Self::Field], + joint_rand: &[Self::Field], + ) -> Result<Vec<Self::Field>, FlpError> { + if input.len() != self.input_len() { + return Err(FlpError::Prove(format!( + "unexpected input length: got {}; want {}", + input.len(), + self.input_len() + ))); + } + + if prove_rand.len() != self.prove_rand_len() { + return Err(FlpError::Prove(format!( + "unexpected prove randomness length: got {}; want {}", + prove_rand.len(), + self.prove_rand_len() + ))); + } + + if joint_rand.len() != self.joint_rand_len() { + return Err(FlpError::Prove(format!( + "unexpected joint randomness length: got {}; want {}", + joint_rand.len(), + self.joint_rand_len() + ))); + } + + let mut prove_rand_len = 0; + let mut shims = self + .gadget() + .into_iter() + .map(|inner| { + let inner_arity = inner.arity(); + if prove_rand_len + inner_arity > prove_rand.len() { + return Err(FlpError::Prove(format!( + "short prove randomness: got {}; want at least {}", + prove_rand.len(), + prove_rand_len + inner_arity + ))); + } + + let gadget = Box::new(ProveShimGadget::new( + inner, + &prove_rand[prove_rand_len..prove_rand_len + inner_arity], + )?) as Box<dyn Gadget<Self::Field>>; + prove_rand_len += inner_arity; + + Ok(gadget) + }) + .collect::<Result<Vec<_>, FlpError>>()?; + assert_eq!(prove_rand_len, self.prove_rand_len()); + + // Create a buffer for storing the proof. The buffer is longer than the proof itself; the extra + // length is to accommodate the computation of each gadget polynomial. + let data_len = shims + .iter() + .map(|shim| { + let gadget_poly_len = gadget_poly_len(shim.degree(), wire_poly_len(shim.calls())); + + // Computing the gadget polynomial using FFT requires an amount of memory that is a + // power of 2. Thus we choose the smallest power of 2 that is at least as large as + // the gadget polynomial. The wire seeds are encoded in the proof, too, so we + // include the arity of the gadget to ensure there is always enough room at the end + // of the buffer to compute the next gadget polynomial. It's likely that the + // memory footprint here can be reduced, with a bit of care. + shim.arity() + gadget_poly_len.next_power_of_two() + }) + .sum(); + let mut proof = vec![Self::Field::zero(); data_len]; + + // Run the validity circuit with a sequence of "shim" gadgets that record the value of each + // input wire of each gadget evaluation. These values are used to construct the wire + // polynomials for each gadget in the next step. + let _ = self.valid(&mut shims, input, joint_rand, 1)?; + + // Construct the proof. + let mut proof_len = 0; + for shim in shims.iter_mut() { + let gadget = shim + .as_any() + .downcast_mut::<ProveShimGadget<Self::Field>>() + .unwrap(); + + // Interpolate the wire polynomials `f[0], ..., f[g_arity-1]` from the input wires of each + // evaluation of the gadget. + let m = wire_poly_len(gadget.calls()); + let m_inv = Self::Field::from( + <Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap(), + ) + .inv(); + let mut f = vec![vec![Self::Field::zero(); m]; gadget.arity()]; + for ((coefficients, values), proof_val) in f[..gadget.arity()] + .iter_mut() + .zip(gadget.f_vals[..gadget.arity()].iter()) + .zip(proof[proof_len..proof_len + gadget.arity()].iter_mut()) + { + discrete_fourier_transform(coefficients, values, m)?; + discrete_fourier_transform_inv_finish(coefficients, m, m_inv); + + // The first point on each wire polynomial is a random value chosen by the prover. This + // point is stored in the proof so that the verifier can reconstruct the wire + // polynomials. + *proof_val = values[0]; + } + + // Construct the gadget polynomial `G(f[0], ..., f[g_arity-1])` and append it to `proof`. + let gadget_poly_len = gadget_poly_len(gadget.degree(), m); + let start = proof_len + gadget.arity(); + let end = start + gadget_poly_len.next_power_of_two(); + gadget.call_poly(&mut proof[start..end], &f)?; + proof_len += gadget.arity() + gadget_poly_len; + } + + // Truncate the buffer to the size of the proof. + assert_eq!(proof_len, self.proof_len()); + proof.truncate(proof_len); + Ok(proof) + } + + /// Query an input and proof and return the verifier message. The return value has length + /// [`Self::verifier_len`]. + /// + /// # Parameters + /// + /// * `input` is the input or input share. + /// * `proof` is the proof or proof share. + /// * `query_rand` is the verifier's randomness. + /// * `joint_rand` is the randomness shared by the prover and verifier. + /// * `num_shares` is the total number of input shares. + fn query( + &self, + input: &[Self::Field], + proof: &[Self::Field], + query_rand: &[Self::Field], + joint_rand: &[Self::Field], + num_shares: usize, + ) -> Result<Vec<Self::Field>, FlpError> { + if input.len() != self.input_len() { + return Err(FlpError::Query(format!( + "unexpected input length: got {}; want {}", + input.len(), + self.input_len() + ))); + } + + if proof.len() != self.proof_len() { + return Err(FlpError::Query(format!( + "unexpected proof length: got {}; want {}", + proof.len(), + self.proof_len() + ))); + } + + if query_rand.len() != self.query_rand_len() { + return Err(FlpError::Query(format!( + "unexpected query randomness length: got {}; want {}", + query_rand.len(), + self.query_rand_len() + ))); + } + + if joint_rand.len() != self.joint_rand_len() { + return Err(FlpError::Query(format!( + "unexpected joint randomness length: got {}; want {}", + joint_rand.len(), + self.joint_rand_len() + ))); + } + + let mut proof_len = 0; + let mut shims = self + .gadget() + .into_iter() + .enumerate() + .map(|(idx, gadget)| { + let gadget_degree = gadget.degree(); + let gadget_arity = gadget.arity(); + let m = (1 + gadget.calls()).next_power_of_two(); + let r = query_rand[idx]; + + // Make sure the query randomness isn't a root of unity. Evaluating the gadget + // polynomial at any of these points would be a privacy violation, since these points + // were used by the prover to construct the wire polynomials. + if r.pow(<Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap()) + == Self::Field::one() + { + return Err(FlpError::Query(format!( + "invalid query randomness: encountered 2^{m}-th root of unity" + ))); + } + + // Compute the length of the sub-proof corresponding to the `idx`-th gadget. + let next_len = gadget_arity + gadget_degree * (m - 1) + 1; + let proof_data = &proof[proof_len..proof_len + next_len]; + proof_len += next_len; + + Ok(Box::new(QueryShimGadget::new(gadget, r, proof_data)?) + as Box<dyn Gadget<Self::Field>>) + }) + .collect::<Result<Vec<_>, _>>()?; + + // Create a buffer for the verifier data. This includes the output of the validity circuit and, + // for each gadget `shim[idx].inner`, the wire polynomials evaluated at the query randomness + // `query_rand[idx]` and the gadget polynomial evaluated at `query_rand[idx]`. + let data_len = 1 + shims.iter().map(|shim| shim.arity() + 1).sum::<usize>(); + let mut verifier = Vec::with_capacity(data_len); + + // Run the validity circuit with a sequence of "shim" gadgets that record the inputs to each + // wire for each gadget call. Record the output of the circuit and append it to the verifier + // message. + // + // NOTE The proof of [BBC+19, Theorem 4.3] assumes that the output of the validity circuit is + // equal to the output of the last gadget evaluation. Here we relax this assumption. This + // should be OK, since it's possible to transform any circuit into one for which this is true. + // (Needs security analysis.) + let validity = self.valid(&mut shims, input, joint_rand, num_shares)?; + verifier.push(validity); + + // Fill the buffer with the verifier message. + for (query_rand_val, shim) in query_rand[..shims.len()].iter().zip(shims.iter_mut()) { + let gadget = shim + .as_any() + .downcast_ref::<QueryShimGadget<Self::Field>>() + .unwrap(); + + // Reconstruct the wire polynomials `f[0], ..., f[g_arity-1]` and evaluate each wire + // polynomial at query randomness value. + let m = (1 + gadget.calls()).next_power_of_two(); + let m_inv = Self::Field::from( + <Self::Field as FieldElementWithInteger>::Integer::try_from(m).unwrap(), + ) + .inv(); + let mut f = vec![Self::Field::zero(); m]; + for wire in 0..gadget.arity() { + discrete_fourier_transform(&mut f, &gadget.f_vals[wire], m)?; + discrete_fourier_transform_inv_finish(&mut f, m, m_inv); + verifier.push(poly_eval(&f, *query_rand_val)); + } + + // Add the value of the gadget polynomial evaluated at the query randomness value. + verifier.push(gadget.p_at_r); + } + + assert_eq!(verifier.len(), self.verifier_len()); + Ok(verifier) + } + + /// Returns true if the verifier message indicates that the input from which it was generated is valid. + fn decide(&self, verifier: &[Self::Field]) -> Result<bool, FlpError> { + if verifier.len() != self.verifier_len() { + return Err(FlpError::Decide(format!( + "unexpected verifier length: got {}; want {}", + verifier.len(), + self.verifier_len() + ))); + } + + // Check if the output of the circuit is 0. + if verifier[0] != Self::Field::zero() { + return Ok(false); + } + + // Check that each of the proof polynomials are well-formed. + let mut gadgets = self.gadget(); + let mut verifier_len = 1; + for gadget in gadgets.iter_mut() { + let next_len = 1 + gadget.arity(); + + let e = gadget.call(&verifier[verifier_len..verifier_len + next_len - 1])?; + if e != verifier[verifier_len + next_len - 1] { + return Ok(false); + } + + verifier_len += next_len; + } + + Ok(true) + } + + /// Check whether `input` and `joint_rand` have the length expected by `self`, + /// return [`FlpError::Valid`] otherwise. + fn valid_call_check( + &self, + input: &[Self::Field], + joint_rand: &[Self::Field], + ) -> Result<(), FlpError> { + if input.len() != self.input_len() { + return Err(FlpError::Valid(format!( + "unexpected input length: got {}; want {}", + input.len(), + self.input_len(), + ))); + } + + if joint_rand.len() != self.joint_rand_len() { + return Err(FlpError::Valid(format!( + "unexpected joint randomness length: got {}; want {}", + joint_rand.len(), + self.joint_rand_len() + ))); + } + + Ok(()) + } + + /// Check if the length of `input` matches `self`'s `input_len()`, + /// return [`FlpError::Truncate`] otherwise. + fn truncate_call_check(&self, input: &[Self::Field]) -> Result<(), FlpError> { + if input.len() != self.input_len() { + return Err(FlpError::Truncate(format!( + "Unexpected input length: got {}; want {}", + input.len(), + self.input_len() + ))); + } + + Ok(()) + } +} + +/// A type which supports adding noise to aggregate shares for Server Differential Privacy. +#[cfg(feature = "experimental")] +pub trait TypeWithNoise<S>: Type +where + S: DifferentialPrivacyStrategy, +{ + /// Add noise to the aggregate share to obtain differential privacy. + fn add_noise_to_result( + &self, + dp_strategy: &S, + agg_result: &mut [Self::Field], + num_measurements: usize, + ) -> Result<(), FlpError>; +} + +/// A gadget, a non-affine arithmetic circuit that is called when evaluating a validity circuit. +pub trait Gadget<F: FftFriendlyFieldElement>: Debug { + /// Evaluates the gadget on input `inp` and returns the output. + fn call(&mut self, inp: &[F]) -> Result<F, FlpError>; + + /// Evaluate the gadget on input of a sequence of polynomials. The output is written to `outp`. + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError>; + + /// Returns the arity of the gadget. This is the length of `inp` passed to `call` or + /// `call_poly`. + fn arity(&self) -> usize; + + /// Returns the circuit's arithmetic degree. This determines the minimum length the `outp` + /// buffer passed to `call_poly`. + fn degree(&self) -> usize; + + /// Returns the number of times the gadget is expected to be called. + fn calls(&self) -> usize; + + /// This call is used to downcast a `Box<dyn Gadget<F>>` to a concrete type. + fn as_any(&mut self) -> &mut dyn Any; +} + +// A "shim" gadget used during proof generation to record the input wires each time a gadget is +// evaluated. +#[derive(Debug)] +struct ProveShimGadget<F: FftFriendlyFieldElement> { + inner: Box<dyn Gadget<F>>, + + /// Points at which the wire polynomials are interpolated. + f_vals: Vec<Vec<F>>, + + /// The number of times the gadget has been called so far. + ct: usize, +} + +impl<F: FftFriendlyFieldElement> ProveShimGadget<F> { + fn new(inner: Box<dyn Gadget<F>>, prove_rand: &[F]) -> Result<Self, FlpError> { + let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; inner.arity()]; + + for (prove_rand_val, wire_poly_vals) in + prove_rand[..f_vals.len()].iter().zip(f_vals.iter_mut()) + { + // Choose a random field element as the first point on the wire polynomial. + wire_poly_vals[0] = *prove_rand_val; + } + + Ok(Self { + inner, + f_vals, + ct: 1, + }) + } +} + +impl<F: FftFriendlyFieldElement> Gadget<F> for ProveShimGadget<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) { + wire_poly_vals[self.ct] = *inp_val; + } + self.ct += 1; + self.inner.call(inp) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + self.inner.call_poly(outp, inp) + } + + fn arity(&self) -> usize { + self.inner.arity() + } + + fn degree(&self) -> usize { + self.inner.degree() + } + + fn calls(&self) -> usize { + self.inner.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +// A "shim" gadget used during proof verification to record the points at which the intermediate +// proof polynomials are evaluated. +#[derive(Debug)] +struct QueryShimGadget<F: FftFriendlyFieldElement> { + inner: Box<dyn Gadget<F>>, + + /// Points at which intermediate proof polynomials are interpolated. + f_vals: Vec<Vec<F>>, + + /// Points at which the gadget polynomial is interpolated. + p_vals: Vec<F>, + + /// The gadget polynomial evaluated on a random input `r`. + p_at_r: F, + + /// Used to compute an index into `p_val`. + step: usize, + + /// The number of times the gadget has been called so far. + ct: usize, +} + +impl<F: FftFriendlyFieldElement> QueryShimGadget<F> { + fn new(inner: Box<dyn Gadget<F>>, r: F, proof_data: &[F]) -> Result<Self, FlpError> { + let gadget_degree = inner.degree(); + let gadget_arity = inner.arity(); + let m = (1 + inner.calls()).next_power_of_two(); + let p = m * gadget_degree; + + // Each call to this gadget records the values at which intermediate proof polynomials were + // interpolated. The first point was a random value chosen by the prover and transmitted in + // the proof. + let mut f_vals = vec![vec![F::zero(); 1 + inner.calls()]; gadget_arity]; + for wire in 0..gadget_arity { + f_vals[wire][0] = proof_data[wire]; + } + + // Evaluate the gadget polynomial at roots of unity. + let size = p.next_power_of_two(); + let mut p_vals = vec![F::zero(); size]; + discrete_fourier_transform(&mut p_vals, &proof_data[gadget_arity..], size)?; + + // The step is used to compute the element of `p_val` that will be returned by a call to + // the gadget. + let step = (1 << (log2(p as u128) - log2(m as u128))) as usize; + + // Evaluate the gadget polynomial `p` at query randomness `r`. + let p_at_r = poly_eval(&proof_data[gadget_arity..], r); + + Ok(Self { + inner, + f_vals, + p_vals, + p_at_r, + step, + ct: 1, + }) + } +} + +impl<F: FftFriendlyFieldElement> Gadget<F> for QueryShimGadget<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + for (wire_poly_vals, inp_val) in self.f_vals[..inp.len()].iter_mut().zip(inp.iter()) { + wire_poly_vals[self.ct] = *inp_val; + } + let outp = self.p_vals[self.ct * self.step]; + self.ct += 1; + Ok(outp) + } + + fn call_poly(&mut self, _outp: &mut [F], _inp: &[Vec<F>]) -> Result<(), FlpError> { + panic!("no-op"); + } + + fn arity(&self) -> usize { + self.inner.arity() + } + + fn degree(&self) -> usize { + self.inner.degree() + } + + fn calls(&self) -> usize { + self.inner.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// Compute the length of the wire polynomial constructed from the given number of gadget calls. +#[inline] +pub(crate) fn wire_poly_len(num_calls: usize) -> usize { + (1 + num_calls).next_power_of_two() +} + +/// Compute the length of the gadget polynomial for a gadget with the given degree and from wire +/// polynomials of the given length. +#[inline] +pub(crate) fn gadget_poly_len(gadget_degree: usize, wire_poly_len: usize) -> usize { + gadget_degree * (wire_poly_len - 1) + 1 +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::{random_vector, split_vector, Field128}; + use crate::flp::gadgets::{Mul, PolyEval}; + use crate::polynomial::poly_range_check; + + use std::marker::PhantomData; + + // Simple integration test for the core FLP logic. You'll find more extensive unit tests for + // each implemented data type in src/types.rs. + #[test] + fn test_flp() { + const NUM_SHARES: usize = 2; + + let typ: TestType<Field128> = TestType::new(); + let input = typ.encode_measurement(&3).unwrap(); + assert_eq!(input.len(), typ.input_len()); + + let input_shares: Vec<Vec<Field128>> = split_vector(input.as_slice(), NUM_SHARES) + .unwrap() + .into_iter() + .collect(); + + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + assert_eq!(proof.len(), typ.proof_len()); + + let proof_shares: Vec<Vec<Field128>> = split_vector(&proof, NUM_SHARES) + .unwrap() + .into_iter() + .collect(); + + let verifier: Vec<Field128> = (0..NUM_SHARES) + .map(|i| { + typ.query( + &input_shares[i], + &proof_shares[i], + &query_rand, + &joint_rand, + NUM_SHARES, + ) + .unwrap() + }) + .reduce(|mut left, right| { + for (x, y) in left.iter_mut().zip(right.iter()) { + *x += *y; + } + left + }) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + + assert!(typ.decide(&verifier).unwrap()); + } + + /// A toy type used for testing multiple gadgets. Valid inputs of this type consist of a pair + /// of field elements `(x, y)` where `2 <= x < 5` and `x^3 == y`. + #[derive(Clone, Debug, PartialEq, Eq)] + struct TestType<F>(PhantomData<F>); + + impl<F> TestType<F> { + fn new() -> Self { + Self(PhantomData) + } + } + + impl<F: FftFriendlyFieldElement> Type for TestType<F> { + const ID: u32 = 0xFFFF0000; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + let r = joint_rand[0]; + let mut res = F::zero(); + + // Check that `data[0]^3 == data[1]`. + let mut inp = [input[0], input[0]]; + inp[0] = g[0].call(&inp)?; + inp[0] = g[0].call(&inp)?; + let x3_diff = inp[0] - input[1]; + res += r * x3_diff; + + // Check that `data[0]` is in the correct range. + let x_checked = g[1].call(&[input[0]])?; + res += (r * r) * x_checked; + + Ok(res) + } + + fn input_len(&self) -> usize { + 2 + } + + fn proof_len(&self) -> usize { + // First chunk + let mul = 2 /* gadget arity */ + 2 /* gadget degree */ * ( + (1 + 2_usize /* gadget calls */).next_power_of_two() - 1) + 1; + + // Second chunk + let poly = 1 /* gadget arity */ + 3 /* gadget degree */ * ( + (1 + 1_usize /* gadget calls */).next_power_of_two() - 1) + 1; + + mul + poly + } + + fn verifier_len(&self) -> usize { + // First chunk + let mul = 1 + 2 /* gadget arity */; + + // Second chunk + let poly = 1 + 1 /* gadget arity */; + + 1 + mul + poly + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + 3 + } + + fn query_rand_len(&self) -> usize { + 2 + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![ + Box::new(Mul::new(2)), + Box::new(PolyEval::new(poly_range_check(2, 5), 1)), + ] + } + + fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> { + Ok(vec![ + F::from(*measurement), + F::from(*measurement).pow(F::Integer::try_from(3).unwrap()), + ]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + Ok(input) + } + + fn decode_result( + &self, + _data: &[F], + _num_measurements: usize, + ) -> Result<F::Integer, FlpError> { + panic!("not implemented"); + } + } + + // In https://github.com/divviup/libprio-rs/issues/254 an out-of-bounds bug was reported that + // gets triggered when the size of the buffer passed to `gadget.call_poly()` is larger than + // needed for computing the gadget polynomial. + #[test] + fn issue254() { + let typ: Issue254Type<Field128> = Issue254Type::new(); + let input = typ.encode_measurement(&0).unwrap(); + assert_eq!(input.len(), typ.input_len()); + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + let verifier = typ + .query(&input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + assert!(typ.decide(&verifier).unwrap()); + } + + #[derive(Clone, Debug, PartialEq, Eq)] + struct Issue254Type<F> { + num_gadget_calls: [usize; 2], + phantom: PhantomData<F>, + } + + impl<F> Issue254Type<F> { + fn new() -> Self { + Self { + // The bug is triggered when there are two gadgets, but it doesn't matter how many + // times the second gadget is called. + num_gadget_calls: [100, 0], + phantom: PhantomData, + } + } + } + + impl<F: FftFriendlyFieldElement> Type for Issue254Type<F> { + const ID: u32 = 0xFFFF0000; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + _joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + // This is a useless circuit, as it only accepts "0". Its purpose is to exercise the + // use of multiple gadgets, each of which is called an arbitrary number of times. + let mut res = F::zero(); + for _ in 0..self.num_gadget_calls[0] { + res += g[0].call(&[input[0]])?; + } + for _ in 0..self.num_gadget_calls[1] { + res += g[1].call(&[input[0]])?; + } + Ok(res) + } + + fn input_len(&self) -> usize { + 1 + } + + fn proof_len(&self) -> usize { + // First chunk + let first = 1 /* gadget arity */ + 2 /* gadget degree */ * ( + (1 + self.num_gadget_calls[0]).next_power_of_two() - 1) + 1; + + // Second chunk + let second = 1 /* gadget arity */ + 2 /* gadget degree */ * ( + (1 + self.num_gadget_calls[1]).next_power_of_two() - 1) + 1; + + first + second + } + + fn verifier_len(&self) -> usize { + // First chunk + let first = 1 + 1 /* gadget arity */; + + // Second chunk + let second = 1 + 1 /* gadget arity */; + + 1 + first + second + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 0 + } + + fn prove_rand_len(&self) -> usize { + // First chunk + let first = 1; // gadget arity + + // Second chunk + let second = 1; // gadget arity + + first + second + } + + fn query_rand_len(&self) -> usize { + 2 // number of gadgets + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + let poly = poly_range_check(0, 2); // A polynomial with degree 2 + vec![ + Box::new(PolyEval::new(poly.clone(), self.num_gadget_calls[0])), + Box::new(PolyEval::new(poly, self.num_gadget_calls[1])), + ] + } + + fn encode_measurement(&self, measurement: &F::Integer) -> Result<Vec<F>, FlpError> { + Ok(vec![F::from(*measurement)]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + Ok(input) + } + + fn decode_result( + &self, + _data: &[F], + _num_measurements: usize, + ) -> Result<F::Integer, FlpError> { + panic!("not implemented"); + } + } +} diff --git a/third_party/rust/prio/src/flp/gadgets.rs b/third_party/rust/prio/src/flp/gadgets.rs new file mode 100644 index 0000000000..c2696665f4 --- /dev/null +++ b/third_party/rust/prio/src/flp/gadgets.rs @@ -0,0 +1,591 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! A collection of gadgets. + +use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish}; +use crate::field::FftFriendlyFieldElement; +use crate::flp::{gadget_poly_len, wire_poly_len, FlpError, Gadget}; +use crate::polynomial::{poly_deg, poly_eval, poly_mul}; + +#[cfg(feature = "multithreaded")] +use rayon::prelude::*; + +use std::any::Any; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::marker::PhantomData; + +/// For input polynomials larger than or equal to this threshold, gadgets will use FFT for +/// polynomial multiplication. Otherwise, the gadget uses direct multiplication. +const FFT_THRESHOLD: usize = 60; + +/// An arity-2 gadget that multiples its inputs. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct Mul<F: FftFriendlyFieldElement> { + /// Size of buffer for FFT operations. + n: usize, + /// Inverse of `n` in `F`. + n_inv: F, + /// The number of times this gadget will be called. + num_calls: usize, +} + +impl<F: FftFriendlyFieldElement> Mul<F> { + /// Return a new multiplier gadget. `num_calls` is the number of times this gadget will be + /// called by the validity circuit. + pub fn new(num_calls: usize) -> Self { + let n = gadget_poly_fft_mem_len(2, num_calls); + let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); + Self { + n, + n_inv, + num_calls, + } + } + + // Multiply input polynomials directly. + pub(crate) fn call_poly_direct( + &mut self, + outp: &mut [F], + inp: &[Vec<F>], + ) -> Result<(), FlpError> { + let v = poly_mul(&inp[0], &inp[1]); + outp[..v.len()].clone_from_slice(&v); + Ok(()) + } + + // Multiply input polynomials using FFT. + pub(crate) fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let n = self.n; + let mut buf = vec![F::zero(); n]; + + discrete_fourier_transform(&mut buf, &inp[0], n)?; + discrete_fourier_transform(outp, &inp[1], n)?; + + for i in 0..n { + buf[i] *= outp[i]; + } + + discrete_fourier_transform(outp, &buf, n)?; + discrete_fourier_transform_inv_finish(outp, n, self.n_inv); + Ok(()) + } +} + +impl<F: FftFriendlyFieldElement> Gadget<F> for Mul<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + Ok(inp[0] * inp[1]) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + if inp[0].len() >= FFT_THRESHOLD { + self.call_poly_fft(outp, inp) + } else { + self.call_poly_direct(outp, inp) + } + } + + fn arity(&self) -> usize { + 2 + } + + fn degree(&self) -> usize { + 2 + } + + fn calls(&self) -> usize { + self.num_calls + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// An arity-1 gadget that evaluates its input on some polynomial. +// +// TODO Make `poly` an array of length determined by a const generic. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct PolyEval<F: FftFriendlyFieldElement> { + poly: Vec<F>, + /// Size of buffer for FFT operations. + n: usize, + /// Inverse of `n` in `F`. + n_inv: F, + /// The number of times this gadget will be called. + num_calls: usize, +} + +impl<F: FftFriendlyFieldElement> PolyEval<F> { + /// Returns a gadget that evaluates its input on `poly`. `num_calls` is the number of times + /// this gadget is called by the validity circuit. + pub fn new(poly: Vec<F>, num_calls: usize) -> Self { + let n = gadget_poly_fft_mem_len(poly_deg(&poly), num_calls); + let n_inv = F::from(F::Integer::try_from(n).unwrap()).inv(); + Self { + poly, + n, + n_inv, + num_calls, + } + } +} + +impl<F: FftFriendlyFieldElement> PolyEval<F> { + // Multiply input polynomials directly. + fn call_poly_direct(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + outp[0] = self.poly[0]; + let mut x = inp[0].to_vec(); + for i in 1..self.poly.len() { + for j in 0..x.len() { + outp[j] += self.poly[i] * x[j]; + } + + if i < self.poly.len() - 1 { + x = poly_mul(&x, &inp[0]); + } + } + Ok(()) + } + + // Multiply input polynomials using FFT. + fn call_poly_fft(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + let n = self.n; + let inp = &inp[0]; + + let mut inp_vals = vec![F::zero(); n]; + discrete_fourier_transform(&mut inp_vals, inp, n)?; + + let mut x_vals = inp_vals.clone(); + let mut x = vec![F::zero(); n]; + x[..inp.len()].clone_from_slice(inp); + + outp[0] = self.poly[0]; + for i in 1..self.poly.len() { + for j in 0..n { + outp[j] += self.poly[i] * x[j]; + } + + if i < self.poly.len() - 1 { + for j in 0..n { + x_vals[j] *= inp_vals[j]; + } + + discrete_fourier_transform(&mut x, &x_vals, n)?; + discrete_fourier_transform_inv_finish(&mut x, n, self.n_inv); + } + } + Ok(()) + } +} + +impl<F: FftFriendlyFieldElement> Gadget<F> for PolyEval<F> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + Ok(poly_eval(&self.poly, inp[0])) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + for item in outp.iter_mut() { + *item = F::zero(); + } + + if inp[0].len() >= FFT_THRESHOLD { + self.call_poly_fft(outp, inp) + } else { + self.call_poly_direct(outp, inp) + } + } + + fn arity(&self) -> usize { + 1 + } + + fn degree(&self) -> usize { + poly_deg(&self.poly) + } + + fn calls(&self) -> usize { + self.num_calls + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// Trait for abstracting over [`ParallelSum`]. +pub trait ParallelSumGadget<F: FftFriendlyFieldElement, G>: Gadget<F> + Debug { + /// Wraps `inner` into a sum gadget that calls it `chunks` many times, and adds the reuslts. + fn new(inner: G, chunks: usize) -> Self; +} + +/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the +/// outputs. The arity is equal to the arity of the inner gadget times the number of times it is +/// called. +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ParallelSum<F: FftFriendlyFieldElement, G: Gadget<F>> { + inner: G, + chunks: usize, + phantom: PhantomData<F>, +} + +impl<F: FftFriendlyFieldElement, G: 'static + Gadget<F>> ParallelSumGadget<F, G> + for ParallelSum<F, G> +{ + fn new(inner: G, chunks: usize) -> Self { + Self { + inner, + chunks, + phantom: PhantomData, + } + } +} + +impl<F: FftFriendlyFieldElement, G: 'static + Gadget<F>> Gadget<F> for ParallelSum<F, G> { + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + gadget_call_check(self, inp.len())?; + let mut outp = F::zero(); + for chunk in inp.chunks(self.inner.arity()) { + outp += self.inner.call(chunk)?; + } + Ok(outp) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + for x in outp.iter_mut() { + *x = F::zero(); + } + + let mut partial_outp = vec![F::zero(); outp.len()]; + + for chunk in inp.chunks(self.inner.arity()) { + self.inner.call_poly(&mut partial_outp, chunk)?; + for i in 0..outp.len() { + outp[i] += partial_outp[i] + } + } + + Ok(()) + } + + fn arity(&self) -> usize { + self.chunks * self.inner.arity() + } + + fn degree(&self) -> usize { + self.inner.degree() + } + + fn calls(&self) -> usize { + self.inner.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +/// A wrapper gadget that applies the inner gadget to chunks of input and returns the sum of the +/// outputs. The arity is equal to the arity of the inner gadget times the number of chunks. The sum +/// evaluation is multithreaded. +#[cfg(feature = "multithreaded")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +#[derive(Clone, Debug, Eq, PartialEq)] +pub struct ParallelSumMultithreaded<F: FftFriendlyFieldElement, G: Gadget<F>> { + serial_sum: ParallelSum<F, G>, +} + +#[cfg(feature = "multithreaded")] +impl<F, G> ParallelSumGadget<F, G> for ParallelSumMultithreaded<F, G> +where + F: FftFriendlyFieldElement + Sync + Send, + G: 'static + Gadget<F> + Clone + Sync + Send, +{ + fn new(inner: G, chunks: usize) -> Self { + Self { + serial_sum: ParallelSum::new(inner, chunks), + } + } +} + +/// Data structures passed between fold operations in [`ParallelSumMultithreaded`]. +#[cfg(feature = "multithreaded")] +struct ParallelSumFoldState<F, G> { + /// Inner gadget. + inner: G, + /// Output buffer for `call_poly()`. + partial_output: Vec<F>, + /// Sum accumulator. + partial_sum: Vec<F>, +} + +#[cfg(feature = "multithreaded")] +impl<F, G> ParallelSumFoldState<F, G> { + fn new(gadget: &G, length: usize) -> ParallelSumFoldState<F, G> + where + G: Clone, + F: FftFriendlyFieldElement, + { + ParallelSumFoldState { + inner: gadget.clone(), + partial_output: vec![F::zero(); length], + partial_sum: vec![F::zero(); length], + } + } +} + +#[cfg(feature = "multithreaded")] +impl<F, G> Gadget<F> for ParallelSumMultithreaded<F, G> +where + F: FftFriendlyFieldElement + Sync + Send, + G: 'static + Gadget<F> + Clone + Sync + Send, +{ + fn call(&mut self, inp: &[F]) -> Result<F, FlpError> { + self.serial_sum.call(inp) + } + + fn call_poly(&mut self, outp: &mut [F], inp: &[Vec<F>]) -> Result<(), FlpError> { + gadget_call_poly_check(self, outp, inp)?; + + // Create a copy of the inner gadget and two working buffers on each thread. Evaluate the + // gadget on each input polynomial, using the first temporary buffer as an output buffer. + // Then accumulate that result into the second temporary buffer, which acts as a running + // sum. Then, discard everything but the partial sums, add them, and finally copy the sum + // to the output parameter. This is equivalent to the single threaded calculation in + // ParallelSum, since we only rearrange additions, and field addition is associative. + let res = inp + .par_chunks(self.serial_sum.inner.arity()) + .fold( + || ParallelSumFoldState::new(&self.serial_sum.inner, outp.len()), + |mut state, chunk| { + state + .inner + .call_poly(&mut state.partial_output, chunk) + .unwrap(); + for (sum_elem, output_elem) in state + .partial_sum + .iter_mut() + .zip(state.partial_output.iter()) + { + *sum_elem += *output_elem; + } + state + }, + ) + .map(|state| state.partial_sum) + .reduce( + || vec![F::zero(); outp.len()], + |mut x, y| { + for (xi, yi) in x.iter_mut().zip(y.iter()) { + *xi += *yi; + } + x + }, + ); + + outp.copy_from_slice(&res[..]); + Ok(()) + } + + fn arity(&self) -> usize { + self.serial_sum.arity() + } + + fn degree(&self) -> usize { + self.serial_sum.degree() + } + + fn calls(&self) -> usize { + self.serial_sum.calls() + } + + fn as_any(&mut self) -> &mut dyn Any { + self + } +} + +// Check that the input parameters of g.call() are well-formed. +fn gadget_call_check<F: FftFriendlyFieldElement, G: Gadget<F>>( + gadget: &G, + in_len: usize, +) -> Result<(), FlpError> { + if in_len != gadget.arity() { + return Err(FlpError::Gadget(format!( + "unexpected number of inputs: got {}; want {}", + in_len, + gadget.arity() + ))); + } + + if in_len == 0 { + return Err(FlpError::Gadget("can't call an arity-0 gadget".to_string())); + } + + Ok(()) +} + +// Check that the input parameters of g.call_poly() are well-formed. +fn gadget_call_poly_check<F: FftFriendlyFieldElement, G: Gadget<F>>( + gadget: &G, + outp: &[F], + inp: &[Vec<F>], +) -> Result<(), FlpError> +where + G: Gadget<F>, +{ + gadget_call_check(gadget, inp.len())?; + + for i in 1..inp.len() { + if inp[i].len() != inp[0].len() { + return Err(FlpError::Gadget( + "gadget called on wire polynomials with different lengths".to_string(), + )); + } + } + + let expected = gadget_poly_len(gadget.degree(), inp[0].len()).next_power_of_two(); + if outp.len() != expected { + return Err(FlpError::Gadget(format!( + "incorrect output length: got {}; want {}", + outp.len(), + expected + ))); + } + + Ok(()) +} + +#[inline] +fn gadget_poly_fft_mem_len(degree: usize, num_calls: usize) -> usize { + gadget_poly_len(degree, wire_poly_len(num_calls)).next_power_of_two() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "multithreaded")] + use crate::field::FieldElement; + use crate::field::{random_vector, Field64 as TestField}; + use crate::prng::Prng; + + #[test] + fn test_mul() { + // Test the gadget with input polynomials shorter than `FFT_THRESHOLD`. This exercises the + // naive multiplication code path. + let num_calls = FFT_THRESHOLD / 2; + let mut g: Mul<TestField> = Mul::new(num_calls); + gadget_test(&mut g, num_calls); + + // Test the gadget with input polynomials longer than `FFT_THRESHOLD`. This exercises + // FFT-based polynomial multiplication. + let num_calls = FFT_THRESHOLD; + let mut g: Mul<TestField> = Mul::new(num_calls); + gadget_test(&mut g, num_calls); + } + + #[test] + fn test_poly_eval() { + let poly: Vec<TestField> = random_vector(10).unwrap(); + + let num_calls = FFT_THRESHOLD / 2; + let mut g: PolyEval<TestField> = PolyEval::new(poly.clone(), num_calls); + gadget_test(&mut g, num_calls); + + let num_calls = FFT_THRESHOLD; + let mut g: PolyEval<TestField> = PolyEval::new(poly, num_calls); + gadget_test(&mut g, num_calls); + } + + #[test] + fn test_parallel_sum() { + let num_calls = 10; + let chunks = 23; + + let mut g = ParallelSum::new(Mul::<TestField>::new(num_calls), chunks); + gadget_test(&mut g, num_calls); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_parallel_sum_multithreaded() { + use std::iter; + + for num_calls in [1, 10, 100] { + let chunks = 23; + + let mut g = ParallelSumMultithreaded::new(Mul::new(num_calls), chunks); + gadget_test(&mut g, num_calls); + + // Test that the multithreaded version has the same output as the normal version. + let mut g_serial = ParallelSum::new(Mul::new(num_calls), chunks); + assert_eq!(g.arity(), g_serial.arity()); + assert_eq!(g.degree(), g_serial.degree()); + assert_eq!(g.calls(), g_serial.calls()); + + let arity = g.arity(); + let degree = g.degree(); + + // Test that both gadgets evaluate to the same value when run on scalar inputs. + let inp: Vec<TestField> = random_vector(arity).unwrap(); + let result = g.call(&inp).unwrap(); + let result_serial = g_serial.call(&inp).unwrap(); + assert_eq!(result, result_serial); + + // Test that both gadgets evaluate to the same value when run on polynomial inputs. + let mut poly_outp = + vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()]; + let mut poly_outp_serial = + vec![TestField::zero(); (degree * num_calls + 1).next_power_of_two()]; + let mut prng: Prng<TestField, _> = Prng::new().unwrap(); + let poly_inp: Vec<_> = iter::repeat_with(|| { + iter::repeat_with(|| prng.get()) + .take(1 + num_calls) + .collect::<Vec<_>>() + }) + .take(arity) + .collect(); + + g.call_poly(&mut poly_outp, &poly_inp).unwrap(); + g_serial + .call_poly(&mut poly_outp_serial, &poly_inp) + .unwrap(); + assert_eq!(poly_outp, poly_outp_serial); + } + } + + // Test that calling g.call_poly() and evaluating the output at a given point is equivalent + // to evaluating each of the inputs at the same point and applying g.call() on the results. + fn gadget_test<F: FftFriendlyFieldElement, G: Gadget<F>>(g: &mut G, num_calls: usize) { + let wire_poly_len = (1 + num_calls).next_power_of_two(); + let mut prng = Prng::new().unwrap(); + let mut inp = vec![F::zero(); g.arity()]; + let mut gadget_poly = vec![F::zero(); gadget_poly_fft_mem_len(g.degree(), num_calls)]; + let mut wire_polys = vec![vec![F::zero(); wire_poly_len]; g.arity()]; + + let r = prng.get(); + for i in 0..g.arity() { + for j in 0..wire_poly_len { + wire_polys[i][j] = prng.get(); + } + inp[i] = poly_eval(&wire_polys[i], r); + } + + g.call_poly(&mut gadget_poly, &wire_polys).unwrap(); + let got = poly_eval(&gadget_poly, r); + let want = g.call(&inp).unwrap(); + assert_eq!(got, want); + + // Repeat the call to make sure that the gadget's memory is reset properly between calls. + g.call_poly(&mut gadget_poly, &wire_polys).unwrap(); + let got = poly_eval(&gadget_poly, r); + assert_eq!(got, want); + } +} diff --git a/third_party/rust/prio/src/flp/types.rs b/third_party/rust/prio/src/flp/types.rs new file mode 100644 index 0000000000..18c290355c --- /dev/null +++ b/third_party/rust/prio/src/flp/types.rs @@ -0,0 +1,1415 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! A collection of [`Type`] implementations. + +use crate::field::{FftFriendlyFieldElement, FieldElementWithIntegerExt}; +use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; +use crate::flp::{FlpError, Gadget, Type}; +use crate::polynomial::poly_range_check; +use std::convert::TryInto; +use std::fmt::{self, Debug}; +use std::marker::PhantomData; +/// The counter data type. Each measurement is `0` or `1` and the aggregate result is the sum of the measurements (i.e., the total number of `1s`). +#[derive(Clone, PartialEq, Eq)] +pub struct Count<F> { + range_checker: Vec<F>, +} + +impl<F> Debug for Count<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Count").finish() + } +} + +impl<F: FftFriendlyFieldElement> Count<F> { + /// Return a new [`Count`] type instance. + pub fn new() -> Self { + Self { + range_checker: poly_range_check(0, 2), + } + } +} + +impl<F: FftFriendlyFieldElement> Default for Count<F> { + fn default() -> Self { + Self::new() + } +} + +impl<F: FftFriendlyFieldElement> Type for Count<F> { + const ID: u32 = 0x00000000; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn encode_measurement(&self, value: &F::Integer) -> Result<Vec<F>, FlpError> { + let max = F::valid_integer_try_from(1)?; + if *value > max { + return Err(FlpError::Encode("Count value must be 0 or 1".to_string())); + } + + Ok(vec![F::from(*value)]) + } + + fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> { + decode_result(data) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(Mul::new(1))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + Ok(g[0].call(&[input[0], input[0]])? - input[0]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn input_len(&self) -> usize { + 1 + } + + fn proof_len(&self) -> usize { + 5 + } + + fn verifier_len(&self) -> usize { + 4 + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 0 + } + + fn prove_rand_len(&self) -> usize { + 2 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// This sum type. Each measurement is a integer in `[0, 2^bits)` and the aggregate is the sum of +/// the measurements. +/// +/// The validity circuit is based on the SIMD circuit construction of [[BBCG+19], Theorem 5.3]. +/// +/// [BBCG+19]: https://ia.cr/2019/188 +#[derive(Clone, PartialEq, Eq)] +pub struct Sum<F: FftFriendlyFieldElement> { + bits: usize, + range_checker: Vec<F>, +} + +impl<F: FftFriendlyFieldElement> Debug for Sum<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Sum").field("bits", &self.bits).finish() + } +} + +impl<F: FftFriendlyFieldElement> Sum<F> { + /// Return a new [`Sum`] type parameter. Each value of this type is an integer in range `[0, + /// 2^bits)`. + pub fn new(bits: usize) -> Result<Self, FlpError> { + if !F::valid_integer_bitlength(bits) { + return Err(FlpError::Encode( + "invalid bits: number of bits exceeds maximum number of bits in this field" + .to_string(), + )); + } + Ok(Self { + bits, + range_checker: poly_range_check(0, 2), + }) + } +} + +impl<F: FftFriendlyFieldElement> Type for Sum<F> { + const ID: u32 = 0x00000001; + type Measurement = F::Integer; + type AggregateResult = F::Integer; + type Field = F; + + fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { + let v = F::encode_into_bitvector_representation(summand, self.bits)?; + Ok(v) + } + + fn decode_result(&self, data: &[F], _num_measurements: usize) -> Result<F::Integer, FlpError> { + decode_result(data) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(PolyEval::new( + self.range_checker.clone(), + self.bits, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + let res = F::decode_from_bitvector_representation(&input)?; + Ok(vec![res]) + } + + fn input_len(&self) -> usize { + self.bits + } + + fn proof_len(&self) -> usize { + 2 * ((1 + self.bits).next_power_of_two() - 1) + 2 + } + + fn verifier_len(&self) -> usize { + 3 + } + + fn output_len(&self) -> usize { + 1 + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + 1 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the +/// aggregate is the arithmetic average. +#[derive(Clone, PartialEq, Eq)] +pub struct Average<F: FftFriendlyFieldElement> { + bits: usize, + range_checker: Vec<F>, +} + +impl<F: FftFriendlyFieldElement> Debug for Average<F> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Average").field("bits", &self.bits).finish() + } +} + +impl<F: FftFriendlyFieldElement> Average<F> { + /// Return a new [`Average`] type parameter. Each value of this type is an integer in range `[0, + /// 2^bits)`. + pub fn new(bits: usize) -> Result<Self, FlpError> { + if !F::valid_integer_bitlength(bits) { + return Err(FlpError::Encode( + "invalid bits: number of bits exceeds maximum number of bits in this field" + .to_string(), + )); + } + Ok(Self { + bits, + range_checker: poly_range_check(0, 2), + }) + } +} + +impl<F: FftFriendlyFieldElement> Type for Average<F> { + const ID: u32 = 0xFFFF0000; + type Measurement = F::Integer; + type AggregateResult = f64; + type Field = F; + + fn encode_measurement(&self, summand: &F::Integer) -> Result<Vec<F>, FlpError> { + let v = F::encode_into_bitvector_representation(summand, self.bits)?; + Ok(v) + } + + fn decode_result(&self, data: &[F], num_measurements: usize) -> Result<f64, FlpError> { + // Compute the average from the aggregated sum. + let data = decode_result(data)?; + let data: u64 = data.try_into().map_err(|err| { + FlpError::Decode(format!("failed to convert {data:?} to u64: {err}",)) + })?; + let result = (data as f64) / (num_measurements as f64); + Ok(result) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(PolyEval::new( + self.range_checker.clone(), + self.bits, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + _num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + call_gadget_on_vec_entries(&mut g[0], input, joint_rand[0]) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + let res = F::decode_from_bitvector_representation(&input)?; + Ok(vec![res]) + } + + fn input_len(&self) -> usize { + self.bits + } + + fn proof_len(&self) -> usize { + 2 * ((1 + self.bits).next_power_of_two() - 1) + 2 + } + + fn verifier_len(&self) -> usize { + 3 + } + + fn output_len(&self) -> usize { + 1 + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + 1 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// The histogram type. Each measurement is an integer in `[0, length)` and the aggregate is a +/// histogram counting the number of occurrences of each measurement. +#[derive(PartialEq, Eq)] +pub struct Histogram<F, S> { + length: usize, + chunk_length: usize, + gadget_calls: usize, + phantom: PhantomData<(F, S)>, +} + +impl<F: FftFriendlyFieldElement, S> Debug for Histogram<F, S> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Histogram") + .field("length", &self.length) + .field("chunk_length", &self.chunk_length) + .finish() + } +} + +impl<F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>>> Histogram<F, S> { + /// Return a new [`Histogram`] type with the given number of buckets. + pub fn new(length: usize, chunk_length: usize) -> Result<Self, FlpError> { + if length >= u32::MAX as usize { + return Err(FlpError::Encode( + "invalid length: number of buckets exceeds maximum permitted".to_string(), + )); + } + if length == 0 { + return Err(FlpError::InvalidParameter( + "length cannot be zero".to_string(), + )); + } + if chunk_length == 0 { + return Err(FlpError::InvalidParameter( + "chunk_length cannot be zero".to_string(), + )); + } + + let mut gadget_calls = length / chunk_length; + if length % chunk_length != 0 { + gadget_calls += 1; + } + + Ok(Self { + length, + chunk_length, + gadget_calls, + phantom: PhantomData, + }) + } +} + +impl<F, S> Clone for Histogram<F, S> { + fn clone(&self) -> Self { + Self { + length: self.length, + chunk_length: self.chunk_length, + gadget_calls: self.gadget_calls, + phantom: self.phantom, + } + } +} + +impl<F, S> Type for Histogram<F, S> +where + F: FftFriendlyFieldElement, + S: ParallelSumGadget<F, Mul<F>> + Eq + 'static, +{ + const ID: u32 = 0x00000003; + type Measurement = usize; + type AggregateResult = Vec<F::Integer>; + type Field = F; + + fn encode_measurement(&self, measurement: &usize) -> Result<Vec<F>, FlpError> { + let mut data = vec![F::zero(); self.length]; + + data[*measurement] = F::one(); + Ok(data) + } + + fn decode_result( + &self, + data: &[F], + _num_measurements: usize, + ) -> Result<Vec<F::Integer>, FlpError> { + decode_result_vec(data, self.length) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(S::new( + Mul::new(self.gadget_calls), + self.chunk_length, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + + // Check that each element of `input` is a 0 or 1. + let range_check = parallel_sum_range_checks( + &mut g[0], + input, + joint_rand[0], + self.chunk_length, + num_shares, + )?; + + // Check that the elements of `input` sum to 1. + let mut sum_check = -(F::one() / F::from(F::valid_integer_try_from(num_shares)?)); + for val in input.iter() { + sum_check += *val; + } + + // Take a random linear combination of both checks. + let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * sum_check; + Ok(out) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + Ok(input) + } + + fn input_len(&self) -> usize { + self.length + } + + fn proof_len(&self) -> usize { + (self.chunk_length * 2) + 2 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1 + } + + fn verifier_len(&self) -> usize { + 2 + self.chunk_length * 2 + } + + fn output_len(&self) -> usize { + self.input_len() + } + + fn joint_rand_len(&self) -> usize { + 2 + } + + fn prove_rand_len(&self) -> usize { + self.chunk_length * 2 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// A sequence of integers in range `[0, 2^bits)`. This type uses a neat trick from [[BBCG+19], +/// Corollary 4.9] to reduce the proof size to roughly the square root of the input size. +/// +/// [BBCG+19]: https://eprint.iacr.org/2019/188 +#[derive(PartialEq, Eq)] +pub struct SumVec<F: FftFriendlyFieldElement, S> { + len: usize, + bits: usize, + flattened_len: usize, + max: F::Integer, + chunk_length: usize, + gadget_calls: usize, + phantom: PhantomData<S>, +} + +impl<F: FftFriendlyFieldElement, S> Debug for SumVec<F, S> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SumVec") + .field("len", &self.len) + .field("bits", &self.bits) + .field("chunk_length", &self.chunk_length) + .finish() + } +} + +impl<F: FftFriendlyFieldElement, S: ParallelSumGadget<F, Mul<F>>> SumVec<F, S> { + /// Returns a new [`SumVec`] with the desired bit width and vector length. + /// + /// # Errors + /// + /// * The length of the encoded measurement, i.e., `bits * len`, overflows addressable memory. + /// * The bit width cannot be encoded, i.e., `bits` is larger than or equal to the number of + /// bits required to encode field elements. + /// * Any of `bits`, `len`, or `chunk_length` are zero. + pub fn new(bits: usize, len: usize, chunk_length: usize) -> Result<Self, FlpError> { + let flattened_len = bits.checked_mul(len).ok_or_else(|| { + FlpError::InvalidParameter("`bits*len` overflows addressable memory".into()) + })?; + + // Check if the bit width is too large. This limit is defined to be one bit less than the + // number of bits required to encode `F::Integer`. (One less so that we can compute `1 << + // bits` without overflowing.) + let limit = std::mem::size_of::<F::Integer>() * 8 - 1; + if bits > limit { + return Err(FlpError::InvalidParameter(format!( + "bit wdith exceeds limit of {limit}" + ))); + } + + // Check for degenerate parameters. + if bits == 0 { + return Err(FlpError::InvalidParameter( + "bits cannot be zero".to_string(), + )); + } + if len == 0 { + return Err(FlpError::InvalidParameter("len cannot be zero".to_string())); + } + if chunk_length == 0 { + return Err(FlpError::InvalidParameter( + "chunk_length cannot be zero".to_string(), + )); + } + + // Compute the largest encodable measurement. + let one = F::Integer::from(F::one()); + let max = (one << bits) - one; + + let mut gadget_calls = flattened_len / chunk_length; + if flattened_len % chunk_length != 0 { + gadget_calls += 1; + } + + Ok(Self { + len, + bits, + flattened_len, + max, + chunk_length, + gadget_calls, + phantom: PhantomData, + }) + } +} + +impl<F: FftFriendlyFieldElement, S> Clone for SumVec<F, S> { + fn clone(&self) -> Self { + Self { + len: self.len, + bits: self.bits, + flattened_len: self.flattened_len, + max: self.max, + chunk_length: self.chunk_length, + gadget_calls: self.gadget_calls, + phantom: PhantomData, + } + } +} + +impl<F, S> Type for SumVec<F, S> +where + F: FftFriendlyFieldElement, + S: ParallelSumGadget<F, Mul<F>> + Eq + 'static, +{ + const ID: u32 = 0x00000002; + type Measurement = Vec<F::Integer>; + type AggregateResult = Vec<F::Integer>; + type Field = F; + + fn encode_measurement(&self, measurement: &Vec<F::Integer>) -> Result<Vec<F>, FlpError> { + if measurement.len() != self.len { + return Err(FlpError::Encode(format!( + "unexpected measurement length: got {}; want {}", + measurement.len(), + self.len + ))); + } + + let mut flattened = vec![F::zero(); self.flattened_len]; + for (summand, chunk) in measurement + .iter() + .zip(flattened.chunks_exact_mut(self.bits)) + { + if summand > &self.max { + return Err(FlpError::Encode(format!( + "summand exceeds maximum of 2^{}-1", + self.bits + ))); + } + F::fill_with_bitvector_representation(summand, chunk)?; + } + + Ok(flattened) + } + + fn decode_result( + &self, + data: &[F], + _num_measurements: usize, + ) -> Result<Vec<F::Integer>, FlpError> { + decode_result_vec(data, self.len) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<F>>> { + vec![Box::new(S::new( + Mul::new(self.gadget_calls), + self.chunk_length, + ))] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<F>>>, + input: &[F], + joint_rand: &[F], + num_shares: usize, + ) -> Result<F, FlpError> { + self.valid_call_check(input, joint_rand)?; + + parallel_sum_range_checks( + &mut g[0], + input, + joint_rand[0], + self.chunk_length, + num_shares, + ) + } + + fn truncate(&self, input: Vec<F>) -> Result<Vec<F>, FlpError> { + self.truncate_call_check(&input)?; + let mut unflattened = Vec::with_capacity(self.len); + for chunk in input.chunks(self.bits) { + unflattened.push(F::decode_from_bitvector_representation(chunk)?); + } + Ok(unflattened) + } + + fn input_len(&self) -> usize { + self.flattened_len + } + + fn proof_len(&self) -> usize { + (self.chunk_length * 2) + 2 * ((1 + self.gadget_calls).next_power_of_two() - 1) + 1 + } + + fn verifier_len(&self) -> usize { + 2 + self.chunk_length * 2 + } + + fn output_len(&self) -> usize { + self.len + } + + fn joint_rand_len(&self) -> usize { + 1 + } + + fn prove_rand_len(&self) -> usize { + self.chunk_length * 2 + } + + fn query_rand_len(&self) -> usize { + 1 + } +} + +/// Compute a random linear combination of the result of calls of `g` on each element of `input`. +/// +/// # Arguments +/// +/// * `g` - The gadget to be applied elementwise +/// * `input` - The vector on whose elements to apply `g` +/// * `rnd` - The randomness used for the linear combination +pub(crate) fn call_gadget_on_vec_entries<F: FftFriendlyFieldElement>( + g: &mut Box<dyn Gadget<F>>, + input: &[F], + rnd: F, +) -> Result<F, FlpError> { + let mut range_check = F::zero(); + let mut r = rnd; + for chunk in input.chunks(1) { + range_check += r * g.call(chunk)?; + r *= rnd; + } + Ok(range_check) +} + +/// Given a vector `data` of field elements which should contain exactly one entry, return the +/// integer representation of that entry. +pub(crate) fn decode_result<F: FftFriendlyFieldElement>( + data: &[F], +) -> Result<F::Integer, FlpError> { + if data.len() != 1 { + return Err(FlpError::Decode("unexpected input length".into())); + } + Ok(F::Integer::from(data[0])) +} + +/// Given a vector `data` of field elements, return a vector containing the corresponding integer +/// representations, if the number of entries matches `expected_len`. +pub(crate) fn decode_result_vec<F: FftFriendlyFieldElement>( + data: &[F], + expected_len: usize, +) -> Result<Vec<F::Integer>, FlpError> { + if data.len() != expected_len { + return Err(FlpError::Decode("unexpected input length".into())); + } + Ok(data.iter().map(|elem| F::Integer::from(*elem)).collect()) +} + +/// This evaluates range checks on a slice of field elements, using a ParallelSum gadget evaluating +/// many multiplication gates. +/// +/// # Arguments +/// +/// * `gadget`: A `ParallelSumGadget<F, Mul<F>>` gadget, or a shim wrapping the same. +/// * `input`: A slice of inputs. This calculation will check that all inputs were zero or one +/// before secret sharing. +/// * `joint_randomness`: A joint randomness value, used to compute a random linear combination of +/// individual range checks. +/// * `chunk_length`: How many multiplication gates per ParallelSum gadget. This must match what the +/// gadget was constructed with. +/// * `num_shares`: The number of shares that the inputs were secret shared into. This is needed to +/// correct constant terms in the circuit. +/// +/// # Returns +/// +/// This returns (additive shares of) zero if all inputs were zero or one, and otherwise returns a +/// non-zero value with high probability. +pub(crate) fn parallel_sum_range_checks<F: FftFriendlyFieldElement>( + gadget: &mut Box<dyn Gadget<F>>, + input: &[F], + joint_randomness: F, + chunk_length: usize, + num_shares: usize, +) -> Result<F, FlpError> { + let f_num_shares = F::from(F::valid_integer_try_from::<usize>(num_shares)?); + let num_shares_inverse = f_num_shares.inv(); + + let mut output = F::zero(); + let mut r_power = joint_randomness; + let mut padded_chunk = vec![F::zero(); 2 * chunk_length]; + + for chunk in input.chunks(chunk_length) { + // Construct arguments for the Mul subcircuits. + for (input, args) in chunk.iter().zip(padded_chunk.chunks_exact_mut(2)) { + args[0] = r_power * *input; + args[1] = *input - num_shares_inverse; + r_power *= joint_randomness; + } + // If the chunk of the input is smaller than chunk_length, use zeros instead of measurement + // inputs for the remaining calls. + for args in padded_chunk[chunk.len() * 2..].chunks_exact_mut(2) { + args[0] = F::zero(); + args[1] = -num_shares_inverse; + // Skip updating r_power. This inner loop is only used during the last iteration of the + // outer loop, if the last input chunk is a partial chunk. Thus, r_power won't be + // accessed again before returning. + } + + output += gadget.call(&padded_chunk)?; + } + + Ok(output) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::field::{random_vector, Field64 as TestField, FieldElement}; + use crate::flp::gadgets::ParallelSum; + #[cfg(feature = "multithreaded")] + use crate::flp::gadgets::ParallelSumMultithreaded; + use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase}; + use std::cmp; + + #[test] + fn test_count() { + let count: Count<TestField> = Count::new(); + let zero = TestField::zero(); + let one = TestField::one(); + + // Round trip + assert_eq!( + count + .decode_result( + &count + .truncate(count.encode_measurement(&1).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + 1, + ); + + // Test FLP on valid input. + flp_validity_test( + &count, + &count.encode_measurement(&1).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one]), + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &count, + &count.encode_measurement(&0).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero]), + num_shares: 3, + }, + ) + .unwrap(); + + // Test FLP on invalid input. + flp_validity_test( + &count, + &[TestField::from(1337)], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + num_shares: 3, + }, + ) + .unwrap(); + + // Try running the validity circuit on an input that's too short. + count.valid(&mut count.gadget(), &[], &[], 1).unwrap_err(); + count + .valid(&mut count.gadget(), &[1.into(), 2.into()], &[], 1) + .unwrap_err(); + } + + #[test] + fn test_sum() { + let sum = Sum::new(11).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let nine = TestField::from(9); + + // Round trip + assert_eq!( + sum.decode_result( + &sum.truncate(sum.encode_measurement(&27).unwrap()).unwrap(), + 1 + ) + .unwrap(), + 27, + ); + + // Test FLP on valid input. + flp_validity_test( + &sum, + &sum.encode_measurement(&1337).unwrap(), + &ValidityTestCase { + expect_valid: true, + expected_output: Some(vec![TestField::from(1337)]), + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(0).unwrap(), + &[], + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero]), + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(2).unwrap(), + &[one, zero], + &ValidityTestCase { + expect_valid: true, + expected_output: Some(vec![one]), + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(9).unwrap(), + &[one, zero, one, one, zero, one, one, one, zero], + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![TestField::from(237)]), + num_shares: 3, + }, + ) + .unwrap(); + + // Test FLP on invalid input. + flp_validity_test( + &Sum::new(3).unwrap(), + &[one, nine, zero], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &Sum::new(5).unwrap(), + &[zero, zero, zero, zero, nine], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + num_shares: 3, + }, + ) + .unwrap(); + } + + #[test] + fn test_average() { + let average = Average::new(11).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let ten = TestField::from(10); + + // Testing that average correctly quotients the sum of the measurements + // by the number of measurements. + assert_eq!(average.decode_result(&[zero], 1).unwrap(), 0.0); + assert_eq!(average.decode_result(&[one], 1).unwrap(), 1.0); + assert_eq!(average.decode_result(&[one], 2).unwrap(), 0.5); + assert_eq!(average.decode_result(&[one], 4).unwrap(), 0.25); + assert_eq!(average.decode_result(&[ten], 8).unwrap(), 1.25); + + // round trip of 12 with `num_measurements`=1 + assert_eq!( + average + .decode_result( + &average + .truncate(average.encode_measurement(&12).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + 12.0 + ); + + // round trip of 12 with `num_measurements`=24 + assert_eq!( + average + .decode_result( + &average + .truncate(average.encode_measurement(&12).unwrap()) + .unwrap(), + 24 + ) + .unwrap(), + 0.5 + ); + } + + fn test_histogram<F, S>(f: F) + where + F: Fn(usize, usize) -> Result<Histogram<TestField, S>, FlpError>, + S: ParallelSumGadget<TestField, Mul<TestField>> + Eq + 'static, + { + let hist = f(3, 2).unwrap(); + let zero = TestField::zero(); + let one = TestField::one(); + let nine = TestField::from(9); + + assert_eq!(&hist.encode_measurement(&0).unwrap(), &[one, zero, zero]); + assert_eq!(&hist.encode_measurement(&1).unwrap(), &[zero, one, zero]); + assert_eq!(&hist.encode_measurement(&2).unwrap(), &[zero, zero, one]); + + // Round trip + assert_eq!( + hist.decode_result( + &hist.truncate(hist.encode_measurement(&2).unwrap()).unwrap(), + 1 + ) + .unwrap(), + [0, 0, 1] + ); + + // Test valid inputs. + flp_validity_test( + &hist, + &hist.encode_measurement(&0).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one, zero, zero]), + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &hist.encode_measurement(&1).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero, one, zero]), + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &hist.encode_measurement(&2).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![zero, zero, one]), + num_shares: 3, + }, + ) + .unwrap(); + + // Test invalid inputs. + flp_validity_test( + &hist, + &[zero, zero, nine], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &[zero, one, one], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &[one, one, one], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + num_shares: 3, + }, + ) + .unwrap(); + + flp_validity_test( + &hist, + &[zero, zero, zero], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + num_shares: 3, + }, + ) + .unwrap(); + } + + #[test] + fn test_histogram_serial() { + test_histogram(Histogram::<TestField, ParallelSum<TestField, Mul<TestField>>>::new); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_histogram_parallel() { + test_histogram( + Histogram::<TestField, ParallelSumMultithreaded<TestField, Mul<TestField>>>::new, + ); + } + + fn test_sum_vec<F, S>(f: F) + where + F: Fn(usize, usize, usize) -> Result<SumVec<TestField, S>, FlpError>, + S: 'static + ParallelSumGadget<TestField, Mul<TestField>> + Eq, + { + let one = TestField::one(); + let nine = TestField::from(9); + + // Test on valid inputs. + for len in 1..10 { + let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); + let sum_vec = f(1, len, chunk_length).unwrap(); + flp_validity_test( + &sum_vec, + &sum_vec.encode_measurement(&vec![1; len]).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one; len]), + num_shares: 3, + }, + ) + .unwrap(); + } + + let len = 100; + let sum_vec = f(1, len, 10).unwrap(); + flp_validity_test( + &sum_vec, + &sum_vec.encode_measurement(&vec![1; len]).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![one; len]), + num_shares: 3, + }, + ) + .unwrap(); + + let len = 23; + let sum_vec = f(4, len, 4).unwrap(); + flp_validity_test( + &sum_vec, + &sum_vec.encode_measurement(&vec![9; len]).unwrap(), + &ValidityTestCase::<TestField> { + expect_valid: true, + expected_output: Some(vec![nine; len]), + num_shares: 3, + }, + ) + .unwrap(); + + // Test on invalid inputs. + for len in 1..10 { + let chunk_length = cmp::max((len as f64).sqrt() as usize, 1); + let sum_vec = f(1, len, chunk_length).unwrap(); + flp_validity_test( + &sum_vec, + &vec![nine; len], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + num_shares: 3, + }, + ) + .unwrap(); + } + + let len = 23; + let sum_vec = f(2, len, 4).unwrap(); + flp_validity_test( + &sum_vec, + &vec![nine; 2 * len], + &ValidityTestCase::<TestField> { + expect_valid: false, + expected_output: None, + num_shares: 3, + }, + ) + .unwrap(); + + // Round trip + let want = vec![1; len]; + assert_eq!( + sum_vec + .decode_result( + &sum_vec + .truncate(sum_vec.encode_measurement(&want).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + want + ); + } + + #[test] + fn test_sum_vec_serial() { + test_sum_vec(SumVec::<TestField, ParallelSum<TestField, Mul<TestField>>>::new) + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_sum_vec_parallel() { + test_sum_vec(SumVec::<TestField, ParallelSumMultithreaded<TestField, Mul<TestField>>>::new) + } + + #[test] + fn sum_vec_serial_long() { + let typ: SumVec<TestField, ParallelSum<TestField, _>> = SumVec::new(1, 1000, 31).unwrap(); + let input = typ.encode_measurement(&vec![0; 1000]).unwrap(); + assert_eq!(input.len(), typ.input_len()); + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + let verifier = typ + .query(&input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + assert!(typ.decide(&verifier).unwrap()); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn sum_vec_parallel_long() { + let typ: SumVec<TestField, ParallelSumMultithreaded<TestField, _>> = + SumVec::new(1, 1000, 31).unwrap(); + let input = typ.encode_measurement(&vec![0; 1000]).unwrap(); + assert_eq!(input.len(), typ.input_len()); + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); + let verifier = typ + .query(&input, &proof, &query_rand, &joint_rand, 1) + .unwrap(); + assert_eq!(verifier.len(), typ.verifier_len()); + assert!(typ.decide(&verifier).unwrap()); + } +} + +#[cfg(test)] +mod test_utils { + use super::*; + use crate::field::{random_vector, split_vector, FieldElement}; + + pub(crate) struct ValidityTestCase<F> { + pub(crate) expect_valid: bool, + pub(crate) expected_output: Option<Vec<F>>, + // Number of shares to split input and proofs into in `flp_test`. + pub(crate) num_shares: usize, + } + + pub(crate) fn flp_validity_test<T: Type>( + typ: &T, + input: &[T::Field], + t: &ValidityTestCase<T::Field>, + ) -> Result<(), FlpError> { + let mut gadgets = typ.gadget(); + + if input.len() != typ.input_len() { + return Err(FlpError::Test(format!( + "unexpected input length: got {}; want {}", + input.len(), + typ.input_len() + ))); + } + + if typ.query_rand_len() != gadgets.len() { + return Err(FlpError::Test(format!( + "query rand length: got {}; want {}", + typ.query_rand_len(), + gadgets.len() + ))); + } + + let joint_rand = random_vector(typ.joint_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let query_rand = random_vector(typ.query_rand_len()).unwrap(); + + // Run the validity circuit. + let v = typ.valid(&mut gadgets, input, &joint_rand, 1)?; + if v != T::Field::zero() && t.expect_valid { + return Err(FlpError::Test(format!( + "expected valid input: valid() returned {v}" + ))); + } + if v == T::Field::zero() && !t.expect_valid { + return Err(FlpError::Test(format!( + "expected invalid input: valid() returned {v}" + ))); + } + + // Generate the proof. + let proof = typ.prove(input, &prove_rand, &joint_rand)?; + if proof.len() != typ.proof_len() { + return Err(FlpError::Test(format!( + "unexpected proof length: got {}; want {}", + proof.len(), + typ.proof_len() + ))); + } + + // Query the proof. + let verifier = typ.query(input, &proof, &query_rand, &joint_rand, 1)?; + if verifier.len() != typ.verifier_len() { + return Err(FlpError::Test(format!( + "unexpected verifier length: got {}; want {}", + verifier.len(), + typ.verifier_len() + ))); + } + + // Decide if the input is valid. + let res = typ.decide(&verifier)?; + if res != t.expect_valid { + return Err(FlpError::Test(format!( + "decision is {}; want {}", + res, t.expect_valid, + ))); + } + + // Run distributed FLP. + let input_shares: Vec<Vec<T::Field>> = split_vector(input, t.num_shares) + .unwrap() + .into_iter() + .collect(); + + let proof_shares: Vec<Vec<T::Field>> = split_vector(&proof, t.num_shares) + .unwrap() + .into_iter() + .collect(); + + let verifier: Vec<T::Field> = (0..t.num_shares) + .map(|i| { + typ.query( + &input_shares[i], + &proof_shares[i], + &query_rand, + &joint_rand, + t.num_shares, + ) + .unwrap() + }) + .reduce(|mut left, right| { + for (x, y) in left.iter_mut().zip(right.iter()) { + *x += *y; + } + left + }) + .unwrap(); + + let res = typ.decide(&verifier)?; + if res != t.expect_valid { + return Err(FlpError::Test(format!( + "distributed decision is {}; want {}", + res, t.expect_valid, + ))); + } + + // Try verifying various proof mutants. + for i in 0..proof.len() { + let mut mutated_proof = proof.clone(); + mutated_proof[i] += T::Field::one(); + let verifier = typ.query(input, &mutated_proof, &query_rand, &joint_rand, 1)?; + if typ.decide(&verifier)? { + return Err(FlpError::Test(format!( + "decision for proof mutant {} is {}; want {}", + i, true, false, + ))); + } + } + + // Try verifying a proof that is too short. + let mut mutated_proof = proof.clone(); + mutated_proof.truncate(gadgets[0].arity() - 1); + if typ + .query(input, &mutated_proof, &query_rand, &joint_rand, 1) + .is_ok() + { + return Err(FlpError::Test( + "query on short proof succeeded; want failure".to_string(), + )); + } + + // Try verifying a proof that is too long. + let mut mutated_proof = proof; + mutated_proof.extend_from_slice(&[T::Field::one(); 17]); + if typ + .query(input, &mutated_proof, &query_rand, &joint_rand, 1) + .is_ok() + { + return Err(FlpError::Test( + "query on long proof succeeded; want failure".to_string(), + )); + } + + if let Some(ref want) = t.expected_output { + let got = typ.truncate(input.to_vec())?; + + if got.len() != typ.output_len() { + return Err(FlpError::Test(format!( + "unexpected output length: got {}; want {}", + got.len(), + typ.output_len() + ))); + } + + if &got != want { + return Err(FlpError::Test(format!( + "unexpected output: got {got:?}; want {want:?}" + ))); + } + } + + Ok(()) + } +} + +#[cfg(feature = "experimental")] +#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))] +pub mod fixedpoint_l2; diff --git a/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs new file mode 100644 index 0000000000..b5aa2fd116 --- /dev/null +++ b/third_party/rust/prio/src/flp/types/fixedpoint_l2.rs @@ -0,0 +1,899 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! A [`Type`] for summing vectors of fixed point numbers where the +//! [L2 norm](https://en.wikipedia.org/wiki/Norm_(mathematics)#Euclidean_norm) +//! of each vector is bounded by `1` and adding [discrete Gaussian +//! noise](https://arxiv.org/abs/2004.00010) in order to achieve server +//! differential privacy. +//! +//! In the following a high level overview over the inner workings of this type +//! is given and implementation details are discussed. It is not necessary for +//! using the type, but it should be helpful when trying to understand the +//! implementation. +//! +//! ### Overview +//! +//! Clients submit a vector of numbers whose values semantically lie in `[-1,1)`, +//! together with a norm in the range `[0,1)`. The validation circuit checks that +//! the norm of the vector is equal to the submitted norm, while the encoding +//! guarantees that the submitted norm lies in the correct range. +//! +//! The bound on the L2 norm allows calibration of discrete Gaussian noise added +//! after aggregation, making the procedure differentially private. +//! +//! ### Submission layout +//! +//! The client submissions contain a share of their vector and the norm +//! they claim it has. +//! The submission is a vector of field elements laid out as follows: +//! ```text +//! |---- bits_per_entry * entries ----|---- bits_for_norm ----| +//! ^ ^ +//! \- the input vector entries | +//! \- the encoded norm +//! ``` +//! +//! ### Different number encodings +//! +//! Let `n` denote the number of bits of the chosen fixed-point type. +//! Numbers occur in 5 different representations: +//! 1. Clients have a vector whose entries are fixed point numbers. Only those +//! fixed point types are supported where the numbers lie in the range +//! `[-1,1)`. +//! 2. Because norm computation happens in the validation circuit, it is done +//! on entries encoded as field elements. That is, the same vector entries +//! are now represented by integers in the range `[0,2^n)`, where `-1` is +//! represented by `0` and `+1` by `2^n`. +//! 3. Because the field is not necessarily exactly of size `2^n`, but might be +//! larger, it is not enough to encode a vector entry as in (2.) and submit +//! it to the aggregator. Instead, in order to make sure that all submitted +//! values are in the correct range, they are bit-encoded. (This is the same +//! as what happens in the [`Sum`](crate::flp::types::Sum) type.) +//! This means that instead of sending a field element in the range `[0,2^n)`, +//! we send `n` field elements representing the bit encoding. The validation +//! circuit can verify that all submitted "bits" are indeed either `0` or `1`. +//! 4. The computed and submitted norms are treated similar to the vector +//! entries, but they have a different number of bits, namely `2n-2`. +//! 5. As the aggregation result is a pointwise sum of the client vectors, +//! the numbers no longer (semantically) lie in the range `[-1,1)`, and cannot +//! be represented by the same fixed point type as the input. Instead the +//! decoding happens directly into a vector of floats. +//! +//! ### Fixed point encoding +//! +//! Submissions consist of encoded fixed-point numbers in `[-1,1)` represented as +//! field elements in `[0,2^n)`, where n is the number of bits the fixed-point +//! representation has. Encoding and decoding is handled by the associated functions +//! of the [`CompatibleFloat`] trait. Semantically, the following function describes +//! how a fixed-point value `x` in range `[-1,1)` is converted to a field integer: +//! ```text +//! enc : [-1,1) -> [0,2^n) +//! enc(x) = 2^(n-1) * x + 2^(n-1) +//! ``` +//! The inverse is: +//! ```text +//! dec : [0,2^n) -> [-1,1) +//! dec(y) = (y - 2^(n-1)) * 2^(1-n) +//! ``` +//! Note that these functions only make sense when interpreting all occuring +//! numbers as real numbers. Since our signed fixed-point numbers are encoded as +//! two's complement integers, the computation that happens in +//! [`CompatibleFloat::to_field_integer`] is actually simpler. +//! +//! ### Value `1` +//! +//! We actually do not allow the submitted norm or vector entries to be +//! exactly `1`, but rather require them to be strictly less. Supporting `1` would +//! entail a more fiddly encoding and is not necessary for our usecase. +//! The largest representable vector entry can be computed by `dec(2^n-1)`. +//! For example, it is `0.999969482421875` for `n = 16`. +//! +//! ### Norm computation +//! +//! The L2 norm of a vector xs of numbers in `[-1,1)` is given by: +//! ```text +//! norm(xs) = sqrt(sum_{x in xs} x^2) +//! ``` +//! Instead of computing the norm, we make two simplifications: +//! 1. We ignore the square root, which means that we are actually computing +//! the square of the norm. +//! 2. We want our norm computation result to be integral and in the range `[0, 2^(2n-2))`, +//! so we can represent it in our field integers. We achieve this by multiplying with `2^(2n-2)`. +//! This means that what is actually computed in this type is the following: +//! ```text +//! our_norm(xs) = 2^(2n-2) * norm(xs)^2 +//! ``` +//! +//! Explained more visually, `our_norm()` is a composition of three functions: +//! +//! ```text +//! map of dec() norm() "mult with 2^(2n-2)" +//! vector of [0,2^n) -> vector of [-1,1) -> [0,1) -> [0,2^(2n-2)) +//! ^ ^ +//! | | +//! fractions with denom of 2^(n-1) fractions with denom of 2^(2n-2) +//! ``` +//! (Note that the ranges on the LHS and RHS of `"mult with 2^(2n-2)"` are stated +//! here for vectors with a norm less than `1`.) +//! +//! Given a vector `ys` of numbers in the field integer encoding (in `[0,2^n)`), +//! this gives the following equation: +//! ```text +//! our_norm_on_encoded(ys) = our_norm([dec(y) for y in ys]) +//! = 2^(2n-2) * sum_{y in ys} ((y - 2^(n-1)) * 2^(1-n))^2 +//! = 2^(2n-2) +//! * sum_{y in ys} y^2 - 2*y*2^(n-1) + (2^(n-1))^2 +//! * 2^(1-n)^2 +//! = sum_{y in ys} y^2 - (2^n)*y + 2^(2n-2) +//! ``` +//! +//! Let `d` denote the number of the vector entries. The maximal value the result +//! of `our_norm_on_encoded()` can take occurs in the case where all entries are +//! `2^n-1`, in which case `d * 2^(2n-2)` is an upper bound to the result. The +//! finite field used for encoding must be at least as large as this. +//! For validating that the norm of the submitted vector lies in the correct +//! range, consider the following: +//! - The result of `norm(xs)` should be in `[0,1)`. +//! - Thus, the result of `our_norm(xs)` should be in `[0,2^(2n-2))`. +//! - The result of `our_norm_on_encoded(ys)` should be in `[0,2^(2n-2))`. +//! This means that the valid norms are exactly those representable with `2n-2` +//! bits. +//! +//! ### Noise and Differential Privacy +//! +//! Bounding the submission norm bounds the impact that changing a single +//! client's submission can have on the aggregate. That is, the so-called +//! L2-sensitivity of the procedure is equal to two times the norm bound, namely +//! `2^n`. Therefore, adding discrete Gaussian noise with standard deviation +//! `sigma = `(2^n)/epsilon` for some `epsilon` will make the procedure [`(epsilon^2)/2` +//! zero-concentrated differentially private](https://arxiv.org/abs/2004.00010). +//! `epsilon` is given as a parameter to the `add_noise_to_result` function, as part of the +//! `dp_strategy` argument of type [`ZCdpDiscreteGaussian`]. +//! +//! ### Differences in the computation because of distribution +//! +//! In `decode_result()`, what is decoded are not the submitted entries of a +//! single client, but the sum of the the entries of all clients. We have to +//! take this into account, and cannot directly use the `dec()` function from +//! above. Instead we use: +//! ```text +//! dec'(y) = y * 2^(1-n) - c +//! ``` +//! Here, `c` is the number of clients. +//! +//! ### Naming in the implementation +//! +//! The following names are used: +//! - `self.bits_per_entry` is `n` +//! - `self.entries` is `d` +//! - `self.bits_for_norm` is `2n-2` +//! + +pub mod compatible_float; + +use crate::dp::{distributions::ZCdpDiscreteGaussian, DifferentialPrivacyStrategy, DpError}; +use crate::field::{Field128, FieldElement, FieldElementWithInteger, FieldElementWithIntegerExt}; +use crate::flp::gadgets::{Mul, ParallelSumGadget, PolyEval}; +use crate::flp::types::fixedpoint_l2::compatible_float::CompatibleFloat; +use crate::flp::types::parallel_sum_range_checks; +use crate::flp::{FlpError, Gadget, Type, TypeWithNoise}; +use crate::vdaf::xof::SeedStreamSha3; +use fixed::traits::Fixed; +use num_bigint::{BigInt, BigUint, TryFromBigIntError}; +use num_integer::Integer; +use num_rational::Ratio; +use rand::{distributions::Distribution, Rng}; +use rand_core::SeedableRng; +use std::{convert::TryFrom, convert::TryInto, fmt::Debug, marker::PhantomData}; + +/// The fixed point vector sum data type. Each measurement is a vector of fixed point numbers of +/// type `T`, and the aggregate result is the float vector of the sum of the measurements. +/// +/// The validity circuit verifies that the L2 norm of each measurement is bounded by 1. +/// +/// The [*fixed* crate](https://crates.io/crates/fixed) is used for fixed point numbers, in +/// particular, exactly the following types are supported: +/// `FixedI16<U15>`, `FixedI32<U31>` and `FixedI64<U63>`. +/// +/// The type implements the [`TypeWithNoise`] trait. The `add_noise_to_result` function adds +/// discrete Gaussian noise to an aggregate share, calibrated to the passed privacy budget. +/// This will result in the aggregate satisfying zero-concentrated differential privacy. +/// +/// Depending on the size of the vector that needs to be transmitted, a corresponding field type has +/// to be chosen for `F`. For a `n`-bit fixed point type and a `d`-dimensional vector, the field +/// modulus needs to be larger than `d * 2^(2n-2)` so there are no overflows during norm validity +/// computation. +#[derive(Clone, PartialEq, Eq)] +pub struct FixedPointBoundedL2VecSum< + T: Fixed, + SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Clone, + SMul: ParallelSumGadget<Field128, Mul<Field128>> + Clone, +> { + bits_per_entry: usize, + entries: usize, + bits_for_norm: usize, + norm_summand_poly: Vec<Field128>, + phantom: PhantomData<(T, SPoly, SMul)>, + + // range/position constants + range_norm_begin: usize, + range_norm_end: usize, + + // configuration of parallel sum gadgets + gadget0_calls: usize, + gadget0_chunk_length: usize, + gadget1_calls: usize, + gadget1_chunk_length: usize, +} + +impl<T, SPoly, SMul> Debug for FixedPointBoundedL2VecSum<T, SPoly, SMul> +where + T: Fixed, + SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Clone, + SMul: ParallelSumGadget<Field128, Mul<Field128>> + Clone, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("FixedPointBoundedL2VecSum") + .field("bits_per_entry", &self.bits_per_entry) + .field("entries", &self.entries) + .finish() + } +} + +impl<T, SPoly, SMul> FixedPointBoundedL2VecSum<T, SPoly, SMul> +where + T: Fixed, + SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Clone, + SMul: ParallelSumGadget<Field128, Mul<Field128>> + Clone, +{ + /// Return a new [`FixedPointBoundedL2VecSum`] type parameter. Each value of this type is a + /// fixed point vector with `entries` entries. + pub fn new(entries: usize) -> Result<Self, FlpError> { + // (0) initialize constants + let fi_one = u128::from(Field128::one()); + + // (I) Check that the fixed type is compatible. + // + // We only support fixed types that encode values in [-1,1]. + // These have a single integer bit. + if <T as Fixed>::INT_NBITS != 1 { + return Err(FlpError::Encode(format!( + "Expected fixed point type with one integer bit, but got {}.", + <T as Fixed>::INT_NBITS, + ))); + } + + // Compute number of bits of an entry, and check that an entry fits + // into the field. + let bits_per_entry: usize = (<T as Fixed>::INT_NBITS + <T as Fixed>::FRAC_NBITS) + .try_into() + .map_err(|_| FlpError::Encode("Could not convert u32 into usize.".to_string()))?; + if !Field128::valid_integer_bitlength(bits_per_entry) { + return Err(FlpError::Encode(format!( + "fixed point type bit length ({bits_per_entry}) too large for field modulus", + ))); + } + + // (II) Check that the field is large enough for the norm. + // + // Valid norms encoded as field integers lie in [0,2^(2*bits - 2)). + let bits_for_norm = 2 * bits_per_entry - 2; + if !Field128::valid_integer_bitlength(bits_for_norm) { + return Err(FlpError::Encode(format!( + "maximal norm bit length ({bits_for_norm}) too large for field modulus", + ))); + } + + // In order to compare the actual norm of the vector with the claimed + // norm, the field needs to be able to represent all numbers that can + // occur during the computation of the norm of any submitted vector, + // even if its norm is not bounded by 1. Because of our encoding, an + // upper bound to that value is `entries * 2^(2*bits - 2)` (see docs of + // compute_norm_of_entries for details). It has to fit into the field. + let err = Err(FlpError::Encode(format!( + "number of entries ({entries}) not compatible with field size", + ))); + + if let Some(val) = (entries as u128).checked_mul(1 << bits_for_norm) { + if val >= Field128::modulus() { + return err; + } + } else { + return err; + } + + // Construct the polynomial that computes a part of the norm for a + // single vector entry. + // + // the linear part is 2^n, + // the constant part is 2^(2n-2), + // the polynomial is: + // p(y) = 2^(2n-2) + -(2^n) * y + 1 * y^2 + let linear_part = fi_one << bits_per_entry; + let constant_part = fi_one << (bits_per_entry + bits_per_entry - 2); + let norm_summand_poly = vec![ + Field128::from(constant_part), + -Field128::from(linear_part), + Field128::one(), + ]; + + // Compute chunk length and number of calls for parallel sum gadgets. + let len0 = bits_per_entry * entries + bits_for_norm; + let gadget0_chunk_length = std::cmp::max(1, (len0 as f64).sqrt() as usize); + let gadget0_calls = (len0 + gadget0_chunk_length - 1) / gadget0_chunk_length; + + let len1 = entries; + let gadget1_chunk_length = std::cmp::max(1, (len1 as f64).sqrt() as usize); + let gadget1_calls = (len1 + gadget1_chunk_length - 1) / gadget1_chunk_length; + + Ok(Self { + bits_per_entry, + entries, + bits_for_norm, + norm_summand_poly, + phantom: PhantomData, + + // range constants + range_norm_begin: entries * bits_per_entry, + range_norm_end: entries * bits_per_entry + bits_for_norm, + + // configuration of parallel sum gadgets + gadget0_calls, + gadget0_chunk_length, + gadget1_calls, + gadget1_chunk_length, + }) + } + + /// This noising function can be called on the aggregate share to make + /// the entire aggregation process differentially private. The noise is + /// calibrated to result in a guarantee of `1/2 * epsilon^2` zero-concentrated + /// differential privacy, where `epsilon` is given by `dp_strategy.budget`. + fn add_noise<R: Rng>( + &self, + dp_strategy: &ZCdpDiscreteGaussian, + agg_result: &mut [Field128], + rng: &mut R, + ) -> Result<(), FlpError> { + // generate and add discrete gaussian noise for each entry + + // 0. Compute sensitivity of aggregation, namely 2^n. + let sensitivity = BigUint::from(2u128).pow(self.bits_per_entry as u32); + // Also create a BigInt containing the field modulus. + let modulus = BigInt::from(Field128::modulus()); + + // 1. initialize sampler + let sampler = dp_strategy.create_distribution(Ratio::from_integer(sensitivity))?; + + // 2. Generate noise for each slice entry and apply it. + for entry in agg_result.iter_mut() { + // (a) Generate noise. + let noise: BigInt = sampler.sample(rng); + + // (b) Put noise into field. + // + // The noise is generated as BigInt, but has to fit into the Field128, + // which has modulus `Field128::modulus()`. Thus we use `BigInt::mod_floor()` + // to calculate `noise mod modulus`. This value fits into `u128`, and + // can be then put into the field. + // + // Note: we cannot use the operator `%` here, since it is not the mathematical + // modulus operation: for negative inputs and positive modulus it gives a + // negative result! + let noise: BigInt = noise.mod_floor(&modulus); + let noise: u128 = noise.try_into().map_err(|e: TryFromBigIntError<BigInt>| { + FlpError::DifferentialPrivacy(DpError::BigIntConversion(e)) + })?; + let f_noise = Field128::from(Field128::valid_integer_try_from::<u128>(noise)?); + + // (c) Apply noise to each entry of the aggregate share. + *entry += f_noise; + } + + Ok(()) + } +} + +impl<T, SPoly, SMul> Type for FixedPointBoundedL2VecSum<T, SPoly, SMul> +where + T: Fixed + CompatibleFloat, + SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Eq + Clone + 'static, + SMul: ParallelSumGadget<Field128, Mul<Field128>> + Eq + Clone + 'static, +{ + const ID: u32 = 0xFFFF0000; + type Measurement = Vec<T>; + type AggregateResult = Vec<f64>; + type Field = Field128; + + fn encode_measurement(&self, fp_entries: &Vec<T>) -> Result<Vec<Field128>, FlpError> { + if fp_entries.len() != self.entries { + return Err(FlpError::Encode("unexpected input length".into())); + } + + // Convert the fixed-point encoded input values to field integers. We do + // this once here because we need them for encoding but also for + // computing the norm. + let integer_entries = fp_entries.iter().map(|x| x.to_field_integer()); + + // (I) Vector entries. + // Encode the integer entries bitwise, and write them into the `encoded` + // vector. + let mut encoded: Vec<Field128> = + vec![Field128::zero(); self.bits_per_entry * self.entries + self.bits_for_norm]; + for (l, entry) in integer_entries.clone().enumerate() { + Field128::fill_with_bitvector_representation( + &entry, + &mut encoded[l * self.bits_per_entry..(l + 1) * self.bits_per_entry], + )?; + } + + // (II) Vector norm. + // Compute the norm of the input vector. + let field_entries = integer_entries.map(Field128::from); + let norm = compute_norm_of_entries(field_entries, self.bits_per_entry)?; + let norm_int = u128::from(norm); + + // Write the norm into the `entries` vector. + Field128::fill_with_bitvector_representation( + &norm_int, + &mut encoded[self.range_norm_begin..self.range_norm_end], + )?; + + Ok(encoded) + } + + fn decode_result( + &self, + data: &[Field128], + num_measurements: usize, + ) -> Result<Vec<f64>, FlpError> { + if data.len() != self.entries { + return Err(FlpError::Decode("unexpected input length".into())); + } + let num_measurements = match u128::try_from(num_measurements) { + Ok(m) => m, + Err(_) => { + return Err(FlpError::Decode( + "number of clients is too large to fit into u128".into(), + )) + } + }; + let mut res = Vec::with_capacity(data.len()); + for d in data { + let decoded = <T as CompatibleFloat>::to_float(*d, num_measurements); + res.push(decoded); + } + Ok(res) + } + + fn gadget(&self) -> Vec<Box<dyn Gadget<Field128>>> { + // This gadget checks that a field element is zero or one. + // It is called for all the "bits" of the encoded entries + // and of the encoded norm. + let gadget0 = SMul::new(Mul::new(self.gadget0_calls), self.gadget0_chunk_length); + + // This gadget computes the square of a fixed point number, operating on + // its encoding as a field element. It is called on each entry during + // norm computation. + let gadget1 = SPoly::new( + PolyEval::new(self.norm_summand_poly.clone(), self.gadget1_calls), + self.gadget1_chunk_length, + ); + + vec![Box::new(gadget0), Box::new(gadget1)] + } + + fn valid( + &self, + g: &mut Vec<Box<dyn Gadget<Field128>>>, + input: &[Field128], + joint_rand: &[Field128], + num_shares: usize, + ) -> Result<Field128, FlpError> { + self.valid_call_check(input, joint_rand)?; + + let f_num_shares = Field128::from(Field128::valid_integer_try_from::<usize>(num_shares)?); + let num_shares_inverse = Field128::one() / f_num_shares; + + // Ensure that all submitted field elements are either 0 or 1. + // This is done for: + // (I) all vector entries (each of them encoded in `self.bits_per_entry` + // field elements) + // (II) the submitted norm (encoded in `self.bits_for_norm` field + // elements) + // + // Since all input vector entry (field-)bits, as well as the norm bits, + // are contiguous, we do the check directly for all bits from 0 to + // entries*bits_per_entry + bits_for_norm. + // + // In order to keep the proof size down, this is done using the + // `ParallelSum` gadget. For a similar application see the `SumVec` + // type. + let range_check = parallel_sum_range_checks( + &mut g[0], + &input[..self.range_norm_end], + joint_rand[0], + self.gadget0_chunk_length, + num_shares, + )?; + + // Compute the norm of the entries and ensure that it is the same as the + // submitted norm. There are exactly enough bits such that a submitted + // norm is always a valid norm (semantically in the range [0,1]). By + // comparing submitted with actual, we make sure the actual norm is + // valid. + // + // The function to compute here (see explanatory comment at the top) is + // norm(ys) = sum_{y in ys} y^2 - (2^n)*y + 2^(2n-2) + // + // This is done by the `ParallelSum` gadget `g[1]`, which evaluates the + // inner polynomial on each (decoded) vector entry, and then sums the + // results. Note that the gadget is not called on the whole vector at + // once, but sequentially on chunks of size `self.gadget1_chunk_length` of + // it. The results of these calls are accumulated in the `outp` variable. + // + // decode the bit-encoded entries into elements in the range [0,2^n): + let decoded_entries: Result<Vec<_>, _> = input[0..self.entries * self.bits_per_entry] + .chunks(self.bits_per_entry) + .map(Field128::decode_from_bitvector_representation) + .collect(); + + // run parallel sum gadget on the decoded entries + let computed_norm = { + let mut outp = Field128::zero(); + + // Chunks which are too short need to be extended with a share of the + // encoded zero value, that is: 1/num_shares * (2^(n-1)) + let fi_one = u128::from(Field128::one()); + let zero_enc = Field128::from(fi_one << (self.bits_per_entry - 1)); + let zero_enc_share = zero_enc * num_shares_inverse; + + for chunk in decoded_entries?.chunks(self.gadget1_chunk_length) { + let d = chunk.len(); + if d == self.gadget1_chunk_length { + outp += g[1].call(chunk)?; + } else { + // If the chunk is smaller than the chunk length, extend + // chunk with zeros. + let mut padded_chunk: Vec<_> = chunk.to_owned(); + padded_chunk.resize(self.gadget1_chunk_length, zero_enc_share); + outp += g[1].call(&padded_chunk)?; + } + } + + outp + }; + + // The submitted norm is also decoded from its bit-encoding, and + // compared with the computed norm. + let submitted_norm_enc = &input[self.range_norm_begin..self.range_norm_end]; + let submitted_norm = Field128::decode_from_bitvector_representation(submitted_norm_enc)?; + + let norm_check = computed_norm - submitted_norm; + + // Finally, we require both checks to be successful by computing a + // random linear combination of them. + let out = joint_rand[1] * range_check + (joint_rand[1] * joint_rand[1]) * norm_check; + Ok(out) + } + + fn truncate(&self, input: Vec<Field128>) -> Result<Vec<Self::Field>, FlpError> { + self.truncate_call_check(&input)?; + + let mut decoded_vector = vec![]; + + for i_entry in 0..self.entries { + let start = i_entry * self.bits_per_entry; + let end = (i_entry + 1) * self.bits_per_entry; + + let decoded = Field128::decode_from_bitvector_representation(&input[start..end])?; + decoded_vector.push(decoded); + } + Ok(decoded_vector) + } + + fn input_len(&self) -> usize { + self.bits_per_entry * self.entries + self.bits_for_norm + } + + fn proof_len(&self) -> usize { + // computed via + // `gadget.arity() + gadget.degree() + // * ((1 + gadget.calls()).next_power_of_two() - 1) + 1;` + let proof_gadget_0 = (self.gadget0_chunk_length * 2) + + 2 * ((1 + self.gadget0_calls).next_power_of_two() - 1) + + 1; + let proof_gadget_1 = (self.gadget1_chunk_length) + + 2 * ((1 + self.gadget1_calls).next_power_of_two() - 1) + + 1; + + proof_gadget_0 + proof_gadget_1 + } + + fn verifier_len(&self) -> usize { + self.gadget0_chunk_length * 2 + self.gadget1_chunk_length + 3 + } + + fn output_len(&self) -> usize { + self.entries + } + + fn joint_rand_len(&self) -> usize { + 2 + } + + fn prove_rand_len(&self) -> usize { + self.gadget0_chunk_length * 2 + self.gadget1_chunk_length + } + + fn query_rand_len(&self) -> usize { + 2 + } +} + +impl<T, SPoly, SMul> TypeWithNoise<ZCdpDiscreteGaussian> + for FixedPointBoundedL2VecSum<T, SPoly, SMul> +where + T: Fixed + CompatibleFloat, + SPoly: ParallelSumGadget<Field128, PolyEval<Field128>> + Eq + Clone + 'static, + SMul: ParallelSumGadget<Field128, Mul<Field128>> + Eq + Clone + 'static, +{ + fn add_noise_to_result( + &self, + dp_strategy: &ZCdpDiscreteGaussian, + agg_result: &mut [Self::Field], + _num_measurements: usize, + ) -> Result<(), FlpError> { + self.add_noise(dp_strategy, agg_result, &mut SeedStreamSha3::from_entropy()) + } +} + +/// Compute the square of the L2 norm of a vector of fixed-point numbers encoded as field elements. +/// +/// * `entries` - Iterator over the vector entries. +/// * `bits_per_entry` - Number of bits one entry has. +fn compute_norm_of_entries<Fs>(entries: Fs, bits_per_entry: usize) -> Result<Field128, FlpError> +where + Fs: IntoIterator<Item = Field128>, +{ + let fi_one = u128::from(Field128::one()); + + // The value that is computed here is: + // sum_{y in entries} 2^(2n-2) + -(2^n) * y + 1 * y^2 + // + // Check out the norm computation bit in the explanatory comment block for + // more information. + // + // Initialize `norm_accumulator`. + let mut norm_accumulator = Field128::zero(); + + // constants + let linear_part = fi_one << bits_per_entry; // = 2^(2n-2) + let constant_part = fi_one << (bits_per_entry + bits_per_entry - 2); // = 2^n + + // Add term for a given `entry` to `norm_accumulator`. + for entry in entries.into_iter() { + let summand = + entry * entry + Field128::from(constant_part) - Field128::from(linear_part) * (entry); + norm_accumulator += summand; + } + Ok(norm_accumulator) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dp::{Rational, ZCdpBudget}; + use crate::field::{random_vector, Field128, FieldElement}; + use crate::flp::gadgets::ParallelSum; + use crate::flp::types::test_utils::{flp_validity_test, ValidityTestCase}; + use crate::vdaf::xof::SeedStreamSha3; + use fixed::types::extra::{U127, U14, U63}; + use fixed::{FixedI128, FixedI16, FixedI64}; + use fixed_macro::fixed; + use rand::SeedableRng; + + #[test] + fn test_bounded_fpvec_sum_parallel_fp16() { + let fp16_4_inv = fixed!(0.25: I1F15); + let fp16_8_inv = fixed!(0.125: I1F15); + let fp16_16_inv = fixed!(0.0625: I1F15); + + let fp16_vec = vec![fp16_4_inv, fp16_8_inv, fp16_16_inv]; + + // the encoded vector has the following entries: + // enc(0.25) = 2^(n-1) * 0.25 + 2^(n-1) = 40960 + // enc(0.125) = 2^(n-1) * 0.125 + 2^(n-1) = 36864 + // enc(0.0625) = 2^(n-1) * 0.0625 + 2^(n-1) = 34816 + test_fixed(fp16_vec, vec![40960, 36864, 34816]); + } + + #[test] + fn test_bounded_fpvec_sum_parallel_fp32() { + let fp32_4_inv = fixed!(0.25: I1F31); + let fp32_8_inv = fixed!(0.125: I1F31); + let fp32_16_inv = fixed!(0.0625: I1F31); + + let fp32_vec = vec![fp32_4_inv, fp32_8_inv, fp32_16_inv]; + // computed as above but with n=32 + test_fixed(fp32_vec, vec![2684354560, 2415919104, 2281701376]); + } + + #[test] + fn test_bounded_fpvec_sum_parallel_fp64() { + let fp64_4_inv = fixed!(0.25: I1F63); + let fp64_8_inv = fixed!(0.125: I1F63); + let fp64_16_inv = fixed!(0.0625: I1F63); + + let fp64_vec = vec![fp64_4_inv, fp64_8_inv, fp64_16_inv]; + // computed as above but with n=64 + test_fixed( + fp64_vec, + vec![ + 11529215046068469760, + 10376293541461622784, + 9799832789158199296, + ], + ); + } + + fn test_fixed<F: Fixed>(fp_vec: Vec<F>, enc_vec: Vec<u128>) + where + F: CompatibleFloat, + { + let n: usize = (F::INT_NBITS + F::FRAC_NBITS).try_into().unwrap(); + + type Ps = ParallelSum<Field128, PolyEval<Field128>>; + type Psm = ParallelSum<Field128, Mul<Field128>>; + + let vsum: FixedPointBoundedL2VecSum<F, Ps, Psm> = + FixedPointBoundedL2VecSum::new(3).unwrap(); + let one = Field128::one(); + // Round trip + assert_eq!( + vsum.decode_result( + &vsum + .truncate(vsum.encode_measurement(&fp_vec).unwrap()) + .unwrap(), + 1 + ) + .unwrap(), + vec!(0.25, 0.125, 0.0625) + ); + + // Noise + let mut v = vsum + .truncate(vsum.encode_measurement(&fp_vec).unwrap()) + .unwrap(); + let strategy = ZCdpDiscreteGaussian::from_budget(ZCdpBudget::new( + Rational::from_unsigned(100u8, 3u8).unwrap(), + )); + vsum.add_noise(&strategy, &mut v, &mut SeedStreamSha3::from_seed([0u8; 16])) + .unwrap(); + assert_eq!( + vsum.decode_result(&v, 1).unwrap(), + match n { + // sensitivity depends on encoding so the noise differs + 16 => vec![0.150604248046875, 0.139373779296875, -0.03759765625], + 32 => vec![0.3051439793780446, 0.1226568529382348, 0.08595499861985445], + 64 => vec![0.2896077990915178, 0.16115188007715098, 0.0788390114728425], + _ => panic!("unsupported bitsize"), + } + ); + + // encoded norm does not match computed norm + let mut input: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap(); + assert_eq!(input[0], Field128::zero()); + input[0] = one; // it was zero + flp_validity_test( + &vsum, + &input, + &ValidityTestCase::<Field128> { + expect_valid: false, + expected_output: Some(vec![ + Field128::from(enc_vec[0] + 1), // = enc(0.25) + 2^0 + Field128::from(enc_vec[1]), + Field128::from(enc_vec[2]), + ]), + num_shares: 3, + }, + ) + .unwrap(); + + // encoding contains entries that are not zero or one + let mut input2: Vec<Field128> = vsum.encode_measurement(&fp_vec).unwrap(); + input2[0] = one + one; + flp_validity_test( + &vsum, + &input2, + &ValidityTestCase::<Field128> { + expect_valid: false, + expected_output: Some(vec![ + Field128::from(enc_vec[0] + 2), // = enc(0.25) + 2*2^0 + Field128::from(enc_vec[1]), + Field128::from(enc_vec[2]), + ]), + num_shares: 3, + }, + ) + .unwrap(); + + // norm is too big + // 2^n - 1, the field element encoded by the all-1 vector + let one_enc = Field128::from(((2_u128) << (n - 1)) - 1); + flp_validity_test( + &vsum, + &vec![one; 3 * n + 2 * n - 2], // all vector entries and the norm are all-1-vectors + &ValidityTestCase::<Field128> { + expect_valid: false, + expected_output: Some(vec![one_enc; 3]), + num_shares: 3, + }, + ) + .unwrap(); + + // invalid submission length, should be 3n + (2*n - 2) for a + // 3-element n-bit vector. 3*n bits for 3 entries, (2*n-2) for norm. + let joint_rand = random_vector(vsum.joint_rand_len()).unwrap(); + vsum.valid( + &mut vsum.gadget(), + &vec![one; 3 * n + 2 * n - 1], + &joint_rand, + 1, + ) + .unwrap_err(); + + // test that the zero vector has correct norm, where zero is encoded as: + // enc(0) = 2^(n-1) * 0 + 2^(n-1) + let zero_enc = Field128::from((2_u128) << (n - 2)); + { + let entries = vec![zero_enc; 3]; + let norm = compute_norm_of_entries(entries, vsum.bits_per_entry).unwrap(); + let expected_norm = Field128::from(0); + assert_eq!(norm, expected_norm); + } + + // ensure that no overflow occurs with largest possible norm + { + // the largest possible entries (2^n-1) + let entries = vec![one_enc; 3]; + let norm = compute_norm_of_entries(entries, vsum.bits_per_entry).unwrap(); + let expected_norm = Field128::from(3 * (1 + (1 << (2 * n - 2)) - (1 << n))); + // = 3 * ((2^n-1)^2 - (2^n-1)*2^16 + 2^(2*n-2)) + assert_eq!(norm, expected_norm); + + // the smallest possible entries (0) + let entries = vec![Field128::from(0), Field128::from(0), Field128::from(0)]; + let norm = compute_norm_of_entries(entries, vsum.bits_per_entry).unwrap(); + let expected_norm = Field128::from(3 * (1 << (2 * n - 2))); + // = 3 * (0^2 - 0*2^n + 2^(2*n-2)) + assert_eq!(norm, expected_norm); + } + } + + #[test] + fn test_bounded_fpvec_sum_parallel_invalid_args() { + // invalid initialization + // fixed point too large + <FixedPointBoundedL2VecSum< + FixedI128<U127>, + ParallelSum<Field128, PolyEval<Field128>>, + ParallelSum<Field128, Mul<Field128>>, + >>::new(3) + .unwrap_err(); + // vector too large + <FixedPointBoundedL2VecSum< + FixedI64<U63>, + ParallelSum<Field128, PolyEval<Field128>>, + ParallelSum<Field128, Mul<Field128>>, + >>::new(3000000000) + .unwrap_err(); + // fixed point type has more than one int bit + <FixedPointBoundedL2VecSum< + FixedI16<U14>, + ParallelSum<Field128, PolyEval<Field128>>, + ParallelSum<Field128, Mul<Field128>>, + >>::new(3) + .unwrap_err(); + } +} diff --git a/third_party/rust/prio/src/flp/types/fixedpoint_l2/compatible_float.rs b/third_party/rust/prio/src/flp/types/fixedpoint_l2/compatible_float.rs new file mode 100644 index 0000000000..404bec125a --- /dev/null +++ b/third_party/rust/prio/src/flp/types/fixedpoint_l2/compatible_float.rs @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementations of encoding fixed point types as field elements and field elements as floats +//! for the [`FixedPointBoundedL2VecSum`](crate::flp::types::fixedpoint_l2::FixedPointBoundedL2VecSum) type. + +use crate::field::{Field128, FieldElementWithInteger}; +use fixed::types::extra::{U15, U31, U63}; +use fixed::{FixedI16, FixedI32, FixedI64}; + +/// Assign a `Float` type to this type and describe how to represent this type as an integer of the +/// given field, and how to represent a field element as the assigned `Float` type. +pub trait CompatibleFloat { + /// Represent a field element as `Float`, given the number of clients `c`. + fn to_float(t: Field128, c: u128) -> f64; + + /// Represent a value of this type as an integer in the given field. + fn to_field_integer(&self) -> <Field128 as FieldElementWithInteger>::Integer; +} + +impl CompatibleFloat for FixedI16<U15> { + fn to_float(d: Field128, c: u128) -> f64 { + to_float_bits(d, c, 16) + } + + fn to_field_integer(&self) -> <Field128 as FieldElementWithInteger>::Integer { + //signed two's complement integer representation + let i: i16 = self.to_bits(); + // reinterpret as unsigned + let u = i as u16; + // invert the left-most bit to de-two-complement + u128::from(u ^ (1 << 15)) + } +} + +impl CompatibleFloat for FixedI32<U31> { + fn to_float(d: Field128, c: u128) -> f64 { + to_float_bits(d, c, 32) + } + + fn to_field_integer(&self) -> <Field128 as FieldElementWithInteger>::Integer { + //signed two's complement integer representation + let i: i32 = self.to_bits(); + // reinterpret as unsigned + let u = i as u32; + // invert the left-most bit to de-two-complement + u128::from(u ^ (1 << 31)) + } +} + +impl CompatibleFloat for FixedI64<U63> { + fn to_float(d: Field128, c: u128) -> f64 { + to_float_bits(d, c, 64) + } + + fn to_field_integer(&self) -> <Field128 as FieldElementWithInteger>::Integer { + //signed two's complement integer representation + let i: i64 = self.to_bits(); + // reinterpret as unsigned + let u = i as u64; + // invert the left-most bit to de-two-complement + u128::from(u ^ (1 << 63)) + } +} + +/// Return an `f64` representation of the field element `s`, assuming it is the computation result +/// of a `c`-client fixed point vector summation with `n` fractional bits. +fn to_float_bits(s: Field128, c: u128, n: i32) -> f64 { + // get integer representation of field element + let s_int: u128 = <Field128 as FieldElementWithInteger>::Integer::from(s); + + // to decode a single integer, we'd use the function + // dec(y) = (y - 2^(n-1)) * 2^(1-n) = y * 2^(1-n) - 1 + // as s is the sum of c encoded vector entries where c is the number of + // clients, we have to compute instead + // s * 2^(1-n) - c + // + // Furthermore, for better numerical stability, we reformulate this as + // = (s - c*2^(n-1)) * 2^(1-n) + // where the subtraction of `c` is done on integers and only afterwards + // the conversion to floats is done. + // + // Since the RHS of the substraction may be larger than the LHS + // (when the number we are decoding is going to be negative), + // yet we are dealing with unsigned 128-bit integers, we manually + // check for the resulting sign while ensuring that the subtraction + // does not underflow. + let (a, b, sign) = match (s_int, c << (n - 1)) { + (x, y) if x < y => (y, x, -1.0f64), + (x, y) => (x, y, 1.0f64), + }; + + ((a - b) as f64) * sign * f64::powi(2.0, 1 - n) +} diff --git a/third_party/rust/prio/src/fp.rs b/third_party/rust/prio/src/fp.rs new file mode 100644 index 0000000000..d4c0dcdc2c --- /dev/null +++ b/third_party/rust/prio/src/fp.rs @@ -0,0 +1,533 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Finite field arithmetic for any field GF(p) for which p < 2^128. + +#[cfg(test)] +use rand::{prelude::*, Rng}; + +/// For each set of field parameters we pre-compute the 1st, 2nd, 4th, ..., 2^20-th principal roots +/// of unity. The largest of these is used to run the FFT algorithm on an input of size 2^20. This +/// is the largest input size we would ever need for the cryptographic applications in this crate. +pub(crate) const MAX_ROOTS: usize = 20; + +/// This structure represents the parameters of a finite field GF(p) for which p < 2^128. +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct FieldParameters { + /// The prime modulus `p`. + pub p: u128, + /// `mu = -p^(-1) mod 2^64`. + pub mu: u64, + /// `r2 = (2^128)^2 mod p`. + pub r2: u128, + /// The `2^num_roots`-th -principal root of unity. This element is used to generate the + /// elements of `roots`. + pub g: u128, + /// The number of principal roots of unity in `roots`. + pub num_roots: usize, + /// Equal to `2^b - 1`, where `b` is the length of `p` in bits. + pub bit_mask: u128, + /// `roots[l]` is the `2^l`-th principal root of unity, i.e., `roots[l]` has order `2^l` in the + /// multiplicative group. `roots[0]` is equal to one by definition. + pub roots: [u128; MAX_ROOTS + 1], +} + +impl FieldParameters { + /// Addition. The result will be in [0, p), so long as both x and y are as well. + #[inline(always)] + pub fn add(&self, x: u128, y: u128) -> u128 { + // 0,x + // + 0,y + // ===== + // c,z + let (z, carry) = x.overflowing_add(y); + // c, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = z.overflowing_sub(self.p); + let (_s1, b1) = (carry as u128).overflowing_sub(b0 as u128); + // if b1 == 1: return z + // else: return s0 + let m = 0u128.wrapping_sub(b1 as u128); + (z & m) | (s0 & !m) + } + + /// Subtraction. The result will be in [0, p), so long as both x and y are as well. + #[inline(always)] + pub fn sub(&self, x: u128, y: u128) -> u128 { + // 0, x + // - 0, y + // ======== + // b1,z1,z0 + let (z0, b0) = x.overflowing_sub(y); + let (_z1, b1) = 0u128.overflowing_sub(b0 as u128); + let m = 0u128.wrapping_sub(b1 as u128); + // z1,z0 + // + 0, p + // ======== + // s1,s0 + z0.wrapping_add(m & self.p) + // if b1 == 1: return s0 + // else: return z0 + } + + /// Multiplication of field elements in the Montgomery domain. This uses the REDC algorithm + /// described + /// [here](https://www.ams.org/journals/mcom/1985-44-170/S0025-5718-1985-0777282-X/S0025-5718-1985-0777282-X.pdf). + /// The result will be in [0, p). + /// + /// # Example usage + /// ```text + /// assert_eq!(fp.residue(fp.mul(fp.montgomery(23), fp.montgomery(2))), 46); + /// ``` + #[inline(always)] + pub fn mul(&self, x: u128, y: u128) -> u128 { + let x = [lo64(x), hi64(x)]; + let y = [lo64(y), hi64(y)]; + let p = [lo64(self.p), hi64(self.p)]; + let mut zz = [0; 4]; + + // Integer multiplication + // z = x * y + + // x1,x0 + // * y1,y0 + // =========== + // z3,z2,z1,z0 + let mut result = x[0] * y[0]; + let mut carry = hi64(result); + zz[0] = lo64(result); + result = x[0] * y[1]; + let mut hi = hi64(result); + let mut lo = lo64(result); + result = lo + carry; + zz[1] = lo64(result); + let mut cc = hi64(result); + result = hi + cc; + zz[2] = lo64(result); + + result = x[1] * y[0]; + hi = hi64(result); + lo = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = x[1] * y[1]; + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[2] + lo; + zz[2] = lo64(result); + cc = hi64(result); + result = hi + cc; + zz[3] = lo64(result); + + // Montgomery Reduction + // z = z + p * mu*(z mod 2^64), where mu = (-p)^(-1) mod 2^64. + + // z3,z2,z1,z0 + // + p1,p0 + // * w = mu*z0 + // =========== + // z3,z2,z1, 0 + let w = self.mu.wrapping_mul(zz[0] as u64); + result = p[0] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = zz[0] + lo; + zz[0] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = p[1] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = zz[2] + hi + cc; + zz[2] = lo64(result); + cc = hi64(result); + result = zz[3] + cc; + zz[3] = lo64(result); + + // z3,z2,z1 + // + p1,p0 + // * w = mu*z1 + // =========== + // z3,z2, 0 + let w = self.mu.wrapping_mul(zz[1] as u64); + result = p[0] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = zz[1] + lo; + zz[1] = lo64(result); + cc = hi64(result); + result = hi + cc; + carry = lo64(result); + + result = p[1] * (w as u128); + hi = hi64(result); + lo = lo64(result); + result = lo + carry; + lo = lo64(result); + cc = hi64(result); + result = hi + cc; + hi = lo64(result); + result = zz[2] + lo; + zz[2] = lo64(result); + cc = hi64(result); + result = zz[3] + hi + cc; + zz[3] = lo64(result); + cc = hi64(result); + + // z = (z3,z2) + let prod = zz[2] | (zz[3] << 64); + + // Final subtraction + // If z >= p, then z = z - p + + // 0, z + // - 0, p + // ======== + // b1,s1,s0 + let (s0, b0) = prod.overflowing_sub(self.p); + let (_s1, b1) = cc.overflowing_sub(b0 as u128); + // if b1 == 1: return z + // else: return s0 + let mask = 0u128.wrapping_sub(b1 as u128); + (prod & mask) | (s0 & !mask) + } + + /// Modular exponentiation, i.e., `x^exp (mod p)` where `p` is the modulus. Note that the + /// runtime of this algorithm is linear in the bit length of `exp`. + pub fn pow(&self, x: u128, exp: u128) -> u128 { + let mut t = self.montgomery(1); + for i in (0..128 - exp.leading_zeros()).rev() { + t = self.mul(t, t); + if (exp >> i) & 1 != 0 { + t = self.mul(t, x); + } + } + t + } + + /// Modular inversion, i.e., x^-1 (mod p) where `p` is the modulus. Note that the runtime of + /// this algorithm is linear in the bit length of `p`. + #[inline(always)] + pub fn inv(&self, x: u128) -> u128 { + self.pow(x, self.p - 2) + } + + /// Negation, i.e., `-x (mod p)` where `p` is the modulus. + #[inline(always)] + pub fn neg(&self, x: u128) -> u128 { + self.sub(0, x) + } + + /// Maps an integer to its internal representation. Field elements are mapped to the Montgomery + /// domain in order to carry out field arithmetic. The result will be in [0, p). + /// + /// # Example usage + /// ```text + /// let integer = 1; // Standard integer representation + /// let elem = fp.montgomery(integer); // Internal representation in the Montgomery domain + /// assert_eq!(elem, 2564090464); + /// ``` + #[inline(always)] + pub fn montgomery(&self, x: u128) -> u128 { + modp(self.mul(x, self.r2), self.p) + } + + /// Returns a random field element mapped. + #[cfg(test)] + pub fn rand_elem<R: Rng + ?Sized>(&self, rng: &mut R) -> u128 { + let uniform = rand::distributions::Uniform::from(0..self.p); + self.montgomery(uniform.sample(rng)) + } + + /// Maps a field element to its representation as an integer. The result will be in [0, p). + /// + /// #Example usage + /// ```text + /// let elem = 2564090464; // Internal representation in the Montgomery domain + /// let integer = fp.residue(elem); // Standard integer representation + /// assert_eq!(integer, 1); + /// ``` + #[inline(always)] + pub fn residue(&self, x: u128) -> u128 { + modp(self.mul(x, 1), self.p) + } + + #[cfg(test)] + pub fn check(&self, p: u128, g: u128, order: u128) { + use modinverse::modinverse; + use num_bigint::{BigInt, ToBigInt}; + use std::cmp::max; + + assert_eq!(self.p, p, "p mismatch"); + + let mu = match modinverse((-(p as i128)).rem_euclid(1 << 64), 1 << 64) { + Some(mu) => mu as u64, + None => panic!("inverse of -p (mod 2^64) is undefined"), + }; + assert_eq!(self.mu, mu, "mu mismatch"); + + let big_p = &p.to_bigint().unwrap(); + let big_r: &BigInt = &(&(BigInt::from(1) << 128) % big_p); + let big_r2: &BigInt = &(&(big_r * big_r) % big_p); + let mut it = big_r2.iter_u64_digits(); + let mut r2 = 0; + r2 |= it.next().unwrap() as u128; + if let Some(x) = it.next() { + r2 |= (x as u128) << 64; + } + assert_eq!(self.r2, r2, "r2 mismatch"); + + assert_eq!(self.g, self.montgomery(g), "g mismatch"); + assert_eq!( + self.residue(self.pow(self.g, order)), + 1, + "g order incorrect" + ); + + let num_roots = log2(order) as usize; + assert_eq!(order, 1 << num_roots, "order not a power of 2"); + assert_eq!(self.num_roots, num_roots, "num_roots mismatch"); + + let mut roots = vec![0; max(num_roots, MAX_ROOTS) + 1]; + roots[num_roots] = self.montgomery(g); + for i in (0..num_roots).rev() { + roots[i] = self.mul(roots[i + 1], roots[i + 1]); + } + assert_eq!(&self.roots, &roots[..MAX_ROOTS + 1], "roots mismatch"); + assert_eq!(self.residue(self.roots[0]), 1, "first root is not one"); + + let bit_mask = (BigInt::from(1) << big_p.bits()) - BigInt::from(1); + assert_eq!( + self.bit_mask.to_bigint().unwrap(), + bit_mask, + "bit_mask mismatch" + ); + } +} + +#[inline(always)] +fn lo64(x: u128) -> u128 { + x & ((1 << 64) - 1) +} + +#[inline(always)] +fn hi64(x: u128) -> u128 { + x >> 64 +} + +#[inline(always)] +fn modp(x: u128, p: u128) -> u128 { + let (z, carry) = x.overflowing_sub(p); + let m = 0u128.wrapping_sub(carry as u128); + z.wrapping_add(m & p) +} + +pub(crate) const FP32: FieldParameters = FieldParameters { + p: 4293918721, // 32-bit prime + mu: 17302828673139736575, + r2: 1676699750, + g: 1074114499, + num_roots: 20, + bit_mask: 4294967295, + roots: [ + 2564090464, 1729828257, 306605458, 2294308040, 1648889905, 57098624, 2788941825, + 2779858277, 368200145, 2760217336, 594450960, 4255832533, 1372848488, 721329415, + 3873251478, 1134002069, 7138597, 2004587313, 2989350643, 725214187, 1074114499, + ], +}; + +pub(crate) const FP64: FieldParameters = FieldParameters { + p: 18446744069414584321, // 64-bit prime + mu: 18446744069414584319, + r2: 4294967295, + g: 959634606461954525, + num_roots: 32, + bit_mask: 18446744073709551615, + roots: [ + 18446744065119617025, + 4294967296, + 18446462594437939201, + 72057594037927936, + 1152921504338411520, + 16384, + 18446743519658770561, + 18446735273187346433, + 6519596376689022014, + 9996039020351967275, + 15452408553935940313, + 15855629130643256449, + 8619522106083987867, + 13036116919365988132, + 1033106119984023956, + 16593078884869787648, + 16980581328500004402, + 12245796497946355434, + 8709441440702798460, + 8611358103550827629, + 8120528636261052110, + ], +}; + +pub(crate) const FP128: FieldParameters = FieldParameters { + p: 340282366920938462946865773367900766209, // 128-bit prime + mu: 18446744073709551615, + r2: 403909908237944342183153, + g: 107630958476043550189608038630704257141, + num_roots: 66, + bit_mask: 340282366920938463463374607431768211455, + roots: [ + 516508834063867445247, + 340282366920938462430356939304033320962, + 129526470195413442198896969089616959958, + 169031622068548287099117778531474117974, + 81612939378432101163303892927894236156, + 122401220764524715189382260548353967708, + 199453575871863981432000940507837456190, + 272368408887745135168960576051472383806, + 24863773656265022616993900367764287617, + 257882853788779266319541142124730662203, + 323732363244658673145040701829006542956, + 57532865270871759635014308631881743007, + 149571414409418047452773959687184934208, + 177018931070866797456844925926211239962, + 268896136799800963964749917185333891349, + 244556960591856046954834420512544511831, + 118945432085812380213390062516065622346, + 202007153998709986841225284843501908420, + 332677126194796691532164818746739771387, + 258279638927684931537542082169183965856, + 148221243758794364405224645520862378432, + ], +}; + +// Compute the ceiling of the base-2 logarithm of `x`. +pub(crate) fn log2(x: u128) -> u128 { + let y = (127 - x.leading_zeros()) as u128; + y + ((x > 1 << y) as u128) +} + +#[cfg(test)] +mod tests { + use super::*; + use num_bigint::ToBigInt; + + #[test] + fn test_log2() { + assert_eq!(log2(1), 0); + assert_eq!(log2(2), 1); + assert_eq!(log2(3), 2); + assert_eq!(log2(4), 2); + assert_eq!(log2(15), 4); + assert_eq!(log2(16), 4); + assert_eq!(log2(30), 5); + assert_eq!(log2(32), 5); + assert_eq!(log2(1 << 127), 127); + assert_eq!(log2((1 << 127) + 13), 128); + } + + struct TestFieldParametersData { + fp: FieldParameters, // The paramters being tested + expected_p: u128, // Expected fp.p + expected_g: u128, // Expected fp.residue(fp.g) + expected_order: u128, // Expect fp.residue(fp.pow(fp.g, expected_order)) == 1 + } + + #[test] + fn test_fp() { + let test_fps = vec![ + TestFieldParametersData { + fp: FP32, + expected_p: 4293918721, + expected_g: 3925978153, + expected_order: 1 << 20, + }, + TestFieldParametersData { + fp: FP64, + expected_p: 18446744069414584321, + expected_g: 1753635133440165772, + expected_order: 1 << 32, + }, + TestFieldParametersData { + fp: FP128, + expected_p: 340282366920938462946865773367900766209, + expected_g: 145091266659756586618791329697897684742, + expected_order: 1 << 66, + }, + ]; + + for t in test_fps.into_iter() { + // Check that the field parameters have been constructed properly. + t.fp.check(t.expected_p, t.expected_g, t.expected_order); + + // Check that the generator has the correct order. + assert_eq!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order)), 1); + assert_ne!(t.fp.residue(t.fp.pow(t.fp.g, t.expected_order / 2)), 1); + + // Test arithmetic using the field parameters. + arithmetic_test(&t.fp); + } + } + + fn arithmetic_test(fp: &FieldParameters) { + let mut rng = rand::thread_rng(); + let big_p = &fp.p.to_bigint().unwrap(); + + for _ in 0..100 { + let x = fp.rand_elem(&mut rng); + let y = fp.rand_elem(&mut rng); + let big_x = &fp.residue(x).to_bigint().unwrap(); + let big_y = &fp.residue(y).to_bigint().unwrap(); + + // Test addition. + let got = fp.add(x, y); + let want = (big_x + big_y) % big_p; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + + // Test subtraction. + let got = fp.sub(x, y); + let want = if big_x >= big_y { + big_x - big_y + } else { + big_p - big_y + big_x + }; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + + // Test multiplication. + let got = fp.mul(x, y); + let want = (big_x * big_y) % big_p; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + + // Test inversion. + let got = fp.inv(x); + let want = big_x.modpow(&(big_p - 2u128), big_p); + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + assert_eq!(fp.residue(fp.mul(got, x)), 1); + + // Test negation. + let got = fp.neg(x); + let want = (big_p - big_x) % big_p; + assert_eq!(fp.residue(got).to_bigint().unwrap(), want); + assert_eq!(fp.residue(fp.add(got, x)), 0); + } + } +} diff --git a/third_party/rust/prio/src/idpf.rs b/third_party/rust/prio/src/idpf.rs new file mode 100644 index 0000000000..2bb73f2159 --- /dev/null +++ b/third_party/rust/prio/src/idpf.rs @@ -0,0 +1,2200 @@ +//! This module implements the incremental distributed point function (IDPF) described in +//! [[draft-irtf-cfrg-vdaf-07]]. +//! +//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ + +use crate::{ + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{FieldElement, FieldElementExt}, + vdaf::{ + xof::{Seed, XofFixedKeyAes128Key}, + VdafError, VERSION, + }, +}; +use bitvec::{ + bitvec, + boxed::BitBox, + prelude::{Lsb0, Msb0}, + slice::BitSlice, + vec::BitVec, + view::BitView, +}; +use rand_core::RngCore; +use std::{ + collections::{HashMap, VecDeque}, + fmt::Debug, + io::{Cursor, Read}, + ops::{Add, AddAssign, ControlFlow, Index, Sub}, +}; +use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq}; + +/// IDPF-related errors. +#[derive(Debug, thiserror::Error)] +pub enum IdpfError { + /// Error from incompatible shares at different levels. + #[error("tried to merge shares from incompatible levels")] + MismatchedLevel, + + /// Invalid parameter, indicates an invalid input to either [`Idpf::gen`] or [`Idpf::eval`]. + #[error("invalid parameter: {0}")] + InvalidParameter(String), +} + +/// An index used as the input to an IDPF evaluation. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct IdpfInput { + /// The index as a boxed bit slice. + index: BitBox, +} + +impl IdpfInput { + /// Convert a slice of bytes into an IDPF input, where the bits of each byte are processed in + /// MSB-to-LSB order. (Subsequent bytes are processed in their natural order.) + pub fn from_bytes(bytes: &[u8]) -> IdpfInput { + let bit_slice_u8_storage = bytes.view_bits::<Msb0>(); + let mut bit_vec_usize_storage = bitvec![0; bit_slice_u8_storage.len()]; + bit_vec_usize_storage.clone_from_bitslice(bit_slice_u8_storage); + IdpfInput { + index: bit_vec_usize_storage.into_boxed_bitslice(), + } + } + + /// Convert a slice of booleans into an IDPF input. + pub fn from_bools(bools: &[bool]) -> IdpfInput { + let bits = bools.iter().collect::<BitVec>(); + IdpfInput { + index: bits.into_boxed_bitslice(), + } + } + + /// Create a new IDPF input by appending to this input. + pub fn clone_with_suffix(&self, suffix: &[bool]) -> IdpfInput { + let mut vec = BitVec::with_capacity(self.index.len() + suffix.len()); + vec.extend_from_bitslice(&self.index); + vec.extend(suffix); + IdpfInput { + index: vec.into_boxed_bitslice(), + } + } + + /// Get the length of the input in bits. + pub fn len(&self) -> usize { + self.index.len() + } + + /// Check if the input is empty, i.e. it does not contain any bits. + pub fn is_empty(&self) -> bool { + self.index.is_empty() + } + + /// Get an iterator over the bits that make up this input. + pub fn iter(&self) -> impl DoubleEndedIterator<Item = bool> + '_ { + self.index.iter().by_vals() + } + + /// Convert the IDPF into a byte slice. If the length of the underlying bit vector is not a + /// multiple of `8`, then the least significant bits of the last byte are `0`-padded. + pub fn to_bytes(&self) -> Vec<u8> { + let mut vec = BitVec::<u8, Msb0>::with_capacity(self.index.len()); + vec.extend_from_bitslice(&self.index); + vec.set_uninitialized(false); + vec.into_vec() + } + + /// Return the `level`-bit prefix of this IDPF input. + pub fn prefix(&self, level: usize) -> Self { + Self { + index: self.index[..=level].to_owned().into(), + } + } +} + +impl From<BitVec<usize, Lsb0>> for IdpfInput { + fn from(bit_vec: BitVec<usize, Lsb0>) -> Self { + IdpfInput { + index: bit_vec.into_boxed_bitslice(), + } + } +} + +impl From<BitBox<usize, Lsb0>> for IdpfInput { + fn from(bit_box: BitBox<usize, Lsb0>) -> Self { + IdpfInput { index: bit_box } + } +} + +impl<I> Index<I> for IdpfInput +where + BitSlice: Index<I>, +{ + type Output = <BitSlice as Index<I>>::Output; + + fn index(&self, index: I) -> &Self::Output { + &self.index[index] + } +} + +/// Trait for values to be programmed into an IDPF. +/// +/// Values must form an Abelian group, so that they can be secret-shared, and the group operation +/// must be represented by [`Add`]. Values must be encodable and decodable, without need for a +/// decoding parameter. Values can be pseudorandomly generated, with a uniform probability +/// distribution, from XOF output. +pub trait IdpfValue: + Add<Output = Self> + + AddAssign + + Sub<Output = Self> + + ConditionallyNegatable + + Encode + + Decode + + Sized +{ + /// Any run-time parameters needed to produce a value. + type ValueParameter; + + /// Generate a pseudorandom value from a seed stream. + fn generate<S>(seed_stream: &mut S, parameter: &Self::ValueParameter) -> Self + where + S: RngCore; + + /// Returns the additive identity. + fn zero(parameter: &Self::ValueParameter) -> Self; + + /// Conditionally select between two values. Implementations must perform this operation in + /// constant time. + /// + /// This is the same as in [`subtle::ConditionallySelectable`], but without the [`Copy`] bound. + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self; +} + +impl<F> IdpfValue for F +where + F: FieldElement, +{ + type ValueParameter = (); + + fn generate<S>(seed_stream: &mut S, _: &()) -> Self + where + S: RngCore, + { + // This is analogous to `Prng::get()`, but does not make use of a persistent buffer of + // output. + let mut buffer = [0u8; 64]; + assert!( + buffer.len() >= F::ENCODED_SIZE, + "field is too big for buffer" + ); + loop { + seed_stream.fill_bytes(&mut buffer[..F::ENCODED_SIZE]); + match F::from_random_rejection(&buffer[..F::ENCODED_SIZE]) { + ControlFlow::Break(x) => return x, + ControlFlow::Continue(()) => continue, + } + } + } + + fn zero(_: &()) -> Self { + <Self as FieldElement>::zero() + } + + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + <F as ConditionallySelectable>::conditional_select(a, b, choice) + } +} + +/// An output from evaluation of an IDPF at some level and index. +#[derive(Debug, PartialEq, Eq)] +pub enum IdpfOutputShare<VI, VL> { + /// An IDPF output share corresponding to an inner tree node. + Inner(VI), + /// An IDPF output share corresponding to a leaf tree node. + Leaf(VL), +} + +impl<VI, VL> IdpfOutputShare<VI, VL> +where + VI: IdpfValue, + VL: IdpfValue, +{ + /// Combine two output share values into one. + pub fn merge(self, other: Self) -> Result<IdpfOutputShare<VI, VL>, IdpfError> { + match (self, other) { + (IdpfOutputShare::Inner(mut self_value), IdpfOutputShare::Inner(other_value)) => { + self_value += other_value; + Ok(IdpfOutputShare::Inner(self_value)) + } + (IdpfOutputShare::Leaf(mut self_value), IdpfOutputShare::Leaf(other_value)) => { + self_value += other_value; + Ok(IdpfOutputShare::Leaf(self_value)) + } + (_, _) => Err(IdpfError::MismatchedLevel), + } + } +} + +fn extend(seed: &[u8; 16], xof_fixed_key: &XofFixedKeyAes128Key) -> ([[u8; 16]; 2], [Choice; 2]) { + let mut seed_stream = xof_fixed_key.with_seed(seed); + + let mut seeds = [[0u8; 16], [0u8; 16]]; + seed_stream.fill_bytes(&mut seeds[0]); + seed_stream.fill_bytes(&mut seeds[1]); + + let mut byte = [0u8]; + seed_stream.fill_bytes(&mut byte); + let control_bits = [(byte[0] & 1).into(), ((byte[0] >> 1) & 1).into()]; + + (seeds, control_bits) +} + +fn convert<V>( + seed: &[u8; 16], + xof_fixed_key: &XofFixedKeyAes128Key, + parameter: &V::ValueParameter, +) -> ([u8; 16], V) +where + V: IdpfValue, +{ + let mut seed_stream = xof_fixed_key.with_seed(seed); + + let mut next_seed = [0u8; 16]; + seed_stream.fill_bytes(&mut next_seed); + + (next_seed, V::generate(&mut seed_stream, parameter)) +} + +/// Helper method to update seeds, update control bits, and output the correction word for one level +/// of the IDPF key generation process. +fn generate_correction_word<V>( + input_bit: Choice, + value: V, + parameter: &V::ValueParameter, + keys: &mut [[u8; 16]; 2], + control_bits: &mut [Choice; 2], + extend_xof_fixed_key: &XofFixedKeyAes128Key, + convert_xof_fixed_key: &XofFixedKeyAes128Key, +) -> IdpfCorrectionWord<V> +where + V: IdpfValue, +{ + // Expand both keys into two seeds and two control bits each. + let (seed_0, control_bits_0) = extend(&keys[0], extend_xof_fixed_key); + let (seed_1, control_bits_1) = extend(&keys[1], extend_xof_fixed_key); + + let (keep, lose) = (input_bit, !input_bit); + + let cw_seed = xor_seeds( + &conditional_select_seed(lose, &seed_0), + &conditional_select_seed(lose, &seed_1), + ); + let cw_control_bits = [ + control_bits_0[0] ^ control_bits_1[0] ^ input_bit ^ Choice::from(1), + control_bits_0[1] ^ control_bits_1[1] ^ input_bit, + ]; + let cw_control_bits_keep = + Choice::conditional_select(&cw_control_bits[0], &cw_control_bits[1], keep); + + let previous_control_bits = *control_bits; + let control_bits_0_keep = + Choice::conditional_select(&control_bits_0[0], &control_bits_0[1], keep); + let control_bits_1_keep = + Choice::conditional_select(&control_bits_1[0], &control_bits_1[1], keep); + control_bits[0] = control_bits_0_keep ^ (cw_control_bits_keep & previous_control_bits[0]); + control_bits[1] = control_bits_1_keep ^ (cw_control_bits_keep & previous_control_bits[1]); + + let seed_0_keep = conditional_select_seed(keep, &seed_0); + let seed_1_keep = conditional_select_seed(keep, &seed_1); + let seeds_corrected = [ + conditional_xor_seeds(&seed_0_keep, &cw_seed, previous_control_bits[0]), + conditional_xor_seeds(&seed_1_keep, &cw_seed, previous_control_bits[1]), + ]; + + let (new_key_0, elements_0) = + convert::<V>(&seeds_corrected[0], convert_xof_fixed_key, parameter); + let (new_key_1, elements_1) = + convert::<V>(&seeds_corrected[1], convert_xof_fixed_key, parameter); + + keys[0] = new_key_0; + keys[1] = new_key_1; + + let mut cw_value = value - elements_0 + elements_1; + cw_value.conditional_negate(control_bits[1]); + + IdpfCorrectionWord { + seed: cw_seed, + control_bits: cw_control_bits, + value: cw_value, + } +} + +/// Helper function to evaluate one level of an IDPF. This updates the seed and control bit +/// arguments that are passed in. +#[allow(clippy::too_many_arguments)] +fn eval_next<V>( + is_leader: bool, + parameter: &V::ValueParameter, + key: &mut [u8; 16], + control_bit: &mut Choice, + correction_word: &IdpfCorrectionWord<V>, + input_bit: Choice, + extend_xof_fixed_key: &XofFixedKeyAes128Key, + convert_xof_fixed_key: &XofFixedKeyAes128Key, +) -> V +where + V: IdpfValue, +{ + let (mut seeds, mut control_bits) = extend(key, extend_xof_fixed_key); + + seeds[0] = conditional_xor_seeds(&seeds[0], &correction_word.seed, *control_bit); + control_bits[0] ^= correction_word.control_bits[0] & *control_bit; + seeds[1] = conditional_xor_seeds(&seeds[1], &correction_word.seed, *control_bit); + control_bits[1] ^= correction_word.control_bits[1] & *control_bit; + + let seed_corrected = conditional_select_seed(input_bit, &seeds); + *control_bit = Choice::conditional_select(&control_bits[0], &control_bits[1], input_bit); + + let (new_key, elements) = convert::<V>(&seed_corrected, convert_xof_fixed_key, parameter); + *key = new_key; + + let mut out = + elements + V::conditional_select(&V::zero(parameter), &correction_word.value, *control_bit); + out.conditional_negate(Choice::from((!is_leader) as u8)); + out +} + +/// This defines a family of IDPFs (incremental distributed point functions) with certain types of +/// values at inner tree nodes and at leaf tree nodes. +/// +/// IDPF keys can be generated by providing an input and programmed outputs for each tree level to +/// [`Idpf::gen`]. +pub struct Idpf<VI, VL> +where + VI: IdpfValue, + VL: IdpfValue, +{ + inner_node_value_parameter: VI::ValueParameter, + leaf_node_value_parameter: VL::ValueParameter, +} + +impl<VI, VL> Idpf<VI, VL> +where + VI: IdpfValue, + VL: IdpfValue, +{ + /// Construct an [`Idpf`] instance with the given run-time parameters needed for inner and leaf + /// values. + pub fn new( + inner_node_value_parameter: VI::ValueParameter, + leaf_node_value_parameter: VL::ValueParameter, + ) -> Self { + Self { + inner_node_value_parameter, + leaf_node_value_parameter, + } + } + + pub(crate) fn gen_with_random<M: IntoIterator<Item = VI>>( + &self, + input: &IdpfInput, + inner_values: M, + leaf_value: VL, + binder: &[u8], + random: &[[u8; 16]; 2], + ) -> Result<(IdpfPublicShare<VI, VL>, [Seed<16>; 2]), VdafError> { + let bits = input.len(); + + let initial_keys: [Seed<16>; 2] = + [Seed::from_bytes(random[0]), Seed::from_bytes(random[1])]; + + let extend_dst = [ + VERSION, 1, /* algorithm class */ + 0, 0, 0, 0, /* algorithm ID */ + 0, 0, /* usage */ + ]; + let convert_dst = [ + VERSION, 1, /* algorithm class */ + 0, 0, 0, 0, /* algorithm ID */ + 0, 1, /* usage */ + ]; + let extend_xof_fixed_key = XofFixedKeyAes128Key::new(&extend_dst, binder); + let convert_xof_fixed_key = XofFixedKeyAes128Key::new(&convert_dst, binder); + + let mut keys = [initial_keys[0].0, initial_keys[1].0]; + let mut control_bits = [Choice::from(0u8), Choice::from(1u8)]; + let mut inner_correction_words = Vec::with_capacity(bits - 1); + + for (level, value) in inner_values.into_iter().enumerate() { + if level >= bits - 1 { + return Err(IdpfError::InvalidParameter( + "too many values were supplied".to_string(), + ) + .into()); + } + inner_correction_words.push(generate_correction_word::<VI>( + Choice::from(input[level] as u8), + value, + &self.inner_node_value_parameter, + &mut keys, + &mut control_bits, + &extend_xof_fixed_key, + &convert_xof_fixed_key, + )); + } + if inner_correction_words.len() != bits - 1 { + return Err( + IdpfError::InvalidParameter("too few values were supplied".to_string()).into(), + ); + } + let leaf_correction_word = generate_correction_word::<VL>( + Choice::from(input[bits - 1] as u8), + leaf_value, + &self.leaf_node_value_parameter, + &mut keys, + &mut control_bits, + &extend_xof_fixed_key, + &convert_xof_fixed_key, + ); + let public_share = IdpfPublicShare { + inner_correction_words, + leaf_correction_word, + }; + + Ok((public_share, initial_keys)) + } + + /// The IDPF key generation algorithm. + /// + /// Generate and return a sequence of IDPF shares for `input`. The parameters `inner_values` + /// and `leaf_value` provide the output values for each successive level of the prefix tree. + pub fn gen<M>( + &self, + input: &IdpfInput, + inner_values: M, + leaf_value: VL, + binder: &[u8], + ) -> Result<(IdpfPublicShare<VI, VL>, [Seed<16>; 2]), VdafError> + where + M: IntoIterator<Item = VI>, + { + if input.is_empty() { + return Err( + IdpfError::InvalidParameter("invalid number of bits: 0".to_string()).into(), + ); + } + let mut random = [[0u8; 16]; 2]; + for random_seed in random.iter_mut() { + getrandom::getrandom(random_seed)?; + } + self.gen_with_random(input, inner_values, leaf_value, binder, &random) + } + + /// Evaluate an IDPF share on `prefix`, starting from a particular tree level with known + /// intermediate values. + #[allow(clippy::too_many_arguments)] + fn eval_from_node( + &self, + is_leader: bool, + public_share: &IdpfPublicShare<VI, VL>, + start_level: usize, + mut key: [u8; 16], + mut control_bit: Choice, + prefix: &IdpfInput, + binder: &[u8], + cache: &mut dyn IdpfCache, + ) -> Result<IdpfOutputShare<VI, VL>, IdpfError> { + let bits = public_share.inner_correction_words.len() + 1; + + let extend_dst = [ + VERSION, 1, /* algorithm class */ + 0, 0, 0, 0, /* algorithm ID */ + 0, 0, /* usage */ + ]; + let convert_dst = [ + VERSION, 1, /* algorithm class */ + 0, 0, 0, 0, /* algorithm ID */ + 0, 1, /* usage */ + ]; + let extend_xof_fixed_key = XofFixedKeyAes128Key::new(&extend_dst, binder); + let convert_xof_fixed_key = XofFixedKeyAes128Key::new(&convert_dst, binder); + + let mut last_inner_output = None; + for ((correction_word, input_bit), level) in public_share.inner_correction_words + [start_level..] + .iter() + .zip(prefix[start_level..].iter()) + .zip(start_level..) + { + last_inner_output = Some(eval_next( + is_leader, + &self.inner_node_value_parameter, + &mut key, + &mut control_bit, + correction_word, + Choice::from(*input_bit as u8), + &extend_xof_fixed_key, + &convert_xof_fixed_key, + )); + let cache_key = &prefix[..=level]; + cache.insert(cache_key, &(key, control_bit.unwrap_u8())); + } + + if prefix.len() == bits { + let leaf_output = eval_next( + is_leader, + &self.leaf_node_value_parameter, + &mut key, + &mut control_bit, + &public_share.leaf_correction_word, + Choice::from(prefix[bits - 1] as u8), + &extend_xof_fixed_key, + &convert_xof_fixed_key, + ); + // Note: there's no point caching this node's key, because we will always run the + // eval_next() call for the leaf level. + Ok(IdpfOutputShare::Leaf(leaf_output)) + } else { + Ok(IdpfOutputShare::Inner(last_inner_output.unwrap())) + } + } + + /// The IDPF key evaluation algorithm. + /// + /// Evaluate an IDPF share on `prefix`. + pub fn eval( + &self, + agg_id: usize, + public_share: &IdpfPublicShare<VI, VL>, + key: &Seed<16>, + prefix: &IdpfInput, + binder: &[u8], + cache: &mut dyn IdpfCache, + ) -> Result<IdpfOutputShare<VI, VL>, IdpfError> { + let bits = public_share.inner_correction_words.len() + 1; + if agg_id > 1 { + return Err(IdpfError::InvalidParameter(format!( + "invalid aggregator ID {agg_id}" + ))); + } + let is_leader = agg_id == 0; + if prefix.is_empty() { + return Err(IdpfError::InvalidParameter("empty prefix".to_string())); + } + if prefix.len() > bits { + return Err(IdpfError::InvalidParameter(format!( + "prefix length ({}) exceeds configured number of bits ({})", + prefix.len(), + bits, + ))); + } + + // Check for cached keys first, starting from the end of our desired path down the tree, and + // walking back up. If we get a hit, stop there and evaluate the remainder of the tree path + // going forward. + if prefix.len() > 1 { + // Skip checking for `prefix` in the cache, because we don't store field element + // values along with keys and control bits. Instead, start looking one node higher + // up, so we can recompute everything for the last level of `prefix`. + let mut cache_key = &prefix[..prefix.len() - 1]; + while !cache_key.is_empty() { + if let Some((key, control_bit)) = cache.get(cache_key) { + // Evaluate the IDPF starting from the cached data at a previously-computed + // node, and return the result. + return self.eval_from_node( + is_leader, + public_share, + /* start_level */ cache_key.len(), + key, + Choice::from(control_bit), + prefix, + binder, + cache, + ); + } + cache_key = &cache_key[..cache_key.len() - 1]; + } + } + // Evaluate starting from the root node. + self.eval_from_node( + is_leader, + public_share, + /* start_level */ 0, + key.0, + /* control_bit */ Choice::from((!is_leader) as u8), + prefix, + binder, + cache, + ) + } +} + +/// An IDPF public share. This contains the list of correction words used by all parties when +/// evaluating the IDPF. +#[derive(Debug, Clone)] +pub struct IdpfPublicShare<VI, VL> { + /// Correction words for each inner node level. + inner_correction_words: Vec<IdpfCorrectionWord<VI>>, + /// Correction word for the leaf node level. + leaf_correction_word: IdpfCorrectionWord<VL>, +} + +impl<VI, VL> ConstantTimeEq for IdpfPublicShare<VI, VL> +where + VI: ConstantTimeEq, + VL: ConstantTimeEq, +{ + fn ct_eq(&self, other: &Self) -> Choice { + self.inner_correction_words + .ct_eq(&other.inner_correction_words) + & self.leaf_correction_word.ct_eq(&other.leaf_correction_word) + } +} + +impl<VI, VL> PartialEq for IdpfPublicShare<VI, VL> +where + VI: ConstantTimeEq, + VL: ConstantTimeEq, +{ + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<VI, VL> Eq for IdpfPublicShare<VI, VL> +where + VI: ConstantTimeEq, + VL: ConstantTimeEq, +{ +} + +impl<VI, VL> Encode for IdpfPublicShare<VI, VL> +where + VI: Encode, + VL: Encode, +{ + fn encode(&self, bytes: &mut Vec<u8>) { + // Control bits need to be written within each byte in LSB-to-MSB order, and assigned into + // bytes in big-endian order. Thus, the first four levels will have their control bits + // encoded in the last byte, and the last levels will have their control bits encoded in the + // first byte. + let mut control_bits: BitVec<u8, Lsb0> = + BitVec::with_capacity(self.inner_correction_words.len() * 2 + 2); + for correction_words in self.inner_correction_words.iter() { + control_bits.extend(correction_words.control_bits.iter().map(|x| bool::from(*x))); + } + control_bits.extend( + self.leaf_correction_word + .control_bits + .iter() + .map(|x| bool::from(*x)), + ); + control_bits.set_uninitialized(false); + let mut packed_control = control_bits.into_vec(); + bytes.append(&mut packed_control); + + for correction_words in self.inner_correction_words.iter() { + Seed(correction_words.seed).encode(bytes); + correction_words.value.encode(bytes); + } + Seed(self.leaf_correction_word.seed).encode(bytes); + self.leaf_correction_word.value.encode(bytes); + } + + fn encoded_len(&self) -> Option<usize> { + let control_bits_count = (self.inner_correction_words.len() + 1) * 2; + let mut len = (control_bits_count + 7) / 8 + (self.inner_correction_words.len() + 1) * 16; + for correction_words in self.inner_correction_words.iter() { + len += correction_words.value.encoded_len()?; + } + len += self.leaf_correction_word.value.encoded_len()?; + Some(len) + } +} + +impl<VI, VL> ParameterizedDecode<usize> for IdpfPublicShare<VI, VL> +where + VI: Decode, + VL: Decode, +{ + fn decode_with_param(bits: &usize, bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let packed_control_len = (bits + 3) / 4; + let mut packed = vec![0u8; packed_control_len]; + bytes.read_exact(&mut packed)?; + let unpacked_control_bits: BitVec<u8, Lsb0> = BitVec::from_vec(packed); + + let mut inner_correction_words = Vec::with_capacity(bits - 1); + for chunk in unpacked_control_bits[0..(bits - 1) * 2].chunks(2) { + let control_bits = [(chunk[0] as u8).into(), (chunk[1] as u8).into()]; + let seed = Seed::decode(bytes)?.0; + let value = VI::decode(bytes)?; + inner_correction_words.push(IdpfCorrectionWord { + seed, + control_bits, + value, + }) + } + + let control_bits = [ + (unpacked_control_bits[(bits - 1) * 2] as u8).into(), + (unpacked_control_bits[bits * 2 - 1] as u8).into(), + ]; + let seed = Seed::decode(bytes)?.0; + let value = VL::decode(bytes)?; + let leaf_correction_word = IdpfCorrectionWord { + seed, + control_bits, + value, + }; + + // Check that unused packed bits are zero. + if unpacked_control_bits[bits * 2..].any() { + return Err(CodecError::UnexpectedValue); + } + + Ok(IdpfPublicShare { + inner_correction_words, + leaf_correction_word, + }) + } +} + +#[derive(Debug, Clone)] +struct IdpfCorrectionWord<V> { + seed: [u8; 16], + control_bits: [Choice; 2], + value: V, +} + +impl<V> ConstantTimeEq for IdpfCorrectionWord<V> +where + V: ConstantTimeEq, +{ + fn ct_eq(&self, other: &Self) -> Choice { + self.seed.ct_eq(&other.seed) + & self.control_bits.ct_eq(&other.control_bits) + & self.value.ct_eq(&other.value) + } +} + +impl<V> PartialEq for IdpfCorrectionWord<V> +where + V: ConstantTimeEq, +{ + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<V> Eq for IdpfCorrectionWord<V> where V: ConstantTimeEq {} + +fn xor_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] { + let mut seed = [0u8; 16]; + for (a, (b, c)) in left.iter().zip(right.iter().zip(seed.iter_mut())) { + *c = a ^ b; + } + seed +} + +fn and_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] { + let mut seed = [0u8; 16]; + for (a, (b, c)) in left.iter().zip(right.iter().zip(seed.iter_mut())) { + *c = a & b; + } + seed +} + +fn or_seeds(left: &[u8; 16], right: &[u8; 16]) -> [u8; 16] { + let mut seed = [0u8; 16]; + for (a, (b, c)) in left.iter().zip(right.iter().zip(seed.iter_mut())) { + *c = a | b; + } + seed +} + +/// Take a control bit, and fan it out into a byte array that can be used as a mask for XOF seeds, +/// without branching. If the control bit input is 0, all bytes will be equal to 0, and if the +/// control bit input is 1, all bytes will be equal to 255. +fn control_bit_to_seed_mask(control: Choice) -> [u8; 16] { + let mask = -(control.unwrap_u8() as i8) as u8; + [mask; 16] +} + +/// Take two seeds and a control bit, and return the first seed if the control bit is zero, or the +/// XOR of the two seeds if the control bit is one. This does not branch on the control bit. +fn conditional_xor_seeds( + normal_input: &[u8; 16], + switched_input: &[u8; 16], + control: Choice, +) -> [u8; 16] { + xor_seeds( + normal_input, + &and_seeds(switched_input, &control_bit_to_seed_mask(control)), + ) +} + +/// Returns one of two seeds, depending on the value of a selector bit. Does not branch on the +/// selector input or make selector-dependent memory accesses. +fn conditional_select_seed(select: Choice, seeds: &[[u8; 16]; 2]) -> [u8; 16] { + or_seeds( + &and_seeds(&control_bit_to_seed_mask(!select), &seeds[0]), + &and_seeds(&control_bit_to_seed_mask(select), &seeds[1]), + ) +} + +/// An interface that provides memoization of IDPF computations. +/// +/// Each instance of a type implementing `IdpfCache` should only be used with one IDPF key and +/// public share. +/// +/// In typical use, IDPFs will be evaluated repeatedly on inputs of increasing length, as part of a +/// protocol executed by multiple participants. Each IDPF evaluation computes keys and control +/// bits corresponding to tree nodes along a path determined by the input to the IDPF. Thus, the +/// values from nodes further up in the tree may be cached and reused in evaluations of subsequent +/// longer inputs. If one IDPF input is a prefix of another input, then the first input's path down +/// the tree is a prefix of the other input's path. +pub trait IdpfCache { + /// Fetch cached values for the node identified by the IDPF input. + fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)>; + + /// Store values corresponding to the node identified by the IDPF input. + fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8)); +} + +/// A no-op [`IdpfCache`] implementation that always reports a cache miss. +#[derive(Default)] +pub struct NoCache {} + +impl NoCache { + /// Construct a `NoCache` object. + pub fn new() -> NoCache { + NoCache::default() + } +} + +impl IdpfCache for NoCache { + fn get(&self, _: &BitSlice) -> Option<([u8; 16], u8)> { + None + } + + fn insert(&mut self, _: &BitSlice, _: &([u8; 16], u8)) {} +} + +/// A simple [`IdpfCache`] implementation that caches intermediate results in an in-memory hash map, +/// with no eviction. +#[derive(Default)] +pub struct HashMapCache { + map: HashMap<BitBox, ([u8; 16], u8)>, +} + +impl HashMapCache { + /// Create a new unpopulated `HashMapCache`. + pub fn new() -> HashMapCache { + HashMapCache::default() + } + + /// Create a new unpopulated `HashMapCache`, with a set pre-allocated capacity. + pub fn with_capacity(capacity: usize) -> HashMapCache { + Self { + map: HashMap::with_capacity(capacity), + } + } +} + +impl IdpfCache for HashMapCache { + fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)> { + self.map.get(input).cloned() + } + + fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8)) { + if !self.map.contains_key(input) { + self.map + .insert(input.to_owned().into_boxed_bitslice(), *values); + } + } +} + +/// A simple [`IdpfCache`] implementation that caches intermediate results in memory, with +/// first-in-first-out eviction, and lookups via linear probing. +pub struct RingBufferCache { + ring: VecDeque<(BitBox, [u8; 16], u8)>, +} + +impl RingBufferCache { + /// Create a new unpopulated `RingBufferCache`. + pub fn new(capacity: usize) -> RingBufferCache { + Self { + ring: VecDeque::with_capacity(std::cmp::max(capacity, 1)), + } + } +} + +impl IdpfCache for RingBufferCache { + fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)> { + // iterate back-to-front, so that we check the most recently pushed entry first. + for entry in self.ring.iter().rev() { + if input == entry.0 { + return Some((entry.1, entry.2)); + } + } + None + } + + fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8)) { + // evict first (to avoid growing the storage) + if self.ring.len() == self.ring.capacity() { + self.ring.pop_front(); + } + self.ring + .push_back((input.to_owned().into_boxed_bitslice(), values.0, values.1)); + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::HashMap, + convert::{TryFrom, TryInto}, + io::Cursor, + ops::{Add, AddAssign, Sub}, + str::FromStr, + sync::Mutex, + }; + + use assert_matches::assert_matches; + use bitvec::{ + bitbox, + prelude::{BitBox, Lsb0}, + slice::BitSlice, + vec::BitVec, + }; + use num_bigint::BigUint; + use rand::random; + use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable}; + + use super::{ + HashMapCache, Idpf, IdpfCache, IdpfCorrectionWord, IdpfInput, IdpfOutputShare, + IdpfPublicShare, NoCache, RingBufferCache, + }; + use crate::{ + codec::{ + decode_u32_items, encode_u32_items, CodecError, Decode, Encode, ParameterizedDecode, + }, + field::{Field128, Field255, Field64, FieldElement}, + prng::Prng, + vdaf::{poplar1::Poplar1IdpfValue, xof::Seed}, + }; + + #[test] + fn idpf_input_conversion() { + let input_1 = IdpfInput::from_bools(&[ + false, true, false, false, false, false, false, true, false, true, false, false, false, + false, true, false, + ]); + let input_2 = IdpfInput::from_bytes(b"AB"); + assert_eq!(input_1, input_2); + let bits = bitbox![0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0]; + assert_eq!(input_1[..], bits); + } + + /// A lossy IDPF cache, for testing purposes, that randomly returns cache misses. + #[derive(Default)] + struct LossyCache { + map: HashMap<BitBox, ([u8; 16], u8)>, + } + + impl LossyCache { + /// Create a new unpopulated `LossyCache`. + fn new() -> LossyCache { + LossyCache::default() + } + } + + impl IdpfCache for LossyCache { + fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)> { + if random() { + self.map.get(input).cloned() + } else { + None + } + } + + fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8)) { + if !self.map.contains_key(input) { + self.map + .insert(input.to_owned().into_boxed_bitslice(), *values); + } + } + } + + /// A wrapper [`IdpfCache`] implementation that records `get()` calls, for testing purposes. + struct SnoopingCache<T> { + inner: T, + get_calls: Mutex<Vec<BitBox>>, + insert_calls: Mutex<Vec<(BitBox, [u8; 16], u8)>>, + } + + impl<T> SnoopingCache<T> { + fn new(inner: T) -> SnoopingCache<T> { + SnoopingCache { + inner, + get_calls: Mutex::new(Vec::new()), + insert_calls: Mutex::new(Vec::new()), + } + } + } + + impl<T> IdpfCache for SnoopingCache<T> + where + T: IdpfCache, + { + fn get(&self, input: &BitSlice) -> Option<([u8; 16], u8)> { + self.get_calls + .lock() + .unwrap() + .push(input.to_owned().into_boxed_bitslice()); + self.inner.get(input) + } + + fn insert(&mut self, input: &BitSlice, values: &([u8; 16], u8)) { + self.insert_calls.lock().unwrap().push(( + input.to_owned().into_boxed_bitslice(), + values.0, + values.1, + )); + self.inner.insert(input, values) + } + } + + #[test] + fn test_idpf_poplar() { + let input = bitbox![0, 1, 1, 0, 1].into(); + let nonce: [u8; 16] = random(); + let idpf = Idpf::new((), ()); + let (public_share, keys) = idpf + .gen( + &input, + Vec::from([Poplar1IdpfValue::new([Field64::one(), Field64::one()]); 4]), + Poplar1IdpfValue::new([Field255::one(), Field255::one()]), + &nonce, + ) + .unwrap(); + + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![0].into(), + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::one(), Field64::one()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![1].into(), + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![0, 1].into(), + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::one(), Field64::one()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![0, 0].into(), + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![1, 0].into(), + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![1, 1].into(), + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![0, 1, 1].into(), + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::one(), Field64::one()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![0, 1, 1, 0].into(), + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::one(), Field64::one()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![0, 1, 1, 0, 1].into(), + &nonce, + &IdpfOutputShare::Leaf(Poplar1IdpfValue::new([Field255::one(), Field255::one()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![0, 1, 1, 0, 0].into(), + &nonce, + &IdpfOutputShare::Leaf(Poplar1IdpfValue::new([Field255::zero(), Field255::zero()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![1, 0, 1, 0, 0].into(), + &nonce, + &IdpfOutputShare::Leaf(Poplar1IdpfValue::new([Field255::zero(), Field255::zero()])), + &mut NoCache::new(), + &mut NoCache::new(), + ); + } + + fn check_idpf_poplar_evaluation( + public_share: &IdpfPublicShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>, + keys: &[Seed<16>; 2], + prefix: &IdpfInput, + binder: &[u8], + expected_output: &IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>, + cache_0: &mut dyn IdpfCache, + cache_1: &mut dyn IdpfCache, + ) { + let idpf = Idpf::new((), ()); + let share_0 = idpf + .eval(0, public_share, &keys[0], prefix, binder, cache_0) + .unwrap(); + let share_1 = idpf + .eval(1, public_share, &keys[1], prefix, binder, cache_1) + .unwrap(); + let output = share_0.merge(share_1).unwrap(); + assert_eq!(&output, expected_output); + } + + #[test] + fn test_idpf_poplar_medium() { + // This test on 40 byte inputs takes about a second in debug mode. (and ten milliseconds in + // release mode) + const INPUT_LEN: usize = 320; + let mut bits = bitbox![0; INPUT_LEN]; + for mut bit in bits.iter_mut() { + bit.set(random()); + } + let input = bits.clone().into(); + + let mut inner_values = Vec::with_capacity(INPUT_LEN - 1); + let mut prng = Prng::new().unwrap(); + for _ in 0..INPUT_LEN - 1 { + inner_values.push(Poplar1IdpfValue::new([ + Field64::one(), + prng.next().unwrap(), + ])); + } + let leaf_values = + Poplar1IdpfValue::new([Field255::one(), Prng::new().unwrap().next().unwrap()]); + + let nonce: [u8; 16] = random(); + let idpf = Idpf::new((), ()); + let (public_share, keys) = idpf + .gen(&input, inner_values.clone(), leaf_values, &nonce) + .unwrap(); + let mut cache_0 = RingBufferCache::new(3); + let mut cache_1 = RingBufferCache::new(3); + + for (level, values) in inner_values.iter().enumerate() { + let mut prefix = BitBox::from_bitslice(&bits[..=level]).into(); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &prefix, + &nonce, + &IdpfOutputShare::Inner(*values), + &mut cache_0, + &mut cache_1, + ); + let flipped_bit = !prefix[level]; + prefix.index.set(level, flipped_bit); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &prefix, + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])), + &mut cache_0, + &mut cache_1, + ); + } + check_idpf_poplar_evaluation( + &public_share, + &keys, + &input, + &nonce, + &IdpfOutputShare::Leaf(leaf_values), + &mut cache_0, + &mut cache_1, + ); + let mut modified_bits = bits.clone(); + modified_bits.set(INPUT_LEN - 1, !bits[INPUT_LEN - 1]); + check_idpf_poplar_evaluation( + &public_share, + &keys, + &modified_bits.into(), + &nonce, + &IdpfOutputShare::Leaf(Poplar1IdpfValue::new([Field255::zero(), Field255::zero()])), + &mut cache_0, + &mut cache_1, + ); + } + + #[test] + fn idpf_poplar_cache_behavior() { + let bits = bitbox![0, 1, 1, 1, 0, 1, 0, 0]; + let input = bits.into(); + + let mut inner_values = Vec::with_capacity(7); + let mut prng = Prng::new().unwrap(); + for _ in 0..7 { + inner_values.push(Poplar1IdpfValue::new([ + Field64::one(), + prng.next().unwrap(), + ])); + } + let leaf_values = + Poplar1IdpfValue::new([Field255::one(), Prng::new().unwrap().next().unwrap()]); + + let nonce: [u8; 16] = random(); + let idpf = Idpf::new((), ()); + let (public_share, keys) = idpf + .gen(&input, inner_values.clone(), leaf_values, &nonce) + .unwrap(); + let mut cache_0 = SnoopingCache::new(HashMapCache::new()); + let mut cache_1 = HashMapCache::new(); + + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![1, 1, 0, 0].into(), + &nonce, + &IdpfOutputShare::Inner(Poplar1IdpfValue::new([Field64::zero(), Field64::zero()])), + &mut cache_0, + &mut cache_1, + ); + assert_eq!( + cache_0 + .get_calls + .lock() + .unwrap() + .drain(..) + .collect::<Vec<_>>(), + vec![bitbox![1, 1, 0], bitbox![1, 1], bitbox![1]], + ); + assert_eq!( + cache_0 + .insert_calls + .lock() + .unwrap() + .drain(..) + .map(|(input, _, _)| input) + .collect::<Vec<_>>(), + vec![ + bitbox![1], + bitbox![1, 1], + bitbox![1, 1, 0], + bitbox![1, 1, 0, 0] + ], + ); + + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![0].into(), + &nonce, + &IdpfOutputShare::Inner(inner_values[0]), + &mut cache_0, + &mut cache_1, + ); + assert_eq!( + cache_0 + .get_calls + .lock() + .unwrap() + .drain(..) + .collect::<Vec<BitBox>>(), + Vec::<BitBox>::new(), + ); + assert_eq!( + cache_0 + .insert_calls + .lock() + .unwrap() + .drain(..) + .map(|(input, _, _)| input) + .collect::<Vec<_>>(), + vec![bitbox![0]], + ); + + check_idpf_poplar_evaluation( + &public_share, + &keys, + &bitbox![0, 1].into(), + &nonce, + &IdpfOutputShare::Inner(inner_values[1]), + &mut cache_0, + &mut cache_1, + ); + assert_eq!( + cache_0 + .get_calls + .lock() + .unwrap() + .drain(..) + .collect::<Vec<_>>(), + vec![bitbox![0]], + ); + assert_eq!( + cache_0 + .insert_calls + .lock() + .unwrap() + .drain(..) + .map(|(input, _, _)| input) + .collect::<Vec<_>>(), + vec![bitbox![0, 1]], + ); + + check_idpf_poplar_evaluation( + &public_share, + &keys, + &input, + &nonce, + &IdpfOutputShare::Leaf(leaf_values), + &mut cache_0, + &mut cache_1, + ); + assert_eq!( + cache_0 + .get_calls + .lock() + .unwrap() + .drain(..) + .collect::<Vec<_>>(), + vec![ + bitbox![0, 1, 1, 1, 0, 1, 0], + bitbox![0, 1, 1, 1, 0, 1], + bitbox![0, 1, 1, 1, 0], + bitbox![0, 1, 1, 1], + bitbox![0, 1, 1], + bitbox![0, 1], + ], + ); + assert_eq!( + cache_0 + .insert_calls + .lock() + .unwrap() + .drain(..) + .map(|(input, _, _)| input) + .collect::<Vec<_>>(), + vec![ + bitbox![0, 1, 1], + bitbox![0, 1, 1, 1], + bitbox![0, 1, 1, 1, 0], + bitbox![0, 1, 1, 1, 0, 1], + bitbox![0, 1, 1, 1, 0, 1, 0], + ], + ); + + check_idpf_poplar_evaluation( + &public_share, + &keys, + &input, + &nonce, + &IdpfOutputShare::Leaf(leaf_values), + &mut cache_0, + &mut cache_1, + ); + assert_eq!( + cache_0 + .get_calls + .lock() + .unwrap() + .drain(..) + .collect::<Vec<_>>(), + vec![bitbox![0, 1, 1, 1, 0, 1, 0]], + ); + assert!(cache_0.insert_calls.lock().unwrap().is_empty()); + } + + #[test] + fn idpf_poplar_lossy_cache() { + let bits = bitbox![1, 0, 0, 1, 1, 0, 1, 0]; + let input = bits.into(); + + let mut inner_values = Vec::with_capacity(7); + let mut prng = Prng::new().unwrap(); + for _ in 0..7 { + inner_values.push(Poplar1IdpfValue::new([ + Field64::one(), + prng.next().unwrap(), + ])); + } + let leaf_values = + Poplar1IdpfValue::new([Field255::one(), Prng::new().unwrap().next().unwrap()]); + + let nonce: [u8; 16] = random(); + let idpf = Idpf::new((), ()); + let (public_share, keys) = idpf + .gen(&input, inner_values.clone(), leaf_values, &nonce) + .unwrap(); + let mut cache_0 = LossyCache::new(); + let mut cache_1 = LossyCache::new(); + + for (level, values) in inner_values.iter().enumerate() { + check_idpf_poplar_evaluation( + &public_share, + &keys, + &input[..=level].to_owned().into(), + &nonce, + &IdpfOutputShare::Inner(*values), + &mut cache_0, + &mut cache_1, + ); + } + check_idpf_poplar_evaluation( + &public_share, + &keys, + &input, + &nonce, + &IdpfOutputShare::Leaf(leaf_values), + &mut cache_0, + &mut cache_1, + ); + } + + #[test] + fn test_idpf_poplar_error_cases() { + let nonce: [u8; 16] = random(); + let idpf = Idpf::new((), ()); + // Zero bits does not make sense. + idpf.gen( + &bitbox![].into(), + Vec::<Poplar1IdpfValue<Field64>>::new(), + Poplar1IdpfValue::new([Field255::zero(); 2]), + &nonce, + ) + .unwrap_err(); + + let (public_share, keys) = idpf + .gen( + &bitbox![0;10].into(), + Vec::from([Poplar1IdpfValue::new([Field64::zero(); 2]); 9]), + Poplar1IdpfValue::new([Field255::zero(); 2]), + &nonce, + ) + .unwrap(); + + // Wrong number of values. + idpf.gen( + &bitbox![0; 10].into(), + Vec::from([Poplar1IdpfValue::new([Field64::zero(); 2]); 8]), + Poplar1IdpfValue::new([Field255::zero(); 2]), + &nonce, + ) + .unwrap_err(); + idpf.gen( + &bitbox![0; 10].into(), + Vec::from([Poplar1IdpfValue::new([Field64::zero(); 2]); 10]), + Poplar1IdpfValue::new([Field255::zero(); 2]), + &nonce, + ) + .unwrap_err(); + + // Evaluating with empty prefix. + assert!(idpf + .eval( + 0, + &public_share, + &keys[0], + &bitbox![].into(), + &nonce, + &mut NoCache::new(), + ) + .is_err()); + // Evaluating with too-long prefix. + assert!(idpf + .eval( + 0, + &public_share, + &keys[0], + &bitbox![0; 11].into(), + &nonce, + &mut NoCache::new(), + ) + .is_err()); + } + + #[test] + fn idpf_poplar_public_share_round_trip() { + let public_share = IdpfPublicShare { + inner_correction_words: Vec::from([ + IdpfCorrectionWord { + seed: [0xab; 16], + control_bits: [Choice::from(1), Choice::from(0)], + value: Poplar1IdpfValue::new([ + Field64::try_from(83261u64).unwrap(), + Field64::try_from(125159u64).unwrap(), + ]), + }, + IdpfCorrectionWord{ + seed: [0xcd;16], + control_bits: [Choice::from(0), Choice::from(1)], + value: Poplar1IdpfValue::new([ + Field64::try_from(17614120u64).unwrap(), + Field64::try_from(20674u64).unwrap(), + ]), + }, + ]), + leaf_correction_word: IdpfCorrectionWord { + seed: [0xff; 16], + control_bits: [Choice::from(1), Choice::from(1)], + value: Poplar1IdpfValue::new([ + Field255::one(), + Field255::get_decoded( + b"\xf0\xde\xbc\x9a\x78\x56\x34\x12\xf0\xde\xbc\x9a\x78\x56\x34\x12\xf0\xde\xbc\x9a\x78\x56\x34\x12\xf0\xde\xbc\x9a\x78\x56\x34\x12", // field element correction word, continued + ).unwrap(), + ]), + }, + }; + let message = hex::decode(concat!( + "39", // packed control bit correction words (0b00111001) + "abababababababababababababababab", // seed correction word, first level + "3d45010000000000", // field element correction word + "e7e8010000000000", // field element correction word, continued + "cdcdcdcdcdcdcdcdcdcdcdcdcdcdcdcd", // seed correction word, second level + "28c50c0100000000", // field element correction word + "c250000000000000", // field element correction word, continued + "ffffffffffffffffffffffffffffffff", // seed correction word, third level + "0100000000000000000000000000000000000000000000000000000000000000", // field element correction word, leaf field + "f0debc9a78563412f0debc9a78563412f0debc9a78563412f0debc9a78563412", // field element correction word, continued + )) + .unwrap(); + let encoded = public_share.get_encoded(); + let decoded = IdpfPublicShare::get_decoded_with_param(&3, &message).unwrap(); + assert_eq!(public_share, decoded); + assert_eq!(message, encoded); + assert_eq!(public_share.encoded_len().unwrap(), encoded.len()); + + // check serialization of packed control bits when they span multiple bytes: + let public_share = IdpfPublicShare { + inner_correction_words: Vec::from([ + IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(1), Choice::from(1)], + value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]), + }, + IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(1), Choice::from(1)], + value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]), + }, + IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(1), Choice::from(0)], + value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]), + }, + IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(1), Choice::from(1)], + value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]), + }, + IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(1), Choice::from(1)], + value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]), + }, + IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(0), Choice::from(1)], + value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]), + }, + IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(1), Choice::from(1)], + value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]), + }, + IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(1), Choice::from(1)], + value: Poplar1IdpfValue::new([Field64::zero(), Field64::zero()]), + }, + ]), + leaf_correction_word: IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(0), Choice::from(1)], + value: Poplar1IdpfValue::new([Field255::zero(), Field255::zero()]), + }, + }; + let message = hex::decode(concat!( + "dffb02", // packed correction word control bits: 0b11011111, 0b11111011, 0b10 + "00000000000000000000000000000000", + "0000000000000000", + "0000000000000000", + "00000000000000000000000000000000", + "0000000000000000", + "0000000000000000", + "00000000000000000000000000000000", + "0000000000000000", + "0000000000000000", + "00000000000000000000000000000000", + "0000000000000000", + "0000000000000000", + "00000000000000000000000000000000", + "0000000000000000", + "0000000000000000", + "00000000000000000000000000000000", + "0000000000000000", + "0000000000000000", + "00000000000000000000000000000000", + "0000000000000000", + "0000000000000000", + "00000000000000000000000000000000", + "0000000000000000", + "0000000000000000", + "00000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + "0000000000000000000000000000000000000000000000000000000000000000", + )) + .unwrap(); + let encoded = public_share.get_encoded(); + let decoded = IdpfPublicShare::get_decoded_with_param(&9, &message).unwrap(); + assert_eq!(public_share, decoded); + assert_eq!(message, encoded); + } + + #[test] + fn idpf_poplar_public_share_control_bit_codec() { + let test_cases = [ + (&[false, true][..], &[0b10][..]), + ( + &[false, false, true, false, false, true][..], + &[0b10_0100u8][..], + ), + ( + &[ + true, true, false, true, false, false, false, false, true, true, + ][..], + &[0b0000_1011, 0b11][..], + ), + ( + &[ + true, true, false, true, false, true, true, true, false, true, false, true, + false, false, true, false, + ][..], + &[0b1110_1011, 0b0100_1010][..], + ), + ( + &[ + true, true, true, true, true, false, true, true, false, true, true, true, + false, true, false, true, false, false, true, false, true, true, + ][..], + &[0b1101_1111, 0b1010_1110, 0b11_0100][..], + ), + ]; + + for (control_bits, serialized_control_bits) in test_cases { + let public_share = IdpfPublicShare::< + Poplar1IdpfValue<Field64>, + Poplar1IdpfValue<Field255>, + > { + inner_correction_words: control_bits[..control_bits.len() - 2] + .chunks(2) + .map(|chunk| IdpfCorrectionWord { + seed: [0; 16], + control_bits: [Choice::from(chunk[0] as u8), Choice::from(chunk[1] as u8)], + value: Poplar1IdpfValue::new([Field64::zero(); 2]), + }) + .collect(), + leaf_correction_word: IdpfCorrectionWord { + seed: [0; 16], + control_bits: [ + Choice::from(control_bits[control_bits.len() - 2] as u8), + Choice::from(control_bits[control_bits.len() - 1] as u8), + ], + value: Poplar1IdpfValue::new([Field255::zero(); 2]), + }, + }; + + let mut serialized_public_share = serialized_control_bits.to_owned(); + let idpf_bits = control_bits.len() / 2; + let size_seeds = 16 * idpf_bits; + let size_field_vecs = + Field64::ENCODED_SIZE * 2 * (idpf_bits - 1) + Field255::ENCODED_SIZE * 2; + serialized_public_share.resize( + serialized_control_bits.len() + size_seeds + size_field_vecs, + 0, + ); + + assert_eq!(public_share.get_encoded(), serialized_public_share); + assert_eq!( + IdpfPublicShare::get_decoded_with_param(&idpf_bits, &serialized_public_share) + .unwrap(), + public_share + ); + } + } + + #[test] + fn idpf_poplar_public_share_unused_bits() { + let mut buf = vec![0u8; 4096]; + + buf[0] = 1 << 2; + let err = + IdpfPublicShare::<Field64, Field255>::decode_with_param(&1, &mut Cursor::new(&buf)) + .unwrap_err(); + assert_matches!(err, CodecError::UnexpectedValue); + + buf[0] = 1 << 4; + let err = + IdpfPublicShare::<Field64, Field255>::decode_with_param(&2, &mut Cursor::new(&buf)) + .unwrap_err(); + assert_matches!(err, CodecError::UnexpectedValue); + + buf[0] = 1 << 6; + let err = + IdpfPublicShare::<Field64, Field255>::decode_with_param(&3, &mut Cursor::new(&buf)) + .unwrap_err(); + assert_matches!(err, CodecError::UnexpectedValue); + + buf[0] = 0; + buf[1] = 1 << 2; + let err = + IdpfPublicShare::<Field64, Field255>::decode_with_param(&5, &mut Cursor::new(&buf)) + .unwrap_err(); + assert_matches!(err, CodecError::UnexpectedValue); + } + + /// Stores a test vector for the IDPF key generation algorithm. + struct IdpfTestVector { + /// The number of bits in IDPF inputs. + bits: usize, + /// The binder string used when generating and evaluating keys. + binder: Vec<u8>, + /// The IDPF input provided to the key generation algorithm. + alpha: IdpfInput, + /// The IDPF output values, at each inner level, provided to the key generation algorithm. + beta_inner: Vec<Poplar1IdpfValue<Field64>>, + /// The IDPF output values for the leaf level, provided to the key generation algorithm. + beta_leaf: Poplar1IdpfValue<Field255>, + /// The two keys returned by the key generation algorithm. + keys: [[u8; 16]; 2], + /// The public share returned by the key generation algorithm. + public_share: Vec<u8>, + } + + /// Load a test vector for Idpf key generation. + fn load_idpfpoplar_test_vector() -> IdpfTestVector { + let test_vec: serde_json::Value = + serde_json::from_str(include_str!("vdaf/test_vec/07/IdpfPoplar_0.json")).unwrap(); + let test_vec_obj = test_vec.as_object().unwrap(); + + let bits = test_vec_obj + .get("bits") + .unwrap() + .as_u64() + .unwrap() + .try_into() + .unwrap(); + + let alpha_str = test_vec_obj.get("alpha").unwrap().as_str().unwrap(); + let alpha_bignum = BigUint::from_str(alpha_str).unwrap(); + let zero_bignum = BigUint::from(0u8); + let one_bignum = BigUint::from(1u8); + let alpha_bits = (0..bits) + .map(|level| (&alpha_bignum >> (bits - level - 1)) & &one_bignum != zero_bignum) + .collect::<BitVec>(); + let alpha = alpha_bits.into(); + + let beta_inner_level_array = test_vec_obj.get("beta_inner").unwrap().as_array().unwrap(); + let beta_inner = beta_inner_level_array + .iter() + .map(|array| { + Poplar1IdpfValue::new([ + Field64::from(array[0].as_str().unwrap().parse::<u64>().unwrap()), + Field64::from(array[1].as_str().unwrap().parse::<u64>().unwrap()), + ]) + }) + .collect::<Vec<_>>(); + + let beta_leaf_array = test_vec_obj.get("beta_leaf").unwrap().as_array().unwrap(); + let beta_leaf = Poplar1IdpfValue::new([ + Field255::from( + beta_leaf_array[0] + .as_str() + .unwrap() + .parse::<BigUint>() + .unwrap(), + ), + Field255::from( + beta_leaf_array[1] + .as_str() + .unwrap() + .parse::<BigUint>() + .unwrap(), + ), + ]); + + let keys_array = test_vec_obj.get("keys").unwrap().as_array().unwrap(); + let keys = [ + hex::decode(keys_array[0].as_str().unwrap()) + .unwrap() + .try_into() + .unwrap(), + hex::decode(keys_array[1].as_str().unwrap()) + .unwrap() + .try_into() + .unwrap(), + ]; + + let public_share_hex = test_vec_obj.get("public_share").unwrap(); + let public_share = hex::decode(public_share_hex.as_str().unwrap()).unwrap(); + + let binder_hex = test_vec_obj.get("binder").unwrap(); + let binder = hex::decode(binder_hex.as_str().unwrap()).unwrap(); + + IdpfTestVector { + bits, + binder, + alpha, + beta_inner, + beta_leaf, + keys, + public_share, + } + } + + #[test] + fn idpf_poplar_generate_test_vector() { + let test_vector = load_idpfpoplar_test_vector(); + let idpf = Idpf::new((), ()); + let (public_share, keys) = idpf + .gen_with_random( + &test_vector.alpha, + test_vector.beta_inner, + test_vector.beta_leaf, + &test_vector.binder, + &test_vector.keys, + ) + .unwrap(); + + assert_eq!(keys[0].0, test_vector.keys[0]); + assert_eq!(keys[1].0, test_vector.keys[1]); + + let expected_public_share = + IdpfPublicShare::get_decoded_with_param(&test_vector.bits, &test_vector.public_share) + .unwrap(); + for (level, (correction_words, expected_correction_words)) in public_share + .inner_correction_words + .iter() + .zip(expected_public_share.inner_correction_words.iter()) + .enumerate() + { + assert_eq!( + correction_words, expected_correction_words, + "layer {level} did not match\n{correction_words:#x?}\n{expected_correction_words:#x?}" + ); + } + assert_eq!( + public_share.leaf_correction_word, + expected_public_share.leaf_correction_word + ); + + assert_eq!( + public_share, expected_public_share, + "public share did not match\n{public_share:#x?}\n{expected_public_share:#x?}" + ); + let encoded_public_share = public_share.get_encoded(); + assert_eq!(encoded_public_share, test_vector.public_share); + } + + #[test] + fn idpf_input_from_bytes_to_bytes() { + let test_cases: &[&[u8]] = &[b"hello", b"banana", &[1], &[127], &[1, 2, 3, 4], &[]]; + for test_case in test_cases { + assert_eq!(&IdpfInput::from_bytes(test_case).to_bytes(), test_case); + } + } + + #[test] + fn idpf_input_from_bools_to_bytes() { + let input = IdpfInput::from_bools(&[true; 7]); + assert_eq!(input.to_bytes(), &[254]); + let input = IdpfInput::from_bools(&[true; 9]); + assert_eq!(input.to_bytes(), &[255, 128]); + } + + /// Demonstrate use of an IDPF with values that need run-time parameters for random generation. + #[test] + fn idpf_with_value_parameters() { + use super::IdpfValue; + + /// A test-only type for use as an [`IdpfValue`]. + #[derive(Debug, Clone, Copy)] + struct MyUnit; + + impl IdpfValue for MyUnit { + type ValueParameter = (); + + fn generate<S>(_: &mut S, _: &Self::ValueParameter) -> Self + where + S: rand_core::RngCore, + { + MyUnit + } + + fn zero(_: &()) -> Self { + MyUnit + } + + fn conditional_select(_: &Self, _: &Self, _: Choice) -> Self { + MyUnit + } + } + + impl Encode for MyUnit { + fn encode(&self, _: &mut Vec<u8>) {} + } + + impl Decode for MyUnit { + fn decode(_: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(MyUnit) + } + } + + impl ConditionallySelectable for MyUnit { + fn conditional_select(_: &Self, _: &Self, _: Choice) -> Self { + MyUnit + } + } + + impl ConditionallyNegatable for MyUnit { + fn conditional_negate(&mut self, _: Choice) {} + } + + impl Add for MyUnit { + type Output = Self; + + fn add(self, _: Self) -> Self::Output { + MyUnit + } + } + + impl AddAssign for MyUnit { + fn add_assign(&mut self, _: Self) {} + } + + impl Sub for MyUnit { + type Output = Self; + + fn sub(self, _: Self) -> Self::Output { + MyUnit + } + } + + /// A test-only type for use as an [`IdpfValue`], representing a variable-length vector of + /// field elements. The length must be fixed before generating IDPF keys, but we assume it + /// is not known at compile time. + #[derive(Debug, Clone)] + struct MyVector(Vec<Field128>); + + impl IdpfValue for MyVector { + type ValueParameter = usize; + + fn generate<S>(seed_stream: &mut S, length: &Self::ValueParameter) -> Self + where + S: rand_core::RngCore, + { + let mut output = vec![<Field128 as FieldElement>::zero(); *length]; + for element in output.iter_mut() { + *element = <Field128 as IdpfValue>::generate(seed_stream, &()); + } + MyVector(output) + } + + fn zero(length: &usize) -> Self { + MyVector(vec![<Field128 as FieldElement>::zero(); *length]) + } + + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + debug_assert_eq!(a.0.len(), b.0.len()); + let mut output = vec![<Field128 as FieldElement>::zero(); a.0.len()]; + for ((a_elem, b_elem), output_elem) in + a.0.iter().zip(b.0.iter()).zip(output.iter_mut()) + { + *output_elem = <Field128 as ConditionallySelectable>::conditional_select( + a_elem, b_elem, choice, + ); + } + MyVector(output) + } + } + + impl Encode for MyVector { + fn encode(&self, bytes: &mut Vec<u8>) { + encode_u32_items(bytes, &(), &self.0); + } + } + + impl Decode for MyVector { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + decode_u32_items(&(), bytes).map(MyVector) + } + } + + impl ConditionallyNegatable for MyVector { + fn conditional_negate(&mut self, choice: Choice) { + for element in self.0.iter_mut() { + element.conditional_negate(choice); + } + } + } + + impl Add for MyVector { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + debug_assert_eq!(self.0.len(), rhs.0.len()); + let mut output = vec![<Field128 as FieldElement>::zero(); self.0.len()]; + for ((left_elem, right_elem), output_elem) in + self.0.iter().zip(rhs.0.iter()).zip(output.iter_mut()) + { + *output_elem = left_elem + right_elem; + } + MyVector(output) + } + } + + impl AddAssign for MyVector { + fn add_assign(&mut self, rhs: Self) { + debug_assert_eq!(self.0.len(), rhs.0.len()); + for (self_elem, right_elem) in self.0.iter_mut().zip(rhs.0.iter()) { + *self_elem += *right_elem; + } + } + } + + impl Sub for MyVector { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + debug_assert_eq!(self.0.len(), rhs.0.len()); + let mut output = vec![<Field128 as FieldElement>::zero(); self.0.len()]; + for ((left_elem, right_elem), output_elem) in + self.0.iter().zip(rhs.0.iter()).zip(output.iter_mut()) + { + *output_elem = left_elem - right_elem; + } + MyVector(output) + } + } + + // Use a unit type for inner nodes, thus emulating a DPF. Use a newtype around a `Vec` for + // the leaf nodes, to test out values that require runtime parameters. + let idpf = Idpf::new((), 3); + let binder = b"binder"; + let (public_share, [key_0, key_1]) = idpf + .gen( + &IdpfInput::from_bytes(b"ae"), + [MyUnit; 15], + MyVector(Vec::from([ + Field128::from(1), + Field128::from(2), + Field128::from(3), + ])), + binder, + ) + .unwrap(); + + let zero_share_0 = idpf + .eval( + 0, + &public_share, + &key_0, + &IdpfInput::from_bytes(b"ou"), + binder, + &mut NoCache::new(), + ) + .unwrap(); + let zero_share_1 = idpf + .eval( + 1, + &public_share, + &key_1, + &IdpfInput::from_bytes(b"ou"), + binder, + &mut NoCache::new(), + ) + .unwrap(); + let zero_output = zero_share_0.merge(zero_share_1).unwrap(); + assert_matches!(zero_output, IdpfOutputShare::Leaf(value) => { + assert_eq!(value.0.len(), 3); + assert_eq!(value.0[0], <Field128 as FieldElement>::zero()); + assert_eq!(value.0[1], <Field128 as FieldElement>::zero()); + assert_eq!(value.0[2], <Field128 as FieldElement>::zero()); + }); + + let programmed_share_0 = idpf + .eval( + 0, + &public_share, + &key_0, + &IdpfInput::from_bytes(b"ae"), + binder, + &mut NoCache::new(), + ) + .unwrap(); + let programmed_share_1 = idpf + .eval( + 1, + &public_share, + &key_1, + &IdpfInput::from_bytes(b"ae"), + binder, + &mut NoCache::new(), + ) + .unwrap(); + let programmed_output = programmed_share_0.merge(programmed_share_1).unwrap(); + assert_matches!(programmed_output, IdpfOutputShare::Leaf(value) => { + assert_eq!(value.0.len(), 3); + assert_eq!(value.0[0], Field128::from(1)); + assert_eq!(value.0[1], Field128::from(2)); + assert_eq!(value.0[2], Field128::from(3)); + }); + } +} diff --git a/third_party/rust/prio/src/lib.rs b/third_party/rust/prio/src/lib.rs new file mode 100644 index 0000000000..c9d4e22c49 --- /dev/null +++ b/third_party/rust/prio/src/lib.rs @@ -0,0 +1,34 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +#![warn(missing_docs)] +#![cfg_attr(docsrs, feature(doc_cfg))] + +//! # libprio-rs +//! +//! Implementation of the [Prio](https://crypto.stanford.edu/prio/) private data aggregation +//! protocol. +//! +//! Prio3 is available in the `vdaf` module as part of an implementation of [Verifiable Distributed +//! Aggregation Functions][vdaf], along with an experimental implementation of Poplar1. +//! +//! [vdaf]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/05/ + +pub mod benchmarked; +pub mod codec; +#[cfg(feature = "experimental")] +pub mod dp; +mod fft; +pub mod field; +pub mod flp; +mod fp; +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "crypto-dependencies", feature = "experimental"))) +)] +pub mod idpf; +mod polynomial; +mod prng; +pub mod topology; +pub mod vdaf; diff --git a/third_party/rust/prio/src/polynomial.rs b/third_party/rust/prio/src/polynomial.rs new file mode 100644 index 0000000000..89d8a91404 --- /dev/null +++ b/third_party/rust/prio/src/polynomial.rs @@ -0,0 +1,383 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Functions for polynomial interpolation and evaluation + +#[cfg(feature = "prio2")] +use crate::fft::{discrete_fourier_transform, discrete_fourier_transform_inv_finish}; +use crate::field::FftFriendlyFieldElement; + +use std::convert::TryFrom; + +/// Temporary memory used for FFT +#[derive(Clone, Debug)] +pub struct PolyFFTTempMemory<F> { + fft_tmp: Vec<F>, + fft_y_sub: Vec<F>, + fft_roots_sub: Vec<F>, +} + +impl<F: FftFriendlyFieldElement> PolyFFTTempMemory<F> { + fn new(length: usize) -> Self { + PolyFFTTempMemory { + fft_tmp: vec![F::zero(); length], + fft_y_sub: vec![F::zero(); length], + fft_roots_sub: vec![F::zero(); length], + } + } +} + +/// Auxiliary memory for polynomial interpolation and evaluation +#[derive(Clone, Debug)] +pub struct PolyAuxMemory<F> { + pub roots_2n: Vec<F>, + pub roots_2n_inverted: Vec<F>, + pub roots_n: Vec<F>, + pub roots_n_inverted: Vec<F>, + pub coeffs: Vec<F>, + pub fft_memory: PolyFFTTempMemory<F>, +} + +impl<F: FftFriendlyFieldElement> PolyAuxMemory<F> { + pub fn new(n: usize) -> Self { + PolyAuxMemory { + roots_2n: fft_get_roots(2 * n, false), + roots_2n_inverted: fft_get_roots(2 * n, true), + roots_n: fft_get_roots(n, false), + roots_n_inverted: fft_get_roots(n, true), + coeffs: vec![F::zero(); 2 * n], + fft_memory: PolyFFTTempMemory::new(2 * n), + } + } +} + +fn fft_recurse<F: FftFriendlyFieldElement>( + out: &mut [F], + n: usize, + roots: &[F], + ys: &[F], + tmp: &mut [F], + y_sub: &mut [F], + roots_sub: &mut [F], +) { + if n == 1 { + out[0] = ys[0]; + return; + } + + let half_n = n / 2; + + let (tmp_first, tmp_second) = tmp.split_at_mut(half_n); + let (y_sub_first, y_sub_second) = y_sub.split_at_mut(half_n); + let (roots_sub_first, roots_sub_second) = roots_sub.split_at_mut(half_n); + + // Recurse on the first half + for i in 0..half_n { + y_sub_first[i] = ys[i] + ys[i + half_n]; + roots_sub_first[i] = roots[2 * i]; + } + fft_recurse( + tmp_first, + half_n, + roots_sub_first, + y_sub_first, + tmp_second, + y_sub_second, + roots_sub_second, + ); + for i in 0..half_n { + out[2 * i] = tmp_first[i]; + } + + // Recurse on the second half + for i in 0..half_n { + y_sub_first[i] = ys[i] - ys[i + half_n]; + y_sub_first[i] *= roots[i]; + } + fft_recurse( + tmp_first, + half_n, + roots_sub_first, + y_sub_first, + tmp_second, + y_sub_second, + roots_sub_second, + ); + for i in 0..half_n { + out[2 * i + 1] = tmp[i]; + } +} + +/// Calculate `count` number of roots of unity of order `count` +fn fft_get_roots<F: FftFriendlyFieldElement>(count: usize, invert: bool) -> Vec<F> { + let mut roots = vec![F::zero(); count]; + let mut gen = F::generator(); + if invert { + gen = gen.inv(); + } + + roots[0] = F::one(); + let step_size = F::generator_order() / F::Integer::try_from(count).unwrap(); + // generator for subgroup of order count + gen = gen.pow(step_size); + + roots[1] = gen; + + for i in 2..count { + roots[i] = gen * roots[i - 1]; + } + + roots +} + +fn fft_interpolate_raw<F: FftFriendlyFieldElement>( + out: &mut [F], + ys: &[F], + n_points: usize, + roots: &[F], + invert: bool, + mem: &mut PolyFFTTempMemory<F>, +) { + fft_recurse( + out, + n_points, + roots, + ys, + &mut mem.fft_tmp, + &mut mem.fft_y_sub, + &mut mem.fft_roots_sub, + ); + if invert { + let n_inverse = F::from(F::Integer::try_from(n_points).unwrap()).inv(); + for out_val in out[0..n_points].iter_mut() { + *out_val *= n_inverse; + } + } +} + +pub fn poly_fft<F: FftFriendlyFieldElement>( + points_out: &mut [F], + points_in: &[F], + scaled_roots: &[F], + n_points: usize, + invert: bool, + mem: &mut PolyFFTTempMemory<F>, +) { + fft_interpolate_raw(points_out, points_in, n_points, scaled_roots, invert, mem) +} + +// Evaluate a polynomial using Horner's method. +pub fn poly_eval<F: FftFriendlyFieldElement>(poly: &[F], eval_at: F) -> F { + if poly.is_empty() { + return F::zero(); + } + + let mut result = poly[poly.len() - 1]; + for i in (0..poly.len() - 1).rev() { + result *= eval_at; + result += poly[i]; + } + + result +} + +// Returns the degree of polynomial `p`. +pub fn poly_deg<F: FftFriendlyFieldElement>(p: &[F]) -> usize { + let mut d = p.len(); + while d > 0 && p[d - 1] == F::zero() { + d -= 1; + } + d.saturating_sub(1) +} + +// Multiplies polynomials `p` and `q` and returns the result. +pub fn poly_mul<F: FftFriendlyFieldElement>(p: &[F], q: &[F]) -> Vec<F> { + let p_size = poly_deg(p) + 1; + let q_size = poly_deg(q) + 1; + let mut out = vec![F::zero(); p_size + q_size]; + for i in 0..p_size { + for j in 0..q_size { + out[i + j] += p[i] * q[j]; + } + } + out.truncate(poly_deg(&out) + 1); + out +} + +#[cfg(feature = "prio2")] +#[inline] +pub fn poly_interpret_eval<F: FftFriendlyFieldElement>( + points: &[F], + eval_at: F, + tmp_coeffs: &mut [F], +) -> F { + let size_inv = F::from(F::Integer::try_from(points.len()).unwrap()).inv(); + discrete_fourier_transform(tmp_coeffs, points, points.len()).unwrap(); + discrete_fourier_transform_inv_finish(tmp_coeffs, points.len(), size_inv); + poly_eval(&tmp_coeffs[..points.len()], eval_at) +} + +// Returns a polynomial that evaluates to `0` if the input is in range `[start, end)`. Otherwise, +// the output is not `0`. +pub(crate) fn poly_range_check<F: FftFriendlyFieldElement>(start: usize, end: usize) -> Vec<F> { + let mut p = vec![F::one()]; + let mut q = [F::zero(), F::one()]; + for i in start..end { + q[0] = -F::from(F::Integer::try_from(i).unwrap()); + p = poly_mul(&p, &q); + } + p +} + +#[cfg(test)] +mod tests { + use crate::{ + field::{ + FftFriendlyFieldElement, Field64, FieldElement, FieldElementWithInteger, FieldPrio2, + }, + polynomial::{ + fft_get_roots, poly_deg, poly_eval, poly_fft, poly_mul, poly_range_check, PolyAuxMemory, + }, + }; + use rand::prelude::*; + use std::convert::TryFrom; + + #[test] + fn test_roots() { + let count = 128; + let roots = fft_get_roots::<FieldPrio2>(count, false); + let roots_inv = fft_get_roots::<FieldPrio2>(count, true); + + for i in 0..count { + assert_eq!(roots[i] * roots_inv[i], 1); + assert_eq!(roots[i].pow(u32::try_from(count).unwrap()), 1); + assert_eq!(roots_inv[i].pow(u32::try_from(count).unwrap()), 1); + } + } + + #[test] + fn test_eval() { + let mut poly = [FieldPrio2::from(0); 4]; + poly[0] = 2.into(); + poly[1] = 1.into(); + poly[2] = 5.into(); + // 5*3^2 + 3 + 2 = 50 + assert_eq!(poly_eval(&poly[..3], 3.into()), 50); + poly[3] = 4.into(); + // 4*3^3 + 5*3^2 + 3 + 2 = 158 + assert_eq!(poly_eval(&poly[..4], 3.into()), 158); + } + + #[test] + fn test_poly_deg() { + let zero = FieldPrio2::zero(); + let one = FieldPrio2::root(0).unwrap(); + assert_eq!(poly_deg(&[zero]), 0); + assert_eq!(poly_deg(&[one]), 0); + assert_eq!(poly_deg(&[zero, one]), 1); + assert_eq!(poly_deg(&[zero, zero, one]), 2); + assert_eq!(poly_deg(&[zero, one, one]), 2); + assert_eq!(poly_deg(&[zero, one, one, one]), 3); + assert_eq!(poly_deg(&[zero, one, one, one, zero]), 3); + assert_eq!(poly_deg(&[zero, one, one, one, zero, zero]), 3); + } + + #[test] + fn test_poly_mul() { + let p = [ + Field64::from(u64::try_from(2).unwrap()), + Field64::from(u64::try_from(3).unwrap()), + ]; + + let q = [ + Field64::one(), + Field64::zero(), + Field64::from(u64::try_from(5).unwrap()), + ]; + + let want = [ + Field64::from(u64::try_from(2).unwrap()), + Field64::from(u64::try_from(3).unwrap()), + Field64::from(u64::try_from(10).unwrap()), + Field64::from(u64::try_from(15).unwrap()), + ]; + + let got = poly_mul(&p, &q); + assert_eq!(&got, &want); + } + + #[test] + fn test_poly_range_check() { + let start = 74; + let end = 112; + let p = poly_range_check(start, end); + + // Check each number in the range. + for i in start..end { + let x = Field64::from(i as u64); + let y = poly_eval(&p, x); + assert_eq!(y, Field64::zero(), "range check failed for {i}"); + } + + // Check the number below the range. + let x = Field64::from((start - 1) as u64); + let y = poly_eval(&p, x); + assert_ne!(y, Field64::zero()); + + // Check a number above the range. + let x = Field64::from(end as u64); + let y = poly_eval(&p, x); + assert_ne!(y, Field64::zero()); + } + + #[test] + fn test_fft() { + let count = 128; + let mut mem = PolyAuxMemory::new(count / 2); + + let mut poly = vec![FieldPrio2::from(0); count]; + let mut points2 = vec![FieldPrio2::from(0); count]; + + let points = (0..count) + .map(|_| FieldPrio2::from(random::<u32>())) + .collect::<Vec<FieldPrio2>>(); + + // From points to coeffs and back + poly_fft( + &mut poly, + &points, + &mem.roots_2n, + count, + false, + &mut mem.fft_memory, + ); + poly_fft( + &mut points2, + &poly, + &mem.roots_2n_inverted, + count, + true, + &mut mem.fft_memory, + ); + + assert_eq!(points, points2); + + // interpolation + poly_fft( + &mut poly, + &points, + &mem.roots_2n, + count, + false, + &mut mem.fft_memory, + ); + + for (poly_coeff, root) in poly[..count].iter().zip(mem.roots_2n[..count].iter()) { + let mut should_be = FieldPrio2::from(0); + for (j, point_j) in points[..count].iter().enumerate() { + should_be = root.pow(u32::try_from(j).unwrap()) * *point_j + should_be; + } + assert_eq!(should_be, *poly_coeff); + } + } +} diff --git a/third_party/rust/prio/src/prng.rs b/third_party/rust/prio/src/prng.rs new file mode 100644 index 0000000000..cb7d3a54c8 --- /dev/null +++ b/third_party/rust/prio/src/prng.rs @@ -0,0 +1,278 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Tool for generating pseudorandom field elements. +//! +//! NOTE: The public API for this module is a work in progress. + +use crate::field::{FieldElement, FieldElementExt}; +#[cfg(feature = "crypto-dependencies")] +use crate::vdaf::xof::SeedStreamAes128; +#[cfg(feature = "crypto-dependencies")] +use getrandom::getrandom; +use rand_core::RngCore; + +use std::marker::PhantomData; +use std::ops::ControlFlow; + +const BUFFER_SIZE_IN_ELEMENTS: usize = 32; + +/// Errors propagated by methods in this module. +#[derive(Debug, thiserror::Error)] +pub enum PrngError { + /// Failure when calling getrandom(). + #[error("getrandom: {0}")] + GetRandom(#[from] getrandom::Error), +} + +/// This type implements an iterator that generates a pseudorandom sequence of field elements. The +/// sequence is derived from a XOF's key stream. +#[derive(Debug)] +pub(crate) struct Prng<F, S> { + phantom: PhantomData<F>, + seed_stream: S, + buffer: Vec<u8>, + buffer_index: usize, +} + +#[cfg(feature = "crypto-dependencies")] +impl<F: FieldElement> Prng<F, SeedStreamAes128> { + /// Create a [`Prng`] from a seed for Prio 2. The first 16 bytes of the seed and the last 16 + /// bytes of the seed are used, respectively, for the key and initialization vector for AES128 + /// in CTR mode. + pub(crate) fn from_prio2_seed(seed: &[u8; 32]) -> Self { + let seed_stream = SeedStreamAes128::new(&seed[..16], &seed[16..]); + Self::from_seed_stream(seed_stream) + } + + /// Create a [`Prng`] from a randomly generated seed. + pub(crate) fn new() -> Result<Self, PrngError> { + let mut seed = [0; 32]; + getrandom(&mut seed)?; + Ok(Self::from_prio2_seed(&seed)) + } +} + +impl<F, S> Prng<F, S> +where + F: FieldElement, + S: RngCore, +{ + pub(crate) fn from_seed_stream(mut seed_stream: S) -> Self { + let mut buffer = vec![0; BUFFER_SIZE_IN_ELEMENTS * F::ENCODED_SIZE]; + seed_stream.fill_bytes(&mut buffer); + + Self { + phantom: PhantomData::<F>, + seed_stream, + buffer, + buffer_index: 0, + } + } + + pub(crate) fn get(&mut self) -> F { + loop { + // Seek to the next chunk of the buffer that encodes an element of F. + for i in (self.buffer_index..self.buffer.len()).step_by(F::ENCODED_SIZE) { + let j = i + F::ENCODED_SIZE; + + if j > self.buffer.len() { + break; + } + + self.buffer_index = j; + + match F::from_random_rejection(&self.buffer[i..j]) { + ControlFlow::Break(x) => return x, + ControlFlow::Continue(()) => continue, // reject this sample + } + } + + // Refresh buffer with the next chunk of XOF output, filling the front of the buffer + // with the leftovers. This ensures continuity of the seed stream after converting the + // `Prng` to a new field type via `into_new_field()`. + let left_over = self.buffer.len() - self.buffer_index; + self.buffer.copy_within(self.buffer_index.., 0); + self.seed_stream.fill_bytes(&mut self.buffer[left_over..]); + self.buffer_index = 0; + } + } + + /// Convert this object into a field element generator for a different field. + #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] + pub(crate) fn into_new_field<F1: FieldElement>(self) -> Prng<F1, S> { + Prng { + phantom: PhantomData, + seed_stream: self.seed_stream, + buffer: self.buffer, + buffer_index: self.buffer_index, + } + } +} + +impl<F, S> Iterator for Prng<F, S> +where + F: FieldElement, + S: RngCore, +{ + type Item = F; + + fn next(&mut self) -> Option<F> { + Some(self.get()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + codec::Decode, + field::{Field64, FieldPrio2}, + vdaf::xof::{Seed, SeedStreamSha3, Xof, XofShake128}, + }; + #[cfg(feature = "prio2")] + use base64::{engine::Engine, prelude::BASE64_STANDARD}; + #[cfg(feature = "prio2")] + use sha2::{Digest, Sha256}; + use std::convert::TryInto; + + #[test] + fn secret_sharing_interop() { + let seed = [ + 0xcd, 0x85, 0x5b, 0xd4, 0x86, 0x48, 0xa4, 0xce, 0x52, 0x5c, 0x36, 0xee, 0x5a, 0x71, + 0xf3, 0x0f, 0x66, 0x80, 0xd3, 0x67, 0x53, 0x9a, 0x39, 0x6f, 0x12, 0x2f, 0xad, 0x94, + 0x4d, 0x34, 0xcb, 0x58, + ]; + + let reference = [ + 0xd0056ec5, 0xe23f9c52, 0x47e4ddb4, 0xbe5dacf6, 0x4b130aba, 0x530c7a90, 0xe8fc4ee5, + 0xb0569cb7, 0x7774cd3c, 0x7f24e6a5, 0xcc82355d, 0xc41f4f13, 0x67fe193c, 0xc94d63a4, + 0x5d7b474c, 0xcc5c9f5f, 0xe368e1d5, 0x020fa0cf, 0x9e96aa2a, 0xe924137d, 0xfa026ab9, + 0x8ebca0cc, 0x26fc58a5, 0x10a7b173, 0xb9c97291, 0x53ef0e28, 0x069cfb8e, 0xe9383cae, + 0xacb8b748, 0x6f5b9d49, 0x887d061b, 0x86db0c58, + ]; + + let share2 = extract_share_from_seed::<FieldPrio2>(reference.len(), &seed); + + assert_eq!(share2, reference); + } + + /// takes a seed and hash as base64 encoded strings + #[cfg(feature = "prio2")] + fn random_data_interop(seed_base64: &str, hash_base64: &str, len: usize) { + let seed = BASE64_STANDARD.decode(seed_base64).unwrap(); + let random_data = extract_share_from_seed::<FieldPrio2>(len, &seed); + + let random_bytes = FieldPrio2::slice_into_byte_vec(&random_data); + + let mut hasher = Sha256::new(); + hasher.update(&random_bytes); + let digest = hasher.finalize(); + assert_eq!(BASE64_STANDARD.encode(digest), hash_base64); + } + + #[test] + #[cfg(feature = "prio2")] + fn test_hash_interop() { + random_data_interop( + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", + "RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=", + 100_000, + ); + + // zero seed + random_data_interop( + "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=", + "3wHQbSwAn9GPfoNkKe1qSzWdKnu/R+hPPyRwwz6Di+w=", + 100_000, + ); + // 0, 1, 2 ... seed + random_data_interop( + "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8=", + "RtzeQuuiWdD6bW2ZTobRELDmClz1wLy3HUiKsYsITOI=", + 100_000, + ); + // one arbirtary fixed seed + random_data_interop( + "rkLrnVcU8ULaiuXTvR3OKrfpMX0kQidqVzta1pleKKg=", + "b1fMXYrGUNR3wOZ/7vmUMmY51QHoPDBzwok0fz6xC0I=", + 100_000, + ); + // all bits set seed + random_data_interop( + "//////////////////////////////////////////8=", + "iBiDaqLrv7/rX/+vs6akPiprGgYfULdh/XhoD61HQXA=", + 100_000, + ); + } + + fn extract_share_from_seed<F: FieldElement>(length: usize, seed: &[u8]) -> Vec<F> { + assert_eq!(seed.len(), 32); + Prng::from_prio2_seed(seed.try_into().unwrap()) + .take(length) + .collect() + } + + #[test] + fn rejection_sampling_test_vector() { + // These constants were found in a brute-force search, and they test that the XOF performs + // rejection sampling correctly when the raw output exceeds the prime modulus. + let seed = Seed::get_decoded(&[ + 0x29, 0xb2, 0x98, 0x64, 0xb4, 0xaa, 0x4e, 0x07, 0x2a, 0x44, 0x49, 0x24, 0xf6, 0x74, + 0x0a, 0x3d, + ]) + .unwrap(); + let expected = Field64::from(2035552711764301796); + + let seed_stream = XofShake128::seed_stream(&seed, b"", b""); + let mut prng = Prng::<Field64, _>::from_seed_stream(seed_stream); + let actual = prng.nth(33236).unwrap(); + assert_eq!(actual, expected); + + #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] + { + let mut seed_stream = XofShake128::seed_stream(&seed, b"", b""); + let mut actual = <Field64 as FieldElement>::zero(); + for _ in 0..=33236 { + actual = <Field64 as crate::idpf::IdpfValue>::generate(&mut seed_stream, &()); + } + assert_eq!(actual, expected); + } + } + + // Test that the `Prng`'s internal buffer properly copies the end of the buffer to the front + // once it reaches the end. + #[test] + fn left_over_buffer_back_fill() { + let seed = Seed::generate().unwrap(); + + let mut prng: Prng<Field64, SeedStreamSha3> = + Prng::from_seed_stream(XofShake128::seed_stream(&seed, b"", b"")); + + // Construct a `Prng` with a longer-than-usual buffer. + let mut prng_weird_buffer_size: Prng<Field64, SeedStreamSha3> = + Prng::from_seed_stream(XofShake128::seed_stream(&seed, b"", b"")); + let mut extra = [0; 7]; + prng_weird_buffer_size.seed_stream.fill_bytes(&mut extra); + prng_weird_buffer_size.buffer.extend_from_slice(&extra); + + // Check that the next several outputs match. We need to check enough outputs to ensure + // that we have to refill the buffer. + for _ in 0..BUFFER_SIZE_IN_ELEMENTS * 2 { + assert_eq!(prng.next().unwrap(), prng_weird_buffer_size.next().unwrap()); + } + } + + #[cfg(feature = "experimental")] + #[test] + fn into_new_field() { + let seed = Seed::generate().unwrap(); + let want: Prng<Field64, SeedStreamSha3> = + Prng::from_seed_stream(XofShake128::seed_stream(&seed, b"", b"")); + let want_buffer = want.buffer.clone(); + + let got: Prng<FieldPrio2, _> = want.into_new_field(); + assert_eq!(got.buffer_index, 0); + assert_eq!(got.buffer, want_buffer); + } +} diff --git a/third_party/rust/prio/src/topology/mod.rs b/third_party/rust/prio/src/topology/mod.rs new file mode 100644 index 0000000000..fdce6d722a --- /dev/null +++ b/third_party/rust/prio/src/topology/mod.rs @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementations of some aggregator communication topologies specified in [VDAF]. +//! +//! [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-06#section-5.7 + +pub mod ping_pong; diff --git a/third_party/rust/prio/src/topology/ping_pong.rs b/third_party/rust/prio/src/topology/ping_pong.rs new file mode 100644 index 0000000000..c55d4f638d --- /dev/null +++ b/third_party/rust/prio/src/topology/ping_pong.rs @@ -0,0 +1,968 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implements the Ping-Pong Topology described in [VDAF]. This topology assumes there are exactly +//! two aggregators, designated "Leader" and "Helper". This topology is required for implementing +//! the [Distributed Aggregation Protocol][DAP]. +//! +//! [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8 +//! [DAP]: https://datatracker.ietf.org/doc/html/draft-ietf-ppm-dap + +use crate::{ + codec::{decode_u32_items, encode_u32_items, CodecError, Decode, Encode, ParameterizedDecode}, + vdaf::{Aggregator, PrepareTransition, VdafError}, +}; +use std::fmt::Debug; + +/// Errors emitted by this module. +#[derive(Debug, thiserror::Error)] +pub enum PingPongError { + /// Error running prepare_init + #[error("vdaf.prepare_init: {0}")] + VdafPrepareInit(VdafError), + + /// Error running prepare_shares_to_prepare_message + #[error("vdaf.prepare_shares_to_prepare_message {0}")] + VdafPrepareSharesToPrepareMessage(VdafError), + + /// Error running prepare_next + #[error("vdaf.prepare_next {0}")] + VdafPrepareNext(VdafError), + + /// Error decoding a prepare share + #[error("decode prep share {0}")] + CodecPrepShare(CodecError), + + /// Error decoding a prepare message + #[error("decode prep message {0}")] + CodecPrepMessage(CodecError), + + /// Host is in an unexpected state + #[error("host state mismatch: in {found} expected {expected}")] + HostStateMismatch { + /// The state the host is in. + found: &'static str, + /// The state the host expected to be in. + expected: &'static str, + }, + + /// Message from peer indicates it is in an unexpected state + #[error("peer message mismatch: message is {found} expected {expected}")] + PeerMessageMismatch { + /// The state in the message from the peer. + found: &'static str, + /// The message expected from the peer. + expected: &'static str, + }, + + /// Internal error + #[error("internal error: {0}")] + InternalError(&'static str), +} + +/// Corresponds to `struct Message` in [VDAF's Ping-Pong Topology][VDAF]. All of the fields of the +/// variants are opaque byte buffers. This is because the ping-pong routines take responsibility for +/// decoding preparation shares and messages, which usually requires having the preparation state. +/// +/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8 +#[derive(Clone, PartialEq, Eq)] +pub enum PingPongMessage { + /// Corresponds to MessageType.initialize. + Initialize { + /// The leader's initial preparation share. + prep_share: Vec<u8>, + }, + /// Corresponds to MessageType.continue. + Continue { + /// The current round's preparation message. + prep_msg: Vec<u8>, + /// The next round's preparation share. + prep_share: Vec<u8>, + }, + /// Corresponds to MessageType.finish. + Finish { + /// The current round's preparation message. + prep_msg: Vec<u8>, + }, +} + +impl PingPongMessage { + fn variant(&self) -> &'static str { + match self { + Self::Initialize { .. } => "Initialize", + Self::Continue { .. } => "Continue", + Self::Finish { .. } => "Finish", + } + } +} + +impl Debug for PingPongMessage { + // We want `PingPongMessage` to implement `Debug`, but we don't want that impl to print out + // prepare shares or messages, because (1) their contents are sensitive and (2) their contents + // are long and not intelligible to humans. For both reasons they generally shouldn't get + // logged. Normally, we'd use the `derivative` crate to customize a derived `Debug`, but that + // crate has not been audited (in the `cargo vet` sense) so we can't use it here unless we audit + // 8,000+ lines of proc macros. + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple(self.variant()).finish() + } +} + +impl Encode for PingPongMessage { + fn encode(&self, bytes: &mut Vec<u8>) { + // The encoding includes an implicit discriminator byte, called MessageType in the VDAF + // spec. + match self { + Self::Initialize { prep_share } => { + 0u8.encode(bytes); + encode_u32_items(bytes, &(), prep_share); + } + Self::Continue { + prep_msg, + prep_share, + } => { + 1u8.encode(bytes); + encode_u32_items(bytes, &(), prep_msg); + encode_u32_items(bytes, &(), prep_share); + } + Self::Finish { prep_msg } => { + 2u8.encode(bytes); + encode_u32_items(bytes, &(), prep_msg); + } + } + } + + fn encoded_len(&self) -> Option<usize> { + match self { + Self::Initialize { prep_share } => Some(1 + 4 + prep_share.len()), + Self::Continue { + prep_msg, + prep_share, + } => Some(1 + 4 + prep_msg.len() + 4 + prep_share.len()), + Self::Finish { prep_msg } => Some(1 + 4 + prep_msg.len()), + } + } +} + +impl Decode for PingPongMessage { + fn decode(bytes: &mut std::io::Cursor<&[u8]>) -> Result<Self, CodecError> { + let message_type = u8::decode(bytes)?; + Ok(match message_type { + 0 => { + let prep_share = decode_u32_items(&(), bytes)?; + Self::Initialize { prep_share } + } + 1 => { + let prep_msg = decode_u32_items(&(), bytes)?; + let prep_share = decode_u32_items(&(), bytes)?; + Self::Continue { + prep_msg, + prep_share, + } + } + 2 => { + let prep_msg = decode_u32_items(&(), bytes)?; + Self::Finish { prep_msg } + } + _ => return Err(CodecError::UnexpectedValue), + }) + } +} + +/// A transition in the pong-pong topology. This represents the `ping_pong_transition` function +/// defined in [VDAF]. +/// +/// # Discussion +/// +/// The obvious implementation of `ping_pong_transition` would be a method on trait +/// [`PingPongTopology`] that returns `(State, Message)`, and then `ContinuedValue::WithMessage` +/// would contain those values. But then DAP implementations would have to store relatively large +/// VDAF prepare shares between rounds of input preparation. +/// +/// Instead, this structure stores just the previous round's prepare state and the current round's +/// preprocessed prepare message. Their encoding is much smaller than the `(State, Message)` tuple, +/// which can always be recomputed with [`Self::evaluate`]. +/// +/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8 +#[derive(Clone, Debug, Eq)] +pub struct PingPongTransition< + const VERIFY_KEY_SIZE: usize, + const NONCE_SIZE: usize, + A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, +> { + previous_prepare_state: A::PrepareState, + current_prepare_message: A::PrepareMessage, +} + +impl< + const VERIFY_KEY_SIZE: usize, + const NONCE_SIZE: usize, + A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, + > PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A> +{ + /// Evaluate this transition to obtain a new [`PingPongState`] and a [`PingPongMessage`] which + /// should be transmitted to the peer. + #[allow(clippy::type_complexity)] + pub fn evaluate( + &self, + vdaf: &A, + ) -> Result< + ( + PingPongState<VERIFY_KEY_SIZE, NONCE_SIZE, A>, + PingPongMessage, + ), + PingPongError, + > { + let prep_msg = self.current_prepare_message.get_encoded(); + + vdaf.prepare_next( + self.previous_prepare_state.clone(), + self.current_prepare_message.clone(), + ) + .map(|transition| match transition { + PrepareTransition::Continue(prep_state, prep_share) => ( + PingPongState::Continued(prep_state), + PingPongMessage::Continue { + prep_msg, + prep_share: prep_share.get_encoded(), + }, + ), + PrepareTransition::Finish(output_share) => ( + PingPongState::Finished(output_share), + PingPongMessage::Finish { prep_msg }, + ), + }) + .map_err(PingPongError::VdafPrepareNext) + } +} + +impl< + const VERIFY_KEY_SIZE: usize, + const NONCE_SIZE: usize, + A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, + > PartialEq for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A> +{ + fn eq(&self, other: &Self) -> bool { + self.previous_prepare_state == other.previous_prepare_state + && self.current_prepare_message == other.current_prepare_message + } +} + +impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A> Encode + for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A> +where + A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, + A::PrepareState: Encode, +{ + fn encode(&self, bytes: &mut Vec<u8>) { + self.previous_prepare_state.encode(bytes); + self.current_prepare_message.encode(bytes); + } + + fn encoded_len(&self) -> Option<usize> { + Some( + self.previous_prepare_state.encoded_len()? + + self.current_prepare_message.encoded_len()?, + ) + } +} + +impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A, PrepareStateDecode> + ParameterizedDecode<PrepareStateDecode> for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A> +where + A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, + A::PrepareState: ParameterizedDecode<PrepareStateDecode> + PartialEq, + A::PrepareMessage: PartialEq, +{ + fn decode_with_param( + decoding_param: &PrepareStateDecode, + bytes: &mut std::io::Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let previous_prepare_state = A::PrepareState::decode_with_param(decoding_param, bytes)?; + let current_prepare_message = + A::PrepareMessage::decode_with_param(&previous_prepare_state, bytes)?; + + Ok(Self { + previous_prepare_state, + current_prepare_message, + }) + } +} + +/// Corresponds to the `State` enumeration implicitly defined in [VDAF's Ping-Pong Topology][VDAF]. +/// VDAF describes `Start` and `Rejected` states, but the `Start` state is never instantiated in +/// code, and the `Rejected` state is represented as `std::result::Result::Err`, so this enum does +/// not include those variants. +/// +/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8 +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum PingPongState< + const VERIFY_KEY_SIZE: usize, + const NONCE_SIZE: usize, + A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, +> { + /// Preparation of the report will continue with the enclosed state. + Continued(A::PrepareState), + /// Preparation of the report is finished and has yielded the enclosed output share. + Finished(A::OutputShare), +} + +/// Values returned by [`PingPongTopology::leader_continued`] or +/// [`PingPongTopology::helper_continued`]. +#[derive(Clone, Debug)] +pub enum PingPongContinuedValue< + const VERIFY_KEY_SIZE: usize, + const NONCE_SIZE: usize, + A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, +> { + /// The operation resulted in a new state and a message to transmit to the peer. + WithMessage { + /// The transition that will be executed. Call `PingPongTransition::evaluate` to obtain the + /// next + /// [`PingPongState`] and a [`PingPongMessage`] to transmit to the peer. + transition: PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>, + }, + /// The operation caused the host to finish preparation of the input share, yielding an output + /// share and no message for the peer. + FinishedNoMessage { + /// The output share which may now be accumulated. + output_share: A::OutputShare, + }, +} + +/// Extension trait on [`crate::vdaf::Aggregator`] which adds the [VDAF Ping-Pong Topology][VDAF]. +/// +/// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8 +pub trait PingPongTopology<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>: + Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE> +{ + /// Specialization of [`PingPongState`] for this VDAF. + type State; + /// Specialization of [`PingPongContinuedValue`] for this VDAF. + type ContinuedValue; + /// Specializaton of [`PingPongTransition`] for this VDAF. + type Transition; + + /// Initialize leader state using the leader's input share. Corresponds to + /// `ping_pong_leader_init` in [VDAF]. + /// + /// If successful, the returned [`PingPongMessage`] (which will always be + /// `PingPongMessage::Initialize`) should be transmitted to the helper. The returned + /// [`PingPongState`] (which will always be `PingPongState::Continued`) should be used by the + /// leader along with the next [`PingPongMessage`] received from the helper as input to + /// [`Self::leader_continued`] to advance to the next round. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8 + fn leader_initialized( + &self, + verify_key: &[u8; VERIFY_KEY_SIZE], + agg_param: &Self::AggregationParam, + nonce: &[u8; NONCE_SIZE], + public_share: &Self::PublicShare, + input_share: &Self::InputShare, + ) -> Result<(Self::State, PingPongMessage), PingPongError>; + + /// Initialize helper state using the helper's input share and the leader's first prepare share. + /// Corresponds to `ping_pong_helper_init` in the forthcoming `draft-irtf-cfrg-vdaf-07`. + /// + /// If successful, the returned [`PingPongTransition`] should be evaluated, yielding a + /// [`PingPongMessage`], which should be transmitted to the leader, and a [`PingPongState`]. + /// + /// If the state is `PingPongState::Continued`, then it should be used by the helper along with + /// the next `PingPongMessage` received from the leader as input to [`Self::helper_continued`] + /// to advance to the next round. The helper may store the `PingPongTransition` between rounds + /// of preparation instead of the `PingPongState` and `PingPongMessage`. + /// + /// If the state is `PingPongState::Finished`, then preparation is finished and the output share + /// may be accumulated. + /// + /// # Errors + /// + /// `inbound` must be `PingPongMessage::Initialize` or the function will fail. + fn helper_initialized( + &self, + verify_key: &[u8; VERIFY_KEY_SIZE], + agg_param: &Self::AggregationParam, + nonce: &[u8; NONCE_SIZE], + public_share: &Self::PublicShare, + input_share: &Self::InputShare, + inbound: &PingPongMessage, + ) -> Result<PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, Self>, PingPongError>; + + /// Continue preparation based on the leader's current state and an incoming [`PingPongMessage`] + /// from the helper. Corresponds to `ping_pong_leader_continued` in [VDAF]. + /// + /// If successful, the returned [`PingPongContinuedValue`] will either be: + /// + /// - `PingPongContinuedValue::WithMessage { transition }`: `transition` should be evaluated, + /// yielding a [`PingPongMessage`], which should be transmitted to the helper, and a + /// [`PingPongState`]. + /// + /// If the state is `PingPongState::Continued`, then it should be used by the leader along + /// with the next `PingPongMessage` received from the helper as input to + /// [`Self::leader_continued`] to advance to the next round. The leader may store the + /// `PingPongTransition` between rounds of preparation instead of of the `PingPongState` and + /// `PingPongMessage`. + /// + /// If the state is `PingPongState::Finished`, then preparation is finished and the output + /// share may be accumulated. + /// + /// - `PingPongContinuedValue::FinishedNoMessage`: preparation is finished and the output share + /// may be accumulated. No message needs to be sent to the helper. + /// + /// # Errors + /// + /// `leader_state` must be `PingPongState::Continued` or the function will fail. + /// + /// `inbound` must not be `PingPongMessage::Initialize` or the function will fail. + /// + /// # Notes + /// + /// The specification of this function in [VDAF] takes the aggregation parameter. This version + /// does not, because [`crate::vdaf::Aggregator::prepare_preprocess`] does not take the + /// aggregation parameter. This may change in the future if/when [#670][issue] is addressed. + /// + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8 + /// [issue]: https://github.com/divviup/libprio-rs/issues/670 + fn leader_continued( + &self, + leader_state: Self::State, + agg_param: &Self::AggregationParam, + inbound: &PingPongMessage, + ) -> Result<Self::ContinuedValue, PingPongError>; + + /// PingPongContinue preparation based on the helper's current state and an incoming + /// [`PingPongMessage`] from the leader. Corresponds to `ping_pong_helper_contnued` in [VDAF]. + /// + /// If successful, the returned [`PingPongContinuedValue`] will either be: + /// + /// - `PingPongContinuedValue::WithMessage { transition }`: `transition` should be evaluated, + /// yielding a [`PingPongMessage`], which should be transmitted to the leader, and a + /// [`PingPongState`]. + /// + /// If the state is `PingPongState::Continued`, then it should be used by the helper along + /// with the next `PingPongMessage` received from the leader as input to + /// [`Self::helper_continued`] to advance to the next round. The helper may store the + /// `PingPongTransition` between rounds of preparation instead of the `PingPongState` and + /// `PingPongMessage`. + /// + /// If the state is `PingPongState::Finished`, then preparation is finished and the output + /// share may be accumulated. + /// + /// - `PingPongContinuedValue::FinishedNoMessage`: preparation is finished and the output share + /// may be accumulated. No message needs to be sent to the leader. + /// + /// # Errors + /// + /// `helper_state` must be `PingPongState::Continued` or the function will fail. + /// + /// `inbound` must not be `PingPongMessage::Initialize` or the function will fail. + /// + /// # Notes + /// + /// The specification of this function in [VDAF] takes the aggregation parameter. This version + /// does not, because [`crate::vdaf::Aggregator::prepare_preprocess`] does not take the + /// aggregation parameter. This may change in the future if/when [#670][issue] is addressed. + /// + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.8 + /// [issue]: https://github.com/divviup/libprio-rs/issues/670 + fn helper_continued( + &self, + helper_state: Self::State, + agg_param: &Self::AggregationParam, + inbound: &PingPongMessage, + ) -> Result<Self::ContinuedValue, PingPongError>; +} + +/// Private interfaces for implementing ping-pong +trait PingPongTopologyPrivate<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>: + PingPongTopology<VERIFY_KEY_SIZE, NONCE_SIZE> +{ + fn continued( + &self, + is_leader: bool, + host_state: Self::State, + agg_param: &Self::AggregationParam, + inbound: &PingPongMessage, + ) -> Result<Self::ContinuedValue, PingPongError>; +} + +impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A> + PingPongTopology<VERIFY_KEY_SIZE, NONCE_SIZE> for A +where + A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, +{ + type State = PingPongState<VERIFY_KEY_SIZE, NONCE_SIZE, Self>; + type ContinuedValue = PingPongContinuedValue<VERIFY_KEY_SIZE, NONCE_SIZE, Self>; + type Transition = PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, Self>; + + fn leader_initialized( + &self, + verify_key: &[u8; VERIFY_KEY_SIZE], + agg_param: &Self::AggregationParam, + nonce: &[u8; NONCE_SIZE], + public_share: &Self::PublicShare, + input_share: &Self::InputShare, + ) -> Result<(Self::State, PingPongMessage), PingPongError> { + self.prepare_init( + verify_key, + /* Leader */ 0, + agg_param, + nonce, + public_share, + input_share, + ) + .map(|(prep_state, prep_share)| { + ( + PingPongState::Continued(prep_state), + PingPongMessage::Initialize { + prep_share: prep_share.get_encoded(), + }, + ) + }) + .map_err(PingPongError::VdafPrepareInit) + } + + fn helper_initialized( + &self, + verify_key: &[u8; VERIFY_KEY_SIZE], + agg_param: &Self::AggregationParam, + nonce: &[u8; NONCE_SIZE], + public_share: &Self::PublicShare, + input_share: &Self::InputShare, + inbound: &PingPongMessage, + ) -> Result<Self::Transition, PingPongError> { + let (prep_state, prep_share) = self + .prepare_init( + verify_key, + /* Helper */ 1, + agg_param, + nonce, + public_share, + input_share, + ) + .map_err(PingPongError::VdafPrepareInit)?; + + let inbound_prep_share = if let PingPongMessage::Initialize { prep_share } = inbound { + Self::PrepareShare::get_decoded_with_param(&prep_state, prep_share) + .map_err(PingPongError::CodecPrepShare)? + } else { + return Err(PingPongError::PeerMessageMismatch { + found: inbound.variant(), + expected: "initialize", + }); + }; + + let current_prepare_message = self + .prepare_shares_to_prepare_message(agg_param, [inbound_prep_share, prep_share]) + .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?; + + Ok(PingPongTransition { + previous_prepare_state: prep_state, + current_prepare_message, + }) + } + + fn leader_continued( + &self, + leader_state: Self::State, + agg_param: &Self::AggregationParam, + inbound: &PingPongMessage, + ) -> Result<Self::ContinuedValue, PingPongError> { + self.continued(true, leader_state, agg_param, inbound) + } + + fn helper_continued( + &self, + helper_state: Self::State, + agg_param: &Self::AggregationParam, + inbound: &PingPongMessage, + ) -> Result<Self::ContinuedValue, PingPongError> { + self.continued(false, helper_state, agg_param, inbound) + } +} + +impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A> + PingPongTopologyPrivate<VERIFY_KEY_SIZE, NONCE_SIZE> for A +where + A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, +{ + fn continued( + &self, + is_leader: bool, + host_state: Self::State, + agg_param: &Self::AggregationParam, + inbound: &PingPongMessage, + ) -> Result<Self::ContinuedValue, PingPongError> { + let host_prep_state = if let PingPongState::Continued(state) = host_state { + state + } else { + return Err(PingPongError::HostStateMismatch { + found: "finished", + expected: "continue", + }); + }; + + let (prep_msg, next_peer_prep_share) = match inbound { + PingPongMessage::Initialize { .. } => { + return Err(PingPongError::PeerMessageMismatch { + found: inbound.variant(), + expected: "continue", + }); + } + PingPongMessage::Continue { + prep_msg, + prep_share, + } => (prep_msg, Some(prep_share)), + PingPongMessage::Finish { prep_msg } => (prep_msg, None), + }; + + let prep_msg = Self::PrepareMessage::get_decoded_with_param(&host_prep_state, prep_msg) + .map_err(PingPongError::CodecPrepMessage)?; + let host_prep_transition = self + .prepare_next(host_prep_state, prep_msg) + .map_err(PingPongError::VdafPrepareNext)?; + + match (host_prep_transition, next_peer_prep_share) { + ( + PrepareTransition::Continue(next_prep_state, next_host_prep_share), + Some(next_peer_prep_share), + ) => { + let next_peer_prep_share = Self::PrepareShare::get_decoded_with_param( + &next_prep_state, + next_peer_prep_share, + ) + .map_err(PingPongError::CodecPrepShare)?; + let mut prep_shares = [next_peer_prep_share, next_host_prep_share]; + if is_leader { + prep_shares.reverse(); + } + let current_prepare_message = self + .prepare_shares_to_prepare_message(agg_param, prep_shares) + .map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?; + + Ok(PingPongContinuedValue::WithMessage { + transition: PingPongTransition { + previous_prepare_state: next_prep_state, + current_prepare_message, + }, + }) + } + (PrepareTransition::Finish(output_share), None) => { + Ok(PingPongContinuedValue::FinishedNoMessage { output_share }) + } + (PrepareTransition::Continue(_, _), None) => { + return Err(PingPongError::PeerMessageMismatch { + found: inbound.variant(), + expected: "continue", + }) + } + (PrepareTransition::Finish(_), Some(_)) => { + return Err(PingPongError::PeerMessageMismatch { + found: inbound.variant(), + expected: "finish", + }) + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use super::*; + use crate::vdaf::dummy; + use assert_matches::assert_matches; + + #[test] + fn ping_pong_one_round() { + let verify_key = []; + let aggregation_param = dummy::AggregationParam(0); + let nonce = [0; 16]; + #[allow(clippy::let_unit_value)] + let public_share = (); + let input_share = dummy::InputShare(0); + + let leader = dummy::Vdaf::new(1); + let helper = dummy::Vdaf::new(1); + + // Leader inits into round 0 + let (leader_state, leader_message) = leader + .leader_initialized( + &verify_key, + &aggregation_param, + &nonce, + &public_share, + &input_share, + ) + .unwrap(); + + // Helper inits into round 1 + let (helper_state, helper_message) = helper + .helper_initialized( + &verify_key, + &aggregation_param, + &nonce, + &public_share, + &input_share, + &leader_message, + ) + .unwrap() + .evaluate(&helper) + .unwrap(); + + // 1 round VDAF: helper should finish immediately. + assert_matches!(helper_state, PingPongState::Finished(_)); + + let leader_state = leader + .leader_continued(leader_state, &aggregation_param, &helper_message) + .unwrap(); + // 1 round VDAF: leader should finish when it gets helper message and emit no message. + assert_matches!( + leader_state, + PingPongContinuedValue::FinishedNoMessage { .. } + ); + } + + #[test] + fn ping_pong_two_rounds() { + let verify_key = []; + let aggregation_param = dummy::AggregationParam(0); + let nonce = [0; 16]; + #[allow(clippy::let_unit_value)] + let public_share = (); + let input_share = dummy::InputShare(0); + + let leader = dummy::Vdaf::new(2); + let helper = dummy::Vdaf::new(2); + + // Leader inits into round 0 + let (leader_state, leader_message) = leader + .leader_initialized( + &verify_key, + &aggregation_param, + &nonce, + &public_share, + &input_share, + ) + .unwrap(); + + // Helper inits into round 1 + let (helper_state, helper_message) = helper + .helper_initialized( + &verify_key, + &aggregation_param, + &nonce, + &public_share, + &input_share, + &leader_message, + ) + .unwrap() + .evaluate(&helper) + .unwrap(); + + // 2 round VDAF, round 1: helper should continue. + assert_matches!(helper_state, PingPongState::Continued(_)); + + let leader_state = leader + .leader_continued(leader_state, &aggregation_param, &helper_message) + .unwrap(); + // 2 round VDAF, round 1: leader should finish and emit a finish message. + let leader_message = assert_matches!( + leader_state, PingPongContinuedValue::WithMessage { transition } => { + let (state, message) = transition.evaluate(&leader).unwrap(); + assert_matches!(state, PingPongState::Finished(_)); + message + } + ); + + let helper_state = helper + .helper_continued(helper_state, &aggregation_param, &leader_message) + .unwrap(); + // 2 round vdaf, round 1: helper should finish and emit no message. + assert_matches!( + helper_state, + PingPongContinuedValue::FinishedNoMessage { .. } + ); + } + + #[test] + fn ping_pong_three_rounds() { + let verify_key = []; + let aggregation_param = dummy::AggregationParam(0); + let nonce = [0; 16]; + #[allow(clippy::let_unit_value)] + let public_share = (); + let input_share = dummy::InputShare(0); + + let leader = dummy::Vdaf::new(3); + let helper = dummy::Vdaf::new(3); + + // Leader inits into round 0 + let (leader_state, leader_message) = leader + .leader_initialized( + &verify_key, + &aggregation_param, + &nonce, + &public_share, + &input_share, + ) + .unwrap(); + + // Helper inits into round 1 + let (helper_state, helper_message) = helper + .helper_initialized( + &verify_key, + &aggregation_param, + &nonce, + &public_share, + &input_share, + &leader_message, + ) + .unwrap() + .evaluate(&helper) + .unwrap(); + + // 3 round VDAF, round 1: helper should continue. + assert_matches!(helper_state, PingPongState::Continued(_)); + + let leader_state = leader + .leader_continued(leader_state, &aggregation_param, &helper_message) + .unwrap(); + // 3 round VDAF, round 1: leader should continue and emit a continue message. + let (leader_state, leader_message) = assert_matches!( + leader_state, PingPongContinuedValue::WithMessage { transition } => { + let (state, message) = transition.evaluate(&leader).unwrap(); + assert_matches!(state, PingPongState::Continued(_)); + (state, message) + } + ); + + let helper_state = helper + .helper_continued(helper_state, &aggregation_param, &leader_message) + .unwrap(); + // 3 round vdaf, round 2: helper should finish and emit a finish message. + let helper_message = assert_matches!( + helper_state, PingPongContinuedValue::WithMessage { transition } => { + let (state, message) = transition.evaluate(&helper).unwrap(); + assert_matches!(state, PingPongState::Finished(_)); + message + } + ); + + let leader_state = leader + .leader_continued(leader_state, &aggregation_param, &helper_message) + .unwrap(); + // 3 round VDAF, round 2: leader should finish and emit no message. + assert_matches!( + leader_state, + PingPongContinuedValue::FinishedNoMessage { .. } + ); + } + + #[test] + fn roundtrip_message() { + let messages = [ + ( + PingPongMessage::Initialize { + prep_share: Vec::from("prepare share"), + }, + concat!( + "00", // enum discriminant + concat!( + // prep_share + "0000000d", // length + "70726570617265207368617265", // contents + ), + ), + ), + ( + PingPongMessage::Continue { + prep_msg: Vec::from("prepare message"), + prep_share: Vec::from("prepare share"), + }, + concat!( + "01", // enum discriminant + concat!( + // prep_msg + "0000000f", // length + "70726570617265206d657373616765", // contents + ), + concat!( + // prep_share + "0000000d", // length + "70726570617265207368617265", // contents + ), + ), + ), + ( + PingPongMessage::Finish { + prep_msg: Vec::from("prepare message"), + }, + concat!( + "02", // enum discriminant + concat!( + // prep_msg + "0000000f", // length + "70726570617265206d657373616765", // contents + ), + ), + ), + ]; + + for (message, expected_hex) in messages { + let mut encoded_val = Vec::new(); + message.encode(&mut encoded_val); + let got_hex = hex::encode(&encoded_val); + assert_eq!( + &got_hex, expected_hex, + "Couldn't roundtrip (encoded value differs): {message:?}", + ); + let decoded_val = PingPongMessage::decode(&mut Cursor::new(&encoded_val)).unwrap(); + assert_eq!( + decoded_val, message, + "Couldn't roundtrip (decoded value differs): {message:?}" + ); + assert_eq!( + encoded_val.len(), + message.encoded_len().expect("No encoded length hint"), + "Encoded length hint is incorrect: {message:?}" + ) + } + } + + #[test] + fn roundtrip_transition() { + // VDAF implementations have tests for encoding/decoding their respective PrepareShare and + // PrepareMessage types, so we test here using the dummy VDAF. + let transition = PingPongTransition::<0, 16, dummy::Vdaf> { + previous_prepare_state: dummy::PrepareState::default(), + current_prepare_message: (), + }; + + let encoded = transition.get_encoded(); + let hex_encoded = hex::encode(&encoded); + + assert_eq!( + hex_encoded, + concat!( + concat!( + // previous_prepare_state + "00", // input_share + "00000000", // current_round + ), + // current_prepare_message (0 length encoding) + ) + ); + + let decoded = PingPongTransition::get_decoded_with_param(&(), &encoded).unwrap(); + assert_eq!(transition, decoded); + + assert_eq!( + encoded.len(), + transition.encoded_len().expect("No encoded length hint"), + ); + } +} diff --git a/third_party/rust/prio/src/vdaf.rs b/third_party/rust/prio/src/vdaf.rs new file mode 100644 index 0000000000..1a6c5f0315 --- /dev/null +++ b/third_party/rust/prio/src/vdaf.rs @@ -0,0 +1,757 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Verifiable Distributed Aggregation Functions (VDAFs) as described in +//! [[draft-irtf-cfrg-vdaf-07]]. +//! +//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ + +#[cfg(feature = "experimental")] +use crate::dp::DifferentialPrivacyStrategy; +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +use crate::idpf::IdpfError; +use crate::{ + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{encode_fieldvec, merge_vector, FieldElement, FieldError}, + flp::FlpError, + prng::PrngError, + vdaf::xof::Seed, +}; +use serde::{Deserialize, Serialize}; +use std::{fmt::Debug, io::Cursor}; +use subtle::{Choice, ConstantTimeEq}; + +/// A component of the domain-separation tag, used to bind the VDAF operations to the document +/// version. This will be revised with each draft with breaking changes. +pub(crate) const VERSION: u8 = 7; + +/// Errors emitted by this module. +#[derive(Debug, thiserror::Error)] +pub enum VdafError { + /// An error occurred. + #[error("vdaf error: {0}")] + Uncategorized(String), + + /// Field error. + #[error("field error: {0}")] + Field(#[from] FieldError), + + /// An error occured while parsing a message. + #[error("io error: {0}")] + IoError(#[from] std::io::Error), + + /// FLP error. + #[error("flp error: {0}")] + Flp(#[from] FlpError), + + /// PRNG error. + #[error("prng error: {0}")] + Prng(#[from] PrngError), + + /// Failure when calling getrandom(). + #[error("getrandom: {0}")] + GetRandom(#[from] getrandom::Error), + + /// IDPF error. + #[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] + #[error("idpf error: {0}")] + Idpf(#[from] IdpfError), +} + +/// An additive share of a vector of field elements. +#[derive(Clone, Debug)] +pub enum Share<F, const SEED_SIZE: usize> { + /// An uncompressed share, typically sent to the leader. + Leader(Vec<F>), + + /// A compressed share, typically sent to the helper. + Helper(Seed<SEED_SIZE>), +} + +impl<F: Clone, const SEED_SIZE: usize> Share<F, SEED_SIZE> { + /// Truncate the Leader's share to the given length. If this is the Helper's share, then this + /// method clones the input without modifying it. + #[cfg(feature = "prio2")] + pub(crate) fn truncated(&self, len: usize) -> Self { + match self { + Self::Leader(ref data) => Self::Leader(data[..len].to_vec()), + Self::Helper(ref seed) => Self::Helper(seed.clone()), + } + } +} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Share<F, SEED_SIZE> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Share<F, SEED_SIZE> {} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Share<F, SEED_SIZE> { + fn ct_eq(&self, other: &Self) -> subtle::Choice { + // We allow short-circuiting on the type (Leader vs Helper) of the value, but not the types' + // contents. + match (self, other) { + (Share::Leader(self_val), Share::Leader(other_val)) => self_val.ct_eq(other_val), + (Share::Helper(self_val), Share::Helper(other_val)) => self_val.ct_eq(other_val), + _ => Choice::from(0), + } + } +} + +/// Parameters needed to decode a [`Share`] +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) enum ShareDecodingParameter<const SEED_SIZE: usize> { + Leader(usize), + Helper, +} + +impl<F: FieldElement, const SEED_SIZE: usize> ParameterizedDecode<ShareDecodingParameter<SEED_SIZE>> + for Share<F, SEED_SIZE> +{ + fn decode_with_param( + decoding_parameter: &ShareDecodingParameter<SEED_SIZE>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + match decoding_parameter { + ShareDecodingParameter::Leader(share_length) => { + let mut data = Vec::with_capacity(*share_length); + for _ in 0..*share_length { + data.push(F::decode(bytes)?) + } + Ok(Self::Leader(data)) + } + ShareDecodingParameter::Helper => { + let seed = Seed::decode(bytes)?; + Ok(Self::Helper(seed)) + } + } + } +} + +impl<F: FieldElement, const SEED_SIZE: usize> Encode for Share<F, SEED_SIZE> { + fn encode(&self, bytes: &mut Vec<u8>) { + match self { + Share::Leader(share_data) => { + for x in share_data { + x.encode(bytes); + } + } + Share::Helper(share_seed) => { + share_seed.encode(bytes); + } + } + } + + fn encoded_len(&self) -> Option<usize> { + match self { + Share::Leader(share_data) => { + // Each element of the data vector has the same size. + Some(share_data.len() * F::ENCODED_SIZE) + } + Share::Helper(share_seed) => share_seed.encoded_len(), + } + } +} + +/// The base trait for VDAF schemes. This trait is inherited by traits [`Client`], [`Aggregator`], +/// and [`Collector`], which define the roles of the various parties involved in the execution of +/// the VDAF. +pub trait Vdaf: Clone + Debug { + /// Algorithm identifier for this VDAF. + const ID: u32; + + /// The type of Client measurement to be aggregated. + type Measurement: Clone + Debug; + + /// The aggregate result of the VDAF execution. + type AggregateResult: Clone + Debug; + + /// The aggregation parameter, used by the Aggregators to map their input shares to output + /// shares. + type AggregationParam: Clone + Debug + Decode + Encode; + + /// A public share sent by a Client. + type PublicShare: Clone + Debug + ParameterizedDecode<Self> + Encode; + + /// An input share sent by a Client. + type InputShare: Clone + Debug + for<'a> ParameterizedDecode<(&'a Self, usize)> + Encode; + + /// An output share recovered from an input share by an Aggregator. + type OutputShare: Clone + + Debug + + for<'a> ParameterizedDecode<(&'a Self, &'a Self::AggregationParam)> + + Encode; + + /// An Aggregator's share of the aggregate result. + type AggregateShare: Aggregatable<OutputShare = Self::OutputShare> + + for<'a> ParameterizedDecode<(&'a Self, &'a Self::AggregationParam)> + + Encode; + + /// The number of Aggregators. The Client generates as many input shares as there are + /// Aggregators. + fn num_aggregators(&self) -> usize; + + /// Generate the domain separation tag for this VDAF. The output is used for domain separation + /// by the XOF. + fn domain_separation_tag(usage: u16) -> [u8; 8] { + let mut dst = [0_u8; 8]; + dst[0] = VERSION; + dst[1] = 0; // algorithm class + dst[2..6].copy_from_slice(&(Self::ID).to_be_bytes()); + dst[6..8].copy_from_slice(&usage.to_be_bytes()); + dst + } +} + +/// The Client's role in the execution of a VDAF. +pub trait Client<const NONCE_SIZE: usize>: Vdaf { + /// Shards a measurement into a public share and a sequence of input shares, one for each + /// Aggregator. + /// + /// Implements `Vdaf::shard` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.1 + fn shard( + &self, + measurement: &Self::Measurement, + nonce: &[u8; NONCE_SIZE], + ) -> Result<(Self::PublicShare, Vec<Self::InputShare>), VdafError>; +} + +/// The Aggregator's role in the execution of a VDAF. +pub trait Aggregator<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>: Vdaf { + /// State of the Aggregator during the Prepare process. + type PrepareState: Clone + Debug + PartialEq + Eq; + + /// The type of messages sent by each aggregator at each round of the Prepare Process. + /// + /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be + /// associated with any aggregator involved in the execution of the VDAF. + type PrepareShare: Clone + Debug + ParameterizedDecode<Self::PrepareState> + Encode; + + /// Result of preprocessing a round of preparation shares. This is used by all aggregators as an + /// input to the next round of the Prepare Process. + /// + /// Decoding takes a [`Self::PrepareState`] as a parameter; this [`Self::PrepareState`] may be + /// associated with any aggregator involved in the execution of the VDAF. + type PrepareMessage: Clone + + Debug + + PartialEq + + Eq + + ParameterizedDecode<Self::PrepareState> + + Encode; + + /// Begins the Prepare process with the other Aggregators. The [`Self::PrepareState`] returned + /// is passed to [`Self::prepare_next`] to get this aggregator's first-round prepare message. + /// + /// Implements `Vdaf.prep_init` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 + fn prepare_init( + &self, + verify_key: &[u8; VERIFY_KEY_SIZE], + agg_id: usize, + agg_param: &Self::AggregationParam, + nonce: &[u8; NONCE_SIZE], + public_share: &Self::PublicShare, + input_share: &Self::InputShare, + ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError>; + + /// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`]. + /// + /// Implements `Vdaf.prep_shares_to_prep` from [VDAF]. + /// + /// # Notes + /// + /// [`Self::prepare_shares_to_prepare_message`] is preferable since its name better matches the + /// specification. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 + #[deprecated( + since = "0.15.0", + note = "Use Vdaf::prepare_shares_to_prepare_message instead" + )] + fn prepare_preprocess<M: IntoIterator<Item = Self::PrepareShare>>( + &self, + agg_param: &Self::AggregationParam, + inputs: M, + ) -> Result<Self::PrepareMessage, VdafError> { + self.prepare_shares_to_prepare_message(agg_param, inputs) + } + + /// Preprocess a round of preparation shares into a single input to [`Self::prepare_next`]. + /// + /// Implements `Vdaf.prep_shares_to_prep` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 + fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Self::PrepareShare>>( + &self, + agg_param: &Self::AggregationParam, + inputs: M, + ) -> Result<Self::PrepareMessage, VdafError>; + + /// Compute the next state transition from the current state and the previous round of input + /// messages. If this returns [`PrepareTransition::Continue`], then the returned + /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from + /// this round and passed into another call to this method. This continues until this method + /// returns [`PrepareTransition::Finish`], at which point the returned output share may be + /// aggregated. If the method returns an error, the aggregator should consider its input share + /// invalid and not attempt to process it any further. + /// + /// Implements `Vdaf.prep_next` from [VDAF]. + /// + /// # Notes + /// + /// [`Self::prepare_next`] is preferable since its name better matches the specification. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 + #[deprecated(since = "0.15.0", note = "Use Vdaf::prepare_next")] + fn prepare_step( + &self, + state: Self::PrepareState, + input: Self::PrepareMessage, + ) -> Result<PrepareTransition<Self, VERIFY_KEY_SIZE, NONCE_SIZE>, VdafError> { + self.prepare_next(state, input) + } + + /// Compute the next state transition from the current state and the previous round of input + /// messages. If this returns [`PrepareTransition::Continue`], then the returned + /// [`Self::PrepareShare`] should be combined with the other Aggregators' `PrepareShare`s from + /// this round and passed into another call to this method. This continues until this method + /// returns [`PrepareTransition::Finish`], at which point the returned output share may be + /// aggregated. If the method returns an error, the aggregator should consider its input share + /// invalid and not attempt to process it any further. + /// + /// Implements `Vdaf.prep_next` from [VDAF]. + /// + /// [VDAF]: https://datatracker.ietf.org/doc/html/draft-irtf-cfrg-vdaf-07#section-5.2 + fn prepare_next( + &self, + state: Self::PrepareState, + input: Self::PrepareMessage, + ) -> Result<PrepareTransition<Self, VERIFY_KEY_SIZE, NONCE_SIZE>, VdafError>; + + /// Aggregates a sequence of output shares into an aggregate share. + fn aggregate<M: IntoIterator<Item = Self::OutputShare>>( + &self, + agg_param: &Self::AggregationParam, + output_shares: M, + ) -> Result<Self::AggregateShare, VdafError>; +} + +/// Aggregator that implements differential privacy with Aggregator-side noise addition. +#[cfg(feature = "experimental")] +pub trait AggregatorWithNoise< + const VERIFY_KEY_SIZE: usize, + const NONCE_SIZE: usize, + DPStrategy: DifferentialPrivacyStrategy, +>: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE> +{ + /// Adds noise to an aggregate share such that the aggregate result is differentially private + /// as long as one Aggregator is honest. + fn add_noise_to_agg_share( + &self, + dp_strategy: &DPStrategy, + agg_param: &Self::AggregationParam, + agg_share: &mut Self::AggregateShare, + num_measurements: usize, + ) -> Result<(), VdafError>; +} + +/// The Collector's role in the execution of a VDAF. +pub trait Collector: Vdaf { + /// Combines aggregate shares into the aggregate result. + fn unshard<M: IntoIterator<Item = Self::AggregateShare>>( + &self, + agg_param: &Self::AggregationParam, + agg_shares: M, + num_measurements: usize, + ) -> Result<Self::AggregateResult, VdafError>; +} + +/// A state transition of an Aggregator during the Prepare process. +#[derive(Clone, Debug)] +pub enum PrepareTransition< + V: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>, + const VERIFY_KEY_SIZE: usize, + const NONCE_SIZE: usize, +> { + /// Continue processing. + Continue(V::PrepareState, V::PrepareShare), + + /// Finish processing and return the output share. + Finish(V::OutputShare), +} + +/// An aggregate share resulting from aggregating output shares together that +/// can merged with aggregate shares of the same type. +pub trait Aggregatable: Clone + Debug + From<Self::OutputShare> { + /// Type of output shares that can be accumulated into an aggregate share. + type OutputShare; + + /// Update an aggregate share by merging it with another (`agg_share`). + fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError>; + + /// Update an aggregate share by adding `output_share`. + fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError>; +} + +/// An output share comprised of a vector of field elements. +#[derive(Clone)] +pub struct OutputShare<F>(Vec<F>); + +impl<F: ConstantTimeEq> PartialEq for OutputShare<F> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<F: ConstantTimeEq> Eq for OutputShare<F> {} + +impl<F: ConstantTimeEq> ConstantTimeEq for OutputShare<F> { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +impl<F> AsRef<[F]> for OutputShare<F> { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + +impl<F> From<Vec<F>> for OutputShare<F> { + fn from(other: Vec<F>) -> Self { + Self(other) + } +} + +impl<F: FieldElement> Encode for OutputShare<F> { + fn encode(&self, bytes: &mut Vec<u8>) { + encode_fieldvec(&self.0, bytes) + } + + fn encoded_len(&self) -> Option<usize> { + Some(F::ENCODED_SIZE * self.0.len()) + } +} + +impl<F> Debug for OutputShare<F> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("OutputShare").finish() + } +} + +/// An aggregate share comprised of a vector of field elements. +/// +/// This is suitable for VDAFs where both output shares and aggregate shares are vectors of field +/// elements, and output shares need no special transformation to be merged into an aggregate share. +#[derive(Clone, Debug, Serialize, Deserialize)] + +pub struct AggregateShare<F>(Vec<F>); + +impl<F: ConstantTimeEq> PartialEq for AggregateShare<F> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<F: ConstantTimeEq> Eq for AggregateShare<F> {} + +impl<F: ConstantTimeEq> ConstantTimeEq for AggregateShare<F> { + fn ct_eq(&self, other: &Self) -> subtle::Choice { + self.0.ct_eq(&other.0) + } +} + +impl<F: FieldElement> AsRef<[F]> for AggregateShare<F> { + fn as_ref(&self) -> &[F] { + &self.0 + } +} + +impl<F> From<OutputShare<F>> for AggregateShare<F> { + fn from(other: OutputShare<F>) -> Self { + Self(other.0) + } +} + +impl<F: FieldElement> Aggregatable for AggregateShare<F> { + type OutputShare = OutputShare<F>; + + fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError> { + self.sum(agg_share.as_ref()) + } + + fn accumulate(&mut self, output_share: &Self::OutputShare) -> Result<(), VdafError> { + // For Poplar1, Prio2, and Prio3, no conversion is needed between output shares and + // aggregate shares. + self.sum(output_share.as_ref()) + } +} + +impl<F: FieldElement> AggregateShare<F> { + fn sum(&mut self, other: &[F]) -> Result<(), VdafError> { + merge_vector(&mut self.0, other).map_err(Into::into) + } +} + +impl<F: FieldElement> Encode for AggregateShare<F> { + fn encode(&self, bytes: &mut Vec<u8>) { + encode_fieldvec(&self.0, bytes) + } + + fn encoded_len(&self) -> Option<usize> { + Some(F::ENCODED_SIZE * self.0.len()) + } +} + +#[cfg(test)] +pub(crate) fn run_vdaf<V, M, const SEED_SIZE: usize>( + vdaf: &V, + agg_param: &V::AggregationParam, + measurements: M, +) -> Result<V::AggregateResult, VdafError> +where + V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector, + M: IntoIterator<Item = V::Measurement>, +{ + use rand::prelude::*; + let mut rng = thread_rng(); + let mut verify_key = [0; SEED_SIZE]; + rng.fill(&mut verify_key[..]); + + let mut agg_shares: Vec<Option<V::AggregateShare>> = vec![None; vdaf.num_aggregators()]; + let mut num_measurements: usize = 0; + for measurement in measurements.into_iter() { + num_measurements += 1; + let nonce = rng.gen(); + let (public_share, input_shares) = vdaf.shard(&measurement, &nonce)?; + let out_shares = run_vdaf_prepare( + vdaf, + &verify_key, + agg_param, + &nonce, + public_share, + input_shares, + )?; + for (out_share, agg_share) in out_shares.into_iter().zip(agg_shares.iter_mut()) { + // Check serialization of output shares + let encoded_out_share = out_share.get_encoded(); + let round_trip_out_share = + V::OutputShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_out_share) + .unwrap(); + assert_eq!(round_trip_out_share.get_encoded(), encoded_out_share); + + let this_agg_share = V::AggregateShare::from(out_share); + if let Some(ref mut inner) = agg_share { + inner.merge(&this_agg_share)?; + } else { + *agg_share = Some(this_agg_share); + } + } + } + + for agg_share in agg_shares.iter() { + // Check serialization of aggregate shares + let encoded_agg_share = agg_share.as_ref().unwrap().get_encoded(); + let round_trip_agg_share = + V::AggregateShare::get_decoded_with_param(&(vdaf, agg_param), &encoded_agg_share) + .unwrap(); + assert_eq!(round_trip_agg_share.get_encoded(), encoded_agg_share); + } + + let res = vdaf.unshard( + agg_param, + agg_shares.into_iter().map(|option| option.unwrap()), + num_measurements, + )?; + Ok(res) +} + +#[cfg(test)] +pub(crate) fn run_vdaf_prepare<V, M, const SEED_SIZE: usize>( + vdaf: &V, + verify_key: &[u8; SEED_SIZE], + agg_param: &V::AggregationParam, + nonce: &[u8; 16], + public_share: V::PublicShare, + input_shares: M, +) -> Result<Vec<V::OutputShare>, VdafError> +where + V: Client<16> + Aggregator<SEED_SIZE, 16> + Collector, + M: IntoIterator<Item = V::InputShare>, +{ + let input_shares = input_shares + .into_iter() + .map(|input_share| input_share.get_encoded()); + + let mut states = Vec::new(); + let mut outbound = Vec::new(); + for (agg_id, input_share) in input_shares.enumerate() { + let (state, msg) = vdaf.prepare_init( + verify_key, + agg_id, + agg_param, + nonce, + &public_share, + &V::InputShare::get_decoded_with_param(&(vdaf, agg_id), &input_share) + .expect("failed to decode input share"), + )?; + states.push(state); + outbound.push(msg.get_encoded()); + } + + let mut inbound = vdaf + .prepare_shares_to_prepare_message( + agg_param, + outbound.iter().map(|encoded| { + V::PrepareShare::get_decoded_with_param(&states[0], encoded) + .expect("failed to decode prep share") + }), + )? + .get_encoded(); + + let mut out_shares = Vec::new(); + loop { + let mut outbound = Vec::new(); + for state in states.iter_mut() { + match vdaf.prepare_next( + state.clone(), + V::PrepareMessage::get_decoded_with_param(state, &inbound) + .expect("failed to decode prep message"), + )? { + PrepareTransition::Continue(new_state, msg) => { + outbound.push(msg.get_encoded()); + *state = new_state + } + PrepareTransition::Finish(out_share) => { + out_shares.push(out_share); + } + } + } + + if outbound.len() == vdaf.num_aggregators() { + // Another round is required before output shares are computed. + inbound = vdaf + .prepare_shares_to_prepare_message( + agg_param, + outbound.iter().map(|encoded| { + V::PrepareShare::get_decoded_with_param(&states[0], encoded) + .expect("failed to decode prep share") + }), + )? + .get_encoded(); + } else if outbound.is_empty() { + // Each Aggregator recovered an output share. + break; + } else { + panic!("Aggregators did not finish the prepare phase at the same time"); + } + } + + Ok(out_shares) +} + +#[cfg(test)] +fn fieldvec_roundtrip_test<F, V, T>(vdaf: &V, agg_param: &V::AggregationParam, length: usize) +where + F: FieldElement, + V: Vdaf, + T: Encode, + for<'a> T: ParameterizedDecode<(&'a V, &'a V::AggregationParam)>, +{ + // Generate an arbitrary vector of field elements. + let g = F::one() + F::one(); + let vec: Vec<F> = itertools::iterate(F::one(), |&v| g * v) + .take(length) + .collect(); + + // Serialize the field element vector into a vector of bytes. + let mut bytes = Vec::with_capacity(vec.len() * F::ENCODED_SIZE); + encode_fieldvec(&vec, &mut bytes); + + // Deserialize the type of interest from those bytes. + let value = T::get_decoded_with_param(&(vdaf, agg_param), &bytes).unwrap(); + + // Round-trip the value back to a vector of bytes. + let encoded = value.get_encoded(); + + assert_eq!(encoded, bytes); +} + +#[cfg(test)] +fn equality_comparison_test<T>(values: &[T]) +where + T: Debug + PartialEq, +{ + use std::ptr; + + // This function expects that every value passed in `values` is distinct, i.e. should not + // compare as equal to any other element. We test both (i, j) and (j, i) to gain confidence that + // equality implementations are symmetric. + for (i, i_val) in values.iter().enumerate() { + for (j, j_val) in values.iter().enumerate() { + if i == j { + assert!(ptr::eq(i_val, j_val)); // sanity + assert_eq!( + i_val, j_val, + "Expected element at index {i} to be equal to itself, but it was not" + ); + } else { + assert_ne!( + i_val, j_val, + "Expected elements at indices {i} & {j} to not be equal, but they were" + ) + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::vdaf::{equality_comparison_test, xof::Seed, AggregateShare, OutputShare, Share}; + + #[test] + fn share_equality_test() { + equality_comparison_test(&[ + Share::Leader(Vec::from([1, 2, 3])), + Share::Leader(Vec::from([3, 2, 1])), + Share::Helper(Seed([1, 2, 3])), + Share::Helper(Seed([3, 2, 1])), + ]) + } + + #[test] + fn output_share_equality_test() { + equality_comparison_test(&[ + OutputShare(Vec::from([1, 2, 3])), + OutputShare(Vec::from([3, 2, 1])), + ]) + } + + #[test] + fn aggregate_share_equality_test() { + equality_comparison_test(&[ + AggregateShare(Vec::from([1, 2, 3])), + AggregateShare(Vec::from([3, 2, 1])), + ]) + } +} + +#[cfg(feature = "test-util")] +pub mod dummy; +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "crypto-dependencies", feature = "experimental"))) +)] +pub mod poplar1; +#[cfg(feature = "prio2")] +#[cfg_attr(docsrs, doc(cfg(feature = "prio2")))] +pub mod prio2; +pub mod prio3; +#[cfg(test)] +mod prio3_test; +pub mod xof; diff --git a/third_party/rust/prio/src/vdaf/dummy.rs b/third_party/rust/prio/src/vdaf/dummy.rs new file mode 100644 index 0000000000..507e7916bb --- /dev/null +++ b/third_party/rust/prio/src/vdaf/dummy.rs @@ -0,0 +1,316 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of a dummy VDAF which conforms to the specification in [draft-irtf-cfrg-vdaf-06] +//! but does nothing. Useful for testing. +//! +//! [draft-irtf-cfrg-vdaf-06]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/06/ + +use crate::{ + codec::{CodecError, Decode, Encode}, + vdaf::{self, Aggregatable, PrepareTransition, VdafError}, +}; +use rand::random; +use std::{fmt::Debug, io::Cursor, sync::Arc}; + +type ArcPrepInitFn = + Arc<dyn Fn(&AggregationParam) -> Result<(), VdafError> + 'static + Send + Sync>; +type ArcPrepStepFn = Arc< + dyn Fn(&PrepareState) -> Result<PrepareTransition<Vdaf, 0, 16>, VdafError> + + 'static + + Send + + Sync, +>; + +/// Dummy VDAF that does nothing. +#[derive(Clone)] +pub struct Vdaf { + prep_init_fn: ArcPrepInitFn, + prep_step_fn: ArcPrepStepFn, +} + +impl Debug for Vdaf { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Vdaf") + .field("prep_init_fn", &"[redacted]") + .field("prep_step_fn", &"[redacted]") + .finish() + } +} + +impl Vdaf { + /// The length of the verify key parameter for fake VDAF instantiations. + pub const VERIFY_KEY_LEN: usize = 0; + + /// Construct a new instance of the dummy VDAF. + pub fn new(rounds: u32) -> Self { + Self { + prep_init_fn: Arc::new(|_| -> Result<(), VdafError> { Ok(()) }), + prep_step_fn: Arc::new( + move |state| -> Result<PrepareTransition<Self, 0, 16>, VdafError> { + let new_round = state.current_round + 1; + if new_round == rounds { + Ok(PrepareTransition::Finish(OutputShare(state.input_share))) + } else { + Ok(PrepareTransition::Continue( + PrepareState { + current_round: new_round, + ..*state + }, + (), + )) + } + }, + ), + } + } + + /// Provide an alternate implementation of [`vdaf::Aggregator::prepare_init`]. + pub fn with_prep_init_fn<F: Fn(&AggregationParam) -> Result<(), VdafError>>( + mut self, + f: F, + ) -> Self + where + F: 'static + Send + Sync, + { + self.prep_init_fn = Arc::new(f); + self + } + + /// Provide an alternate implementation of [`vdaf::Aggregator::prepare_step`]. + pub fn with_prep_step_fn< + F: Fn(&PrepareState) -> Result<PrepareTransition<Self, 0, 16>, VdafError>, + >( + mut self, + f: F, + ) -> Self + where + F: 'static + Send + Sync, + { + self.prep_step_fn = Arc::new(f); + self + } +} + +impl Default for Vdaf { + fn default() -> Self { + Self::new(1) + } +} + +impl vdaf::Vdaf for Vdaf { + const ID: u32 = 0xFFFF0000; + + type Measurement = u8; + type AggregateResult = u8; + type AggregationParam = AggregationParam; + type PublicShare = (); + type InputShare = InputShare; + type OutputShare = OutputShare; + type AggregateShare = AggregateShare; + + fn num_aggregators(&self) -> usize { + 2 + } +} + +impl vdaf::Aggregator<0, 16> for Vdaf { + type PrepareState = PrepareState; + type PrepareShare = (); + type PrepareMessage = (); + + fn prepare_init( + &self, + _verify_key: &[u8; 0], + _: usize, + aggregation_param: &Self::AggregationParam, + _nonce: &[u8; 16], + _: &Self::PublicShare, + input_share: &Self::InputShare, + ) -> Result<(Self::PrepareState, Self::PrepareShare), VdafError> { + (self.prep_init_fn)(aggregation_param)?; + Ok(( + PrepareState { + input_share: input_share.0, + current_round: 0, + }, + (), + )) + } + + fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Self::PrepareShare>>( + &self, + _: &Self::AggregationParam, + _: M, + ) -> Result<Self::PrepareMessage, VdafError> { + Ok(()) + } + + fn prepare_next( + &self, + state: Self::PrepareState, + _: Self::PrepareMessage, + ) -> Result<PrepareTransition<Self, 0, 16>, VdafError> { + (self.prep_step_fn)(&state) + } + + fn aggregate<M: IntoIterator<Item = Self::OutputShare>>( + &self, + _: &Self::AggregationParam, + output_shares: M, + ) -> Result<Self::AggregateShare, VdafError> { + let mut aggregate_share = AggregateShare(0); + for output_share in output_shares { + aggregate_share.accumulate(&output_share)?; + } + Ok(aggregate_share) + } +} + +impl vdaf::Client<16> for Vdaf { + fn shard( + &self, + measurement: &Self::Measurement, + _nonce: &[u8; 16], + ) -> Result<(Self::PublicShare, Vec<Self::InputShare>), VdafError> { + let first_input_share = random(); + let (second_input_share, _) = measurement.overflowing_sub(first_input_share); + Ok(( + (), + Vec::from([ + InputShare(first_input_share), + InputShare(second_input_share), + ]), + )) + } +} + +/// A dummy input share. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct InputShare(pub u8); + +impl Encode for InputShare { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.encode(bytes) + } + + fn encoded_len(&self) -> Option<usize> { + self.0.encoded_len() + } +} + +impl Decode for InputShare { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(Self(u8::decode(bytes)?)) + } +} + +/// Dummy aggregation parameter. +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct AggregationParam(pub u8); + +impl Encode for AggregationParam { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.encode(bytes) + } + + fn encoded_len(&self) -> Option<usize> { + self.0.encoded_len() + } +} + +impl Decode for AggregationParam { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(Self(u8::decode(bytes)?)) + } +} + +/// Dummy output share. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct OutputShare(pub u8); + +impl Decode for OutputShare { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(Self(u8::decode(bytes)?)) + } +} + +impl Encode for OutputShare { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.encode(bytes); + } + + fn encoded_len(&self) -> Option<usize> { + self.0.encoded_len() + } +} + +/// Dummy prepare state. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub struct PrepareState { + input_share: u8, + current_round: u32, +} + +impl Encode for PrepareState { + fn encode(&self, bytes: &mut Vec<u8>) { + self.input_share.encode(bytes); + self.current_round.encode(bytes); + } + + fn encoded_len(&self) -> Option<usize> { + Some(self.input_share.encoded_len()? + self.current_round.encoded_len()?) + } +} + +impl Decode for PrepareState { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let input_share = u8::decode(bytes)?; + let current_round = u32::decode(bytes)?; + + Ok(Self { + input_share, + current_round, + }) + } +} + +/// Dummy aggregate share. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct AggregateShare(pub u64); + +impl Aggregatable for AggregateShare { + type OutputShare = OutputShare; + + fn merge(&mut self, other: &Self) -> Result<(), VdafError> { + self.0 += other.0; + Ok(()) + } + + fn accumulate(&mut self, out_share: &Self::OutputShare) -> Result<(), VdafError> { + self.0 += u64::from(out_share.0); + Ok(()) + } +} + +impl From<OutputShare> for AggregateShare { + fn from(out_share: OutputShare) -> Self { + Self(u64::from(out_share.0)) + } +} + +impl Decode for AggregateShare { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let val = u64::decode(bytes)?; + Ok(Self(val)) + } +} + +impl Encode for AggregateShare { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.encode(bytes) + } + + fn encoded_len(&self) -> Option<usize> { + self.0.encoded_len() + } +} diff --git a/third_party/rust/prio/src/vdaf/poplar1.rs b/third_party/rust/prio/src/vdaf/poplar1.rs new file mode 100644 index 0000000000..e8591f2049 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/poplar1.rs @@ -0,0 +1,2465 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of Poplar1 as specified in [[draft-irtf-cfrg-vdaf-07]]. +//! +//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ + +use crate::{ + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{decode_fieldvec, merge_vector, Field255, Field64, FieldElement}, + idpf::{Idpf, IdpfInput, IdpfOutputShare, IdpfPublicShare, IdpfValue, RingBufferCache}, + prng::Prng, + vdaf::{ + xof::{Seed, Xof, XofShake128}, + Aggregatable, Aggregator, Client, Collector, PrepareTransition, Vdaf, VdafError, + }, +}; +use bitvec::{prelude::Lsb0, vec::BitVec}; +use rand_core::RngCore; +use std::{ + convert::TryFrom, + fmt::Debug, + io::{Cursor, Read}, + iter, + marker::PhantomData, + num::TryFromIntError, + ops::{Add, AddAssign, Sub}, +}; +use subtle::{Choice, ConditionallyNegatable, ConditionallySelectable, ConstantTimeEq}; + +const DST_SHARD_RANDOMNESS: u16 = 1; +const DST_CORR_INNER: u16 = 2; +const DST_CORR_LEAF: u16 = 3; +const DST_VERIFY_RANDOMNESS: u16 = 4; + +impl<P, const SEED_SIZE: usize> Poplar1<P, SEED_SIZE> { + /// Create an instance of [`Poplar1`]. The caller provides the bit length of each + /// measurement (`BITS` as defined in the [[draft-irtf-cfrg-vdaf-07]]). + /// + /// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ + pub fn new(bits: usize) -> Self { + Self { + bits, + phantom: PhantomData, + } + } +} + +impl Poplar1<XofShake128, 16> { + /// Create an instance of [`Poplar1`] using [`XofShake128`]. The caller provides the bit length of + /// each measurement (`BITS` as defined in the [[draft-irtf-cfrg-vdaf-07]]). + /// + /// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ + pub fn new_shake128(bits: usize) -> Self { + Poplar1::new(bits) + } +} + +/// The Poplar1 VDAF. +#[derive(Debug)] +pub struct Poplar1<P, const SEED_SIZE: usize> { + bits: usize, + phantom: PhantomData<P>, +} + +impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Poplar1<P, SEED_SIZE> { + /// Construct a `Prng` with the given seed and info-string suffix. + fn init_prng<I, B, F>( + seed: &[u8; SEED_SIZE], + usage: u16, + binder_chunks: I, + ) -> Prng<F, P::SeedStream> + where + I: IntoIterator<Item = B>, + B: AsRef<[u8]>, + P: Xof<SEED_SIZE>, + F: FieldElement, + { + let mut xof = P::init(seed, &Self::domain_separation_tag(usage)); + for binder_chunk in binder_chunks.into_iter() { + xof.update(binder_chunk.as_ref()); + } + Prng::from_seed_stream(xof.into_seed_stream()) + } +} + +impl<P, const SEED_SIZE: usize> Clone for Poplar1<P, SEED_SIZE> { + fn clone(&self) -> Self { + Self { + bits: self.bits, + phantom: PhantomData, + } + } +} + +/// Poplar1 public share. +/// +/// This is comprised of the correction words generated for the IDPF. +pub type Poplar1PublicShare = + IdpfPublicShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>; + +impl<P, const SEED_SIZE: usize> ParameterizedDecode<Poplar1<P, SEED_SIZE>> for Poplar1PublicShare { + fn decode_with_param( + poplar1: &Poplar1<P, SEED_SIZE>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + Self::decode_with_param(&poplar1.bits, bytes) + } +} + +/// Poplar1 input share. +/// +/// This is comprised of an IDPF key share and the correlated randomness used to compute the sketch +/// during preparation. +#[derive(Debug, Clone)] +pub struct Poplar1InputShare<const SEED_SIZE: usize> { + /// IDPF key share. + idpf_key: Seed<16>, + + /// Seed used to generate the Aggregator's share of the correlated randomness used in the first + /// part of the sketch. + corr_seed: Seed<SEED_SIZE>, + + /// Aggregator's share of the correlated randomness used in the second part of the sketch. Used + /// for inner nodes of the IDPF tree. + corr_inner: Vec<[Field64; 2]>, + + /// Aggregator's share of the correlated randomness used in the second part of the sketch. Used + /// for leaf nodes of the IDPF tree. + corr_leaf: [Field255; 2], +} + +impl<const SEED_SIZE: usize> PartialEq for Poplar1InputShare<SEED_SIZE> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<const SEED_SIZE: usize> Eq for Poplar1InputShare<SEED_SIZE> {} + +impl<const SEED_SIZE: usize> ConstantTimeEq for Poplar1InputShare<SEED_SIZE> { + fn ct_eq(&self, other: &Self) -> Choice { + // We short-circuit on the length of corr_inner being different. Only the content is + // protected. + if self.corr_inner.len() != other.corr_inner.len() { + return Choice::from(0); + } + + let mut res = self.idpf_key.ct_eq(&other.idpf_key) + & self.corr_seed.ct_eq(&other.corr_seed) + & self.corr_leaf.ct_eq(&other.corr_leaf); + for (x, y) in self.corr_inner.iter().zip(other.corr_inner.iter()) { + res &= x.ct_eq(y); + } + res + } +} + +impl<const SEED_SIZE: usize> Encode for Poplar1InputShare<SEED_SIZE> { + fn encode(&self, bytes: &mut Vec<u8>) { + self.idpf_key.encode(bytes); + self.corr_seed.encode(bytes); + for corr in self.corr_inner.iter() { + corr[0].encode(bytes); + corr[1].encode(bytes); + } + self.corr_leaf[0].encode(bytes); + self.corr_leaf[1].encode(bytes); + } + + fn encoded_len(&self) -> Option<usize> { + let mut len = 0; + len += SEED_SIZE; // idpf_key + len += SEED_SIZE; // corr_seed + len += self.corr_inner.len() * 2 * Field64::ENCODED_SIZE; // corr_inner + len += 2 * Field255::ENCODED_SIZE; // corr_leaf + Some(len) + } +} + +impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)> + for Poplar1InputShare<SEED_SIZE> +{ + fn decode_with_param( + (poplar1, _agg_id): &(&'a Poplar1<P, SEED_SIZE>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let idpf_key = Seed::decode(bytes)?; + let corr_seed = Seed::decode(bytes)?; + let mut corr_inner = Vec::with_capacity(poplar1.bits - 1); + for _ in 0..poplar1.bits - 1 { + corr_inner.push([Field64::decode(bytes)?, Field64::decode(bytes)?]); + } + let corr_leaf = [Field255::decode(bytes)?, Field255::decode(bytes)?]; + Ok(Self { + idpf_key, + corr_seed, + corr_inner, + corr_leaf, + }) + } +} + +/// Poplar1 preparation state. +#[derive(Clone, Debug)] +pub struct Poplar1PrepareState(PrepareStateVariant); + +impl PartialEq for Poplar1PrepareState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Poplar1PrepareState {} + +impl ConstantTimeEq for Poplar1PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +impl Encode for Poplar1PrepareState { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.encode(bytes) + } + + fn encoded_len(&self) -> Option<usize> { + self.0.encoded_len() + } +} + +impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)> + for Poplar1PrepareState +{ + fn decode_with_param( + decoding_parameter: &(&'a Poplar1<P, SEED_SIZE>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + Ok(Self(PrepareStateVariant::decode_with_param( + decoding_parameter, + bytes, + )?)) + } +} + +#[derive(Clone, Debug)] +enum PrepareStateVariant { + Inner(PrepareState<Field64>), + Leaf(PrepareState<Field255>), +} + +impl PartialEq for PrepareStateVariant { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for PrepareStateVariant {} + +impl ConstantTimeEq for PrepareStateVariant { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the type (Inner vs Leaf). + match (self, other) { + (Self::Inner(self_val), Self::Inner(other_val)) => self_val.ct_eq(other_val), + (Self::Leaf(self_val), Self::Leaf(other_val)) => self_val.ct_eq(other_val), + _ => Choice::from(0), + } + } +} + +impl Encode for PrepareStateVariant { + fn encode(&self, bytes: &mut Vec<u8>) { + match self { + PrepareStateVariant::Inner(prep_state) => { + 0u8.encode(bytes); + prep_state.encode(bytes); + } + PrepareStateVariant::Leaf(prep_state) => { + 1u8.encode(bytes); + prep_state.encode(bytes); + } + } + } + + fn encoded_len(&self) -> Option<usize> { + Some( + 1 + match self { + PrepareStateVariant::Inner(prep_state) => prep_state.encoded_len()?, + PrepareStateVariant::Leaf(prep_state) => prep_state.encoded_len()?, + }, + ) + } +} + +impl<'a, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)> + for PrepareStateVariant +{ + fn decode_with_param( + decoding_parameter: &(&'a Poplar1<P, SEED_SIZE>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + match u8::decode(bytes)? { + 0 => { + let prep_state = PrepareState::decode_with_param(decoding_parameter, bytes)?; + Ok(Self::Inner(prep_state)) + } + 1 => { + let prep_state = PrepareState::decode_with_param(decoding_parameter, bytes)?; + Ok(Self::Leaf(prep_state)) + } + _ => Err(CodecError::UnexpectedValue), + } + } +} + +#[derive(Clone)] +struct PrepareState<F> { + sketch: SketchState<F>, + output_share: Vec<F>, +} + +impl<F: ConstantTimeEq> PartialEq for PrepareState<F> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<F: ConstantTimeEq> Eq for PrepareState<F> {} + +impl<F: ConstantTimeEq> ConstantTimeEq for PrepareState<F> { + fn ct_eq(&self, other: &Self) -> Choice { + self.sketch.ct_eq(&other.sketch) & self.output_share.ct_eq(&other.output_share) + } +} + +impl<F> Debug for PrepareState<F> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrepareState") + .field("sketch", &"[redacted]") + .field("output_share", &"[redacted]") + .finish() + } +} + +impl<F: FieldElement> Encode for PrepareState<F> { + fn encode(&self, bytes: &mut Vec<u8>) { + self.sketch.encode(bytes); + // `expect` safety: output_share's length is the same as the number of prefixes; the number + // of prefixes is capped at 2^32-1. + u32::try_from(self.output_share.len()) + .expect("Couldn't convert output_share length to u32") + .encode(bytes); + for elem in &self.output_share { + elem.encode(bytes); + } + } + + fn encoded_len(&self) -> Option<usize> { + Some(self.sketch.encoded_len()? + 4 + self.output_share.len() * F::ENCODED_SIZE) + } +} + +impl<'a, P, F: FieldElement, const SEED_SIZE: usize> + ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)> for PrepareState<F> +{ + fn decode_with_param( + decoding_parameter: &(&'a Poplar1<P, SEED_SIZE>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let sketch = SketchState::<F>::decode_with_param(decoding_parameter, bytes)?; + let output_share_len = u32::decode(bytes)? + .try_into() + .map_err(|err: TryFromIntError| CodecError::Other(err.into()))?; + let output_share = iter::repeat_with(|| F::decode(bytes)) + .take(output_share_len) + .collect::<Result<_, _>>()?; + Ok(Self { + sketch, + output_share, + }) + } +} + +#[derive(Clone, Debug)] +enum SketchState<F> { + #[allow(non_snake_case)] + RoundOne { + A_share: F, + B_share: F, + is_leader: bool, + }, + RoundTwo, +} + +impl<F: ConstantTimeEq> PartialEq for SketchState<F> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<F: ConstantTimeEq> Eq for SketchState<F> {} + +impl<F: ConstantTimeEq> ConstantTimeEq for SketchState<F> { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the round (RoundOne vs RoundTwo), as well as is_leader for + // RoundOne comparisons. + match (self, other) { + ( + SketchState::RoundOne { + A_share: self_a_share, + B_share: self_b_share, + is_leader: self_is_leader, + }, + SketchState::RoundOne { + A_share: other_a_share, + B_share: other_b_share, + is_leader: other_is_leader, + }, + ) => { + if self_is_leader != other_is_leader { + return Choice::from(0); + } + + self_a_share.ct_eq(other_a_share) & self_b_share.ct_eq(other_b_share) + } + + (SketchState::RoundTwo, SketchState::RoundTwo) => Choice::from(1), + _ => Choice::from(0), + } + } +} + +impl<F: FieldElement> Encode for SketchState<F> { + fn encode(&self, bytes: &mut Vec<u8>) { + match self { + SketchState::RoundOne { + A_share, B_share, .. + } => { + 0u8.encode(bytes); + A_share.encode(bytes); + B_share.encode(bytes); + } + SketchState::RoundTwo => 1u8.encode(bytes), + } + } + + fn encoded_len(&self) -> Option<usize> { + Some( + 1 + match self { + SketchState::RoundOne { .. } => 2 * F::ENCODED_SIZE, + SketchState::RoundTwo => 0, + }, + ) + } +} + +impl<'a, P, F: FieldElement, const SEED_SIZE: usize> + ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, usize)> for SketchState<F> +{ + #[allow(non_snake_case)] + fn decode_with_param( + (_, agg_id): &(&'a Poplar1<P, SEED_SIZE>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + match u8::decode(bytes)? { + 0 => { + let A_share = F::decode(bytes)?; + let B_share = F::decode(bytes)?; + let is_leader = agg_id == &0; + Ok(Self::RoundOne { + A_share, + B_share, + is_leader, + }) + } + 1 => Ok(Self::RoundTwo), + _ => Err(CodecError::UnexpectedValue), + } + } +} + +impl<F: FieldElement> SketchState<F> { + fn decode_sketch_share(&self, bytes: &mut Cursor<&[u8]>) -> Result<Vec<F>, CodecError> { + match self { + // The sketch share is three field elements. + Self::RoundOne { .. } => Ok(vec![ + F::decode(bytes)?, + F::decode(bytes)?, + F::decode(bytes)?, + ]), + // The sketch verifier share is one field element. + Self::RoundTwo => Ok(vec![F::decode(bytes)?]), + } + } + + fn decode_sketch(&self, bytes: &mut Cursor<&[u8]>) -> Result<Option<[F; 3]>, CodecError> { + match self { + // The sketch is three field elements. + Self::RoundOne { .. } => Ok(Some([ + F::decode(bytes)?, + F::decode(bytes)?, + F::decode(bytes)?, + ])), + // The sketch verifier should be zero if the sketch if valid. Instead of transmitting + // this zero over the wire, we just expect an empty message. + Self::RoundTwo => Ok(None), + } + } +} + +/// Poplar1 preparation message. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Poplar1PrepareMessage(PrepareMessageVariant); + +#[derive(Clone, Debug, PartialEq, Eq)] +enum PrepareMessageVariant { + SketchInner([Field64; 3]), + SketchLeaf([Field255; 3]), + Done, +} + +impl Encode for Poplar1PrepareMessage { + fn encode(&self, bytes: &mut Vec<u8>) { + match self.0 { + PrepareMessageVariant::SketchInner(vec) => { + vec[0].encode(bytes); + vec[1].encode(bytes); + vec[2].encode(bytes); + } + PrepareMessageVariant::SketchLeaf(vec) => { + vec[0].encode(bytes); + vec[1].encode(bytes); + vec[2].encode(bytes); + } + PrepareMessageVariant::Done => (), + } + } + + fn encoded_len(&self) -> Option<usize> { + match self.0 { + PrepareMessageVariant::SketchInner(..) => Some(3 * Field64::ENCODED_SIZE), + PrepareMessageVariant::SketchLeaf(..) => Some(3 * Field255::ENCODED_SIZE), + PrepareMessageVariant::Done => Some(0), + } + } +} + +impl ParameterizedDecode<Poplar1PrepareState> for Poplar1PrepareMessage { + fn decode_with_param( + state: &Poplar1PrepareState, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + match state.0 { + PrepareStateVariant::Inner(ref state_variant) => Ok(Self( + state_variant + .sketch + .decode_sketch(bytes)? + .map_or(PrepareMessageVariant::Done, |sketch| { + PrepareMessageVariant::SketchInner(sketch) + }), + )), + PrepareStateVariant::Leaf(ref state_variant) => Ok(Self( + state_variant + .sketch + .decode_sketch(bytes)? + .map_or(PrepareMessageVariant::Done, |sketch| { + PrepareMessageVariant::SketchLeaf(sketch) + }), + )), + } + } +} + +/// A vector of field elements transmitted while evaluating Poplar1. +#[derive(Clone, Debug)] +pub enum Poplar1FieldVec { + /// Field type for inner nodes of the IDPF tree. + Inner(Vec<Field64>), + + /// Field type for leaf nodes of the IDPF tree. + Leaf(Vec<Field255>), +} + +impl Poplar1FieldVec { + fn zero(is_leaf: bool, len: usize) -> Self { + if is_leaf { + Self::Leaf(vec![<Field255 as FieldElement>::zero(); len]) + } else { + Self::Inner(vec![<Field64 as FieldElement>::zero(); len]) + } + } +} + +impl PartialEq for Poplar1FieldVec { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Poplar1FieldVec {} + +impl ConstantTimeEq for Poplar1FieldVec { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the type (Inner vs Leaf). + match (self, other) { + (Poplar1FieldVec::Inner(self_val), Poplar1FieldVec::Inner(other_val)) => { + self_val.ct_eq(other_val) + } + (Poplar1FieldVec::Leaf(self_val), Poplar1FieldVec::Leaf(other_val)) => { + self_val.ct_eq(other_val) + } + _ => Choice::from(0), + } + } +} + +impl Encode for Poplar1FieldVec { + fn encode(&self, bytes: &mut Vec<u8>) { + match self { + Self::Inner(ref data) => { + for elem in data { + elem.encode(bytes); + } + } + Self::Leaf(ref data) => { + for elem in data { + elem.encode(bytes); + } + } + } + } + + fn encoded_len(&self) -> Option<usize> { + match self { + Self::Inner(ref data) => Some(Field64::ENCODED_SIZE * data.len()), + Self::Leaf(ref data) => Some(Field255::ENCODED_SIZE * data.len()), + } + } +} + +impl<'a, P: Xof<SEED_SIZE>, const SEED_SIZE: usize> + ParameterizedDecode<(&'a Poplar1<P, SEED_SIZE>, &'a Poplar1AggregationParam)> + for Poplar1FieldVec +{ + fn decode_with_param( + (poplar1, agg_param): &(&'a Poplar1<P, SEED_SIZE>, &'a Poplar1AggregationParam), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + if agg_param.level() == poplar1.bits - 1 { + decode_fieldvec(agg_param.prefixes().len(), bytes).map(Poplar1FieldVec::Leaf) + } else { + decode_fieldvec(agg_param.prefixes().len(), bytes).map(Poplar1FieldVec::Inner) + } + } +} + +impl ParameterizedDecode<Poplar1PrepareState> for Poplar1FieldVec { + fn decode_with_param( + state: &Poplar1PrepareState, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + match state.0 { + PrepareStateVariant::Inner(ref state_variant) => Ok(Poplar1FieldVec::Inner( + state_variant.sketch.decode_sketch_share(bytes)?, + )), + PrepareStateVariant::Leaf(ref state_variant) => Ok(Poplar1FieldVec::Leaf( + state_variant.sketch.decode_sketch_share(bytes)?, + )), + } + } +} + +impl Aggregatable for Poplar1FieldVec { + type OutputShare = Self; + + fn merge(&mut self, agg_share: &Self) -> Result<(), VdafError> { + match (self, agg_share) { + (Self::Inner(ref mut left), Self::Inner(right)) => Ok(merge_vector(left, right)?), + (Self::Leaf(ref mut left), Self::Leaf(right)) => Ok(merge_vector(left, right)?), + _ => Err(VdafError::Uncategorized( + "cannot merge leaf nodes wiith inner nodes".into(), + )), + } + } + + fn accumulate(&mut self, output_share: &Self) -> Result<(), VdafError> { + match (self, output_share) { + (Self::Inner(ref mut left), Self::Inner(right)) => Ok(merge_vector(left, right)?), + (Self::Leaf(ref mut left), Self::Leaf(right)) => Ok(merge_vector(left, right)?), + _ => Err(VdafError::Uncategorized( + "cannot accumulate leaf nodes with inner nodes".into(), + )), + } + } +} + +/// Poplar1 aggregation parameter. +/// +/// This includes an indication of what level of the IDPF tree is being evaluated and the set of +/// prefixes to evaluate at that level. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct Poplar1AggregationParam { + level: u16, + prefixes: Vec<IdpfInput>, +} + +impl Poplar1AggregationParam { + /// Construct an aggregation parameter from a set of candidate prefixes. + /// + /// # Errors + /// + /// * The list of prefixes is empty. + /// * The prefixes have different lengths (they must all be the same). + /// * The prefixes have length 0, or length longer than 2^16 bits. + /// * There are more than 2^32 - 1 prefixes. + /// * The prefixes are not unique. + /// * The prefixes are not in lexicographic order. + pub fn try_from_prefixes(prefixes: Vec<IdpfInput>) -> Result<Self, VdafError> { + if prefixes.is_empty() { + return Err(VdafError::Uncategorized( + "at least one prefix is required".into(), + )); + } + if u32::try_from(prefixes.len()).is_err() { + return Err(VdafError::Uncategorized("too many prefixes".into())); + } + + let len = prefixes[0].len(); + let mut last_prefix = None; + for prefix in prefixes.iter() { + if prefix.len() != len { + return Err(VdafError::Uncategorized( + "all prefixes must have the same length".into(), + )); + } + if let Some(last_prefix) = last_prefix { + if prefix <= last_prefix { + if prefix == last_prefix { + return Err(VdafError::Uncategorized( + "prefixes must be nonrepeating".into(), + )); + } else { + return Err(VdafError::Uncategorized( + "prefixes must be in lexicographic order".into(), + )); + } + } + } + last_prefix = Some(prefix); + } + + let level = len + .checked_sub(1) + .ok_or_else(|| VdafError::Uncategorized("prefixes are too short".into()))?; + let level = u16::try_from(level) + .map_err(|_| VdafError::Uncategorized("prefixes are too long".into()))?; + + Ok(Self { level, prefixes }) + } + + /// Return the level of the IDPF tree. + pub fn level(&self) -> usize { + usize::from(self.level) + } + + /// Return the prefixes. + pub fn prefixes(&self) -> &[IdpfInput] { + self.prefixes.as_ref() + } +} + +impl Encode for Poplar1AggregationParam { + fn encode(&self, bytes: &mut Vec<u8>) { + // Okay to unwrap because `try_from_prefixes()` checks this conversion succeeds. + let prefix_count = u32::try_from(self.prefixes.len()).unwrap(); + self.level.encode(bytes); + prefix_count.encode(bytes); + + // The encoding of the prefixes is defined by treating the IDPF indices as integers, + // shifting and ORing them together, and encoding the resulting arbitrary precision integer + // in big endian byte order. Thus, the first prefix will appear in the last encoded byte, + // aligned to its least significant bit. The last prefix will appear in the first encoded + // byte, not necessarily aligned to a byte boundary. If the highest bits in the first byte + // are unused, they will be set to zero. + + // When an IDPF index is treated as an integer, the first bit is the integer's most + // significant bit, and bits are subsequently processed in order of decreasing significance. + // Thus, setting aside the order of bytes, bits within each byte are ordered with the + // [`Msb0`](bitvec::prelude::Msb0) convention, not [`Lsb0`](bitvec::prelude::Msb0). Yet, + // the entire integer is aligned to the least significant bit of the last byte, so we + // could not use `Msb0` directly without padding adjustments. Instead, we use `Lsb0` + // throughout and reverse the bit order of each prefix. + + let mut packed = self + .prefixes + .iter() + .flat_map(|input| input.iter().rev()) + .collect::<BitVec<u8, Lsb0>>(); + packed.set_uninitialized(false); + let mut packed = packed.into_vec(); + packed.reverse(); + bytes.append(&mut packed); + } + + fn encoded_len(&self) -> Option<usize> { + let packed_bit_count = (usize::from(self.level) + 1) * self.prefixes.len(); + // 4 bytes for the number of prefixes, 2 bytes for the level, and a variable number of bytes + // for the packed prefixes themselves. + Some(6 + (packed_bit_count + 7) / 8) + } +} + +impl Decode for Poplar1AggregationParam { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let level = u16::decode(bytes)?; + let prefix_count = + usize::try_from(u32::decode(bytes)?).map_err(|e| CodecError::Other(e.into()))?; + + let packed_bit_count = (usize::from(level) + 1) * prefix_count; + let mut packed = vec![0u8; (packed_bit_count + 7) / 8]; + bytes.read_exact(&mut packed)?; + if packed_bit_count % 8 != 0 { + let unused_bits = packed[0] >> (packed_bit_count % 8); + if unused_bits != 0 { + return Err(CodecError::UnexpectedValue); + } + } + packed.reverse(); + let bits = BitVec::<u8, Lsb0>::from_vec(packed); + + let prefixes = bits + .chunks_exact(usize::from(level) + 1) + .take(prefix_count) + .map(|chunk| IdpfInput::from(chunk.iter().rev().collect::<BitVec>())) + .collect::<Vec<IdpfInput>>(); + + Poplar1AggregationParam::try_from_prefixes(prefixes) + .map_err(|e| CodecError::Other(e.into())) + } +} + +impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Vdaf for Poplar1<P, SEED_SIZE> { + const ID: u32 = 0x00001000; + type Measurement = IdpfInput; + type AggregateResult = Vec<u64>; + type AggregationParam = Poplar1AggregationParam; + type PublicShare = Poplar1PublicShare; + type InputShare = Poplar1InputShare<SEED_SIZE>; + type OutputShare = Poplar1FieldVec; + type AggregateShare = Poplar1FieldVec; + + fn num_aggregators(&self) -> usize { + 2 + } +} + +impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Poplar1<P, SEED_SIZE> { + fn shard_with_random( + &self, + input: &IdpfInput, + nonce: &[u8; 16], + idpf_random: &[[u8; 16]; 2], + poplar_random: &[[u8; SEED_SIZE]; 3], + ) -> Result<(Poplar1PublicShare, Vec<Poplar1InputShare<SEED_SIZE>>), VdafError> { + if input.len() != self.bits { + return Err(VdafError::Uncategorized(format!( + "unexpected input length ({})", + input.len() + ))); + } + + // Generate the authenticator for each inner level of the IDPF tree. + let mut prng = + Self::init_prng::<_, _, Field64>(&poplar_random[2], DST_SHARD_RANDOMNESS, [&[]]); + let auth_inner: Vec<Field64> = (0..self.bits - 1).map(|_| prng.get()).collect(); + + // Generate the authenticator for the last level of the IDPF tree (i.e., the leaves). + // + // TODO(cjpatton) spec: Consider using a different XOF for the leaf and inner nodes. + // "Switching" the XOF between field types is awkward. + let mut prng = prng.into_new_field::<Field255>(); + let auth_leaf = prng.get(); + + // Generate the IDPF shares. + let idpf = Idpf::new((), ()); + let (public_share, [idpf_key_0, idpf_key_1]) = idpf.gen_with_random( + input, + auth_inner + .iter() + .map(|auth| Poplar1IdpfValue([Field64::one(), *auth])), + Poplar1IdpfValue([Field255::one(), auth_leaf]), + nonce, + idpf_random, + )?; + + // Generate the correlated randomness for the inner nodes. This includes additive shares of + // the random offsets `a, b, c` and additive shares of `A := -2*a + auth` and `B := a^2 + b + // - a*auth + c`, where `auth` is the authenticator for the level of the tree. These values + // are used, respectively, to compute and verify the sketch during the preparation phase. + // (See Section 4.2 of [BBCG+21].) + let corr_seed_0 = &poplar_random[0]; + let corr_seed_1 = &poplar_random[1]; + let mut prng = prng.into_new_field::<Field64>(); + let mut corr_prng_0 = Self::init_prng::<_, _, Field64>( + corr_seed_0, + DST_CORR_INNER, + [[0].as_slice(), nonce.as_slice()], + ); + let mut corr_prng_1 = Self::init_prng::<_, _, Field64>( + corr_seed_1, + DST_CORR_INNER, + [[1].as_slice(), nonce.as_slice()], + ); + let mut corr_inner_0 = Vec::with_capacity(self.bits - 1); + let mut corr_inner_1 = Vec::with_capacity(self.bits - 1); + for auth in auth_inner.into_iter() { + let (next_corr_inner_0, next_corr_inner_1) = + compute_next_corr_shares(&mut prng, &mut corr_prng_0, &mut corr_prng_1, auth); + corr_inner_0.push(next_corr_inner_0); + corr_inner_1.push(next_corr_inner_1); + } + + // Generate the correlated randomness for the leaf nodes. + let mut prng = prng.into_new_field::<Field255>(); + let mut corr_prng_0 = Self::init_prng::<_, _, Field255>( + corr_seed_0, + DST_CORR_LEAF, + [[0].as_slice(), nonce.as_slice()], + ); + let mut corr_prng_1 = Self::init_prng::<_, _, Field255>( + corr_seed_1, + DST_CORR_LEAF, + [[1].as_slice(), nonce.as_slice()], + ); + let (corr_leaf_0, corr_leaf_1) = + compute_next_corr_shares(&mut prng, &mut corr_prng_0, &mut corr_prng_1, auth_leaf); + + Ok(( + public_share, + vec![ + Poplar1InputShare { + idpf_key: idpf_key_0, + corr_seed: Seed::from_bytes(*corr_seed_0), + corr_inner: corr_inner_0, + corr_leaf: corr_leaf_0, + }, + Poplar1InputShare { + idpf_key: idpf_key_1, + corr_seed: Seed::from_bytes(*corr_seed_1), + corr_inner: corr_inner_1, + corr_leaf: corr_leaf_1, + }, + ], + )) + } +} + +impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Client<16> for Poplar1<P, SEED_SIZE> { + fn shard( + &self, + input: &IdpfInput, + nonce: &[u8; 16], + ) -> Result<(Self::PublicShare, Vec<Poplar1InputShare<SEED_SIZE>>), VdafError> { + let mut idpf_random = [[0u8; 16]; 2]; + let mut poplar_random = [[0u8; SEED_SIZE]; 3]; + for random_seed in idpf_random.iter_mut() { + getrandom::getrandom(random_seed)?; + } + for random_seed in poplar_random.iter_mut() { + getrandom::getrandom(random_seed)?; + } + self.shard_with_random(input, nonce, &idpf_random, &poplar_random) + } +} + +impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Aggregator<SEED_SIZE, 16> + for Poplar1<P, SEED_SIZE> +{ + type PrepareState = Poplar1PrepareState; + type PrepareShare = Poplar1FieldVec; + type PrepareMessage = Poplar1PrepareMessage; + + #[allow(clippy::type_complexity)] + fn prepare_init( + &self, + verify_key: &[u8; SEED_SIZE], + agg_id: usize, + agg_param: &Poplar1AggregationParam, + nonce: &[u8; 16], + public_share: &Poplar1PublicShare, + input_share: &Poplar1InputShare<SEED_SIZE>, + ) -> Result<(Poplar1PrepareState, Poplar1FieldVec), VdafError> { + let is_leader = match agg_id { + 0 => true, + 1 => false, + _ => { + return Err(VdafError::Uncategorized(format!( + "invalid aggregator ID ({agg_id})" + ))) + } + }; + + if usize::from(agg_param.level) < self.bits - 1 { + let mut corr_prng = Self::init_prng::<_, _, Field64>( + input_share.corr_seed.as_ref(), + DST_CORR_INNER, + [[agg_id as u8].as_slice(), nonce.as_slice()], + ); + // Fast-forward the correlated randomness XOF to the level of the tree that we are + // aggregating. + for _ in 0..3 * agg_param.level { + corr_prng.get(); + } + + let (output_share, sketch_share) = eval_and_sketch::<P, Field64, SEED_SIZE>( + verify_key, + agg_id, + nonce, + agg_param, + public_share, + &input_share.idpf_key, + &mut corr_prng, + )?; + + Ok(( + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: input_share.corr_inner[usize::from(agg_param.level)][0], + B_share: input_share.corr_inner[usize::from(agg_param.level)][1], + is_leader, + }, + output_share, + })), + Poplar1FieldVec::Inner(sketch_share), + )) + } else { + let corr_prng = Self::init_prng::<_, _, Field255>( + input_share.corr_seed.as_ref(), + DST_CORR_LEAF, + [[agg_id as u8].as_slice(), nonce.as_slice()], + ); + + let (output_share, sketch_share) = eval_and_sketch::<P, Field255, SEED_SIZE>( + verify_key, + agg_id, + nonce, + agg_param, + public_share, + &input_share.idpf_key, + &mut corr_prng.into_new_field(), + )?; + + Ok(( + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: input_share.corr_leaf[0], + B_share: input_share.corr_leaf[1], + is_leader, + }, + output_share, + })), + Poplar1FieldVec::Leaf(sketch_share), + )) + } + } + + fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Poplar1FieldVec>>( + &self, + _: &Poplar1AggregationParam, + inputs: M, + ) -> Result<Poplar1PrepareMessage, VdafError> { + let mut inputs = inputs.into_iter(); + let prep_share_0 = inputs + .next() + .ok_or_else(|| VdafError::Uncategorized("insufficient number of prep shares".into()))?; + let prep_share_1 = inputs + .next() + .ok_or_else(|| VdafError::Uncategorized("insufficient number of prep shares".into()))?; + if inputs.next().is_some() { + return Err(VdafError::Uncategorized( + "more prep shares than expected".into(), + )); + } + + match (prep_share_0, prep_share_1) { + (Poplar1FieldVec::Inner(share_0), Poplar1FieldVec::Inner(share_1)) => { + Ok(Poplar1PrepareMessage( + next_message(share_0, share_1)?.map_or(PrepareMessageVariant::Done, |sketch| { + PrepareMessageVariant::SketchInner(sketch) + }), + )) + } + (Poplar1FieldVec::Leaf(share_0), Poplar1FieldVec::Leaf(share_1)) => { + Ok(Poplar1PrepareMessage( + next_message(share_0, share_1)?.map_or(PrepareMessageVariant::Done, |sketch| { + PrepareMessageVariant::SketchLeaf(sketch) + }), + )) + } + _ => Err(VdafError::Uncategorized( + "received prep shares with mismatched field types".into(), + )), + } + } + + fn prepare_next( + &self, + state: Poplar1PrepareState, + msg: Poplar1PrepareMessage, + ) -> Result<PrepareTransition<Self, SEED_SIZE, 16>, VdafError> { + match (state.0, msg.0) { + // Round one + ( + PrepareStateVariant::Inner(PrepareState { + sketch: + SketchState::RoundOne { + A_share, + B_share, + is_leader, + }, + output_share, + }), + PrepareMessageVariant::SketchInner(sketch), + ) => Ok(PrepareTransition::Continue( + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share, + })), + Poplar1FieldVec::Inner(finish_sketch(sketch, A_share, B_share, is_leader)), + )), + ( + PrepareStateVariant::Leaf(PrepareState { + sketch: + SketchState::RoundOne { + A_share, + B_share, + is_leader, + }, + output_share, + }), + PrepareMessageVariant::SketchLeaf(sketch), + ) => Ok(PrepareTransition::Continue( + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share, + })), + Poplar1FieldVec::Leaf(finish_sketch(sketch, A_share, B_share, is_leader)), + )), + + // Round two + ( + PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share, + }), + PrepareMessageVariant::Done, + ) => Ok(PrepareTransition::Finish(Poplar1FieldVec::Inner( + output_share, + ))), + ( + PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share, + }), + PrepareMessageVariant::Done, + ) => Ok(PrepareTransition::Finish(Poplar1FieldVec::Leaf( + output_share, + ))), + + _ => Err(VdafError::Uncategorized( + "prep message field type does not match state".into(), + )), + } + } + + fn aggregate<M: IntoIterator<Item = Poplar1FieldVec>>( + &self, + agg_param: &Poplar1AggregationParam, + output_shares: M, + ) -> Result<Poplar1FieldVec, VdafError> { + aggregate( + usize::from(agg_param.level) == self.bits - 1, + agg_param.prefixes.len(), + output_shares, + ) + } +} + +impl<P: Xof<SEED_SIZE>, const SEED_SIZE: usize> Collector for Poplar1<P, SEED_SIZE> { + fn unshard<M: IntoIterator<Item = Poplar1FieldVec>>( + &self, + agg_param: &Poplar1AggregationParam, + agg_shares: M, + _num_measurements: usize, + ) -> Result<Vec<u64>, VdafError> { + let result = aggregate( + usize::from(agg_param.level) == self.bits - 1, + agg_param.prefixes.len(), + agg_shares, + )?; + + match result { + Poplar1FieldVec::Inner(vec) => Ok(vec.into_iter().map(u64::from).collect()), + Poplar1FieldVec::Leaf(vec) => Ok(vec + .into_iter() + .map(u64::try_from) + .collect::<Result<Vec<_>, _>>()?), + } + } +} + +impl From<IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>> + for Poplar1IdpfValue<Field64> +{ + fn from( + out_share: IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>, + ) -> Poplar1IdpfValue<Field64> { + match out_share { + IdpfOutputShare::Inner(array) => array, + IdpfOutputShare::Leaf(..) => panic!("tried to convert leaf share into inner field"), + } + } +} + +impl From<IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>> + for Poplar1IdpfValue<Field255> +{ + fn from( + out_share: IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>, + ) -> Poplar1IdpfValue<Field255> { + match out_share { + IdpfOutputShare::Inner(..) => panic!("tried to convert inner share into leaf field"), + IdpfOutputShare::Leaf(array) => array, + } + } +} + +/// Derive shares of the correlated randomness for the next level of the IDPF tree. +// +// TODO(cjpatton) spec: Consider deriving the shares of a, b, c for each level directly from the +// seed, rather than iteratively, as we do in Doplar. This would be more efficient for the +// Aggregators. As long as the Client isn't significantly slower, this should be a win. +#[allow(non_snake_case)] +fn compute_next_corr_shares<F: FieldElement + From<u64>, S: RngCore>( + prng: &mut Prng<F, S>, + corr_prng_0: &mut Prng<F, S>, + corr_prng_1: &mut Prng<F, S>, + auth: F, +) -> ([F; 2], [F; 2]) { + let two = F::from(2); + let a = corr_prng_0.get() + corr_prng_1.get(); + let b = corr_prng_0.get() + corr_prng_1.get(); + let c = corr_prng_0.get() + corr_prng_1.get(); + let A = -two * a + auth; + let B = a * a + b - a * auth + c; + let corr_1 = [prng.get(), prng.get()]; + let corr_0 = [A - corr_1[0], B - corr_1[1]]; + (corr_0, corr_1) +} + +/// Evaluate the IDPF at the given prefixes and compute the Aggregator's share of the sketch. +fn eval_and_sketch<P, F, const SEED_SIZE: usize>( + verify_key: &[u8; SEED_SIZE], + agg_id: usize, + nonce: &[u8; 16], + agg_param: &Poplar1AggregationParam, + public_share: &Poplar1PublicShare, + idpf_key: &Seed<16>, + corr_prng: &mut Prng<F, P::SeedStream>, +) -> Result<(Vec<F>, Vec<F>), VdafError> +where + P: Xof<SEED_SIZE>, + F: FieldElement, + Poplar1IdpfValue<F>: + From<IdpfOutputShare<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>>, +{ + // TODO(cjpatton) spec: Consider not encoding the prefixes here. + let mut verify_prng = Poplar1::<P, SEED_SIZE>::init_prng( + verify_key, + DST_VERIFY_RANDOMNESS, + [nonce.as_slice(), agg_param.level.to_be_bytes().as_slice()], + ); + + let mut out_share = Vec::with_capacity(agg_param.prefixes.len()); + let mut sketch_share = vec![ + corr_prng.get(), // a_share + corr_prng.get(), // b_share + corr_prng.get(), // c_share + ]; + + let mut idpf_eval_cache = RingBufferCache::new(agg_param.prefixes.len()); + let idpf = Idpf::<Poplar1IdpfValue<Field64>, Poplar1IdpfValue<Field255>>::new((), ()); + for prefix in agg_param.prefixes.iter() { + let share = Poplar1IdpfValue::<F>::from(idpf.eval( + agg_id, + public_share, + idpf_key, + prefix, + nonce, + &mut idpf_eval_cache, + )?); + + let r = verify_prng.get(); + let checked_data_share = share.0[0] * r; + sketch_share[0] += checked_data_share; + sketch_share[1] += checked_data_share * r; + sketch_share[2] += share.0[1] * r; + out_share.push(share.0[0]); + } + + Ok((out_share, sketch_share)) +} + +/// Compute the Aggregator's share of the sketch verifier. The shares should sum to zero. +#[allow(non_snake_case)] +fn finish_sketch<F: FieldElement>( + sketch: [F; 3], + A_share: F, + B_share: F, + is_leader: bool, +) -> Vec<F> { + let mut next_sketch_share = A_share * sketch[0] + B_share; + if !is_leader { + next_sketch_share += sketch[0] * sketch[0] - sketch[1] - sketch[2]; + } + vec![next_sketch_share] +} + +fn next_message<F: FieldElement>( + mut share_0: Vec<F>, + share_1: Vec<F>, +) -> Result<Option<[F; 3]>, VdafError> { + merge_vector(&mut share_0, &share_1)?; + + if share_0.len() == 1 { + if share_0[0] != F::zero() { + Err(VdafError::Uncategorized( + "sketch verification failed".into(), + )) // Invalid sketch + } else { + Ok(None) // Sketch verification succeeded + } + } else if share_0.len() == 3 { + Ok(Some([share_0[0], share_0[1], share_0[2]])) // Sketch verification continues + } else { + Err(VdafError::Uncategorized(format!( + "unexpected sketch length ({})", + share_0.len() + ))) + } +} + +fn aggregate<M: IntoIterator<Item = Poplar1FieldVec>>( + is_leaf: bool, + len: usize, + shares: M, +) -> Result<Poplar1FieldVec, VdafError> { + let mut result = Poplar1FieldVec::zero(is_leaf, len); + for share in shares.into_iter() { + result.accumulate(&share)?; + } + Ok(result) +} + +/// A vector of two field elements. +/// +/// This represents the values that Poplar1 programs into IDPFs while sharding. +#[derive(Debug, Clone, Copy)] +pub struct Poplar1IdpfValue<F>([F; 2]); + +impl<F> Poplar1IdpfValue<F> { + /// Create a new value from a pair of field elements. + pub fn new(array: [F; 2]) -> Self { + Self(array) + } +} + +impl<F> IdpfValue for Poplar1IdpfValue<F> +where + F: FieldElement, +{ + type ValueParameter = (); + + fn zero(_: &()) -> Self { + Self([F::zero(); 2]) + } + + fn generate<S: RngCore>(seed_stream: &mut S, _: &()) -> Self { + Self([F::generate(seed_stream, &()), F::generate(seed_stream, &())]) + } + + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + ConditionallySelectable::conditional_select(a, b, choice) + } +} + +impl<F> Add for Poplar1IdpfValue<F> +where + F: Copy + Add<Output = F>, +{ + type Output = Self; + + fn add(self, rhs: Self) -> Self { + Self([self.0[0] + rhs.0[0], self.0[1] + rhs.0[1]]) + } +} + +impl<F> AddAssign for Poplar1IdpfValue<F> +where + F: Copy + AddAssign, +{ + fn add_assign(&mut self, rhs: Self) { + self.0[0] += rhs.0[0]; + self.0[1] += rhs.0[1]; + } +} + +impl<F> Sub for Poplar1IdpfValue<F> +where + F: Copy + Sub<Output = F>, +{ + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + Self([self.0[0] - rhs.0[0], self.0[1] - rhs.0[1]]) + } +} + +impl<F> PartialEq for Poplar1IdpfValue<F> +where + F: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl<F> ConstantTimeEq for Poplar1IdpfValue<F> +where + F: ConstantTimeEq, +{ + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +impl<F> Encode for Poplar1IdpfValue<F> +where + F: FieldElement, +{ + fn encode(&self, bytes: &mut Vec<u8>) { + self.0[0].encode(bytes); + self.0[1].encode(bytes); + } + + fn encoded_len(&self) -> Option<usize> { + Some(F::ENCODED_SIZE * 2) + } +} + +impl<F> Decode for Poplar1IdpfValue<F> +where + F: Decode, +{ + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + Ok(Self([F::decode(bytes)?, F::decode(bytes)?])) + } +} + +impl<F> ConditionallySelectable for Poplar1IdpfValue<F> +where + F: ConditionallySelectable, +{ + fn conditional_select(a: &Self, b: &Self, choice: subtle::Choice) -> Self { + Self([ + F::conditional_select(&a.0[0], &b.0[0], choice), + F::conditional_select(&a.0[1], &b.0[1], choice), + ]) + } +} + +impl<F> ConditionallyNegatable for Poplar1IdpfValue<F> +where + F: ConditionallyNegatable, +{ + fn conditional_negate(&mut self, choice: subtle::Choice) { + F::conditional_negate(&mut self.0[0], choice); + F::conditional_negate(&mut self.0[1], choice); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::vdaf::{equality_comparison_test, run_vdaf_prepare}; + use assert_matches::assert_matches; + use rand::prelude::*; + use serde::Deserialize; + use std::collections::HashSet; + + fn test_prepare<P: Xof<SEED_SIZE>, const SEED_SIZE: usize>( + vdaf: &Poplar1<P, SEED_SIZE>, + verify_key: &[u8; SEED_SIZE], + nonce: &[u8; 16], + public_share: &Poplar1PublicShare, + input_shares: &[Poplar1InputShare<SEED_SIZE>], + agg_param: &Poplar1AggregationParam, + expected_result: Vec<u64>, + ) { + let out_shares = run_vdaf_prepare( + vdaf, + verify_key, + agg_param, + nonce, + public_share.clone(), + input_shares.to_vec(), + ) + .unwrap(); + + // Convert aggregate shares and unshard. + let agg_share_0 = vdaf.aggregate(agg_param, [out_shares[0].clone()]).unwrap(); + let agg_share_1 = vdaf.aggregate(agg_param, [out_shares[1].clone()]).unwrap(); + let result = vdaf + .unshard(agg_param, [agg_share_0, agg_share_1], 1) + .unwrap(); + assert_eq!( + result, expected_result, + "unexpected result (level={})", + agg_param.level + ); + } + + fn run_heavy_hitters<B: AsRef<[u8]>, P: Xof<SEED_SIZE>, const SEED_SIZE: usize>( + vdaf: &Poplar1<P, SEED_SIZE>, + verify_key: &[u8; SEED_SIZE], + threshold: usize, + measurements: impl IntoIterator<Item = B>, + expected_result: impl IntoIterator<Item = B>, + ) { + let mut rng = thread_rng(); + + // Sharding step + let reports: Vec<( + [u8; 16], + Poplar1PublicShare, + Vec<Poplar1InputShare<SEED_SIZE>>, + )> = measurements + .into_iter() + .map(|measurement| { + let nonce = rng.gen(); + let (public_share, input_shares) = vdaf + .shard(&IdpfInput::from_bytes(measurement.as_ref()), &nonce) + .unwrap(); + (nonce, public_share, input_shares) + }) + .collect(); + + let mut agg_param = Poplar1AggregationParam { + level: 0, + prefixes: vec![ + IdpfInput::from_bools(&[false]), + IdpfInput::from_bools(&[true]), + ], + }; + + let mut agg_result = Vec::new(); + for level in 0..vdaf.bits { + let mut out_shares_0 = Vec::with_capacity(reports.len()); + let mut out_shares_1 = Vec::with_capacity(reports.len()); + + // Preparation step + for (nonce, public_share, input_shares) in reports.iter() { + let out_shares = run_vdaf_prepare( + vdaf, + verify_key, + &agg_param, + nonce, + public_share.clone(), + input_shares.to_vec(), + ) + .unwrap(); + + out_shares_0.push(out_shares[0].clone()); + out_shares_1.push(out_shares[1].clone()); + } + + // Aggregation step + let agg_share_0 = vdaf.aggregate(&agg_param, out_shares_0).unwrap(); + let agg_share_1 = vdaf.aggregate(&agg_param, out_shares_1).unwrap(); + + // Unsharding step + agg_result = vdaf + .unshard(&agg_param, [agg_share_0, agg_share_1], reports.len()) + .unwrap(); + + agg_param.level += 1; + + // Unless this is the last level of the tree, construct the next set of candidate + // prefixes. + if level < vdaf.bits - 1 { + let mut next_prefixes = Vec::new(); + for (prefix, count) in agg_param.prefixes.into_iter().zip(agg_result.iter()) { + if *count >= threshold as u64 { + next_prefixes.push(prefix.clone_with_suffix(&[false])); + next_prefixes.push(prefix.clone_with_suffix(&[true])); + } + } + + agg_param.prefixes = next_prefixes; + } + } + + let got: HashSet<IdpfInput> = agg_param + .prefixes + .into_iter() + .zip(agg_result.iter()) + .filter(|(_prefix, count)| **count >= threshold as u64) + .map(|(prefix, _count)| prefix) + .collect(); + + let want: HashSet<IdpfInput> = expected_result + .into_iter() + .map(|bytes| IdpfInput::from_bytes(bytes.as_ref())) + .collect(); + + assert_eq!(got, want); + } + + #[test] + fn shard_prepare() { + let mut rng = thread_rng(); + let vdaf = Poplar1::new_shake128(64); + let verify_key = rng.gen(); + let input = IdpfInput::from_bytes(b"12341324"); + let nonce = rng.gen(); + let (public_share, input_shares) = vdaf.shard(&input, &nonce).unwrap(); + + test_prepare( + &vdaf, + &verify_key, + &nonce, + &public_share, + &input_shares, + &Poplar1AggregationParam { + level: 7, + prefixes: vec![ + IdpfInput::from_bytes(b"0"), + IdpfInput::from_bytes(b"1"), + IdpfInput::from_bytes(b"2"), + IdpfInput::from_bytes(b"f"), + ], + }, + vec![0, 1, 0, 0], + ); + + for level in 0..vdaf.bits { + test_prepare( + &vdaf, + &verify_key, + &nonce, + &public_share, + &input_shares, + &Poplar1AggregationParam { + level: level as u16, + prefixes: vec![input.prefix(level)], + }, + vec![1], + ); + } + } + + #[test] + fn heavy_hitters() { + let mut rng = thread_rng(); + let verify_key = rng.gen(); + let vdaf = Poplar1::new_shake128(8); + + run_heavy_hitters( + &vdaf, + &verify_key, + 2, // threshold + [ + "a", "b", "c", "d", "e", "f", "g", "g", "h", "i", "i", "i", "j", "j", "k", "l", + ], // measurements + ["g", "i", "j"], // heavy hitters + ); + } + + #[test] + fn encoded_len() { + // Input share + let input_share = Poplar1InputShare { + idpf_key: Seed::<16>::generate().unwrap(), + corr_seed: Seed::<16>::generate().unwrap(), + corr_inner: vec![ + [Field64::one(), <Field64 as FieldElement>::zero()], + [Field64::one(), <Field64 as FieldElement>::zero()], + [Field64::one(), <Field64 as FieldElement>::zero()], + ], + corr_leaf: [Field255::one(), <Field255 as FieldElement>::zero()], + }; + assert_eq!( + input_share.get_encoded().len(), + input_share.encoded_len().unwrap() + ); + + // Prepaare message variants + let prep_msg = Poplar1PrepareMessage(PrepareMessageVariant::SketchInner([ + Field64::one(), + Field64::one(), + Field64::one(), + ])); + assert_eq!( + prep_msg.get_encoded().len(), + prep_msg.encoded_len().unwrap() + ); + let prep_msg = Poplar1PrepareMessage(PrepareMessageVariant::SketchLeaf([ + Field255::one(), + Field255::one(), + Field255::one(), + ])); + assert_eq!( + prep_msg.get_encoded().len(), + prep_msg.encoded_len().unwrap() + ); + let prep_msg = Poplar1PrepareMessage(PrepareMessageVariant::Done); + assert_eq!( + prep_msg.get_encoded().len(), + prep_msg.encoded_len().unwrap() + ); + + // Field vector variants. + let field_vec = Poplar1FieldVec::Inner(vec![Field64::one(); 23]); + assert_eq!( + field_vec.get_encoded().len(), + field_vec.encoded_len().unwrap() + ); + let field_vec = Poplar1FieldVec::Leaf(vec![Field255::one(); 23]); + assert_eq!( + field_vec.get_encoded().len(), + field_vec.encoded_len().unwrap() + ); + + // Aggregation parameter. + let agg_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bytes(b"ab"), + IdpfInput::from_bytes(b"cd"), + ])) + .unwrap(); + assert_eq!( + agg_param.get_encoded().len(), + agg_param.encoded_len().unwrap() + ); + let agg_param = Poplar1AggregationParam::try_from_prefixes(Vec::from([ + IdpfInput::from_bools(&[false]), + IdpfInput::from_bools(&[true]), + ])) + .unwrap(); + assert_eq!( + agg_param.get_encoded().len(), + agg_param.encoded_len().unwrap() + ); + } + + #[test] + fn round_trip_prepare_state() { + let vdaf = Poplar1::new_shake128(1); + for (agg_id, prep_state) in [ + ( + 0, + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: true, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3), Field64::from(4)]), + })), + ), + ( + 1, + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(5), + B_share: Field64::from(6), + is_leader: false, + }, + output_share: Vec::from([Field64::from(7), Field64::from(8), Field64::from(9)]), + })), + ), + ( + 0, + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([ + Field64::from(10), + Field64::from(11), + Field64::from(12), + ]), + })), + ), + ( + 1, + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([ + Field64::from(13), + Field64::from(14), + Field64::from(15), + ]), + })), + ), + ( + 0, + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(16), + B_share: Field255::from(17), + is_leader: true, + }, + output_share: Vec::from([ + Field255::from(18), + Field255::from(19), + Field255::from(20), + ]), + })), + ), + ( + 1, + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(21), + B_share: Field255::from(22), + is_leader: false, + }, + output_share: Vec::from([ + Field255::from(23), + Field255::from(24), + Field255::from(25), + ]), + })), + ), + ( + 0, + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([ + Field255::from(26), + Field255::from(27), + Field255::from(28), + ]), + })), + ), + ( + 1, + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([ + Field255::from(29), + Field255::from(30), + Field255::from(31), + ]), + })), + ), + ] { + let encoded_prep_state = prep_state.get_encoded(); + assert_eq!(prep_state.encoded_len(), Some(encoded_prep_state.len())); + let decoded_prep_state = + Poplar1PrepareState::get_decoded_with_param(&(&vdaf, agg_id), &encoded_prep_state) + .unwrap(); + assert_eq!(prep_state, decoded_prep_state); + } + } + + #[test] + fn round_trip_agg_param() { + // These test cases were generated using the reference Sage implementation. + // (https://github.com/cfrg/draft-irtf-cfrg-vdaf/tree/main/poc) Sage statements used to + // generate each test case are given in comments. + for (prefixes, reference_encoding) in [ + // poplar.encode_agg_param(0, [0]) + ( + Vec::from([IdpfInput::from_bools(&[false])]), + [0, 0, 0, 0, 0, 1, 0].as_slice(), + ), + // poplar.encode_agg_param(0, [1]) + ( + Vec::from([IdpfInput::from_bools(&[true])]), + [0, 0, 0, 0, 0, 1, 1].as_slice(), + ), + // poplar.encode_agg_param(0, [0, 1]) + ( + Vec::from([ + IdpfInput::from_bools(&[false]), + IdpfInput::from_bools(&[true]), + ]), + [0, 0, 0, 0, 0, 2, 2].as_slice(), + ), + // poplar.encode_agg_param(1, [0b00, 0b01, 0b10, 0b11]) + ( + Vec::from([ + IdpfInput::from_bools(&[false, false]), + IdpfInput::from_bools(&[false, true]), + IdpfInput::from_bools(&[true, false]), + IdpfInput::from_bools(&[true, true]), + ]), + [0, 1, 0, 0, 0, 4, 0xe4].as_slice(), + ), + // poplar.encode_agg_param(1, [0b00, 0b10, 0b11]) + ( + Vec::from([ + IdpfInput::from_bools(&[false, false]), + IdpfInput::from_bools(&[true, false]), + IdpfInput::from_bools(&[true, true]), + ]), + [0, 1, 0, 0, 0, 3, 0x38].as_slice(), + ), + // poplar.encode_agg_param(2, [0b000, 0b001, 0b010, 0b011, 0b100, 0b101, 0b110, 0b111]) + ( + Vec::from([ + IdpfInput::from_bools(&[false, false, false]), + IdpfInput::from_bools(&[false, false, true]), + IdpfInput::from_bools(&[false, true, false]), + IdpfInput::from_bools(&[false, true, true]), + IdpfInput::from_bools(&[true, false, false]), + IdpfInput::from_bools(&[true, false, true]), + IdpfInput::from_bools(&[true, true, false]), + IdpfInput::from_bools(&[true, true, true]), + ]), + [0, 2, 0, 0, 0, 8, 0xfa, 0xc6, 0x88].as_slice(), + ), + // poplar.encode_agg_param(9, [0b01_1011_0010, 0b10_1101_1010]) + ( + Vec::from([ + IdpfInput::from_bools(&[ + false, true, true, false, true, true, false, false, true, false, + ]), + IdpfInput::from_bools(&[ + true, false, true, true, false, true, true, false, true, false, + ]), + ]), + [0, 9, 0, 0, 0, 2, 0x0b, 0x69, 0xb2].as_slice(), + ), + // poplar.encode_agg_param(15, [0xcafe]) + ( + Vec::from([IdpfInput::from_bytes(b"\xca\xfe")]), + [0, 15, 0, 0, 0, 1, 0xca, 0xfe].as_slice(), + ), + ] { + let agg_param = Poplar1AggregationParam::try_from_prefixes(prefixes).unwrap(); + let encoded = agg_param.get_encoded(); + assert_eq!(encoded, reference_encoding); + let decoded = Poplar1AggregationParam::get_decoded(reference_encoding).unwrap(); + assert_eq!(decoded, agg_param); + } + } + + #[test] + fn agg_param_wrong_unused_bit() { + let err = Poplar1AggregationParam::get_decoded(&[0, 0, 0, 0, 0, 1, 2]).unwrap_err(); + assert_matches!(err, CodecError::UnexpectedValue); + } + + #[test] + fn agg_param_ordering() { + let err = Poplar1AggregationParam::get_decoded(&[0, 0, 0, 0, 0, 2, 1]).unwrap_err(); + assert_matches!(err, CodecError::Other(_)); + let err = Poplar1AggregationParam::get_decoded(&[0, 0, 0, 0, 0, 2, 0]).unwrap_err(); + assert_matches!(err, CodecError::Other(_)); + let err = Poplar1AggregationParam::get_decoded(&[0, 0, 0, 0, 0, 2, 3]).unwrap_err(); + assert_matches!(err, CodecError::Other(_)); + } + + #[derive(Debug, Deserialize)] + struct HexEncoded(#[serde(with = "hex")] Vec<u8>); + + impl AsRef<[u8]> for HexEncoded { + fn as_ref(&self) -> &[u8] { + &self.0 + } + } + + #[derive(Debug, Deserialize)] + struct PoplarTestVector { + agg_param: (usize, Vec<u64>), + agg_result: Vec<u64>, + agg_shares: Vec<HexEncoded>, + bits: usize, + prep: Vec<PreparationTestVector>, + verify_key: HexEncoded, + } + + #[derive(Debug, Deserialize)] + struct PreparationTestVector { + input_shares: Vec<HexEncoded>, + measurement: u64, + nonce: HexEncoded, + out_shares: Vec<Vec<HexEncoded>>, + prep_messages: Vec<HexEncoded>, + prep_shares: Vec<Vec<HexEncoded>>, + public_share: HexEncoded, + rand: HexEncoded, + } + + fn check_test_vec(input: &str) { + let test_vector: PoplarTestVector = serde_json::from_str(input).unwrap(); + assert_eq!(test_vector.prep.len(), 1); + let prep = &test_vector.prep[0]; + let measurement_bits = (0..test_vector.bits) + .rev() + .map(|i| (prep.measurement >> i) & 1 != 0) + .collect::<BitVec>(); + let measurement = IdpfInput::from(measurement_bits); + let (agg_param_level, agg_param_prefixes_int) = test_vector.agg_param; + let agg_param_prefixes = agg_param_prefixes_int + .iter() + .map(|int| { + let bits = (0..=agg_param_level) + .rev() + .map(|i| (*int >> i) & 1 != 0) + .collect::<BitVec>(); + bits.into() + }) + .collect::<Vec<IdpfInput>>(); + let agg_param = Poplar1AggregationParam::try_from_prefixes(agg_param_prefixes).unwrap(); + let verify_key = test_vector.verify_key.as_ref().try_into().unwrap(); + let nonce = prep.nonce.as_ref().try_into().unwrap(); + + let mut idpf_random = [[0u8; 16]; 2]; + let mut poplar_random = [[0u8; 16]; 3]; + for (input, output) in prep + .rand + .as_ref() + .chunks_exact(16) + .zip(idpf_random.iter_mut().chain(poplar_random.iter_mut())) + { + output.copy_from_slice(input); + } + + // Shard measurement. + let poplar = Poplar1::new_shake128(test_vector.bits); + let (public_share, input_shares) = poplar + .shard_with_random(&measurement, &nonce, &idpf_random, &poplar_random) + .unwrap(); + + // Run aggregation. + let (init_prep_state_0, init_prep_share_0) = poplar + .prepare_init( + &verify_key, + 0, + &agg_param, + &nonce, + &public_share, + &input_shares[0], + ) + .unwrap(); + let (init_prep_state_1, init_prep_share_1) = poplar + .prepare_init( + &verify_key, + 1, + &agg_param, + &nonce, + &public_share, + &input_shares[1], + ) + .unwrap(); + + let r1_prep_msg = poplar + .prepare_shares_to_prepare_message( + &agg_param, + [init_prep_share_0.clone(), init_prep_share_1.clone()], + ) + .unwrap(); + + let (r1_prep_state_0, r1_prep_share_0) = assert_matches!( + poplar + .prepare_next(init_prep_state_0.clone(), r1_prep_msg.clone()) + .unwrap(), + PrepareTransition::Continue(state, share) => (state, share) + ); + let (r1_prep_state_1, r1_prep_share_1) = assert_matches!( + poplar + .prepare_next(init_prep_state_1.clone(), r1_prep_msg.clone()) + .unwrap(), + PrepareTransition::Continue(state, share) => (state, share) + ); + + let r2_prep_msg = poplar + .prepare_shares_to_prepare_message( + &agg_param, + [r1_prep_share_0.clone(), r1_prep_share_1.clone()], + ) + .unwrap(); + + let out_share_0 = assert_matches!( + poplar + .prepare_next(r1_prep_state_0.clone(), r2_prep_msg.clone()) + .unwrap(), + PrepareTransition::Finish(out) => out + ); + let out_share_1 = assert_matches!( + poplar + .prepare_next(r1_prep_state_1, r2_prep_msg.clone()) + .unwrap(), + PrepareTransition::Finish(out) => out + ); + + let agg_share_0 = poplar.aggregate(&agg_param, [out_share_0.clone()]).unwrap(); + let agg_share_1 = poplar.aggregate(&agg_param, [out_share_1.clone()]).unwrap(); + + // Collect result. + let agg_result = poplar + .unshard(&agg_param, [agg_share_0.clone(), agg_share_1.clone()], 1) + .unwrap(); + + // Check all intermediate results against the test vector, and exercise both encoding and decoding. + assert_eq!( + public_share, + Poplar1PublicShare::get_decoded_with_param(&poplar, prep.public_share.as_ref()) + .unwrap() + ); + assert_eq!(&public_share.get_encoded(), prep.public_share.as_ref()); + assert_eq!( + input_shares[0], + Poplar1InputShare::get_decoded_with_param(&(&poplar, 0), prep.input_shares[0].as_ref()) + .unwrap() + ); + assert_eq!( + &input_shares[0].get_encoded(), + prep.input_shares[0].as_ref() + ); + assert_eq!( + input_shares[1], + Poplar1InputShare::get_decoded_with_param(&(&poplar, 1), prep.input_shares[1].as_ref()) + .unwrap() + ); + assert_eq!( + &input_shares[1].get_encoded(), + prep.input_shares[1].as_ref() + ); + assert_eq!( + init_prep_share_0, + Poplar1FieldVec::get_decoded_with_param( + &init_prep_state_0, + prep.prep_shares[0][0].as_ref() + ) + .unwrap() + ); + assert_eq!( + &init_prep_share_0.get_encoded(), + prep.prep_shares[0][0].as_ref() + ); + assert_eq!( + init_prep_share_1, + Poplar1FieldVec::get_decoded_with_param( + &init_prep_state_1, + prep.prep_shares[0][1].as_ref() + ) + .unwrap() + ); + assert_eq!( + &init_prep_share_1.get_encoded(), + prep.prep_shares[0][1].as_ref() + ); + assert_eq!( + r1_prep_msg, + Poplar1PrepareMessage::get_decoded_with_param( + &init_prep_state_0, + prep.prep_messages[0].as_ref() + ) + .unwrap() + ); + assert_eq!(&r1_prep_msg.get_encoded(), prep.prep_messages[0].as_ref()); + + assert_eq!( + r1_prep_share_0, + Poplar1FieldVec::get_decoded_with_param( + &r1_prep_state_0, + prep.prep_shares[1][0].as_ref() + ) + .unwrap() + ); + assert_eq!( + &r1_prep_share_0.get_encoded(), + prep.prep_shares[1][0].as_ref() + ); + assert_eq!( + r1_prep_share_1, + Poplar1FieldVec::get_decoded_with_param( + &r1_prep_state_0, + prep.prep_shares[1][1].as_ref() + ) + .unwrap() + ); + assert_eq!( + &r1_prep_share_1.get_encoded(), + prep.prep_shares[1][1].as_ref() + ); + assert_eq!( + r2_prep_msg, + Poplar1PrepareMessage::get_decoded_with_param( + &r1_prep_state_0, + prep.prep_messages[1].as_ref() + ) + .unwrap() + ); + assert_eq!(&r2_prep_msg.get_encoded(), prep.prep_messages[1].as_ref()); + for (out_share, expected_out_share) in [ + (out_share_0, &prep.out_shares[0]), + (out_share_1, &prep.out_shares[1]), + ] { + match out_share { + Poplar1FieldVec::Inner(vec) => { + assert_eq!(vec.len(), expected_out_share.len()); + for (element, expected) in vec.iter().zip(expected_out_share.iter()) { + assert_eq!(&element.get_encoded(), expected.as_ref()); + } + } + Poplar1FieldVec::Leaf(vec) => { + assert_eq!(vec.len(), expected_out_share.len()); + for (element, expected) in vec.iter().zip(expected_out_share.iter()) { + assert_eq!(&element.get_encoded(), expected.as_ref()); + } + } + }; + } + assert_eq!( + agg_share_0, + Poplar1FieldVec::get_decoded_with_param( + &(&poplar, &agg_param), + test_vector.agg_shares[0].as_ref() + ) + .unwrap() + ); + + assert_eq!( + &agg_share_0.get_encoded(), + test_vector.agg_shares[0].as_ref() + ); + assert_eq!( + agg_share_1, + Poplar1FieldVec::get_decoded_with_param( + &(&poplar, &agg_param), + test_vector.agg_shares[1].as_ref() + ) + .unwrap() + ); + assert_eq!( + &agg_share_1.get_encoded(), + test_vector.agg_shares[1].as_ref() + ); + assert_eq!(agg_result, test_vector.agg_result); + } + + #[test] + fn test_vec_poplar1_0() { + check_test_vec(include_str!("test_vec/07/Poplar1_0.json")); + } + + #[test] + fn test_vec_poplar1_1() { + check_test_vec(include_str!("test_vec/07/Poplar1_1.json")); + } + + #[test] + fn test_vec_poplar1_2() { + check_test_vec(include_str!("test_vec/07/Poplar1_2.json")); + } + + #[test] + fn test_vec_poplar1_3() { + check_test_vec(include_str!("test_vec/07/Poplar1_3.json")); + } + + #[test] + fn input_share_equality_test() { + equality_comparison_test(&[ + // Default. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified idpf_key. + Poplar1InputShare { + idpf_key: Seed([15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified corr_seed. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([18, 17, 16]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified corr_inner. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(24), Field64::from(23)], + [Field64::from(22), Field64::from(21)], + [Field64::from(20), Field64::from(19)], + ]), + corr_leaf: [Field255::from(25), Field255::from(26)], + }, + // Modified corr_leaf. + Poplar1InputShare { + idpf_key: Seed([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]), + corr_seed: Seed([16, 17, 18]), + corr_inner: Vec::from([ + [Field64::from(19), Field64::from(20)], + [Field64::from(21), Field64::from(22)], + [Field64::from(23), Field64::from(24)], + ]), + corr_leaf: [Field255::from(26), Field255::from(25)], + }, + ]) + } + + #[test] + fn prepare_state_equality_test() { + // This test effectively covers PrepareStateVariant, PrepareState, SketchState as well. + equality_comparison_test(&[ + // Inner, round one. (default) + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: false, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified A_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(100), + B_share: Field64::from(1), + is_leader: false, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified B_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(101), + is_leader: false, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified is_leader. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: true, + }, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round one, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field64::from(0), + B_share: Field64::from(1), + is_leader: false, + }, + output_share: Vec::from([Field64::from(3), Field64::from(2)]), + })), + // Inner, round two. (default) + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field64::from(2), Field64::from(3)]), + })), + // Inner, round two, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Inner(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field64::from(3), Field64::from(2)]), + })), + // Leaf, round one. (default) + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(1), + is_leader: false, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified A_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(100), + B_share: Field255::from(1), + is_leader: false, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified B_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(101), + is_leader: false, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified is_leader. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(1), + is_leader: true, + }, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round one, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundOne { + A_share: Field255::from(0), + B_share: Field255::from(1), + is_leader: false, + }, + output_share: Vec::from([Field255::from(3), Field255::from(2)]), + })), + // Leaf, round two. (default) + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field255::from(2), Field255::from(3)]), + })), + // Leaf, round two, modified output_share. + Poplar1PrepareState(PrepareStateVariant::Leaf(PrepareState { + sketch: SketchState::RoundTwo, + output_share: Vec::from([Field255::from(3), Field255::from(2)]), + })), + ]) + } + + #[test] + fn field_vec_equality_test() { + equality_comparison_test(&[ + // Inner. (default) + Poplar1FieldVec::Inner(Vec::from([Field64::from(0), Field64::from(1)])), + // Inner, modified value. + Poplar1FieldVec::Inner(Vec::from([Field64::from(1), Field64::from(0)])), + // Leaf. (deafult) + Poplar1FieldVec::Leaf(Vec::from([Field255::from(0), Field255::from(1)])), + // Leaf, modified value. + Poplar1FieldVec::Leaf(Vec::from([Field255::from(1), Field255::from(0)])), + ]) + } +} diff --git a/third_party/rust/prio/src/vdaf/prio2.rs b/third_party/rust/prio/src/vdaf/prio2.rs new file mode 100644 index 0000000000..4669c47d00 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio2.rs @@ -0,0 +1,543 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Backwards-compatible port of the ENPA Prio system to a VDAF. + +use crate::{ + codec::{CodecError, Decode, Encode, ParameterizedDecode}, + field::{ + decode_fieldvec, FftFriendlyFieldElement, FieldElement, FieldElementWithInteger, FieldPrio2, + }, + prng::Prng, + vdaf::{ + prio2::{ + client::{self as v2_client, proof_length}, + server as v2_server, + }, + xof::Seed, + Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, + PrepareTransition, Share, ShareDecodingParameter, Vdaf, VdafError, + }, +}; +use hmac::{Hmac, Mac}; +use rand_core::RngCore; +use sha2::Sha256; +use std::{convert::TryFrom, io::Cursor}; +use subtle::{Choice, ConstantTimeEq}; + +mod client; +mod server; +#[cfg(test)] +mod test_vector; + +/// The Prio2 VDAF. It supports the same measurement type as +/// [`Prio3SumVec`](crate::vdaf::prio3::Prio3SumVec) with `bits == 1` but uses the proof system and +/// finite field deployed in ENPA. +#[derive(Clone, Debug)] +pub struct Prio2 { + input_len: usize, +} + +impl Prio2 { + /// Returns an instance of the VDAF for the given input length. + pub fn new(input_len: usize) -> Result<Self, VdafError> { + let n = (input_len + 1).next_power_of_two(); + if let Ok(size) = u32::try_from(2 * n) { + if size > FieldPrio2::generator_order() { + return Err(VdafError::Uncategorized( + "input size exceeds field capacity".into(), + )); + } + } else { + return Err(VdafError::Uncategorized( + "input size exceeds memory capacity".into(), + )); + } + + Ok(Prio2 { input_len }) + } + + /// Prepare an input share for aggregation using the given field element `query_rand` to + /// compute the verifier share. + /// + /// In the [`Aggregator`] trait implementation for [`Prio2`], the query randomness is computed + /// jointly by the Aggregators. This method is designed to be used in applications, like ENPA, + /// in which the query randomness is instead chosen by a third-party. + pub fn prepare_init_with_query_rand( + &self, + query_rand: FieldPrio2, + input_share: &Share<FieldPrio2, 32>, + is_leader: bool, + ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> { + let expanded_data: Option<Vec<FieldPrio2>> = match input_share { + Share::Leader(_) => None, + Share::Helper(ref seed) => { + let prng = Prng::from_prio2_seed(seed.as_ref()); + Some(prng.take(proof_length(self.input_len)).collect()) + } + }; + let data = match input_share { + Share::Leader(ref data) => data, + Share::Helper(_) => expanded_data.as_ref().unwrap(), + }; + + let verifier_share = v2_server::generate_verification_message( + self.input_len, + query_rand, + data, // Combined input and proof shares + is_leader, + ) + .map_err(|e| VdafError::Uncategorized(e.to_string()))?; + + Ok(( + Prio2PrepareState(input_share.truncated(self.input_len)), + Prio2PrepareShare(verifier_share), + )) + } + + /// Choose a random point for polynomial evaluation. + /// + /// The point returned is not one of the roots used for polynomial interpolation. + pub(crate) fn choose_eval_at<S>(&self, prng: &mut Prng<FieldPrio2, S>) -> FieldPrio2 + where + S: RngCore, + { + // Make sure the query randomness isn't a root of unity. Evaluating the proof at any of + // these points would be a privacy violation, since these points were used by the prover to + // construct the wire polynomials. + let n = (self.input_len + 1).next_power_of_two(); + let proof_length = 2 * n; + loop { + let eval_at: FieldPrio2 = prng.get(); + // Unwrap safety: the constructor checks that this conversion succeeds. + if eval_at.pow(u32::try_from(proof_length).unwrap()) != FieldPrio2::one() { + return eval_at; + } + } + } +} + +impl Vdaf for Prio2 { + const ID: u32 = 0xFFFF0000; + type Measurement = Vec<u32>; + type AggregateResult = Vec<u32>; + type AggregationParam = (); + type PublicShare = (); + type InputShare = Share<FieldPrio2, 32>; + type OutputShare = OutputShare<FieldPrio2>; + type AggregateShare = AggregateShare<FieldPrio2>; + + fn num_aggregators(&self) -> usize { + // Prio2 can easily be extended to support more than two Aggregators. + 2 + } +} + +impl Client<16> for Prio2 { + fn shard( + &self, + measurement: &Vec<u32>, + _nonce: &[u8; 16], + ) -> Result<(Self::PublicShare, Vec<Share<FieldPrio2, 32>>), VdafError> { + if measurement.len() != self.input_len { + return Err(VdafError::Uncategorized("incorrect input length".into())); + } + let mut input: Vec<FieldPrio2> = Vec::with_capacity(measurement.len()); + for int in measurement { + input.push((*int).into()); + } + + let mut mem = v2_client::ClientMemory::new(self.input_len)?; + let copy_data = |share_data: &mut [FieldPrio2]| { + share_data[..].clone_from_slice(&input); + }; + let mut leader_data = mem.prove_with(self.input_len, copy_data); + + let helper_seed = Seed::generate()?; + let helper_prng = Prng::from_prio2_seed(helper_seed.as_ref()); + for (s1, d) in leader_data.iter_mut().zip(helper_prng.into_iter()) { + *s1 -= d; + } + + Ok(( + (), + vec![Share::Leader(leader_data), Share::Helper(helper_seed)], + )) + } +} + +/// State of each [`Aggregator`] during the Preparation phase. +#[derive(Clone, Debug)] +pub struct Prio2PrepareState(Share<FieldPrio2, 32>); + +impl PartialEq for Prio2PrepareState { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for Prio2PrepareState {} + +impl ConstantTimeEq for Prio2PrepareState { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +impl Encode for Prio2PrepareState { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.encode(bytes); + } + + fn encoded_len(&self) -> Option<usize> { + self.0.encoded_len() + } +} + +impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Prio2PrepareState { + fn decode_with_param( + (prio2, agg_id): &(&'a Prio2, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let share_decoder = if *agg_id == 0 { + ShareDecodingParameter::Leader(prio2.input_len) + } else { + ShareDecodingParameter::Helper + }; + let out_share = Share::decode_with_param(&share_decoder, bytes)?; + Ok(Self(out_share)) + } +} + +/// Message emitted by each [`Aggregator`] during the Preparation phase. +#[derive(Clone, Debug)] +pub struct Prio2PrepareShare(v2_server::VerificationMessage<FieldPrio2>); + +impl Encode for Prio2PrepareShare { + fn encode(&self, bytes: &mut Vec<u8>) { + self.0.f_r.encode(bytes); + self.0.g_r.encode(bytes); + self.0.h_r.encode(bytes); + } + + fn encoded_len(&self) -> Option<usize> { + Some(FieldPrio2::ENCODED_SIZE * 3) + } +} + +impl ParameterizedDecode<Prio2PrepareState> for Prio2PrepareShare { + fn decode_with_param( + _state: &Prio2PrepareState, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + Ok(Self(v2_server::VerificationMessage { + f_r: FieldPrio2::decode(bytes)?, + g_r: FieldPrio2::decode(bytes)?, + h_r: FieldPrio2::decode(bytes)?, + })) + } +} + +impl Aggregator<32, 16> for Prio2 { + type PrepareState = Prio2PrepareState; + type PrepareShare = Prio2PrepareShare; + type PrepareMessage = (); + + fn prepare_init( + &self, + agg_key: &[u8; 32], + agg_id: usize, + _agg_param: &Self::AggregationParam, + nonce: &[u8; 16], + _public_share: &Self::PublicShare, + input_share: &Share<FieldPrio2, 32>, + ) -> Result<(Prio2PrepareState, Prio2PrepareShare), VdafError> { + let is_leader = role_try_from(agg_id)?; + + // In the ENPA Prio system, the query randomness is generated by a third party and + // distributed to the Aggregators after they receive their input shares. In a VDAF, shared + // randomness is derived from a nonce selected by the client. For Prio2 we compute the + // query using HMAC-SHA256 evaluated over the nonce. + // + // Unwrap safety: new_from_slice() is infallible for Hmac. + let mut mac = Hmac::<Sha256>::new_from_slice(agg_key).unwrap(); + mac.update(nonce); + let hmac_tag = mac.finalize(); + let mut prng = Prng::from_prio2_seed(&hmac_tag.into_bytes().into()); + let query_rand = self.choose_eval_at(&mut prng); + + self.prepare_init_with_query_rand(query_rand, input_share, is_leader) + } + + fn prepare_shares_to_prepare_message<M: IntoIterator<Item = Prio2PrepareShare>>( + &self, + _: &Self::AggregationParam, + inputs: M, + ) -> Result<(), VdafError> { + let verifier_shares: Vec<v2_server::VerificationMessage<FieldPrio2>> = + inputs.into_iter().map(|msg| msg.0).collect(); + if verifier_shares.len() != 2 { + return Err(VdafError::Uncategorized( + "wrong number of verifier shares".into(), + )); + } + + if !v2_server::is_valid_share(&verifier_shares[0], &verifier_shares[1]) { + return Err(VdafError::Uncategorized( + "proof verifier check failed".into(), + )); + } + + Ok(()) + } + + fn prepare_next( + &self, + state: Prio2PrepareState, + _input: (), + ) -> Result<PrepareTransition<Self, 32, 16>, VdafError> { + let data = match state.0 { + Share::Leader(data) => data, + Share::Helper(seed) => { + let prng = Prng::from_prio2_seed(seed.as_ref()); + prng.take(self.input_len).collect() + } + }; + Ok(PrepareTransition::Finish(OutputShare::from(data))) + } + + fn aggregate<M: IntoIterator<Item = OutputShare<FieldPrio2>>>( + &self, + _agg_param: &Self::AggregationParam, + out_shares: M, + ) -> Result<AggregateShare<FieldPrio2>, VdafError> { + let mut agg_share = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); + for out_share in out_shares.into_iter() { + agg_share.accumulate(&out_share)?; + } + + Ok(agg_share) + } +} + +impl Collector for Prio2 { + fn unshard<M: IntoIterator<Item = AggregateShare<FieldPrio2>>>( + &self, + _agg_param: &Self::AggregationParam, + agg_shares: M, + _num_measurements: usize, + ) -> Result<Vec<u32>, VdafError> { + let mut agg = AggregateShare(vec![FieldPrio2::zero(); self.input_len]); + for agg_share in agg_shares.into_iter() { + agg.merge(&agg_share)?; + } + + Ok(agg.0.into_iter().map(u32::from).collect()) + } +} + +impl<'a> ParameterizedDecode<(&'a Prio2, usize)> for Share<FieldPrio2, 32> { + fn decode_with_param( + (prio2, agg_id): &(&'a Prio2, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let is_leader = role_try_from(*agg_id).map_err(|e| CodecError::Other(Box::new(e)))?; + let decoder = if is_leader { + ShareDecodingParameter::Leader(proof_length(prio2.input_len)) + } else { + ShareDecodingParameter::Helper + }; + + Share::decode_with_param(&decoder, bytes) + } +} + +impl<'a, F> ParameterizedDecode<(&'a Prio2, &'a ())> for OutputShare<F> +where + F: FieldElement, +{ + fn decode_with_param( + (prio2, _): &(&'a Prio2, &'a ()), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + decode_fieldvec(prio2.input_len, bytes).map(Self) + } +} + +impl<'a, F> ParameterizedDecode<(&'a Prio2, &'a ())> for AggregateShare<F> +where + F: FieldElement, +{ + fn decode_with_param( + (prio2, _): &(&'a Prio2, &'a ()), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + decode_fieldvec(prio2.input_len, bytes).map(Self) + } +} + +fn role_try_from(agg_id: usize) -> Result<bool, VdafError> { + match agg_id { + 0 => Ok(true), + 1 => Ok(false), + _ => Err(VdafError::Uncategorized("unexpected aggregator id".into())), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::vdaf::{ + equality_comparison_test, fieldvec_roundtrip_test, prio2::test_vector::Priov2TestVector, + run_vdaf, + }; + use assert_matches::assert_matches; + use rand::prelude::*; + + #[test] + fn run_prio2() { + let prio2 = Prio2::new(6).unwrap(); + + assert_eq!( + run_vdaf( + &prio2, + &(), + [ + vec![0, 0, 0, 0, 1, 0], + vec![0, 1, 0, 0, 0, 0], + vec![0, 1, 1, 0, 0, 0], + vec![1, 1, 1, 0, 0, 0], + vec![0, 0, 0, 0, 1, 1], + ] + ) + .unwrap(), + vec![1, 3, 2, 0, 2, 1], + ); + } + + #[test] + fn prepare_state_serialization() { + let mut rng = thread_rng(); + let verify_key = rng.gen::<[u8; 32]>(); + let nonce = rng.gen::<[u8; 16]>(); + let data = vec![0, 0, 1, 1, 0]; + let prio2 = Prio2::new(data.len()).unwrap(); + let (public_share, input_shares) = prio2.shard(&data, &nonce).unwrap(); + for (agg_id, input_share) in input_shares.iter().enumerate() { + let (prepare_state, prepare_share) = prio2 + .prepare_init( + &verify_key, + agg_id, + &(), + &[0; 16], + &public_share, + input_share, + ) + .unwrap(); + + let encoded_prepare_state = prepare_state.get_encoded(); + let decoded_prepare_state = Prio2PrepareState::get_decoded_with_param( + &(&prio2, agg_id), + &encoded_prepare_state, + ) + .expect("failed to decode prepare state"); + assert_eq!(decoded_prepare_state, prepare_state); + assert_eq!( + prepare_state.encoded_len().unwrap(), + encoded_prepare_state.len() + ); + + let encoded_prepare_share = prepare_share.get_encoded(); + let decoded_prepare_share = + Prio2PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share) + .expect("failed to decode prepare share"); + assert_eq!(decoded_prepare_share.0.f_r, prepare_share.0.f_r); + assert_eq!(decoded_prepare_share.0.g_r, prepare_share.0.g_r); + assert_eq!(decoded_prepare_share.0.h_r, prepare_share.0.h_r); + assert_eq!( + prepare_share.encoded_len().unwrap(), + encoded_prepare_share.len() + ); + } + } + + #[test] + fn roundtrip_output_share() { + let vdaf = Prio2::new(31).unwrap(); + fieldvec_roundtrip_test::<FieldPrio2, Prio2, OutputShare<FieldPrio2>>(&vdaf, &(), 31); + } + + #[test] + fn roundtrip_aggregate_share() { + let vdaf = Prio2::new(31).unwrap(); + fieldvec_roundtrip_test::<FieldPrio2, Prio2, AggregateShare<FieldPrio2>>(&vdaf, &(), 31); + } + + #[test] + fn priov2_backward_compatibility() { + let test_vector: Priov2TestVector = + serde_json::from_str(include_str!("test_vec/prio2/fieldpriov2.json")).unwrap(); + let vdaf = Prio2::new(test_vector.dimension).unwrap(); + let mut leader_output_shares = Vec::new(); + let mut helper_output_shares = Vec::new(); + for (server_1_share, server_2_share) in test_vector + .server_1_decrypted_shares + .iter() + .zip(&test_vector.server_2_decrypted_shares) + { + let input_share_1 = Share::get_decoded_with_param(&(&vdaf, 0), server_1_share).unwrap(); + let input_share_2 = Share::get_decoded_with_param(&(&vdaf, 1), server_2_share).unwrap(); + let (prepare_state_1, prepare_share_1) = vdaf + .prepare_init(&[0; 32], 0, &(), &[0; 16], &(), &input_share_1) + .unwrap(); + let (prepare_state_2, prepare_share_2) = vdaf + .prepare_init(&[0; 32], 1, &(), &[0; 16], &(), &input_share_2) + .unwrap(); + vdaf.prepare_shares_to_prepare_message(&(), [prepare_share_1, prepare_share_2]) + .unwrap(); + let transition_1 = vdaf.prepare_next(prepare_state_1, ()).unwrap(); + let output_share_1 = + assert_matches!(transition_1, PrepareTransition::Finish(out) => out); + let transition_2 = vdaf.prepare_next(prepare_state_2, ()).unwrap(); + let output_share_2 = + assert_matches!(transition_2, PrepareTransition::Finish(out) => out); + leader_output_shares.push(output_share_1); + helper_output_shares.push(output_share_2); + } + + let leader_aggregate_share = vdaf.aggregate(&(), leader_output_shares).unwrap(); + let helper_aggregate_share = vdaf.aggregate(&(), helper_output_shares).unwrap(); + let aggregate_result = vdaf + .unshard( + &(), + [leader_aggregate_share, helper_aggregate_share], + test_vector.server_1_decrypted_shares.len(), + ) + .unwrap(); + let reconstructed = aggregate_result + .into_iter() + .map(FieldPrio2::from) + .collect::<Vec<_>>(); + + assert_eq!(reconstructed, test_vector.reference_sum); + } + + #[test] + fn prepare_state_equality_test() { + equality_comparison_test(&[ + Prio2PrepareState(Share::Leader(Vec::from([ + FieldPrio2::from(0), + FieldPrio2::from(1), + ]))), + Prio2PrepareState(Share::Leader(Vec::from([ + FieldPrio2::from(1), + FieldPrio2::from(0), + ]))), + Prio2PrepareState(Share::Helper(Seed( + (0..32).collect::<Vec<_>>().try_into().unwrap(), + ))), + Prio2PrepareState(Share::Helper(Seed( + (1..33).collect::<Vec<_>>().try_into().unwrap(), + ))), + ]) + } +} diff --git a/third_party/rust/prio/src/vdaf/prio2/client.rs b/third_party/rust/prio/src/vdaf/prio2/client.rs new file mode 100644 index 0000000000..dbce39ee3f --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio2/client.rs @@ -0,0 +1,306 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Primitives for the Prio2 client. + +use crate::{ + field::{FftFriendlyFieldElement, FieldError}, + polynomial::{poly_fft, PolyAuxMemory}, + prng::{Prng, PrngError}, + vdaf::{xof::SeedStreamAes128, VdafError}, +}; + +use std::convert::TryFrom; + +/// Errors that might be emitted by the client. +#[derive(Debug, thiserror::Error)] +pub(crate) enum ClientError { + /// PRNG error + #[error("prng error: {0}")] + Prng(#[from] PrngError), + /// VDAF error + #[error("vdaf error: {0}")] + Vdaf(#[from] VdafError), + /// failure when calling getrandom(). + #[error("getrandom: {0}")] + GetRandom(#[from] getrandom::Error), +} + +/// Serialization errors +#[derive(Debug, thiserror::Error)] +pub enum SerializeError { + /// Emitted by `unpack_proof[_mut]` if the serialized share+proof has the wrong length + #[error("serialized input has wrong length")] + UnpackInputSizeMismatch, + /// Finite field operation error. + #[error("finite field operation error")] + Field(#[from] FieldError), +} + +#[derive(Debug)] +pub(crate) struct ClientMemory<F> { + prng: Prng<F, SeedStreamAes128>, + points_f: Vec<F>, + points_g: Vec<F>, + evals_f: Vec<F>, + evals_g: Vec<F>, + poly_mem: PolyAuxMemory<F>, +} + +impl<F: FftFriendlyFieldElement> ClientMemory<F> { + pub(crate) fn new(dimension: usize) -> Result<Self, VdafError> { + let n = (dimension + 1).next_power_of_two(); + if let Ok(size) = F::Integer::try_from(2 * n) { + if size > F::generator_order() { + return Err(VdafError::Uncategorized( + "input size exceeds field capacity".into(), + )); + } + } else { + return Err(VdafError::Uncategorized( + "input size exceeds field capacity".into(), + )); + } + + Ok(Self { + prng: Prng::new()?, + points_f: vec![F::zero(); n], + points_g: vec![F::zero(); n], + evals_f: vec![F::zero(); 2 * n], + evals_g: vec![F::zero(); 2 * n], + poly_mem: PolyAuxMemory::new(n), + }) + } +} + +impl<F: FftFriendlyFieldElement> ClientMemory<F> { + pub(crate) fn prove_with<G>(&mut self, dimension: usize, init_function: G) -> Vec<F> + where + G: FnOnce(&mut [F]), + { + let mut proof = vec![F::zero(); proof_length(dimension)]; + // unpack one long vector to different subparts + let unpacked = unpack_proof_mut(&mut proof, dimension).unwrap(); + // initialize the data part + init_function(unpacked.data); + // fill in the rest + construct_proof( + unpacked.data, + dimension, + unpacked.f0, + unpacked.g0, + unpacked.h0, + unpacked.points_h_packed, + self, + ); + + proof + } +} + +/// Returns the number of field elements in the proof for given dimension of +/// data elements +/// +/// Proof is a vector, where the first `dimension` elements are the data +/// elements, the next 3 elements are the zero terms for polynomials f, g and h +/// and the remaining elements are non-zero points of h(x). +pub(crate) fn proof_length(dimension: usize) -> usize { + // number of data items + number of zero terms + N + dimension + 3 + (dimension + 1).next_power_of_two() +} + +/// Unpacked proof with subcomponents +#[derive(Debug)] +pub(crate) struct UnpackedProof<'a, F: FftFriendlyFieldElement> { + /// Data + pub data: &'a [F], + /// Zeroth coefficient of polynomial f + pub f0: &'a F, + /// Zeroth coefficient of polynomial g + pub g0: &'a F, + /// Zeroth coefficient of polynomial h + pub h0: &'a F, + /// Non-zero points of polynomial h + pub points_h_packed: &'a [F], +} + +/// Unpacked proof with mutable subcomponents +#[derive(Debug)] +pub(crate) struct UnpackedProofMut<'a, F: FftFriendlyFieldElement> { + /// Data + pub data: &'a mut [F], + /// Zeroth coefficient of polynomial f + pub f0: &'a mut F, + /// Zeroth coefficient of polynomial g + pub g0: &'a mut F, + /// Zeroth coefficient of polynomial h + pub h0: &'a mut F, + /// Non-zero points of polynomial h + pub points_h_packed: &'a mut [F], +} + +/// Unpacks the proof vector into subcomponents +pub(crate) fn unpack_proof<F: FftFriendlyFieldElement>( + proof: &[F], + dimension: usize, +) -> Result<UnpackedProof<F>, SerializeError> { + // check the proof length + if proof.len() != proof_length(dimension) { + return Err(SerializeError::UnpackInputSizeMismatch); + } + // split share into components + let (data, rest) = proof.split_at(dimension); + if let ([f0, g0, h0], points_h_packed) = rest.split_at(3) { + Ok(UnpackedProof { + data, + f0, + g0, + h0, + points_h_packed, + }) + } else { + Err(SerializeError::UnpackInputSizeMismatch) + } +} + +/// Unpacks a mutable proof vector into mutable subcomponents +pub(crate) fn unpack_proof_mut<F: FftFriendlyFieldElement>( + proof: &mut [F], + dimension: usize, +) -> Result<UnpackedProofMut<F>, SerializeError> { + // check the share length + if proof.len() != proof_length(dimension) { + return Err(SerializeError::UnpackInputSizeMismatch); + } + // split share into components + let (data, rest) = proof.split_at_mut(dimension); + if let ([f0, g0, h0], points_h_packed) = rest.split_at_mut(3) { + Ok(UnpackedProofMut { + data, + f0, + g0, + h0, + points_h_packed, + }) + } else { + Err(SerializeError::UnpackInputSizeMismatch) + } +} + +fn interpolate_and_evaluate_at_2n<F: FftFriendlyFieldElement>( + n: usize, + points_in: &[F], + evals_out: &mut [F], + mem: &mut PolyAuxMemory<F>, +) { + // interpolate through roots of unity + poly_fft( + &mut mem.coeffs, + points_in, + &mem.roots_n_inverted, + n, + true, + &mut mem.fft_memory, + ); + // evaluate at 2N roots of unity + poly_fft( + evals_out, + &mem.coeffs, + &mem.roots_2n, + 2 * n, + false, + &mut mem.fft_memory, + ); +} + +/// Proof construction +/// +/// Based on Theorem 2.3.3 from Henry Corrigan-Gibbs' dissertation +/// This constructs the output \pi by doing the necessesary calculations +fn construct_proof<F: FftFriendlyFieldElement>( + data: &[F], + dimension: usize, + f0: &mut F, + g0: &mut F, + h0: &mut F, + points_h_packed: &mut [F], + mem: &mut ClientMemory<F>, +) { + let n = (dimension + 1).next_power_of_two(); + + // set zero terms to random + *f0 = mem.prng.get(); + *g0 = mem.prng.get(); + mem.points_f[0] = *f0; + mem.points_g[0] = *g0; + + // set zero term for the proof polynomial + *h0 = *f0 * *g0; + + // set f_i = data_(i - 1) + // set g_i = f_i - 1 + for ((f_coeff, g_coeff), data_val) in mem.points_f[1..1 + dimension] + .iter_mut() + .zip(mem.points_g[1..1 + dimension].iter_mut()) + .zip(data[..dimension].iter()) + { + *f_coeff = *data_val; + *g_coeff = *data_val - F::one(); + } + + // interpolate and evaluate at roots of unity + interpolate_and_evaluate_at_2n(n, &mem.points_f, &mut mem.evals_f, &mut mem.poly_mem); + interpolate_and_evaluate_at_2n(n, &mem.points_g, &mut mem.evals_g, &mut mem.poly_mem); + + // calculate the proof polynomial as evals_f(r) * evals_g(r) + // only add non-zero points + let mut j: usize = 0; + let mut i: usize = 1; + while i < 2 * n { + points_h_packed[j] = mem.evals_f[i] * mem.evals_g[i]; + j += 1; + i += 2; + } +} + +#[cfg(test)] +mod tests { + use assert_matches::assert_matches; + + use crate::{ + field::{Field64, FieldPrio2}, + vdaf::prio2::client::{proof_length, unpack_proof, unpack_proof_mut, SerializeError}, + }; + + #[test] + fn test_unpack_share_mut() { + let dim = 15; + let len = proof_length(dim); + + let mut share = vec![FieldPrio2::from(0); len]; + let unpacked = unpack_proof_mut(&mut share, dim).unwrap(); + *unpacked.f0 = FieldPrio2::from(12); + assert_eq!(share[dim], 12); + + let mut short_share = vec![FieldPrio2::from(0); len - 1]; + assert_matches!( + unpack_proof_mut(&mut short_share, dim), + Err(SerializeError::UnpackInputSizeMismatch) + ); + } + + #[test] + fn test_unpack_share() { + let dim = 15; + let len = proof_length(dim); + + let share = vec![Field64::from(0); len]; + unpack_proof(&share, dim).unwrap(); + + let short_share = vec![Field64::from(0); len - 1]; + assert_matches!( + unpack_proof(&short_share, dim), + Err(SerializeError::UnpackInputSizeMismatch) + ); + } +} diff --git a/third_party/rust/prio/src/vdaf/prio2/server.rs b/third_party/rust/prio/src/vdaf/prio2/server.rs new file mode 100644 index 0000000000..11c161babf --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio2/server.rs @@ -0,0 +1,386 @@ +// Copyright (c) 2020 Apple Inc. +// SPDX-License-Identifier: MPL-2.0 + +//! Primitives for the Prio2 server. +use crate::{ + field::{FftFriendlyFieldElement, FieldError}, + polynomial::poly_interpret_eval, + prng::PrngError, + vdaf::prio2::client::{unpack_proof, SerializeError}, +}; +use serde::{Deserialize, Serialize}; + +/// Possible errors from server operations +#[derive(Debug, thiserror::Error)] +pub enum ServerError { + /// Unexpected Share Length + #[allow(unused)] + #[error("unexpected share length")] + ShareLength, + /// Finite field operation error + #[error("finite field operation error")] + Field(#[from] FieldError), + /// Serialization/deserialization error + #[error("serialization/deserialization error")] + Serialize(#[from] SerializeError), + /// Failure when calling getrandom(). + #[error("getrandom: {0}")] + GetRandom(#[from] getrandom::Error), + /// PRNG error. + #[error("prng error: {0}")] + Prng(#[from] PrngError), +} + +/// Verification message for proof validation +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct VerificationMessage<F> { + /// f evaluated at random point + pub f_r: F, + /// g evaluated at random point + pub g_r: F, + /// h evaluated at random point + pub h_r: F, +} + +/// Given a proof and evaluation point, this constructs the verification +/// message. +pub(crate) fn generate_verification_message<F: FftFriendlyFieldElement>( + dimension: usize, + eval_at: F, + proof: &[F], + is_first_server: bool, +) -> Result<VerificationMessage<F>, ServerError> { + let unpacked = unpack_proof(proof, dimension)?; + let n: usize = (dimension + 1).next_power_of_two(); + let proof_length = 2 * n; + let mut fft_in = vec![F::zero(); proof_length]; + let mut fft_mem = vec![F::zero(); proof_length]; + + // construct and evaluate polynomial f at the random point + fft_in[0] = *unpacked.f0; + fft_in[1..unpacked.data.len() + 1].copy_from_slice(unpacked.data); + let f_r = poly_interpret_eval(&fft_in[..n], eval_at, &mut fft_mem); + + // construct and evaluate polynomial g at the random point + fft_in[0] = *unpacked.g0; + if is_first_server { + for x in fft_in[1..unpacked.data.len() + 1].iter_mut() { + *x -= F::one(); + } + } + let g_r = poly_interpret_eval(&fft_in[..n], eval_at, &mut fft_mem); + + // construct and evaluate polynomial h at the random point + fft_in[0] = *unpacked.h0; + fft_in[1] = unpacked.points_h_packed[0]; + for (x, chunk) in unpacked.points_h_packed[1..] + .iter() + .zip(fft_in[2..proof_length].chunks_exact_mut(2)) + { + chunk[0] = F::zero(); + chunk[1] = *x; + } + let h_r = poly_interpret_eval(&fft_in, eval_at, &mut fft_mem); + + Ok(VerificationMessage { f_r, g_r, h_r }) +} + +/// Decides if the distributed proof is valid +pub(crate) fn is_valid_share<F: FftFriendlyFieldElement>( + v1: &VerificationMessage<F>, + v2: &VerificationMessage<F>, +) -> bool { + // reconstruct f_r, g_r, h_r + let f_r = v1.f_r + v2.f_r; + let g_r = v1.g_r + v2.g_r; + let h_r = v1.h_r + v2.h_r; + // validity check + f_r * g_r == h_r +} + +#[cfg(test)] +mod test_util { + use crate::{ + field::{merge_vector, FftFriendlyFieldElement}, + prng::Prng, + vdaf::prio2::client::proof_length, + }; + + use super::{generate_verification_message, is_valid_share, ServerError, VerificationMessage}; + + /// Main workhorse of the server. + #[derive(Debug)] + pub(crate) struct Server<F> { + dimension: usize, + is_first_server: bool, + accumulator: Vec<F>, + } + + impl<F: FftFriendlyFieldElement> Server<F> { + /// Construct a new server instance + /// + /// Params: + /// * `dimension`: the number of elements in the aggregation vector. + /// * `is_first_server`: only one of the servers should have this true. + pub fn new(dimension: usize, is_first_server: bool) -> Result<Server<F>, ServerError> { + Ok(Server { + dimension, + is_first_server, + accumulator: vec![F::zero(); dimension], + }) + } + + /// Deserialize + fn deserialize_share(&self, share: &[u8]) -> Result<Vec<F>, ServerError> { + let len = proof_length(self.dimension); + Ok(if self.is_first_server { + F::byte_slice_into_vec(share)? + } else { + if share.len() != 32 { + return Err(ServerError::ShareLength); + } + + Prng::from_prio2_seed(&share.try_into().unwrap()) + .take(len) + .collect() + }) + } + + /// Generate verification message from an encrypted share + /// + /// This decrypts the share of the proof and constructs the + /// [`VerificationMessage`](struct.VerificationMessage.html). + /// The `eval_at` field should be generate by + /// [choose_eval_at](#method.choose_eval_at). + pub fn generate_verification_message( + &mut self, + eval_at: F, + share: &[u8], + ) -> Result<VerificationMessage<F>, ServerError> { + let share_field = self.deserialize_share(share)?; + generate_verification_message( + self.dimension, + eval_at, + &share_field, + self.is_first_server, + ) + } + + /// Add the content of the encrypted share into the accumulator + /// + /// This only changes the accumulator if the verification messages `v1` and + /// `v2` indicate that the share passed validation. + pub fn aggregate( + &mut self, + share: &[u8], + v1: &VerificationMessage<F>, + v2: &VerificationMessage<F>, + ) -> Result<bool, ServerError> { + let share_field = self.deserialize_share(share)?; + let is_valid = is_valid_share(v1, v2); + if is_valid { + // Add to the accumulator. share_field also includes the proof + // encoding, so we slice off the first dimension fields, which are + // the actual data share. + merge_vector(&mut self.accumulator, &share_field[..self.dimension])?; + } + + Ok(is_valid) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + codec::Encode, + field::{FieldElement, FieldPrio2}, + prng::Prng, + vdaf::{ + prio2::{client::unpack_proof_mut, server::test_util::Server, Prio2}, + Client, + }, + }; + use rand::{random, Rng}; + + fn secret_share(share: &mut [FieldPrio2]) -> Vec<FieldPrio2> { + let mut rng = rand::thread_rng(); + let mut share2 = vec![FieldPrio2::zero(); share.len()]; + for (f1, f2) in share.iter_mut().zip(share2.iter_mut()) { + let f = FieldPrio2::from(rng.gen::<u32>()); + *f2 = f; + *f1 -= f; + } + share2 + } + + #[test] + fn test_validation() { + let dim = 8; + let proof_u32: Vec<u32> = vec![ + 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722, + 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680, + 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149, + ]; + + let mut proof: Vec<FieldPrio2> = proof_u32.iter().map(|x| FieldPrio2::from(*x)).collect(); + let share2 = secret_share(&mut proof); + let eval_at = FieldPrio2::from(12313); + + let v1 = generate_verification_message(dim, eval_at, &proof, true).unwrap(); + let v2 = generate_verification_message(dim, eval_at, &share2, false).unwrap(); + assert!(is_valid_share(&v1, &v2)); + } + + #[test] + fn test_verification_message_serde() { + let dim = 8; + let proof_u32: Vec<u32> = vec![ + 1, 0, 0, 0, 0, 0, 0, 0, 2052337230, 3217065186, 1886032198, 2533724497, 397524722, + 3820138372, 1535223968, 4291254640, 3565670552, 2447741959, 163741941, 335831680, + 2567182742, 3542857140, 124017604, 4201373647, 431621210, 1618555683, 267689149, + ]; + + let mut proof: Vec<FieldPrio2> = proof_u32.iter().map(|x| FieldPrio2::from(*x)).collect(); + let share2 = secret_share(&mut proof); + let eval_at = FieldPrio2::from(12313); + + let v1 = generate_verification_message(dim, eval_at, &proof, true).unwrap(); + let v2 = generate_verification_message(dim, eval_at, &share2, false).unwrap(); + + // serialize and deserialize the first verification message + let serialized = serde_json::to_string(&v1).unwrap(); + let deserialized: VerificationMessage<FieldPrio2> = + serde_json::from_str(&serialized).unwrap(); + + assert!(is_valid_share(&deserialized, &v2)); + } + + #[derive(Debug, Clone, Copy, PartialEq)] + enum Tweak { + None, + WrongInput, + DataPartOfShare, + ZeroTermF, + ZeroTermG, + ZeroTermH, + PointsH, + VerificationF, + VerificationG, + VerificationH, + } + + fn tweaks(tweak: Tweak) { + let dim = 123; + + let mut server1 = Server::<FieldPrio2>::new(dim, true).unwrap(); + let mut server2 = Server::new(dim, false).unwrap(); + + // all zero data + let mut data = vec![0; dim]; + + if let Tweak::WrongInput = tweak { + data[0] = 2; + } + + let vdaf = Prio2::new(dim).unwrap(); + let (_, shares) = vdaf.shard(&data, &[0; 16]).unwrap(); + let share1_original = shares[0].get_encoded(); + let share2 = shares[1].get_encoded(); + + let mut share1_field = FieldPrio2::byte_slice_into_vec(&share1_original).unwrap(); + let unpacked_share1 = unpack_proof_mut(&mut share1_field, dim).unwrap(); + + let one = FieldPrio2::from(1); + + match tweak { + Tweak::DataPartOfShare => unpacked_share1.data[0] += one, + Tweak::ZeroTermF => *unpacked_share1.f0 += one, + Tweak::ZeroTermG => *unpacked_share1.g0 += one, + Tweak::ZeroTermH => *unpacked_share1.h0 += one, + Tweak::PointsH => unpacked_share1.points_h_packed[0] += one, + _ => (), + }; + + // reserialize altered share1 + let share1_modified = FieldPrio2::slice_into_byte_vec(&share1_field); + + let mut prng = Prng::from_prio2_seed(&random()); + let eval_at = vdaf.choose_eval_at(&mut prng); + + let mut v1 = server1 + .generate_verification_message(eval_at, &share1_modified) + .unwrap(); + let v2 = server2 + .generate_verification_message(eval_at, &share2) + .unwrap(); + + match tweak { + Tweak::VerificationF => v1.f_r += one, + Tweak::VerificationG => v1.g_r += one, + Tweak::VerificationH => v1.h_r += one, + _ => (), + } + + let should_be_valid = matches!(tweak, Tweak::None); + assert_eq!( + server1.aggregate(&share1_modified, &v1, &v2).unwrap(), + should_be_valid + ); + assert_eq!( + server2.aggregate(&share2, &v1, &v2).unwrap(), + should_be_valid + ); + } + + #[test] + fn tweak_none() { + tweaks(Tweak::None); + } + + #[test] + fn tweak_input() { + tweaks(Tweak::WrongInput); + } + + #[test] + fn tweak_data() { + tweaks(Tweak::DataPartOfShare); + } + + #[test] + fn tweak_f_zero() { + tweaks(Tweak::ZeroTermF); + } + + #[test] + fn tweak_g_zero() { + tweaks(Tweak::ZeroTermG); + } + + #[test] + fn tweak_h_zero() { + tweaks(Tweak::ZeroTermH); + } + + #[test] + fn tweak_h_points() { + tweaks(Tweak::PointsH); + } + + #[test] + fn tweak_f_verif() { + tweaks(Tweak::VerificationF); + } + + #[test] + fn tweak_g_verif() { + tweaks(Tweak::VerificationG); + } + + #[test] + fn tweak_h_verif() { + tweaks(Tweak::VerificationH); + } +} diff --git a/third_party/rust/prio/src/vdaf/prio2/test_vector.rs b/third_party/rust/prio/src/vdaf/prio2/test_vector.rs new file mode 100644 index 0000000000..ae2b8b0f9d --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio2/test_vector.rs @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Test vectors of serialized Prio inputs, enabling backward compatibility testing. + +use crate::{field::FieldPrio2, vdaf::prio2::client::ClientError}; +use serde::{Deserialize, Serialize}; +use std::fmt::Debug; + +/// Errors propagated by functions in this module. +#[derive(Debug, thiserror::Error)] +pub(crate) enum TestVectorError { + /// Error from Prio client + #[error("Prio client error {0}")] + Client(#[from] ClientError), +} + +/// A test vector of serialized Priov2 inputs, along with a reference sum. The field is always +/// [`FieldPrio2`]. +#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] +pub(crate) struct Priov2TestVector { + /// Dimension (number of buckets) of the inputs + pub dimension: usize, + /// Decrypted shares of Priov2 format inputs for the "first" a.k.a. "PHA" + /// server. The inner `Vec`s are encrypted bytes. + #[serde( + serialize_with = "base64::serialize_bytes", + deserialize_with = "base64::deserialize_bytes" + )] + pub server_1_decrypted_shares: Vec<Vec<u8>>, + /// Decrypted share of Priov2 format inputs for the non-"first" a.k.a. + /// "facilitator" server. + #[serde( + serialize_with = "base64::serialize_bytes", + deserialize_with = "base64::deserialize_bytes" + )] + pub server_2_decrypted_shares: Vec<Vec<u8>>, + /// The sum over the inputs. + #[serde( + serialize_with = "base64::serialize_field", + deserialize_with = "base64::deserialize_field" + )] + pub reference_sum: Vec<FieldPrio2>, +} + +mod base64 { + //! Custom serialization module used for some members of struct + //! `Priov2TestVector` so that byte slices are serialized as base64 strings + //! instead of an array of an array of integers when serializing to JSON. + // + // Thank you, Alice! https://users.rust-lang.org/t/serialize-a-vec-u8-to-json-as-base64/57781/2 + use crate::field::{FieldElement, FieldPrio2}; + use base64::{engine::Engine, prelude::BASE64_STANDARD}; + use serde::{de::Error, Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize_bytes<S: Serializer>(v: &[Vec<u8>], s: S) -> Result<S::Ok, S::Error> { + let base64_vec = v + .iter() + .map(|bytes| BASE64_STANDARD.encode(bytes)) + .collect(); + <Vec<String>>::serialize(&base64_vec, s) + } + + pub fn deserialize_bytes<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<Vec<u8>>, D::Error> { + <Vec<String>>::deserialize(d)? + .iter() + .map(|s| BASE64_STANDARD.decode(s.as_bytes()).map_err(Error::custom)) + .collect() + } + + pub fn serialize_field<S: Serializer>(v: &[FieldPrio2], s: S) -> Result<S::Ok, S::Error> { + String::serialize( + &BASE64_STANDARD.encode(FieldPrio2::slice_into_byte_vec(v)), + s, + ) + } + + pub fn deserialize_field<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<FieldPrio2>, D::Error> { + let bytes = BASE64_STANDARD + .decode(String::deserialize(d)?.as_bytes()) + .map_err(Error::custom)?; + FieldPrio2::byte_slice_into_vec(&bytes).map_err(Error::custom) + } +} diff --git a/third_party/rust/prio/src/vdaf/prio3.rs b/third_party/rust/prio/src/vdaf/prio3.rs new file mode 100644 index 0000000000..4a7cdefb84 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio3.rs @@ -0,0 +1,2127 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementation of the Prio3 VDAF [[draft-irtf-cfrg-vdaf-07]]. +//! +//! **WARNING:** This code has not undergone significant security analysis. Use at your own risk. +//! +//! Prio3 is based on the Prio system desigend by Dan Boneh and Henry Corrigan-Gibbs and presented +//! at NSDI 2017 [[CGB17]]. However, it incorporates a few techniques from Boneh et al., CRYPTO +//! 2019 [[BBCG+19]], that lead to substantial improvements in terms of run time and communication +//! cost. The security of the construction was analyzed in [[DPRS23]]. +//! +//! Prio3 is a transformation of a Fully Linear Proof (FLP) system [[draft-irtf-cfrg-vdaf-07]] into +//! a VDAF. The base type, [`Prio3`], supports a wide variety of aggregation functions, some of +//! which are instantiated here: +//! +//! - [`Prio3Count`] for aggregating a counter (*) +//! - [`Prio3Sum`] for copmputing the sum of integers (*) +//! - [`Prio3SumVec`] for aggregating a vector of integers +//! - [`Prio3Histogram`] for estimating a distribution via a histogram (*) +//! +//! Additional types can be constructed from [`Prio3`] as needed. +//! +//! (*) denotes that the type is specified in [[draft-irtf-cfrg-vdaf-07]]. +//! +//! [BBCG+19]: https://ia.cr/2019/188 +//! [CGB17]: https://crypto.stanford.edu/prio/ +//! [DPRS23]: https://ia.cr/2023/130 +//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ + +use super::xof::XofShake128; +#[cfg(feature = "experimental")] +use super::AggregatorWithNoise; +use crate::codec::{CodecError, Decode, Encode, ParameterizedDecode}; +#[cfg(feature = "experimental")] +use crate::dp::DifferentialPrivacyStrategy; +use crate::field::{decode_fieldvec, FftFriendlyFieldElement, FieldElement}; +use crate::field::{Field128, Field64}; +#[cfg(feature = "multithreaded")] +use crate::flp::gadgets::ParallelSumMultithreaded; +#[cfg(feature = "experimental")] +use crate::flp::gadgets::PolyEval; +use crate::flp::gadgets::{Mul, ParallelSum}; +#[cfg(feature = "experimental")] +use crate::flp::types::fixedpoint_l2::{ + compatible_float::CompatibleFloat, FixedPointBoundedL2VecSum, +}; +use crate::flp::types::{Average, Count, Histogram, Sum, SumVec}; +use crate::flp::Type; +#[cfg(feature = "experimental")] +use crate::flp::TypeWithNoise; +use crate::prng::Prng; +use crate::vdaf::xof::{IntoFieldVec, Seed, Xof}; +use crate::vdaf::{ + Aggregatable, AggregateShare, Aggregator, Client, Collector, OutputShare, PrepareTransition, + Share, ShareDecodingParameter, Vdaf, VdafError, +}; +#[cfg(feature = "experimental")] +use fixed::traits::Fixed; +use std::convert::TryFrom; +use std::fmt::Debug; +use std::io::Cursor; +use std::iter::{self, IntoIterator}; +use std::marker::PhantomData; +use subtle::{Choice, ConstantTimeEq}; + +const DST_MEASUREMENT_SHARE: u16 = 1; +const DST_PROOF_SHARE: u16 = 2; +const DST_JOINT_RANDOMNESS: u16 = 3; +const DST_PROVE_RANDOMNESS: u16 = 4; +const DST_QUERY_RANDOMNESS: u16 = 5; +const DST_JOINT_RAND_SEED: u16 = 6; +const DST_JOINT_RAND_PART: u16 = 7; + +/// The count type. Each measurement is an integer in `[0,2)` and the aggregate result is the sum. +pub type Prio3Count = Prio3<Count<Field64>, XofShake128, 16>; + +impl Prio3Count { + /// Construct an instance of Prio3Count with the given number of aggregators. + pub fn new_count(num_aggregators: u8) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, Count::new()) + } +} + +/// The count-vector type. Each measurement is a vector of integers in `[0,2^bits)` and the +/// aggregate is the element-wise sum. +pub type Prio3SumVec = + Prio3<SumVec<Field128, ParallelSum<Field128, Mul<Field128>>>, XofShake128, 16>; + +impl Prio3SumVec { + /// Construct an instance of Prio3SumVec with the given number of aggregators. `bits` defines + /// the bit width of each summand of the measurement; `len` defines the length of the + /// measurement vector. + pub fn new_sum_vec( + num_aggregators: u8, + bits: usize, + len: usize, + chunk_length: usize, + ) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, SumVec::new(bits, len, chunk_length)?) + } +} + +/// Like [`Prio3SumVec`] except this type uses multithreading to improve sharding and preparation +/// time. Note that the improvement is only noticeable for very large input lengths. +#[cfg(feature = "multithreaded")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +pub type Prio3SumVecMultithreaded = + Prio3<SumVec<Field128, ParallelSumMultithreaded<Field128, Mul<Field128>>>, XofShake128, 16>; + +#[cfg(feature = "multithreaded")] +impl Prio3SumVecMultithreaded { + /// Construct an instance of Prio3SumVecMultithreaded with the given number of + /// aggregators. `bits` defines the bit width of each summand of the measurement; `len` defines + /// the length of the measurement vector. + pub fn new_sum_vec_multithreaded( + num_aggregators: u8, + bits: usize, + len: usize, + chunk_length: usize, + ) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, SumVec::new(bits, len, chunk_length)?) + } +} + +/// The sum type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and the +/// aggregate is the sum. +pub type Prio3Sum = Prio3<Sum<Field128>, XofShake128, 16>; + +impl Prio3Sum { + /// Construct an instance of Prio3Sum with the given number of aggregators and required bit + /// length. The bit length must not exceed 64. + pub fn new_sum(num_aggregators: u8, bits: usize) -> Result<Self, VdafError> { + if bits > 64 { + return Err(VdafError::Uncategorized(format!( + "bit length ({bits}) exceeds limit for aggregate type (64)" + ))); + } + + Prio3::new(num_aggregators, Sum::new(bits)?) + } +} + +/// The fixed point vector sum type. Each measurement is a vector of fixed point numbers +/// and the aggregate is the sum represented as 64-bit floats. The preparation phase +/// ensures the L2 norm of the input vector is < 1. +/// +/// This is useful for aggregating gradients in a federated version of +/// [gradient descent](https://en.wikipedia.org/wiki/Gradient_descent) with +/// [differential privacy](https://en.wikipedia.org/wiki/Differential_privacy), +/// useful, e.g., for [differentially private deep learning](https://arxiv.org/pdf/1607.00133.pdf). +/// The bound on input norms is required for differential privacy. The fixed point representation +/// allows an easy conversion to the integer type used in internal computation, while leaving +/// conversion to the client. The model itself will have floating point parameters, so the output +/// sum has that type as well. +#[cfg(feature = "experimental")] +#[cfg_attr(docsrs, doc(cfg(feature = "experimental")))] +pub type Prio3FixedPointBoundedL2VecSum<Fx> = Prio3< + FixedPointBoundedL2VecSum< + Fx, + ParallelSum<Field128, PolyEval<Field128>>, + ParallelSum<Field128, Mul<Field128>>, + >, + XofShake128, + 16, +>; + +#[cfg(feature = "experimental")] +impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSum<Fx> { + /// Construct an instance of this VDAF with the given number of aggregators and number of + /// vector entries. + pub fn new_fixedpoint_boundedl2_vec_sum( + num_aggregators: u8, + entries: usize, + ) -> Result<Self, VdafError> { + check_num_aggregators(num_aggregators)?; + Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?) + } +} + +/// The fixed point vector sum type. Each measurement is a vector of fixed point numbers +/// and the aggregate is the sum represented as 64-bit floats. The verification function +/// ensures the L2 norm of the input vector is < 1. +#[cfg(all(feature = "experimental", feature = "multithreaded"))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "experimental", feature = "multithreaded"))) +)] +pub type Prio3FixedPointBoundedL2VecSumMultithreaded<Fx> = Prio3< + FixedPointBoundedL2VecSum< + Fx, + ParallelSumMultithreaded<Field128, PolyEval<Field128>>, + ParallelSumMultithreaded<Field128, Mul<Field128>>, + >, + XofShake128, + 16, +>; + +#[cfg(all(feature = "experimental", feature = "multithreaded"))] +impl<Fx: Fixed + CompatibleFloat> Prio3FixedPointBoundedL2VecSumMultithreaded<Fx> { + /// Construct an instance of this VDAF with the given number of aggregators and number of + /// vector entries. + pub fn new_fixedpoint_boundedl2_vec_sum_multithreaded( + num_aggregators: u8, + entries: usize, + ) -> Result<Self, VdafError> { + check_num_aggregators(num_aggregators)?; + Prio3::new(num_aggregators, FixedPointBoundedL2VecSum::new(entries)?) + } +} + +/// The histogram type. Each measurement is an integer in `[0, length)` and the result is a +/// histogram counting the number of occurrences of each measurement. +pub type Prio3Histogram = + Prio3<Histogram<Field128, ParallelSum<Field128, Mul<Field128>>>, XofShake128, 16>; + +impl Prio3Histogram { + /// Constructs an instance of Prio3Histogram with the given number of aggregators, + /// number of buckets, and parallel sum gadget chunk length. + pub fn new_histogram( + num_aggregators: u8, + length: usize, + chunk_length: usize, + ) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, Histogram::new(length, chunk_length)?) + } +} + +/// Like [`Prio3Histogram`] except this type uses multithreading to improve sharding and preparation +/// time. Note that this improvement is only noticeable for very large input lengths. +#[cfg(feature = "multithreaded")] +#[cfg_attr(docsrs, doc(cfg(feature = "multithreaded")))] +pub type Prio3HistogramMultithreaded = + Prio3<Histogram<Field128, ParallelSumMultithreaded<Field128, Mul<Field128>>>, XofShake128, 16>; + +#[cfg(feature = "multithreaded")] +impl Prio3HistogramMultithreaded { + /// Construct an instance of Prio3HistogramMultithreaded with the given number of aggregators, + /// number of buckets, and parallel sum gadget chunk length. + pub fn new_histogram_multithreaded( + num_aggregators: u8, + length: usize, + chunk_length: usize, + ) -> Result<Self, VdafError> { + Prio3::new(num_aggregators, Histogram::new(length, chunk_length)?) + } +} + +/// The average type. Each measurement is an integer in `[0,2^bits)` for some `0 < bits < 64` and +/// the aggregate is the arithmetic average. +pub type Prio3Average = Prio3<Average<Field128>, XofShake128, 16>; + +impl Prio3Average { + /// Construct an instance of Prio3Average with the given number of aggregators and required bit + /// length. The bit length must not exceed 64. + pub fn new_average(num_aggregators: u8, bits: usize) -> Result<Self, VdafError> { + check_num_aggregators(num_aggregators)?; + + if bits > 64 { + return Err(VdafError::Uncategorized(format!( + "bit length ({bits}) exceeds limit for aggregate type (64)" + ))); + } + + Ok(Prio3 { + num_aggregators, + typ: Average::new(bits)?, + phantom: PhantomData, + }) + } +} + +/// The base type for Prio3. +/// +/// An instance of Prio3 is determined by: +/// +/// - a [`Type`] that defines the set of valid input measurements; and +/// - a [`Xof`] for deriving vectors of field elements from seeds. +/// +/// New instances can be defined by aliasing the base type. For example, [`Prio3Count`] is an alias +/// for `Prio3<Count<Field64>, XofShake128, 16>`. +/// +/// ``` +/// use prio::vdaf::{ +/// Aggregator, Client, Collector, PrepareTransition, +/// prio3::Prio3, +/// }; +/// use rand::prelude::*; +/// +/// let num_shares = 2; +/// let vdaf = Prio3::new_count(num_shares).unwrap(); +/// +/// let mut out_shares = vec![vec![]; num_shares.into()]; +/// let mut rng = thread_rng(); +/// let verify_key = rng.gen(); +/// let measurements = [0, 1, 1, 1, 0]; +/// for measurement in measurements { +/// // Shard +/// let nonce = rng.gen::<[u8; 16]>(); +/// let (public_share, input_shares) = vdaf.shard(&measurement, &nonce).unwrap(); +/// +/// // Prepare +/// let mut prep_states = vec![]; +/// let mut prep_shares = vec![]; +/// for (agg_id, input_share) in input_shares.iter().enumerate() { +/// let (state, share) = vdaf.prepare_init( +/// &verify_key, +/// agg_id, +/// &(), +/// &nonce, +/// &public_share, +/// input_share +/// ).unwrap(); +/// prep_states.push(state); +/// prep_shares.push(share); +/// } +/// let prep_msg = vdaf.prepare_shares_to_prepare_message(&(), prep_shares).unwrap(); +/// +/// for (agg_id, state) in prep_states.into_iter().enumerate() { +/// let out_share = match vdaf.prepare_step(state, prep_msg.clone()).unwrap() { +/// PrepareTransition::Finish(out_share) => out_share, +/// _ => panic!("unexpected transition"), +/// }; +/// out_shares[agg_id].push(out_share); +/// } +/// } +/// +/// // Aggregate +/// let agg_shares = out_shares.into_iter() +/// .map(|o| vdaf.aggregate(&(), o).unwrap()); +/// +/// // Unshard +/// let agg_res = vdaf.unshard(&(), agg_shares, measurements.len()).unwrap(); +/// assert_eq!(agg_res, 3); +/// ``` +#[derive(Clone, Debug)] +pub struct Prio3<T, P, const SEED_SIZE: usize> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + num_aggregators: u8, + typ: T, + phantom: PhantomData<P>, +} + +impl<T, P, const SEED_SIZE: usize> Prio3<T, P, SEED_SIZE> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + /// Construct an instance of this Prio3 VDAF with the given number of aggregators and the + /// underlying type. + pub fn new(num_aggregators: u8, typ: T) -> Result<Self, VdafError> { + check_num_aggregators(num_aggregators)?; + Ok(Self { + num_aggregators, + typ, + phantom: PhantomData, + }) + } + + /// The output length of the underlying FLP. + pub fn output_len(&self) -> usize { + self.typ.output_len() + } + + /// The verifier length of the underlying FLP. + pub fn verifier_len(&self) -> usize { + self.typ.verifier_len() + } + + fn derive_joint_rand_seed<'a>( + parts: impl Iterator<Item = &'a Seed<SEED_SIZE>>, + ) -> Seed<SEED_SIZE> { + let mut xof = P::init( + &[0; SEED_SIZE], + &Self::domain_separation_tag(DST_JOINT_RAND_SEED), + ); + for part in parts { + xof.update(part.as_ref()); + } + xof.into_seed() + } + + fn random_size(&self) -> usize { + if self.typ.joint_rand_len() == 0 { + // Two seeds per helper for measurement and proof shares, plus one seed for proving + // randomness. + (usize::from(self.num_aggregators - 1) * 2 + 1) * SEED_SIZE + } else { + ( + // Two seeds per helper for measurement and proof shares + usize::from(self.num_aggregators - 1) * 2 + // One seed for proving randomness + + 1 + // One seed per aggregator for joint randomness blinds + + usize::from(self.num_aggregators) + ) * SEED_SIZE + } + } + + #[allow(clippy::type_complexity)] + pub(crate) fn shard_with_random<const N: usize>( + &self, + measurement: &T::Measurement, + nonce: &[u8; N], + random: &[u8], + ) -> Result< + ( + Prio3PublicShare<SEED_SIZE>, + Vec<Prio3InputShare<T::Field, SEED_SIZE>>, + ), + VdafError, + > { + if random.len() != self.random_size() { + return Err(VdafError::Uncategorized( + "incorrect random input length".to_string(), + )); + } + let mut random_seeds = random.chunks_exact(SEED_SIZE); + let num_aggregators = self.num_aggregators; + let encoded_measurement = self.typ.encode_measurement(measurement)?; + + // Generate the measurement shares and compute the joint randomness. + let mut helper_shares = Vec::with_capacity(num_aggregators as usize - 1); + let mut helper_joint_rand_parts = if self.typ.joint_rand_len() > 0 { + Some(Vec::with_capacity(num_aggregators as usize - 1)) + } else { + None + }; + let mut leader_measurement_share = encoded_measurement.clone(); + for agg_id in 1..num_aggregators { + // The Option from the ChunksExact iterator is okay to unwrap because we checked that + // the randomness slice is long enough for this VDAF. The slice-to-array conversion + // Result is okay to unwrap because the ChunksExact iterator always returns slices of + // the correct length. + let measurement_share_seed = random_seeds.next().unwrap().try_into().unwrap(); + let proof_share_seed = random_seeds.next().unwrap().try_into().unwrap(); + let measurement_share_prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream( + &Seed(measurement_share_seed), + &Self::domain_separation_tag(DST_MEASUREMENT_SHARE), + &[agg_id], + )); + let joint_rand_blind = + if let Some(helper_joint_rand_parts) = helper_joint_rand_parts.as_mut() { + let joint_rand_blind = random_seeds.next().unwrap().try_into().unwrap(); + let mut joint_rand_part_xof = P::init( + &joint_rand_blind, + &Self::domain_separation_tag(DST_JOINT_RAND_PART), + ); + joint_rand_part_xof.update(&[agg_id]); // Aggregator ID + joint_rand_part_xof.update(nonce); + + let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); + for (x, y) in leader_measurement_share + .iter_mut() + .zip(measurement_share_prng) + { + *x -= y; + y.encode(&mut encoding_buffer); + joint_rand_part_xof.update(&encoding_buffer); + encoding_buffer.clear(); + } + + helper_joint_rand_parts.push(joint_rand_part_xof.into_seed()); + + Some(joint_rand_blind) + } else { + for (x, y) in leader_measurement_share + .iter_mut() + .zip(measurement_share_prng) + { + *x -= y; + } + None + }; + let helper = + HelperShare::from_seeds(measurement_share_seed, proof_share_seed, joint_rand_blind); + helper_shares.push(helper); + } + + let mut leader_blind_opt = None; + let public_share = Prio3PublicShare { + joint_rand_parts: helper_joint_rand_parts + .as_ref() + .map(|helper_joint_rand_parts| { + let leader_blind_bytes = random_seeds.next().unwrap().try_into().unwrap(); + let leader_blind = Seed::from_bytes(leader_blind_bytes); + + let mut joint_rand_part_xof = P::init( + leader_blind.as_ref(), + &Self::domain_separation_tag(DST_JOINT_RAND_PART), + ); + joint_rand_part_xof.update(&[0]); // Aggregator ID + joint_rand_part_xof.update(nonce); + let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); + for x in leader_measurement_share.iter() { + x.encode(&mut encoding_buffer); + joint_rand_part_xof.update(&encoding_buffer); + encoding_buffer.clear(); + } + leader_blind_opt = Some(leader_blind); + + let leader_joint_rand_seed_part = joint_rand_part_xof.into_seed(); + + let mut vec = Vec::with_capacity(self.num_aggregators()); + vec.push(leader_joint_rand_seed_part); + vec.extend(helper_joint_rand_parts.iter().cloned()); + vec + }), + }; + + // Compute the joint randomness. + let joint_rand: Vec<T::Field> = public_share + .joint_rand_parts + .as_ref() + .map(|joint_rand_parts| { + let joint_rand_seed = Self::derive_joint_rand_seed(joint_rand_parts.iter()); + P::seed_stream( + &joint_rand_seed, + &Self::domain_separation_tag(DST_JOINT_RANDOMNESS), + &[], + ) + .into_field_vec(self.typ.joint_rand_len()) + }) + .unwrap_or_default(); + + // Run the proof-generation algorithm. + let prove_rand_seed = random_seeds.next().unwrap().try_into().unwrap(); + let prove_rand = P::seed_stream( + &Seed::from_bytes(prove_rand_seed), + &Self::domain_separation_tag(DST_PROVE_RANDOMNESS), + &[], + ) + .into_field_vec(self.typ.prove_rand_len()); + let mut leader_proof_share = + self.typ + .prove(&encoded_measurement, &prove_rand, &joint_rand)?; + + // Generate the proof shares and distribute the joint randomness seed hints. + for (j, helper) in helper_shares.iter_mut().enumerate() { + let proof_share_prng: Prng<T::Field, _> = Prng::from_seed_stream(P::seed_stream( + &helper.proof_share, + &Self::domain_separation_tag(DST_PROOF_SHARE), + &[j as u8 + 1], + )); + for (x, y) in leader_proof_share + .iter_mut() + .zip(proof_share_prng) + .take(self.typ.proof_len()) + { + *x -= y; + } + } + + // Prep the output messages. + let mut out = Vec::with_capacity(num_aggregators as usize); + out.push(Prio3InputShare { + measurement_share: Share::Leader(leader_measurement_share), + proof_share: Share::Leader(leader_proof_share), + joint_rand_blind: leader_blind_opt, + }); + + for helper in helper_shares.into_iter() { + out.push(Prio3InputShare { + measurement_share: Share::Helper(helper.measurement_share), + proof_share: Share::Helper(helper.proof_share), + joint_rand_blind: helper.joint_rand_blind, + }); + } + + Ok((public_share, out)) + } + + fn role_try_from(&self, agg_id: usize) -> Result<u8, VdafError> { + if agg_id >= self.num_aggregators as usize { + return Err(VdafError::Uncategorized("unexpected aggregator id".into())); + } + Ok(u8::try_from(agg_id).unwrap()) + } +} + +impl<T, P, const SEED_SIZE: usize> Vdaf for Prio3<T, P, SEED_SIZE> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + const ID: u32 = T::ID; + type Measurement = T::Measurement; + type AggregateResult = T::AggregateResult; + type AggregationParam = (); + type PublicShare = Prio3PublicShare<SEED_SIZE>; + type InputShare = Prio3InputShare<T::Field, SEED_SIZE>; + type OutputShare = OutputShare<T::Field>; + type AggregateShare = AggregateShare<T::Field>; + + fn num_aggregators(&self) -> usize { + self.num_aggregators as usize + } +} + +/// Message broadcast by the [`Client`] to every [`Aggregator`] during the Sharding phase. +#[derive(Clone, Debug)] +pub struct Prio3PublicShare<const SEED_SIZE: usize> { + /// Contributions to the joint randomness from every aggregator's share. + joint_rand_parts: Option<Vec<Seed<SEED_SIZE>>>, +} + +impl<const SEED_SIZE: usize> Encode for Prio3PublicShare<SEED_SIZE> { + fn encode(&self, bytes: &mut Vec<u8>) { + if let Some(joint_rand_parts) = self.joint_rand_parts.as_ref() { + for part in joint_rand_parts.iter() { + part.encode(bytes); + } + } + } + + fn encoded_len(&self) -> Option<usize> { + if let Some(joint_rand_parts) = self.joint_rand_parts.as_ref() { + // Each seed has the same size. + Some(SEED_SIZE * joint_rand_parts.len()) + } else { + Some(0) + } + } +} + +impl<const SEED_SIZE: usize> PartialEq for Prio3PublicShare<SEED_SIZE> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<const SEED_SIZE: usize> Eq for Prio3PublicShare<SEED_SIZE> {} + +impl<const SEED_SIZE: usize> ConstantTimeEq for Prio3PublicShare<SEED_SIZE> { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_parts. + option_ct_eq( + self.joint_rand_parts.as_deref(), + other.joint_rand_parts.as_deref(), + ) + } +} + +impl<T, P, const SEED_SIZE: usize> ParameterizedDecode<Prio3<T, P, SEED_SIZE>> + for Prio3PublicShare<SEED_SIZE> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + fn decode_with_param( + decoding_parameter: &Prio3<T, P, SEED_SIZE>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + if decoding_parameter.typ.joint_rand_len() > 0 { + let joint_rand_parts = iter::repeat_with(|| Seed::<SEED_SIZE>::decode(bytes)) + .take(decoding_parameter.num_aggregators.into()) + .collect::<Result<Vec<_>, _>>()?; + Ok(Self { + joint_rand_parts: Some(joint_rand_parts), + }) + } else { + Ok(Self { + joint_rand_parts: None, + }) + } + } +} + +/// Message sent by the [`Client`] to each [`Aggregator`] during the Sharding phase. +#[derive(Clone, Debug)] +pub struct Prio3InputShare<F, const SEED_SIZE: usize> { + /// The measurement share. + measurement_share: Share<F, SEED_SIZE>, + + /// The proof share. + proof_share: Share<F, SEED_SIZE>, + + /// Blinding seed used by the Aggregator to compute the joint randomness. This field is optional + /// because not every [`Type`] requires joint randomness. + joint_rand_blind: Option<Seed<SEED_SIZE>>, +} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3InputShare<F, SEED_SIZE> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3InputShare<F, SEED_SIZE> {} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3InputShare<F, SEED_SIZE> { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_blind. + option_ct_eq( + self.joint_rand_blind.as_ref(), + other.joint_rand_blind.as_ref(), + ) & self.measurement_share.ct_eq(&other.measurement_share) + & self.proof_share.ct_eq(&other.proof_share) + } +} + +impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode for Prio3InputShare<F, SEED_SIZE> { + fn encode(&self, bytes: &mut Vec<u8>) { + if matches!( + (&self.measurement_share, &self.proof_share), + (Share::Leader(_), Share::Helper(_)) | (Share::Helper(_), Share::Leader(_)) + ) { + panic!("tried to encode input share with ambiguous encoding") + } + + self.measurement_share.encode(bytes); + self.proof_share.encode(bytes); + if let Some(ref blind) = self.joint_rand_blind { + blind.encode(bytes); + } + } + + fn encoded_len(&self) -> Option<usize> { + let mut len = self.measurement_share.encoded_len()? + self.proof_share.encoded_len()?; + if let Some(ref blind) = self.joint_rand_blind { + len += blind.encoded_len()?; + } + Some(len) + } +} + +impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, usize)> + for Prio3InputShare<T::Field, SEED_SIZE> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + fn decode_with_param( + (prio3, agg_id): &(&'a Prio3<T, P, SEED_SIZE>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let agg_id = prio3 + .role_try_from(*agg_id) + .map_err(|e| CodecError::Other(Box::new(e)))?; + let (input_decoder, proof_decoder) = if agg_id == 0 { + ( + ShareDecodingParameter::Leader(prio3.typ.input_len()), + ShareDecodingParameter::Leader(prio3.typ.proof_len()), + ) + } else { + ( + ShareDecodingParameter::Helper, + ShareDecodingParameter::Helper, + ) + }; + + let measurement_share = Share::decode_with_param(&input_decoder, bytes)?; + let proof_share = Share::decode_with_param(&proof_decoder, bytes)?; + let joint_rand_blind = if prio3.typ.joint_rand_len() > 0 { + let blind = Seed::decode(bytes)?; + Some(blind) + } else { + None + }; + + Ok(Prio3InputShare { + measurement_share, + proof_share, + joint_rand_blind, + }) + } +} + +#[derive(Clone, Debug)] +/// Message broadcast by each [`Aggregator`] in each round of the Preparation phase. +pub struct Prio3PrepareShare<F, const SEED_SIZE: usize> { + /// A share of the FLP verifier message. (See [`Type`].) + verifier: Vec<F>, + + /// A part of the joint randomness seed. + joint_rand_part: Option<Seed<SEED_SIZE>>, +} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3PrepareShare<F, SEED_SIZE> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3PrepareShare<F, SEED_SIZE> {} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareShare<F, SEED_SIZE> { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_part. + option_ct_eq( + self.joint_rand_part.as_ref(), + other.joint_rand_part.as_ref(), + ) & self.verifier.ct_eq(&other.verifier) + } +} + +impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode + for Prio3PrepareShare<F, SEED_SIZE> +{ + fn encode(&self, bytes: &mut Vec<u8>) { + for x in &self.verifier { + x.encode(bytes); + } + if let Some(ref seed) = self.joint_rand_part { + seed.encode(bytes); + } + } + + fn encoded_len(&self) -> Option<usize> { + // Each element of the verifier has the same size. + let mut len = F::ENCODED_SIZE * self.verifier.len(); + if let Some(ref seed) = self.joint_rand_part { + len += seed.encoded_len()?; + } + Some(len) + } +} + +impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> + ParameterizedDecode<Prio3PrepareState<F, SEED_SIZE>> for Prio3PrepareShare<F, SEED_SIZE> +{ + fn decode_with_param( + decoding_parameter: &Prio3PrepareState<F, SEED_SIZE>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let mut verifier = Vec::with_capacity(decoding_parameter.verifier_len); + for _ in 0..decoding_parameter.verifier_len { + verifier.push(F::decode(bytes)?); + } + + let joint_rand_part = if decoding_parameter.joint_rand_seed.is_some() { + Some(Seed::decode(bytes)?) + } else { + None + }; + + Ok(Prio3PrepareShare { + verifier, + joint_rand_part, + }) + } +} + +#[derive(Clone, Debug)] +/// Result of combining a round of [`Prio3PrepareShare`] messages. +pub struct Prio3PrepareMessage<const SEED_SIZE: usize> { + /// The joint randomness seed computed by the Aggregators. + joint_rand_seed: Option<Seed<SEED_SIZE>>, +} + +impl<const SEED_SIZE: usize> PartialEq for Prio3PrepareMessage<SEED_SIZE> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<const SEED_SIZE: usize> Eq for Prio3PrepareMessage<SEED_SIZE> {} + +impl<const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareMessage<SEED_SIZE> { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presnce or absence of the joint_rand_seed. + option_ct_eq( + self.joint_rand_seed.as_ref(), + other.joint_rand_seed.as_ref(), + ) + } +} + +impl<const SEED_SIZE: usize> Encode for Prio3PrepareMessage<SEED_SIZE> { + fn encode(&self, bytes: &mut Vec<u8>) { + if let Some(ref seed) = self.joint_rand_seed { + seed.encode(bytes); + } + } + + fn encoded_len(&self) -> Option<usize> { + if let Some(ref seed) = self.joint_rand_seed { + seed.encoded_len() + } else { + Some(0) + } + } +} + +impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> + ParameterizedDecode<Prio3PrepareState<F, SEED_SIZE>> for Prio3PrepareMessage<SEED_SIZE> +{ + fn decode_with_param( + decoding_parameter: &Prio3PrepareState<F, SEED_SIZE>, + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let joint_rand_seed = if decoding_parameter.joint_rand_seed.is_some() { + Some(Seed::decode(bytes)?) + } else { + None + }; + + Ok(Prio3PrepareMessage { joint_rand_seed }) + } +} + +impl<T, P, const SEED_SIZE: usize> Client<16> for Prio3<T, P, SEED_SIZE> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + #[allow(clippy::type_complexity)] + fn shard( + &self, + measurement: &T::Measurement, + nonce: &[u8; 16], + ) -> Result<(Self::PublicShare, Vec<Prio3InputShare<T::Field, SEED_SIZE>>), VdafError> { + let mut random = vec![0u8; self.random_size()]; + getrandom::getrandom(&mut random)?; + self.shard_with_random(measurement, nonce, &random) + } +} + +/// State of each [`Aggregator`] during the Preparation phase. +#[derive(Clone)] +pub struct Prio3PrepareState<F, const SEED_SIZE: usize> { + measurement_share: Share<F, SEED_SIZE>, + joint_rand_seed: Option<Seed<SEED_SIZE>>, + agg_id: u8, + verifier_len: usize, +} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> PartialEq for Prio3PrepareState<F, SEED_SIZE> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> Eq for Prio3PrepareState<F, SEED_SIZE> {} + +impl<F: ConstantTimeEq, const SEED_SIZE: usize> ConstantTimeEq for Prio3PrepareState<F, SEED_SIZE> { + fn ct_eq(&self, other: &Self) -> Choice { + // We allow short-circuiting on the presence or absence of the joint_rand_seed, as well as + // the aggregator ID & verifier length parameters. + if self.agg_id != other.agg_id || self.verifier_len != other.verifier_len { + return Choice::from(0); + } + + option_ct_eq( + self.joint_rand_seed.as_ref(), + other.joint_rand_seed.as_ref(), + ) & self.measurement_share.ct_eq(&other.measurement_share) + } +} + +impl<F, const SEED_SIZE: usize> Debug for Prio3PrepareState<F, SEED_SIZE> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Prio3PrepareState") + .field("measurement_share", &"[redacted]") + .field( + "joint_rand_seed", + match self.joint_rand_seed { + Some(_) => &"Some([redacted])", + None => &"None", + }, + ) + .field("agg_id", &self.agg_id) + .field("verifier_len", &self.verifier_len) + .finish() + } +} + +impl<F: FftFriendlyFieldElement, const SEED_SIZE: usize> Encode + for Prio3PrepareState<F, SEED_SIZE> +{ + /// Append the encoded form of this object to the end of `bytes`, growing the vector as needed. + fn encode(&self, bytes: &mut Vec<u8>) { + self.measurement_share.encode(bytes); + if let Some(ref seed) = self.joint_rand_seed { + seed.encode(bytes); + } + } + + fn encoded_len(&self) -> Option<usize> { + let mut len = self.measurement_share.encoded_len()?; + if let Some(ref seed) = self.joint_rand_seed { + len += seed.encoded_len()?; + } + Some(len) + } +} + +impl<'a, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, usize)> + for Prio3PrepareState<T::Field, SEED_SIZE> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + fn decode_with_param( + (prio3, agg_id): &(&'a Prio3<T, P, SEED_SIZE>, usize), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + let agg_id = prio3 + .role_try_from(*agg_id) + .map_err(|e| CodecError::Other(Box::new(e)))?; + + let share_decoder = if agg_id == 0 { + ShareDecodingParameter::Leader(prio3.typ.input_len()) + } else { + ShareDecodingParameter::Helper + }; + let measurement_share = Share::decode_with_param(&share_decoder, bytes)?; + + let joint_rand_seed = if prio3.typ.joint_rand_len() > 0 { + Some(Seed::decode(bytes)?) + } else { + None + }; + + Ok(Self { + measurement_share, + joint_rand_seed, + agg_id, + verifier_len: prio3.typ.verifier_len(), + }) + } +} + +impl<T, P, const SEED_SIZE: usize> Aggregator<SEED_SIZE, 16> for Prio3<T, P, SEED_SIZE> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + type PrepareState = Prio3PrepareState<T::Field, SEED_SIZE>; + type PrepareShare = Prio3PrepareShare<T::Field, SEED_SIZE>; + type PrepareMessage = Prio3PrepareMessage<SEED_SIZE>; + + /// Begins the Prep process with the other aggregators. The result of this process is + /// the aggregator's output share. + #[allow(clippy::type_complexity)] + fn prepare_init( + &self, + verify_key: &[u8; SEED_SIZE], + agg_id: usize, + _agg_param: &Self::AggregationParam, + nonce: &[u8; 16], + public_share: &Self::PublicShare, + msg: &Prio3InputShare<T::Field, SEED_SIZE>, + ) -> Result< + ( + Prio3PrepareState<T::Field, SEED_SIZE>, + Prio3PrepareShare<T::Field, SEED_SIZE>, + ), + VdafError, + > { + let agg_id = self.role_try_from(agg_id)?; + let mut query_rand_xof = P::init( + verify_key, + &Self::domain_separation_tag(DST_QUERY_RANDOMNESS), + ); + query_rand_xof.update(nonce); + let query_rand = query_rand_xof + .into_seed_stream() + .into_field_vec(self.typ.query_rand_len()); + + // Create a reference to the (expanded) measurement share. + let expanded_measurement_share: Option<Vec<T::Field>> = match msg.measurement_share { + Share::Leader(_) => None, + Share::Helper(ref seed) => Some( + P::seed_stream( + seed, + &Self::domain_separation_tag(DST_MEASUREMENT_SHARE), + &[agg_id], + ) + .into_field_vec(self.typ.input_len()), + ), + }; + let measurement_share = match msg.measurement_share { + Share::Leader(ref data) => data, + Share::Helper(_) => expanded_measurement_share.as_ref().unwrap(), + }; + + // Create a reference to the (expanded) proof share. + let expanded_proof_share: Option<Vec<T::Field>> = match msg.proof_share { + Share::Leader(_) => None, + Share::Helper(ref seed) => Some( + P::seed_stream( + seed, + &Self::domain_separation_tag(DST_PROOF_SHARE), + &[agg_id], + ) + .into_field_vec(self.typ.proof_len()), + ), + }; + let proof_share = match msg.proof_share { + Share::Leader(ref data) => data, + Share::Helper(_) => expanded_proof_share.as_ref().unwrap(), + }; + + // Compute the joint randomness. + let (joint_rand_seed, joint_rand_part, joint_rand) = if self.typ.joint_rand_len() > 0 { + let mut joint_rand_part_xof = P::init( + msg.joint_rand_blind.as_ref().unwrap().as_ref(), + &Self::domain_separation_tag(DST_JOINT_RAND_PART), + ); + joint_rand_part_xof.update(&[agg_id]); + joint_rand_part_xof.update(nonce); + let mut encoding_buffer = Vec::with_capacity(T::Field::ENCODED_SIZE); + for x in measurement_share { + x.encode(&mut encoding_buffer); + joint_rand_part_xof.update(&encoding_buffer); + encoding_buffer.clear(); + } + let own_joint_rand_part = joint_rand_part_xof.into_seed(); + + // Make an iterator over the joint randomness parts, but use this aggregator's + // contribution, computed from the input share, in lieu of the the corresponding part + // from the public share. + // + // The locally computed part should match the part from the public share for honestly + // generated reports. If they do not match, the joint randomness seed check during the + // next round of preparation should fail. + let corrected_joint_rand_parts = public_share + .joint_rand_parts + .iter() + .flatten() + .take(agg_id as usize) + .chain(iter::once(&own_joint_rand_part)) + .chain( + public_share + .joint_rand_parts + .iter() + .flatten() + .skip(agg_id as usize + 1), + ); + + let joint_rand_seed = Self::derive_joint_rand_seed(corrected_joint_rand_parts); + + let joint_rand = P::seed_stream( + &joint_rand_seed, + &Self::domain_separation_tag(DST_JOINT_RANDOMNESS), + &[], + ) + .into_field_vec(self.typ.joint_rand_len()); + (Some(joint_rand_seed), Some(own_joint_rand_part), joint_rand) + } else { + (None, None, Vec::new()) + }; + + // Run the query-generation algorithm. + let verifier_share = self.typ.query( + measurement_share, + proof_share, + &query_rand, + &joint_rand, + self.num_aggregators as usize, + )?; + + Ok(( + Prio3PrepareState { + measurement_share: msg.measurement_share.clone(), + joint_rand_seed, + agg_id, + verifier_len: verifier_share.len(), + }, + Prio3PrepareShare { + verifier: verifier_share, + joint_rand_part, + }, + )) + } + + fn prepare_shares_to_prepare_message< + M: IntoIterator<Item = Prio3PrepareShare<T::Field, SEED_SIZE>>, + >( + &self, + _: &Self::AggregationParam, + inputs: M, + ) -> Result<Prio3PrepareMessage<SEED_SIZE>, VdafError> { + let mut verifier = vec![T::Field::zero(); self.typ.verifier_len()]; + let mut joint_rand_parts = Vec::with_capacity(self.num_aggregators()); + let mut count = 0; + for share in inputs.into_iter() { + count += 1; + + if share.verifier.len() != verifier.len() { + return Err(VdafError::Uncategorized(format!( + "unexpected verifier share length: got {}; want {}", + share.verifier.len(), + verifier.len(), + ))); + } + + if self.typ.joint_rand_len() > 0 { + let joint_rand_seed_part = share.joint_rand_part.unwrap(); + joint_rand_parts.push(joint_rand_seed_part); + } + + for (x, y) in verifier.iter_mut().zip(share.verifier) { + *x += y; + } + } + + if count != self.num_aggregators { + return Err(VdafError::Uncategorized(format!( + "unexpected message count: got {}; want {}", + count, self.num_aggregators, + ))); + } + + // Check the proof verifier. + match self.typ.decide(&verifier) { + Ok(true) => (), + Ok(false) => { + return Err(VdafError::Uncategorized( + "proof verifier check failed".into(), + )) + } + Err(err) => return Err(VdafError::from(err)), + }; + + let joint_rand_seed = if self.typ.joint_rand_len() > 0 { + Some(Self::derive_joint_rand_seed(joint_rand_parts.iter())) + } else { + None + }; + + Ok(Prio3PrepareMessage { joint_rand_seed }) + } + + fn prepare_next( + &self, + step: Prio3PrepareState<T::Field, SEED_SIZE>, + msg: Prio3PrepareMessage<SEED_SIZE>, + ) -> Result<PrepareTransition<Self, SEED_SIZE, 16>, VdafError> { + if self.typ.joint_rand_len() > 0 { + // Check that the joint randomness was correct. + if step + .joint_rand_seed + .as_ref() + .unwrap() + .ct_ne(msg.joint_rand_seed.as_ref().unwrap()) + .into() + { + return Err(VdafError::Uncategorized( + "joint randomness mismatch".to_string(), + )); + } + } + + // Compute the output share. + let measurement_share = match step.measurement_share { + Share::Leader(data) => data, + Share::Helper(seed) => { + let dst = Self::domain_separation_tag(DST_MEASUREMENT_SHARE); + P::seed_stream(&seed, &dst, &[step.agg_id]).into_field_vec(self.typ.input_len()) + } + }; + + let output_share = match self.typ.truncate(measurement_share) { + Ok(data) => OutputShare(data), + Err(err) => { + return Err(VdafError::from(err)); + } + }; + + Ok(PrepareTransition::Finish(output_share)) + } + + /// Aggregates a sequence of output shares into an aggregate share. + fn aggregate<It: IntoIterator<Item = OutputShare<T::Field>>>( + &self, + _agg_param: &(), + output_shares: It, + ) -> Result<AggregateShare<T::Field>, VdafError> { + let mut agg_share = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); + for output_share in output_shares.into_iter() { + agg_share.accumulate(&output_share)?; + } + + Ok(agg_share) + } +} + +#[cfg(feature = "experimental")] +impl<T, P, S, const SEED_SIZE: usize> AggregatorWithNoise<SEED_SIZE, 16, S> + for Prio3<T, P, SEED_SIZE> +where + T: TypeWithNoise<S>, + P: Xof<SEED_SIZE>, + S: DifferentialPrivacyStrategy, +{ + fn add_noise_to_agg_share( + &self, + dp_strategy: &S, + _agg_param: &Self::AggregationParam, + agg_share: &mut Self::AggregateShare, + num_measurements: usize, + ) -> Result<(), VdafError> { + self.typ + .add_noise_to_result(dp_strategy, &mut agg_share.0, num_measurements)?; + Ok(()) + } +} + +impl<T, P, const SEED_SIZE: usize> Collector for Prio3<T, P, SEED_SIZE> +where + T: Type, + P: Xof<SEED_SIZE>, +{ + /// Combines aggregate shares into the aggregate result. + fn unshard<It: IntoIterator<Item = AggregateShare<T::Field>>>( + &self, + _agg_param: &Self::AggregationParam, + agg_shares: It, + num_measurements: usize, + ) -> Result<T::AggregateResult, VdafError> { + let mut agg = AggregateShare(vec![T::Field::zero(); self.typ.output_len()]); + for agg_share in agg_shares.into_iter() { + agg.merge(&agg_share)?; + } + + Ok(self.typ.decode_result(&agg.0, num_measurements)?) + } +} + +#[derive(Clone)] +struct HelperShare<const SEED_SIZE: usize> { + measurement_share: Seed<SEED_SIZE>, + proof_share: Seed<SEED_SIZE>, + joint_rand_blind: Option<Seed<SEED_SIZE>>, +} + +impl<const SEED_SIZE: usize> HelperShare<SEED_SIZE> { + fn from_seeds( + measurement_share: [u8; SEED_SIZE], + proof_share: [u8; SEED_SIZE], + joint_rand_blind: Option<[u8; SEED_SIZE]>, + ) -> Self { + HelperShare { + measurement_share: Seed::from_bytes(measurement_share), + proof_share: Seed::from_bytes(proof_share), + joint_rand_blind: joint_rand_blind.map(Seed::from_bytes), + } + } +} + +fn check_num_aggregators(num_aggregators: u8) -> Result<(), VdafError> { + if num_aggregators == 0 { + return Err(VdafError::Uncategorized(format!( + "at least one aggregator is required; got {num_aggregators}" + ))); + } else if num_aggregators > 254 { + return Err(VdafError::Uncategorized(format!( + "number of aggregators must not exceed 254; got {num_aggregators}" + ))); + } + + Ok(()) +} + +impl<'a, F, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, &'a ())> + for OutputShare<F> +where + F: FieldElement, + T: Type, + P: Xof<SEED_SIZE>, +{ + fn decode_with_param( + (vdaf, _): &(&'a Prio3<T, P, SEED_SIZE>, &'a ()), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + decode_fieldvec(vdaf.output_len(), bytes).map(Self) + } +} + +impl<'a, F, T, P, const SEED_SIZE: usize> ParameterizedDecode<(&'a Prio3<T, P, SEED_SIZE>, &'a ())> + for AggregateShare<F> +where + F: FieldElement, + T: Type, + P: Xof<SEED_SIZE>, +{ + fn decode_with_param( + (vdaf, _): &(&'a Prio3<T, P, SEED_SIZE>, &'a ()), + bytes: &mut Cursor<&[u8]>, + ) -> Result<Self, CodecError> { + decode_fieldvec(vdaf.output_len(), bytes).map(Self) + } +} + +// This function determines equality between two optional, constant-time comparable values. It +// short-circuits on the existence (but not contents) of the values -- a timing side-channel may +// reveal whether the values match on Some or None. +#[inline] +fn option_ct_eq<T>(left: Option<&T>, right: Option<&T>) -> Choice +where + T: ConstantTimeEq + ?Sized, +{ + match (left, right) { + (Some(left), Some(right)) => left.ct_eq(right), + (None, None) => Choice::from(1), + _ => Choice::from(0), + } +} + +/// This is a polyfill for `usize::ilog2()`, which is only available in Rust 1.67 and later. It is +/// based on the implementation in the standard library. It can be removed when the MSRV has been +/// advanced past 1.67. +/// +/// # Panics +/// +/// This function will panic if `input` is zero. +fn ilog2(input: usize) -> u32 { + if input == 0 { + panic!("Tried to take the logarithm of zero"); + } + (usize::BITS - 1) - input.leading_zeros() +} + +/// Finds the optimal choice of chunk length for [`Prio3Histogram`] or [`Prio3SumVec`], given its +/// encoded measurement length. For [`Prio3Histogram`], the measurement length is equal to the +/// length parameter. For [`Prio3SumVec`], the measurement length is equal to the product of the +/// length and bits parameters. +pub fn optimal_chunk_length(measurement_length: usize) -> usize { + if measurement_length <= 1 { + return 1; + } + + /// Candidate set of parameter choices for the parallel sum optimization. + struct Candidate { + gadget_calls: usize, + chunk_length: usize, + } + + let max_log2 = ilog2(measurement_length + 1); + let best_opt = (1..=max_log2) + .rev() + .map(|log2| { + let gadget_calls = (1 << log2) - 1; + let chunk_length = (measurement_length + gadget_calls - 1) / gadget_calls; + Candidate { + gadget_calls, + chunk_length, + } + }) + .min_by_key(|candidate| { + // Compute the proof length, in field elements, for either Prio3Histogram or Prio3SumVec + (candidate.chunk_length * 2) + + 2 * ((1 + candidate.gadget_calls).next_power_of_two() - 1) + }); + // Unwrap safety: max_log2 must be at least 1, because smaller measurement_length inputs are + // dealt with separately. Thus, the range iterator that the search is over will be nonempty, + // and min_by_key() will always return Some. + best_opt.unwrap().chunk_length +} + +#[cfg(test)] +mod tests { + use super::*; + #[cfg(feature = "experimental")] + use crate::flp::gadgets::ParallelSumGadget; + use crate::vdaf::{ + equality_comparison_test, fieldvec_roundtrip_test, run_vdaf, run_vdaf_prepare, + }; + use assert_matches::assert_matches; + #[cfg(feature = "experimental")] + use fixed::{ + types::extra::{U15, U31, U63}, + FixedI16, FixedI32, FixedI64, + }; + #[cfg(feature = "experimental")] + use fixed_macro::fixed; + use rand::prelude::*; + + #[test] + fn test_prio3_count() { + let prio3 = Prio3::new_count(2).unwrap(); + + assert_eq!(run_vdaf(&prio3, &(), [1, 0, 0, 1, 1]).unwrap(), 3); + + let mut nonce = [0; 16]; + let mut verify_key = [0; 16]; + thread_rng().fill(&mut verify_key[..]); + thread_rng().fill(&mut nonce[..]); + + let (public_share, input_shares) = prio3.shard(&0, &nonce).unwrap(); + run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap(); + + let (public_share, input_shares) = prio3.shard(&1, &nonce).unwrap(); + run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares).unwrap(); + + test_serialization(&prio3, &1, &nonce).unwrap(); + + let prio3_extra_helper = Prio3::new_count(3).unwrap(); + assert_eq!( + run_vdaf(&prio3_extra_helper, &(), [1, 0, 0, 1, 1]).unwrap(), + 3, + ); + } + + #[test] + fn test_prio3_sum() { + let prio3 = Prio3::new_sum(3, 16).unwrap(); + + assert_eq!( + run_vdaf(&prio3, &(), [0, (1 << 16) - 1, 0, 1, 1]).unwrap(), + (1 << 16) + 1 + ); + + let mut verify_key = [0; 16]; + thread_rng().fill(&mut verify_key[..]); + let nonce = [0; 16]; + + let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); + input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255; + let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); + assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => { + data[0] += Field128::one(); + }); + let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3.shard(&1, &nonce).unwrap(); + assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => { + data[0] += Field128::one(); + }); + let result = run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + test_serialization(&prio3, &1, &nonce).unwrap(); + } + + #[test] + fn test_prio3_sum_vec() { + let prio3 = Prio3::new_sum_vec(2, 2, 20, 4).unwrap(); + assert_eq!( + run_vdaf( + &prio3, + &(), + [ + vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1], + vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0], + vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1], + ] + ) + .unwrap(), + vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2], + ); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_prio3_sum_vec_multithreaded() { + let prio3 = Prio3::new_sum_vec_multithreaded(2, 2, 20, 4).unwrap(); + assert_eq!( + run_vdaf( + &prio3, + &(), + [ + vec![0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1], + vec![0, 2, 0, 0, 1, 0, 0, 0, 1, 1, 1, 3, 0, 3, 0, 0, 0, 1, 0, 0], + vec![1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1], + ] + ) + .unwrap(), + vec![1, 3, 1, 0, 3, 1, 0, 1, 2, 2, 3, 3, 1, 5, 1, 2, 1, 3, 0, 2], + ); + } + + #[test] + #[cfg(feature = "experimental")] + fn test_prio3_bounded_fpvec_sum_unaligned() { + type P<Fx> = Prio3FixedPointBoundedL2VecSum<Fx>; + #[cfg(feature = "multithreaded")] + type PM<Fx> = Prio3FixedPointBoundedL2VecSumMultithreaded<Fx>; + let ctor_32 = P::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum; + #[cfg(feature = "multithreaded")] + let ctor_mt_32 = PM::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum_multithreaded; + + { + const SIZE: usize = 5; + let fp32_0 = fixed!(0: I1F31); + + // 32 bit fixedpoint, non-power-of-2 vector, single-threaded + { + let prio3_32 = ctor_32(2, SIZE).unwrap(); + test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_32); + } + + // 32 bit fixedpoint, non-power-of-2 vector, multi-threaded + #[cfg(feature = "multithreaded")] + { + let prio3_mt_32 = ctor_mt_32(2, SIZE).unwrap(); + test_fixed_vec::<_, _, _, SIZE>(fp32_0, prio3_mt_32); + } + } + + fn test_fixed_vec<Fx, PE, M, const SIZE: usize>( + fp_0: Fx, + prio3: Prio3<FixedPointBoundedL2VecSum<Fx, PE, M>, XofShake128, 16>, + ) where + Fx: Fixed + CompatibleFloat + std::ops::Neg<Output = Fx>, + PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static, + M: Eq + ParallelSumGadget<Field128, Mul<Field128>> + Clone + 'static, + { + let fp_vec = vec![fp_0; SIZE]; + + let measurements = [fp_vec.clone(), fp_vec]; + assert_eq!( + run_vdaf(&prio3, &(), measurements).unwrap(), + vec![0.0; SIZE] + ); + } + } + + #[test] + #[cfg(feature = "experimental")] + fn test_prio3_bounded_fpvec_sum() { + type P<Fx> = Prio3FixedPointBoundedL2VecSum<Fx>; + let ctor_16 = P::<FixedI16<U15>>::new_fixedpoint_boundedl2_vec_sum; + let ctor_32 = P::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum; + let ctor_64 = P::<FixedI64<U63>>::new_fixedpoint_boundedl2_vec_sum; + + #[cfg(feature = "multithreaded")] + type PM<Fx> = Prio3FixedPointBoundedL2VecSumMultithreaded<Fx>; + #[cfg(feature = "multithreaded")] + let ctor_mt_16 = PM::<FixedI16<U15>>::new_fixedpoint_boundedl2_vec_sum_multithreaded; + #[cfg(feature = "multithreaded")] + let ctor_mt_32 = PM::<FixedI32<U31>>::new_fixedpoint_boundedl2_vec_sum_multithreaded; + #[cfg(feature = "multithreaded")] + let ctor_mt_64 = PM::<FixedI64<U63>>::new_fixedpoint_boundedl2_vec_sum_multithreaded; + + { + // 16 bit fixedpoint + let fp16_4_inv = fixed!(0.25: I1F15); + let fp16_8_inv = fixed!(0.125: I1F15); + let fp16_16_inv = fixed!(0.0625: I1F15); + + // two aggregators, three entries per vector. + { + let prio3_16 = ctor_16(2, 3).unwrap(); + test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16); + } + + #[cfg(feature = "multithreaded")] + { + let prio3_16_mt = ctor_mt_16(2, 3).unwrap(); + test_fixed(fp16_4_inv, fp16_8_inv, fp16_16_inv, prio3_16_mt); + } + } + + { + // 32 bit fixedpoint + let fp32_4_inv = fixed!(0.25: I1F31); + let fp32_8_inv = fixed!(0.125: I1F31); + let fp32_16_inv = fixed!(0.0625: I1F31); + + { + let prio3_32 = ctor_32(2, 3).unwrap(); + test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32); + } + + #[cfg(feature = "multithreaded")] + { + let prio3_32_mt = ctor_mt_32(2, 3).unwrap(); + test_fixed(fp32_4_inv, fp32_8_inv, fp32_16_inv, prio3_32_mt); + } + } + + { + // 64 bit fixedpoint + let fp64_4_inv = fixed!(0.25: I1F63); + let fp64_8_inv = fixed!(0.125: I1F63); + let fp64_16_inv = fixed!(0.0625: I1F63); + + { + let prio3_64 = ctor_64(2, 3).unwrap(); + test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64); + } + + #[cfg(feature = "multithreaded")] + { + let prio3_64_mt = ctor_mt_64(2, 3).unwrap(); + test_fixed(fp64_4_inv, fp64_8_inv, fp64_16_inv, prio3_64_mt); + } + } + + fn test_fixed<Fx, PE, M>( + fp_4_inv: Fx, + fp_8_inv: Fx, + fp_16_inv: Fx, + prio3: Prio3<FixedPointBoundedL2VecSum<Fx, PE, M>, XofShake128, 16>, + ) where + Fx: Fixed + CompatibleFloat + std::ops::Neg<Output = Fx>, + PE: Eq + ParallelSumGadget<Field128, PolyEval<Field128>> + Clone + 'static, + M: Eq + ParallelSumGadget<Field128, Mul<Field128>> + Clone + 'static, + { + let fp_vec1 = vec![fp_4_inv, fp_8_inv, fp_16_inv]; + let fp_vec2 = vec![fp_4_inv, fp_8_inv, fp_16_inv]; + + let fp_vec3 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv]; + let fp_vec4 = vec![-fp_4_inv, -fp_8_inv, -fp_16_inv]; + + let fp_vec5 = vec![fp_4_inv, -fp_8_inv, -fp_16_inv]; + let fp_vec6 = vec![fp_4_inv, fp_8_inv, fp_16_inv]; + + // positive entries + let fp_list = [fp_vec1, fp_vec2]; + assert_eq!( + run_vdaf(&prio3, &(), fp_list).unwrap(), + vec!(0.5, 0.25, 0.125), + ); + + // negative entries + let fp_list2 = [fp_vec3, fp_vec4]; + assert_eq!( + run_vdaf(&prio3, &(), fp_list2).unwrap(), + vec!(-0.5, -0.25, -0.125), + ); + + // both + let fp_list3 = [fp_vec5, fp_vec6]; + assert_eq!( + run_vdaf(&prio3, &(), fp_list3).unwrap(), + vec!(0.5, 0.0, 0.0), + ); + + let mut verify_key = [0; 16]; + let mut nonce = [0; 16]; + thread_rng().fill(&mut verify_key); + thread_rng().fill(&mut nonce); + + let (public_share, mut input_shares) = prio3 + .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) + .unwrap(); + input_shares[0].joint_rand_blind.as_mut().unwrap().0[0] ^= 255; + let result = + run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3 + .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) + .unwrap(); + assert_matches!(input_shares[0].measurement_share, Share::Leader(ref mut data) => { + data[0] += Field128::one(); + }); + let result = + run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + let (public_share, mut input_shares) = prio3 + .shard(&vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce) + .unwrap(); + assert_matches!(input_shares[0].proof_share, Share::Leader(ref mut data) => { + data[0] += Field128::one(); + }); + let result = + run_vdaf_prepare(&prio3, &verify_key, &(), &nonce, public_share, input_shares); + assert_matches!(result, Err(VdafError::Uncategorized(_))); + + test_serialization(&prio3, &vec![fp_4_inv, fp_8_inv, fp_16_inv], &nonce).unwrap(); + } + } + + #[test] + fn test_prio3_histogram() { + let prio3 = Prio3::new_histogram(2, 4, 2).unwrap(); + + assert_eq!( + run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(), + vec![1, 1, 1, 1] + ); + assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]); + assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]); + test_serialization(&prio3, &3, &[0; 16]).unwrap(); + } + + #[test] + #[cfg(feature = "multithreaded")] + fn test_prio3_histogram_multithreaded() { + let prio3 = Prio3::new_histogram_multithreaded(2, 4, 2).unwrap(); + + assert_eq!( + run_vdaf(&prio3, &(), [0, 1, 2, 3]).unwrap(), + vec![1, 1, 1, 1] + ); + assert_eq!(run_vdaf(&prio3, &(), [0]).unwrap(), vec![1, 0, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [1]).unwrap(), vec![0, 1, 0, 0]); + assert_eq!(run_vdaf(&prio3, &(), [2]).unwrap(), vec![0, 0, 1, 0]); + assert_eq!(run_vdaf(&prio3, &(), [3]).unwrap(), vec![0, 0, 0, 1]); + test_serialization(&prio3, &3, &[0; 16]).unwrap(); + } + + #[test] + fn test_prio3_average() { + let prio3 = Prio3::new_average(2, 64).unwrap(); + + assert_eq!(run_vdaf(&prio3, &(), [17, 8]).unwrap(), 12.5f64); + assert_eq!(run_vdaf(&prio3, &(), [1, 1, 1, 1]).unwrap(), 1f64); + assert_eq!(run_vdaf(&prio3, &(), [0, 0, 0, 1]).unwrap(), 0.25f64); + assert_eq!( + run_vdaf(&prio3, &(), [1, 11, 111, 1111, 3, 8]).unwrap(), + 207.5f64 + ); + } + + #[test] + fn test_prio3_input_share() { + let prio3 = Prio3::new_sum(5, 16).unwrap(); + let (_public_share, input_shares) = prio3.shard(&1, &[0; 16]).unwrap(); + + // Check that seed shares are distinct. + for (i, x) in input_shares.iter().enumerate() { + for (j, y) in input_shares.iter().enumerate() { + if i != j { + if let (Share::Helper(left), Share::Helper(right)) = + (&x.measurement_share, &y.measurement_share) + { + assert_ne!(left, right); + } + + if let (Share::Helper(left), Share::Helper(right)) = + (&x.proof_share, &y.proof_share) + { + assert_ne!(left, right); + } + + assert_ne!(x.joint_rand_blind, y.joint_rand_blind); + } + } + } + } + + fn test_serialization<T, P, const SEED_SIZE: usize>( + prio3: &Prio3<T, P, SEED_SIZE>, + measurement: &T::Measurement, + nonce: &[u8; 16], + ) -> Result<(), VdafError> + where + T: Type, + P: Xof<SEED_SIZE>, + { + let mut verify_key = [0; SEED_SIZE]; + thread_rng().fill(&mut verify_key[..]); + let (public_share, input_shares) = prio3.shard(measurement, nonce)?; + + let encoded_public_share = public_share.get_encoded(); + let decoded_public_share = + Prio3PublicShare::get_decoded_with_param(prio3, &encoded_public_share) + .expect("failed to decode public share"); + assert_eq!(decoded_public_share, public_share); + assert_eq!( + public_share.encoded_len().unwrap(), + encoded_public_share.len() + ); + + for (agg_id, input_share) in input_shares.iter().enumerate() { + let encoded_input_share = input_share.get_encoded(); + let decoded_input_share = + Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), &encoded_input_share) + .expect("failed to decode input share"); + assert_eq!(&decoded_input_share, input_share); + assert_eq!( + input_share.encoded_len().unwrap(), + encoded_input_share.len() + ); + } + + let mut prepare_shares = Vec::new(); + let mut last_prepare_state = None; + for (agg_id, input_share) in input_shares.iter().enumerate() { + let (prepare_state, prepare_share) = + prio3.prepare_init(&verify_key, agg_id, &(), nonce, &public_share, input_share)?; + + let encoded_prepare_state = prepare_state.get_encoded(); + let decoded_prepare_state = + Prio3PrepareState::get_decoded_with_param(&(prio3, agg_id), &encoded_prepare_state) + .expect("failed to decode prepare state"); + assert_eq!(decoded_prepare_state, prepare_state); + assert_eq!( + prepare_state.encoded_len().unwrap(), + encoded_prepare_state.len() + ); + + let encoded_prepare_share = prepare_share.get_encoded(); + let decoded_prepare_share = + Prio3PrepareShare::get_decoded_with_param(&prepare_state, &encoded_prepare_share) + .expect("failed to decode prepare share"); + assert_eq!(decoded_prepare_share, prepare_share); + assert_eq!( + prepare_share.encoded_len().unwrap(), + encoded_prepare_share.len() + ); + + prepare_shares.push(prepare_share); + last_prepare_state = Some(prepare_state); + } + + let prepare_message = prio3 + .prepare_shares_to_prepare_message(&(), prepare_shares) + .unwrap(); + + let encoded_prepare_message = prepare_message.get_encoded(); + let decoded_prepare_message = Prio3PrepareMessage::get_decoded_with_param( + &last_prepare_state.unwrap(), + &encoded_prepare_message, + ) + .expect("failed to decode prepare message"); + assert_eq!(decoded_prepare_message, prepare_message); + assert_eq!( + prepare_message.encoded_len().unwrap(), + encoded_prepare_message.len() + ); + + Ok(()) + } + + #[test] + fn roundtrip_output_share() { + let vdaf = Prio3::new_count(2).unwrap(); + fieldvec_roundtrip_test::<Field64, Prio3Count, OutputShare<Field64>>(&vdaf, &(), 1); + + let vdaf = Prio3::new_sum(2, 17).unwrap(); + fieldvec_roundtrip_test::<Field128, Prio3Sum, OutputShare<Field128>>(&vdaf, &(), 1); + + let vdaf = Prio3::new_histogram(2, 12, 3).unwrap(); + fieldvec_roundtrip_test::<Field128, Prio3Histogram, OutputShare<Field128>>(&vdaf, &(), 12); + } + + #[test] + fn roundtrip_aggregate_share() { + let vdaf = Prio3::new_count(2).unwrap(); + fieldvec_roundtrip_test::<Field64, Prio3Count, AggregateShare<Field64>>(&vdaf, &(), 1); + + let vdaf = Prio3::new_sum(2, 17).unwrap(); + fieldvec_roundtrip_test::<Field128, Prio3Sum, AggregateShare<Field128>>(&vdaf, &(), 1); + + let vdaf = Prio3::new_histogram(2, 12, 3).unwrap(); + fieldvec_roundtrip_test::<Field128, Prio3Histogram, AggregateShare<Field128>>( + &vdaf, + &(), + 12, + ); + } + + #[test] + fn public_share_equality_test() { + equality_comparison_test(&[ + Prio3PublicShare { + joint_rand_parts: Some(Vec::from([Seed([0])])), + }, + Prio3PublicShare { + joint_rand_parts: Some(Vec::from([Seed([1])])), + }, + Prio3PublicShare { + joint_rand_parts: None, + }, + ]) + } + + #[test] + fn input_share_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: Some(Seed([2])), + }, + // Modified measurement share. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([100])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: Some(Seed([2])), + }, + // Modified proof share. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([101])), + joint_rand_blind: Some(Seed([2])), + }, + // Modified joint_rand_blind. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: Some(Seed([102])), + }, + // Missing joint_rand_blind. + Prio3InputShare { + measurement_share: Share::Leader(Vec::from([0])), + proof_share: Share::Leader(Vec::from([1])), + joint_rand_blind: None, + }, + ]) + } + + #[test] + fn prepare_share_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3PrepareShare { + verifier: Vec::from([0]), + joint_rand_part: Some(Seed([1])), + }, + // Modified verifier. + Prio3PrepareShare { + verifier: Vec::from([100]), + joint_rand_part: Some(Seed([1])), + }, + // Modified joint_rand_part. + Prio3PrepareShare { + verifier: Vec::from([0]), + joint_rand_part: Some(Seed([101])), + }, + // Missing joint_rand_part. + Prio3PrepareShare { + verifier: Vec::from([0]), + joint_rand_part: None, + }, + ]) + } + + #[test] + fn prepare_message_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3PrepareMessage { + joint_rand_seed: Some(Seed([0])), + }, + // Modified joint_rand_seed. + Prio3PrepareMessage { + joint_rand_seed: Some(Seed([100])), + }, + // Missing joint_rand_seed. + Prio3PrepareMessage { + joint_rand_seed: None, + }, + ]) + } + + #[test] + fn prepare_state_equality_test() { + equality_comparison_test(&[ + // Default. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([1])), + agg_id: 2, + verifier_len: 3, + }, + // Modified measurement share. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([100])), + joint_rand_seed: Some(Seed([1])), + agg_id: 2, + verifier_len: 3, + }, + // Modified joint_rand_seed. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([101])), + agg_id: 2, + verifier_len: 3, + }, + // Missing joint_rand_seed. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: None, + agg_id: 2, + verifier_len: 3, + }, + // Modified agg_id. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([1])), + agg_id: 102, + verifier_len: 3, + }, + // Modified verifier_len. + Prio3PrepareState { + measurement_share: Share::Leader(Vec::from([0])), + joint_rand_seed: Some(Seed([1])), + agg_id: 2, + verifier_len: 103, + }, + ]) + } + + #[test] + fn test_optimal_chunk_length() { + // nonsense argument, but make sure it doesn't panic. + optimal_chunk_length(0); + + // edge cases on either side of power-of-two jumps + assert_eq!(optimal_chunk_length(1), 1); + assert_eq!(optimal_chunk_length(2), 2); + assert_eq!(optimal_chunk_length(3), 1); + assert_eq!(optimal_chunk_length(18), 6); + assert_eq!(optimal_chunk_length(19), 3); + + // additional arbitrary test cases + assert_eq!(optimal_chunk_length(40), 6); + assert_eq!(optimal_chunk_length(10_000), 79); + assert_eq!(optimal_chunk_length(100_000), 393); + + // confirm that the chunk lengths are truly optimal + for measurement_length in [2, 3, 4, 5, 18, 19, 40] { + let optimal_chunk_length = optimal_chunk_length(measurement_length); + let optimal_proof_length = Histogram::<Field128, ParallelSum<_, _>>::new( + measurement_length, + optimal_chunk_length, + ) + .unwrap() + .proof_len(); + for chunk_length in 1..=measurement_length { + let proof_length = + Histogram::<Field128, ParallelSum<_, _>>::new(measurement_length, chunk_length) + .unwrap() + .proof_len(); + assert!(proof_length >= optimal_proof_length); + } + } + } +} diff --git a/third_party/rust/prio/src/vdaf/prio3_test.rs b/third_party/rust/prio/src/vdaf/prio3_test.rs new file mode 100644 index 0000000000..372a2c8560 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/prio3_test.rs @@ -0,0 +1,251 @@ +// SPDX-License-Identifier: MPL-2.0 + +use crate::{ + codec::{Encode, ParameterizedDecode}, + flp::Type, + vdaf::{ + prio3::{Prio3, Prio3InputShare, Prio3PrepareShare, Prio3PublicShare}, + xof::Xof, + Aggregator, Collector, OutputShare, PrepareTransition, Vdaf, + }, +}; +use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, convert::TryInto, fmt::Debug}; + +#[derive(Debug, Deserialize, Serialize)] +struct TEncoded(#[serde(with = "hex")] Vec<u8>); + +impl AsRef<[u8]> for TEncoded { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + +#[derive(Deserialize, Serialize)] +struct TPrio3Prep<M> { + measurement: M, + #[serde(with = "hex")] + nonce: Vec<u8>, + #[serde(with = "hex")] + rand: Vec<u8>, + public_share: TEncoded, + input_shares: Vec<TEncoded>, + prep_shares: Vec<Vec<TEncoded>>, + prep_messages: Vec<TEncoded>, + out_shares: Vec<Vec<TEncoded>>, +} + +#[derive(Deserialize, Serialize)] +struct TPrio3<M> { + verify_key: TEncoded, + shares: u8, + prep: Vec<TPrio3Prep<M>>, + agg_shares: Vec<TEncoded>, + agg_result: serde_json::Value, + #[serde(flatten)] + other_params: HashMap<String, serde_json::Value>, +} + +macro_rules! err { + ( + $test_num:ident, + $error:expr, + $msg:expr + ) => { + panic!("test #{} failed: {} err: {}", $test_num, $msg, $error) + }; +} + +// TODO Generalize this method to work with any VDAF. To do so we would need to add +// `shard_with_random()` to traits. (There may be a less invasive alternative.) +fn check_prep_test_vec<M, T, P, const SEED_SIZE: usize>( + prio3: &Prio3<T, P, SEED_SIZE>, + verify_key: &[u8; SEED_SIZE], + test_num: usize, + t: &TPrio3Prep<M>, +) -> Vec<OutputShare<T::Field>> +where + T: Type<Measurement = M>, + P: Xof<SEED_SIZE>, +{ + let nonce = <[u8; 16]>::try_from(t.nonce.clone()).unwrap(); + let (public_share, input_shares) = prio3 + .shard_with_random(&t.measurement, &nonce, &t.rand) + .expect("failed to generate input shares"); + + assert_eq!( + public_share, + Prio3PublicShare::get_decoded_with_param(prio3, t.public_share.as_ref()) + .unwrap_or_else(|e| err!(test_num, e, "decode test vector (public share)")), + ); + for (agg_id, want) in t.input_shares.iter().enumerate() { + assert_eq!( + input_shares[agg_id], + Prio3InputShare::get_decoded_with_param(&(prio3, agg_id), want.as_ref()) + .unwrap_or_else(|e| err!(test_num, e, "decode test vector (input share)")), + "#{test_num}" + ); + assert_eq!( + input_shares[agg_id].get_encoded(), + want.as_ref(), + "#{test_num}" + ) + } + + let mut states = Vec::new(); + let mut prep_shares = Vec::new(); + for (agg_id, input_share) in input_shares.iter().enumerate() { + let (state, prep_share) = prio3 + .prepare_init(verify_key, agg_id, &(), &nonce, &public_share, input_share) + .unwrap_or_else(|e| err!(test_num, e, "prep state init")); + states.push(state); + prep_shares.push(prep_share); + } + + assert_eq!(1, t.prep_shares.len(), "#{test_num}"); + for (i, want) in t.prep_shares[0].iter().enumerate() { + assert_eq!( + prep_shares[i], + Prio3PrepareShare::get_decoded_with_param(&states[i], want.as_ref()) + .unwrap_or_else(|e| err!(test_num, e, "decode test vector (prep share)")), + "#{test_num}" + ); + assert_eq!(prep_shares[i].get_encoded(), want.as_ref(), "#{test_num}"); + } + + let inbound = prio3 + .prepare_shares_to_prepare_message(&(), prep_shares) + .unwrap_or_else(|e| err!(test_num, e, "prep preprocess")); + assert_eq!(t.prep_messages.len(), 1); + assert_eq!(inbound.get_encoded(), t.prep_messages[0].as_ref()); + + let mut out_shares = Vec::new(); + for state in states.iter_mut() { + match prio3.prepare_next(state.clone(), inbound.clone()).unwrap() { + PrepareTransition::Finish(out_share) => { + out_shares.push(out_share); + } + _ => panic!("unexpected transition"), + } + } + + for (got, want) in out_shares.iter().zip(t.out_shares.iter()) { + let got: Vec<Vec<u8>> = got.as_ref().iter().map(|x| x.get_encoded()).collect(); + assert_eq!(got.len(), want.len()); + for (got_elem, want_elem) in got.iter().zip(want.iter()) { + assert_eq!(got_elem.as_slice(), want_elem.as_ref()); + } + } + + out_shares +} + +#[must_use] +fn check_aggregate_test_vec<M, T, P, const SEED_SIZE: usize>( + prio3: &Prio3<T, P, SEED_SIZE>, + t: &TPrio3<M>, +) -> T::AggregateResult +where + T: Type<Measurement = M>, + P: Xof<SEED_SIZE>, +{ + let verify_key = t.verify_key.as_ref().try_into().unwrap(); + + let mut all_output_shares = vec![Vec::new(); prio3.num_aggregators()]; + for (test_num, p) in t.prep.iter().enumerate() { + let output_shares = check_prep_test_vec(prio3, verify_key, test_num, p); + for (aggregator_output_shares, output_share) in + all_output_shares.iter_mut().zip(output_shares.into_iter()) + { + aggregator_output_shares.push(output_share); + } + } + + let aggregate_shares = all_output_shares + .into_iter() + .map(|aggregator_output_shares| prio3.aggregate(&(), aggregator_output_shares).unwrap()) + .collect::<Vec<_>>(); + + for (got, want) in aggregate_shares.iter().zip(t.agg_shares.iter()) { + let got = got.get_encoded(); + assert_eq!(got.as_slice(), want.as_ref()); + } + + prio3.unshard(&(), aggregate_shares, 1).unwrap() +} + +#[test] +fn test_vec_prio3_count() { + for test_vector_str in [ + include_str!("test_vec/07/Prio3Count_0.json"), + include_str!("test_vec/07/Prio3Count_1.json"), + ] { + let t: TPrio3<u64> = serde_json::from_str(test_vector_str).unwrap(); + let prio3 = Prio3::new_count(t.shares).unwrap(); + + let aggregate_result = check_aggregate_test_vec(&prio3, &t); + assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap()); + } +} + +#[test] +fn test_vec_prio3_sum() { + for test_vector_str in [ + include_str!("test_vec/07/Prio3Sum_0.json"), + include_str!("test_vec/07/Prio3Sum_1.json"), + ] { + let t: TPrio3<u128> = serde_json::from_str(test_vector_str).unwrap(); + let bits = t.other_params["bits"].as_u64().unwrap() as usize; + let prio3 = Prio3::new_sum(t.shares, bits).unwrap(); + + let aggregate_result = check_aggregate_test_vec(&prio3, &t); + assert_eq!(aggregate_result, t.agg_result.as_u64().unwrap() as u128); + } +} + +#[test] +fn test_vec_prio3_sum_vec() { + for test_vector_str in [ + include_str!("test_vec/07/Prio3SumVec_0.json"), + include_str!("test_vec/07/Prio3SumVec_1.json"), + ] { + let t: TPrio3<Vec<u128>> = serde_json::from_str(test_vector_str).unwrap(); + let bits = t.other_params["bits"].as_u64().unwrap() as usize; + let length = t.other_params["length"].as_u64().unwrap() as usize; + let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize; + let prio3 = Prio3::new_sum_vec(t.shares, bits, length, chunk_length).unwrap(); + + let aggregate_result = check_aggregate_test_vec(&prio3, &t); + let expected_aggregate_result = t + .agg_result + .as_array() + .unwrap() + .iter() + .map(|val| val.as_u64().unwrap() as u128) + .collect::<Vec<u128>>(); + assert_eq!(aggregate_result, expected_aggregate_result); + } +} + +#[test] +fn test_vec_prio3_histogram() { + for test_vector_str in [ + include_str!("test_vec/07/Prio3Histogram_0.json"), + include_str!("test_vec/07/Prio3Histogram_1.json"), + ] { + let t: TPrio3<usize> = serde_json::from_str(test_vector_str).unwrap(); + let length = t.other_params["length"].as_u64().unwrap() as usize; + let chunk_length = t.other_params["chunk_length"].as_u64().unwrap() as usize; + let prio3 = Prio3::new_histogram(t.shares, length, chunk_length).unwrap(); + + let aggregate_result = check_aggregate_test_vec(&prio3, &t); + let expected_aggregate_result = t + .agg_result + .as_array() + .unwrap() + .iter() + .map(|val| val.as_u64().unwrap() as u128) + .collect::<Vec<u128>>(); + assert_eq!(aggregate_result, expected_aggregate_result); + } +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/IdpfPoplar_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/IdpfPoplar_0.json new file mode 100644 index 0000000000..2ff7aa7ffd --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/IdpfPoplar_0.json @@ -0,0 +1,52 @@ +{ + "alpha": "0", + "beta_inner": [ + [ + "0", + "0" + ], + [ + "1", + "1" + ], + [ + "2", + "2" + ], + [ + "3", + "3" + ], + [ + "4", + "4" + ], + [ + "5", + "5" + ], + [ + "6", + "6" + ], + [ + "7", + "7" + ], + [ + "8", + "8" + ] + ], + "beta_leaf": [ + "9", + "9" + ], + "binder": "736f6d65206e6f6e6365", + "bits": 10, + "keys": [ + "000102030405060708090a0b0c0d0e0f", + "101112131415161718191a1b1c1d1e1f" + ], + "public_share": "921909356f44964d29c537aeeaeba92e573e4298c88dcc35bd3ae6acb4367236226b1af3151d5814f308f04e208fde2110c72523338563bc1c5fb47d22b5c34ae102e1e82fa250c7e23b95e985f91d7d91887fa7fb301ec20a06b1d4408d9a594754dcd86ec00c91f40f17c1ff52ed99fcd59965fe243a6cec7e672fefc5e3a29e653d5dcca8917e8af2c4f19d122c6dd30a3e2a80fb809383ced9d24fcd86516025174f5183fddfc6d74dde3b78834391c785defc8e4fbff92214df4c8322ee433a8eaeed7369419e0d6037a536e081df333aaab9e8e4d207d846961f015d96d57e3b59e24927773d6e0d66108955c1da134baab4eacd363c8e452b8c3845d5fb5c0ff6c27d7423a73d32742ccc3c750a17cd1f6026dd98a2cf6d2bff2dd339017b25af23d6db00ae8975e3f7e6aaef4af71f3e8cd14eb5c4373db9c3a76fc04659b761e650a97cb873df894064ecb2043a4317ef237ffe8f130eb5c2ca2a132c16f14943cd7e462568c8544b82e29329eb2a" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_0.json new file mode 100644 index 0000000000..79fadca3df --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_0.json @@ -0,0 +1,56 @@ +{ + "agg_param": [ + 0, + [ + 0, + 1 + ] + ], + "agg_result": [ + 0, + 1 + ], + "agg_shares": [ + "70f1cb8dc03c9eea88d270d6211a8667", + "910e34723ec361157a2d8f29dde57998" + ], + "bits": 4, + "prep": [ + { + "input_shares": [ + "000102030405060708090a0b0c0d0e0f202122232425262728292a2b2c2d2e2f311e448ab125690bc3a084a34301982c6aa325ab3f268338c9f4db9bda518743ee75c6c8ef7655d2d167d5385213d3bd1be920fad83c1b35fd9239efb406370db4d0a8e97eb41413957e264ded24e074ca433b6b5d451d0f65ec1d4ac246a36da12ecb2a537a449aeeb9d70bd064e930", + "101112131415161718191a1b1c1d1e1f303132333435363738393a3b3c3d3e3f96998a1415c9876c2887f2dcc52f090ec64bcaeb3165e98f47d261a0bed49a156c054b063cad6277a526505ac31807288abfb796b6cee614e5c41ab75c2b9912cc246c66c5e248a7cfc30ad92eefec7e67c2da39726d5b7277eb1449c779f834b0ab75f383f07bd2f3747cb7b98f617c" + ], + "measurement": 13, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "70f1cb8dc03c9eea", + "88d270d6211a8667" + ], + [ + "910e34723ec36115", + "7a2d8f29dde57998" + ] + ], + "prep_messages": [ + "d4cd54eb29f676c2d10fab848e6e85ebd51804e3562cf23b", + "" + ], + "prep_shares": [ + [ + "bd68d28c9fff9a30f84122278759025501b83270bf27b41d", + "1765825e8af6db91d9cd885d07158396d460d17297043e1e" + ], + [ + "7c9659b7c681b4a4", + "8569a648387e4b5b" + ] + ], + "public_share": "b2c16aa5676c3188a74dff403c179dfb28c515d5a9892f38e5eb7be1c96bfdb0ebf761e6500e206a4ce363d09ab0a1d9b225e51798fd599f9dcd204058958c2d625646e5662534ff6650a9af834a248d46304f6d7a3b845f46c71433c833a86846147e264aaee1eb3e0bb19e53cd521e92ab9991265b731bfdb508fb164cd9d48d2c43953e7144a8b97e395bdd8aa2db7a1088f3bf8d245e15172e88764bba8271f6a19f70dc47a279e899394ea8658958", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + } + ], + "shares": 2, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_1.json new file mode 100644 index 0000000000..a566fe8b4d --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_1.json @@ -0,0 +1,64 @@ +{ + "agg_param": [ + 1, + [ + 0, + 1, + 2, + 3 + ] + ], + "agg_result": [ + 0, + 0, + 0, + 1 + ], + "agg_shares": [ + "d83fbcbf13566502f5849058b8b089e568a4e8aab8565425f69a56f809fc4527", + "29c04340eba99afd0c7b6fa7464f761a995b175546a9abda0c65a907f503bad8" + ], + "bits": 4, + "prep": [ + { + "input_shares": [ + "000102030405060708090a0b0c0d0e0f202122232425262728292a2b2c2d2e2f311e448ab125690bc3a084a34301982c6aa325ab3f268338c9f4db9bda518743ee75c6c8ef7655d2d167d5385213d3bd1be920fad83c1b35fd9239efb406370db4d0a8e97eb41413957e264ded24e074ca433b6b5d451d0f65ec1d4ac246a36da12ecb2a537a449aeeb9d70bd064e930", + "101112131415161718191a1b1c1d1e1f303132333435363738393a3b3c3d3e3f96998a1415c9876c2887f2dcc52f090ec64bcaeb3165e98f47d261a0bed49a156c054b063cad6277a526505ac31807288abfb796b6cee614e5c41ab75c2b9912cc246c66c5e248a7cfc30ad92eefec7e67c2da39726d5b7277eb1449c779f834b0ab75f383f07bd2f3747cb7b98f617c" + ], + "measurement": 13, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "d83fbcbf13566502", + "f5849058b8b089e5", + "68a4e8aab8565425", + "f69a56f809fc4527" + ], + [ + "29c04340eba99afd", + "0c7b6fa7464f761a", + "995b175546a9abda", + "0c65a907f503bad8" + ] + ], + "prep_messages": [ + "d45c0eabcc906acfb8239f3d0ef2b69a0f465979b04e355c", + "" + ], + "prep_shares": [ + [ + "5d1b91841835491251436306076eaaa674d4b95b84b2a084", + "77417d26b45b21bd68e03b3706840cf49c719f1d2b9c94d7" + ], + [ + "6e703a28b5960604", + "938fc5d74969f9fb" + ] + ], + "public_share": "b2c16aa5676c3188a74dff403c179dfb28c515d5a9892f38e5eb7be1c96bfdb0ebf761e6500e206a4ce363d09ab0a1d9b225e51798fd599f9dcd204058958c2d625646e5662534ff6650a9af834a248d46304f6d7a3b845f46c71433c833a86846147e264aaee1eb3e0bb19e53cd521e92ab9991265b731bfdb508fb164cd9d48d2c43953e7144a8b97e395bdd8aa2db7a1088f3bf8d245e15172e88764bba8271f6a19f70dc47a279e899394ea8658958", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + } + ], + "shares": 2, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_2.json b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_2.json new file mode 100644 index 0000000000..8141bc942e --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_2.json @@ -0,0 +1,64 @@ +{ + "agg_param": [ + 2, + [ + 0, + 2, + 4, + 6 + ] + ], + "agg_result": [ + 0, + 0, + 0, + 1 + ], + "agg_shares": [ + "7ea47022f22f6be9bce8e0ee2eb522bcbc2d246c17704beed7043426b646fe26", + "835b8fdd0cd0941645171f11d04add4345d2db93e78fb4112bfbcbd948b901d9" + ], + "bits": 4, + "prep": [ + { + "input_shares": [ + "000102030405060708090a0b0c0d0e0f202122232425262728292a2b2c2d2e2f311e448ab125690bc3a084a34301982c6aa325ab3f268338c9f4db9bda518743ee75c6c8ef7655d2d167d5385213d3bd1be920fad83c1b35fd9239efb406370db4d0a8e97eb41413957e264ded24e074ca433b6b5d451d0f65ec1d4ac246a36da12ecb2a537a449aeeb9d70bd064e930", + "101112131415161718191a1b1c1d1e1f303132333435363738393a3b3c3d3e3f96998a1415c9876c2887f2dcc52f090ec64bcaeb3165e98f47d261a0bed49a156c054b063cad6277a526505ac31807288abfb796b6cee614e5c41ab75c2b9912cc246c66c5e248a7cfc30ad92eefec7e67c2da39726d5b7277eb1449c779f834b0ab75f383f07bd2f3747cb7b98f617c" + ], + "measurement": 13, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "7ea47022f22f6be9", + "bce8e0ee2eb522bc", + "bc2d246c17704bee", + "d7043426b646fe26" + ], + [ + "835b8fdd0cd09416", + "45171f11d04add43", + "45d2db93e78fb411", + "2bfbcbd948b901d9" + ] + ], + "prep_messages": [ + "6fb240ce8b8a2a8ce62112240f676105e0398515599f04b4", + "" + ], + "prep_shares": [ + [ + "ca0f02c7c61655263bf76d954b8abd16eb6e5ce2b26911b2", + "a5a23e07c573d565ac2aa48ec2dca3eef5ca2833a635f301" + ], + [ + "f5171a3cc9d49422", + "0ce8e5c3352b6bdd" + ] + ], + "public_share": "b2c16aa5676c3188a74dff403c179dfb28c515d5a9892f38e5eb7be1c96bfdb0ebf761e6500e206a4ce363d09ab0a1d9b225e51798fd599f9dcd204058958c2d625646e5662534ff6650a9af834a248d46304f6d7a3b845f46c71433c833a86846147e264aaee1eb3e0bb19e53cd521e92ab9991265b731bfdb508fb164cd9d48d2c43953e7144a8b97e395bdd8aa2db7a1088f3bf8d245e15172e88764bba8271f6a19f70dc47a279e899394ea8658958", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + } + ], + "shares": 2, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_3.json b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_3.json new file mode 100644 index 0000000000..1741ec0ebc --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Poplar1_3.json @@ -0,0 +1,76 @@ +{ + "agg_param": [ + 3, + [ + 1, + 3, + 5, + 7, + 9, + 13, + 15 + ] + ], + "agg_result": [ + 0, + 0, + 0, + 0, + 0, + 1, + 0 + ], + "agg_shares": [ + "ec2be80f01fd1ded599b1a18d6ef112c400f421cca2c080d4ccc5cdd09562b3e556c1aaabe9dd47e8bc25979394c7bb5c61fd1db34b8dfdcc3eff4a5304fb7706b5462025bb400e644f2e0752f38098702491691494a2b498176ef41c4e6a962f716473c53087a3e80db0b9acb50cb15081b5ea4b50c48093f67a8c75875422dfd64ab2fa71fa3f3b55ec708ba4086672aff514d0cffe6f1c07f117c22af9b2c67b2a0c7ec1366ce474721174edb8b9eb33faef5f9c9d0c956e4407a86473120cfa46e8c634c1bc66c63a2009911f82c8426a45013e637aaba0e471b03f0a67a", + "01d417f0fe02e212a664e5e72910eed3bff0bde335d3f7f2b333a322f6a9d4419893e55541622b81743da686c6b3844a39e02e24cb4720233c100b5acfb0480f82ab9dfda44bff19bb0d1f8ad0c7f678fdb6e96eb6b5d4b67e8910be3b19561df6e8b8c3acf785c17f24f46534af34eaf7e4a15b4af3b7f6c0985738a78abd52f09a54d058e05c0c4aa138f745bf7998d500aeb2f300190e3f80ee83dd506453874d5f3813ec9931b8b8dee8b12474614cc0510a06362f36a91bbf8579b8ce5f1e5b91739cb3e439939c5dff66ee07d37bd95bafec19c85545f1b8e4fc0f5905" + ], + "bits": 4, + "prep": [ + { + "input_shares": [ + "000102030405060708090a0b0c0d0e0f202122232425262728292a2b2c2d2e2f311e448ab125690bc3a084a34301982c6aa325ab3f268338c9f4db9bda518743ee75c6c8ef7655d2d167d5385213d3bd1be920fad83c1b35fd9239efb406370db4d0a8e97eb41413957e264ded24e074ca433b6b5d451d0f65ec1d4ac246a36da12ecb2a537a449aeeb9d70bd064e930", + "101112131415161718191a1b1c1d1e1f303132333435363738393a3b3c3d3e3f96998a1415c9876c2887f2dcc52f090ec64bcaeb3165e98f47d261a0bed49a156c054b063cad6277a526505ac31807288abfb796b6cee614e5c41ab75c2b9912cc246c66c5e248a7cfc30ad92eefec7e67c2da39726d5b7277eb1449c779f834b0ab75f383f07bd2f3747cb7b98f617c" + ], + "measurement": 13, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "ec2be80f01fd1ded599b1a18d6ef112c400f421cca2c080d4ccc5cdd09562b3e", + "556c1aaabe9dd47e8bc25979394c7bb5c61fd1db34b8dfdcc3eff4a5304fb770", + "6b5462025bb400e644f2e0752f38098702491691494a2b498176ef41c4e6a962", + "f716473c53087a3e80db0b9acb50cb15081b5ea4b50c48093f67a8c75875422d", + "fd64ab2fa71fa3f3b55ec708ba4086672aff514d0cffe6f1c07f117c22af9b2c", + "67b2a0c7ec1366ce474721174edb8b9eb33faef5f9c9d0c956e4407a86473120", + "cfa46e8c634c1bc66c63a2009911f82c8426a45013e637aaba0e471b03f0a67a" + ], + [ + "01d417f0fe02e212a664e5e72910eed3bff0bde335d3f7f2b333a322f6a9d441", + "9893e55541622b81743da686c6b3844a39e02e24cb4720233c100b5acfb0480f", + "82ab9dfda44bff19bb0d1f8ad0c7f678fdb6e96eb6b5d4b67e8910be3b19561d", + "f6e8b8c3acf785c17f24f46534af34eaf7e4a15b4af3b7f6c0985738a78abd52", + "f09a54d058e05c0c4aa138f745bf7998d500aeb2f300190e3f80ee83dd506453", + "874d5f3813ec9931b8b8dee8b12474614cc0510a06362f36a91bbf8579b8ce5f", + "1e5b91739cb3e439939c5dff66ee07d37bd95bafec19c85545f1b8e4fc0f5905" + ] + ], + "prep_messages": [ + "4a2b97cf17e54b126a86c6791c50d6507ee8b74b3d9903bcf3881121bc6e0975c4efb2d8b8a132b8a6caa4eb39ac2bbb5bdc351604fa9e78d1a6f5a5f615bb0c8819f485d8b24a4e48da47d3b7458a9cfde1e85c66453319a3f6d43dc40a0135", + "" + ], + "prep_shares": [ + [ + "4e64e5ed76c69ef68d3e144918a719986e40ab82f34bd30298b0085a3265d16988b8f646731ef47cb2fb1598e4cb817747623f1cc70ee7843ce1a9d6e3cf5c456801c9a3ae0c7c7663349a3daaf8fb51d165085c751e5bdd4e800df9e1e0193e", + "fcc6b1e1a01ead1bdc47b23004a9bcb80fa80cc9494d30b95bd808c78909380b2937bc9145833e3bf4ce8e5355e0a943147af6f93cebb7f394c54bcf12465e470d182be229a6ced7e4a5ad950d4d8e4a2c7ce000f126d83b5476c744e229e776" + ], + [ + "003c39f76240f6f9bcc6065a247b4432a651d5d72a35aff45928eec28c8a9d07", + "edc3c6089dbf09064339f9a5db84bbcd59ae2a28d5ca500ba6d7113d73756278" + ] + ], + "public_share": "b2c16aa5676c3188a74dff403c179dfb28c515d5a9892f38e5eb7be1c96bfdb0ebf761e6500e206a4ce363d09ab0a1d9b225e51798fd599f9dcd204058958c2d625646e5662534ff6650a9af834a248d46304f6d7a3b845f46c71433c833a86846147e264aaee1eb3e0bb19e53cd521e92ab9991265b731bfdb508fb164cd9d48d2c43953e7144a8b97e395bdd8aa2db7a1088f3bf8d245e15172e88764bba8271f6a19f70dc47a279e899394ea8658958", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + } + ], + "shares": 2, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_0.json new file mode 100644 index 0000000000..c27ad93435 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_0.json @@ -0,0 +1,39 @@ +{ + "agg_param": null, + "agg_result": 1, + "agg_shares": [ + "afead111dacc0c7e", + "53152eee2433f381" + ], + "prep": [ + { + "input_shares": [ + "afead111dacc0c7ec08c411babd6e2404df512ddfa0a81736b7607f4ccb3f39e414fdb4bc89a63569702c92aed6a6a96", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f" + ], + "measurement": 1, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "afead111dacc0c7e" + ], + [ + "53152eee2433f381" + ] + ], + "prep_messages": [ + "" + ], + "prep_shares": [ + [ + "123f23c117b7ed6099be9e6a31a42a9caa60882a3b4aa50303f8b588c9efe60b", + "efc0dc3ee748129f2da661f47a625a57d64a5b62ab38647c34bb161c7576d721" + ] + ], + "public_share": "", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f" + } + ], + "shares": 2, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_1.json new file mode 100644 index 0000000000..148fe6df58 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Count_1.json @@ -0,0 +1,45 @@ +{ + "agg_param": null, + "agg_result": 1, + "agg_shares": [ + "c5647e016eea69f6", + "53152eee2433f381", + "eb8553106be2a287" + ], + "prep": [ + { + "input_shares": [ + "c5647e016eea69f6d10e90d05e2ad8b402b8580f394a719b371ae8f1a364b280d08ca7177946a1a0b9643e2469b0a2e9", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f", + "202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f" + ], + "measurement": 1, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "c5647e016eea69f6" + ], + [ + "53152eee2433f381" + ], + [ + "eb8553106be2a287" + ] + ], + "prep_messages": [ + "" + ], + "prep_shares": [ + [ + "5c8d00fd24e449d375581d6adbeaf9cf4bdface6d368fd7b1562e5bf47b9fa68", + "efc0dc3ee748129f2da661f47a625a57d64a5b62ab38647c34bb161c7576d721", + "b7b122c4f1d2a38df764c623c266f02f7b5178c3d64735ec06037585d643f528" + ] + ], + "public_share": "", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + } + ], + "shares": 3, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_0.json new file mode 100644 index 0000000000..099f786669 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_0.json @@ -0,0 +1,52 @@ +{ + "agg_param": null, + "agg_result": [ + 0, + 0, + 1, + 0 + ], + "agg_shares": [ + "14be9c4ef7a6e12e963fdeac21cebdd4d36e13f4bc25306322e56303c62c90afd73f6b4aa9fdf33cb0afb55426d645ff8cd7e78cebf9d4f1087f6d4a033c8eae", + "ed4163b108591ed14dc02153de31422b2e91ec0b43dacf9cc11a9cfc39d36f502bc094b556020cc333504aabd929ba007528187314062b0edb8092b5fcc37151" + ], + "chunk_length": 2, + "length": 4, + "prep": [ + { + "input_shares": [ + "14be9c4ef7a6e12e963fdeac21cebdd4d36e13f4bc25306322e56303c62c90afd73f6b4aa9fdf33cb0afb55426d645ff8cd7e78cebf9d4f1087f6d4a033c8eaeec786d3b212d968c939de66318dbacafe73c1f5aa3e9078ba2f63ec5179e6b4694612c36f5d4d539d46dab1ac20e43963978d9dd36f19f31c83e58c903c2cd94215c68b15f5d6071e9e19fa973829dc71b536351b0db1072e77b7570e3e06c65fac248d21dd970f29640050e901d06775f05a897850cab5707ac25543ed6ce7061b9cd70c783e0483727236d0cbb05dafefd78ec4e6419efe93d6f82cdadbfd4e860661238040229f60205bbba983790303132333435363738393a3b3c3d3e3f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f" + ], + "measurement": 2, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "14be9c4ef7a6e12e963fdeac21cebdd4", + "d36e13f4bc25306322e56303c62c90af", + "d73f6b4aa9fdf33cb0afb55426d645ff", + "8cd7e78cebf9d4f1087f6d4a033c8eae" + ], + [ + "ed4163b108591ed14dc02153de31422b", + "2e91ec0b43dacf9cc11a9cfc39d36f50", + "2bc094b556020cc333504aabd929ba00", + "7528187314062b0edb8092b5fcc37151" + ] + ], + "prep_messages": [ + "7556ccbddbd14d509ee89124d31d1feb" + ], + "prep_shares": [ + [ + "806b1f8537500ce0b4b501b0ae5ed8f82679ba11ad995d6f605b000e32c41afb6d0070287fe7b99b8304d264cba1e3c6f4456e1c06f3b9d3d4947b2041c86b020d26c74d7663817e6a91960489806931b304fcd3755b43b96c806d2bbeb0166bbec7c61c35f886f3f539890522388f43", + "8194e07ac8aff31f2f4afe4f51a12707b692a56a1745315a1022b4eb257b2a8725c610416af7b0d1a296f409cdb3fbf4f4c0d488206d794254e4755fd124cdc9a67364ddc7865afe3554de5f52f1ac910f3f8e110cfbad4113861316dc73ec60de4f6c512adaa41de631eda8d6d8c189" + ] + ], + "public_share": "bec7c61c35f886f3f539890522388f43de4f6c512adaa41de631eda8d6d8c189", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + } + ], + "shares": 2, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_1.json new file mode 100644 index 0000000000..0b9a9b4d5d --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Histogram_1.json @@ -0,0 +1,89 @@ +{ + "agg_param": null, + "agg_result": [ + 0, + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + "agg_shares": [ + "a6d2b4756e63a10bf5f650e258c73b0ccb20bace98e225dea29d625d527fdd4ded86beb9a0a0ac5c4216f5add7a297cece34b479a568a327f1e259839f813df97b34de254be5b9b9c8d9e56dbff50b7a6bf1e5967686755a1dc42e0ab170add8c88f8ca68f945e768a5007c775fd27cfecb4495e257a2f2f94ca48830aa16ec0decaeee645e295c5dc2ebe491aae1a7f17b2807fcb33ee08127db466067bf84ec613dac9c93adbe73dd262c1859b2865", + "ed4163b108591ed14dc02153de31422b2e91ec0b43dacf9cc11a9cfc39d36f502bc094b556020cc333504aabd929ba007528187314062b0edb8092b5fcc3715126e16ce274ad58caaa14d22608269a4c41a256d3c9e847c0a6ac1a4fbaf6309e9ccbe74a9442ca956d843d6bd5adf9797a84557597d9cc81ddfa281ae5048d686bdb289ec2f3c96cdfa79b6974e6d15aec047748636d4358226283e11a78e045f59db2dda566162a56c85936ac0f4696", + "6eebe7d888434023a1488dcac80682c8084e592524430a857f4701a673adb261eab8ac90085d47e06d99c0a64e33ae30bfa23313469131cafb9b13c763ba50b560eab4f73f6ded7b7011486b38e45939566cc395bf9042e5038fb6a6949821899ea48b0edc28d7f3cf2abbcdb454deb69cc6602c43ac034f563a8e62105a04d7b859e87af729a0cd2729a64c716b1326fe480838d15ece9eaf20c8b7de0c276b464e7358905e0eee4f654308ce549104" + ], + "chunk_length": 3, + "length": 11, + "prep": [ + { + "input_shares": [ + "a6d2b4756e63a10bf5f650e258c73b0ccb20bace98e225dea29d625d527fdd4ded86beb9a0a0ac5c4216f5add7a297cece34b479a568a327f1e259839f813df97b34de254be5b9b9c8d9e56dbff50b7a6bf1e5967686755a1dc42e0ab170add8c88f8ca68f945e768a5007c775fd27cfecb4495e257a2f2f94ca48830aa16ec0decaeee645e295c5dc2ebe491aae1a7f17b2807fcb33ee08127db466067bf84ec613dac9c93adbe73dd262c1859b2865508d344dda6c4339e650c401324c31481780ef7e7dcc07120ac004c05ab75ee5d22e2d0eb229dcdd3755fab49a1c2916e17c8ed2d975cfe76d576569bf05233c07f94417fccaf73d1cc33e17dae74650badffdd639a9b9f9e89de4b9fd13e258b90fbb2b3817b607dc14e6e5327746ca20d1f1918bce9714b135ffe01eb4e6aefab92b0462f7e676e26007e8c2e5a66e16f32f7c8457a6dfba39d9082f640006d560b4d64e86e2e2358c84e03b857c980f51b1a78b53f7cb44343ed184d8dc87ebf8698609eeefae5d8882224ebd28b9531015badea8ae9fe01c7495cafecdc4f13389ea4eb0bbce0a5ab85aa6fc06aabd96d28c84ecf039bfeb4c350049485f8a4c706a109164ff4c640edaedd0ad50820b1d1ed7ab08fc69c48b39aff1eebc02ef1ea40bd70784bfa50511c3dd64b107f4297842280c3cff8d94be202a0e2cb0090f3adb2189f445fcf291f452f162606162636465666768696a6b6c6d6e6f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f", + "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f" + ], + "measurement": 2, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "a6d2b4756e63a10bf5f650e258c73b0c", + "cb20bace98e225dea29d625d527fdd4d", + "ed86beb9a0a0ac5c4216f5add7a297ce", + "ce34b479a568a327f1e259839f813df9", + "7b34de254be5b9b9c8d9e56dbff50b7a", + "6bf1e5967686755a1dc42e0ab170add8", + "c88f8ca68f945e768a5007c775fd27cf", + "ecb4495e257a2f2f94ca48830aa16ec0", + "decaeee645e295c5dc2ebe491aae1a7f", + "17b2807fcb33ee08127db466067bf84e", + "c613dac9c93adbe73dd262c1859b2865" + ], + [ + "ed4163b108591ed14dc02153de31422b", + "2e91ec0b43dacf9cc11a9cfc39d36f50", + "2bc094b556020cc333504aabd929ba00", + "7528187314062b0edb8092b5fcc37151", + "26e16ce274ad58caaa14d22608269a4c", + "41a256d3c9e847c0a6ac1a4fbaf6309e", + "9ccbe74a9442ca956d843d6bd5adf979", + "7a84557597d9cc81ddfa281ae5048d68", + "6bdb289ec2f3c96cdfa79b6974e6d15a", + "ec047748636d4358226283e11a78e045", + "f59db2dda566162a56c85936ac0f4696" + ], + [ + "6eebe7d888434023a1488dcac80682c8", + "084e592524430a857f4701a673adb261", + "eab8ac90085d47e06d99c0a64e33ae30", + "bfa23313469131cafb9b13c763ba50b5", + "60eab4f73f6ded7b7011486b38e45939", + "566cc395bf9042e5038fb6a694982189", + "9ea48b0edc28d7f3cf2abbcdb454deb6", + "9cc6602c43ac034f563a8e62105a04d7", + "b859e87af729a0cd2729a64c716b1326", + "fe480838d15ece9eaf20c8b7de0c276b", + "464e7358905e0eee4f654308ce549104" + ] + ], + "prep_messages": [ + "4b7dc5c1b2a08aec5dcfc13de800559b" + ], + "prep_shares": [ + [ + "e80c098526d9321dd0801f97a648722016fa117f10cb2b062fc5fb1e55705894007f838333ef348c6306e141369bd88d123c66d2faeb132e330a73882c38765d425847bd86e5f784b3348ee4840c5df103b49f04c4dcca4667abb956187da58c91c946d9d5fdf496d95428f8a625dddfc8b7bb469397ebd4b177f902896febdaac39a8d9ec0aa1a24132036c2430929c", + "4f6d4137ddffd58b243b8f845a1684550b240ea3e91a68335f717b83056e9b45c5e62d7a24da54147fcb9260d023cb7f9c8d036f0100f5fea0ce22f49e3d7672bc83fd5c724f2684f3442e8c5291c41509151808d1da447cddc3fe11cf5cd8d7fe662cf035eff88b583f6b32499b332aa6dee37947ef482e15fcb3a7f04b20813162d162b9bf30eee4953b6fdabd10f1", + "ca85b543fc26f756ef4351e4fea0098a73787eeab8b613029af310882b7e87d28fc5502fd76bd704626d9f0f662e531feaf1fa912cc209de6541401d8508a4788d92e549f58241334cfa29abc1fb80a28ae61d7a4060d9582e3e20f182e4519ab1bb9547c545aafc21416e779856c80cf3690155119111aebf3800757989229e4966453f7aa269163b272848de80227f" + ] + ], + "public_share": "ac39a8d9ec0aa1a24132036c2430929c3162d162b9bf30eee4953b6fdabd10f14966453f7aa269163b272848de80227f", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f" + } + ], + "shares": 3, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_0.json new file mode 100644 index 0000000000..a7178fcfee --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_0.json @@ -0,0 +1,194 @@ +{ + "agg_param": null, + "agg_result": [ + 256, + 257, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265 + ], + "agg_shares": [ + "cdeb52d4615d1c718ef21a6560939efcb5024c89ea9b0f0018302087c0e978b5b5c84c8c4f217b14584cc3939963f56a2718c17af81d1f85129347fc73b2548262e047a67c3e0583eeef7c4a98a5da4ccdb558491081a00a8de5a46b8c2f299bc69162a7411b3ecf2b47670db337083dc50b8fc6d5c5d21da52fc7f538166b0e4564edd8fbb75bc8c4fdd0c02be2b6d11bc4a159297b72e2c635a9250feb2445", + "3415ad2b9ea2e38e550de59a9f6c61034dfeb3761564f0ffcbcfdf783f16874a4e38b373b0de84eb8bb33c6c669c0a95dde83e8507e2e07ad16cb8038c4dab7da320b85983c1fa7cf50f83b5675a25b3394ba7b6ef7e5ff5561a5b9473d0d664416f9d58bee4c130b8b898f24cc8f7c243f570392a3a2de23ed0380ac7e994f1c49c12270448a4371f022f3fd41d492eef3c5ea6d6848d1d1dca56daf014dbba" + ], + "bits": 8, + "chunk_length": 9, + "length": 10, + "prep": [ + { + "input_shares": [ + "2451f59efba2edc493c9246ff1f0e0a7f8f6f22ee46e662c899e485d7ce288d6becdfee804a39618972fbaa595eeec25423e412cbe51d44a62747de627630302368ec3535a2545a2799e8a0b9a144c811158dda278865d834b34fbe77ad11dbb9fdcf0637c24e10d5ab36d03cdc5f6b95e400a0a81608d96c25733c376376de3927c3570e8ab671358a1686d0c44ac938368d5621cff66585454ef124daa5f18efd7e791a4bcb11caf74b378e2c4feff3e5bad16e7c3fab987eb4d4a0c675bb4f4e70e1373fb00a5dd30a1118355c20e2e4c3700be3d3c1cf25d3e4a729836ba564aa074f99be0d23d4cc0dc9f263c986988e0d16a3d28c262d34f220b1ed127cddea3e2a1bd075c653d4b6f1c3d35e25d2804e7960250dea42dc4a52c9545bedc182ee8391b4c6849366af8e15f30bd06872e5ed651ef7db0b0c442886de32eeeeacc5f2dfe87f9375b4774153fc9e442105b5f8e452e80874c84131400d4d588a1a5d94bac9e68dbf917ef6405b0bc13fa89daf46f84405aedb166ac93f6545256b1da6ac65e01d580bb26eef82c34b9d728fc0c96ff898ed46bd289abbec9917397552ebf6d1eb3f916f69ee9f80e9466512bff70af2d8f3a9ed599f24e33550a09304e1b4f51948e2d8cbf5a1bb14455b1786ae3af4670111bc3983293ad9ae029128efd86d0a05cb3f442b43f466cec5cc9c4989bf5a29eb5c2401bc8bba0d5b7487bc0bf010c968fe76e3a9924459dce6704528d56540081240ed0d2f301a8c9baca5c183b1b5c3a9c03dce5036926d06e1470c2e63d15fdc3a61056154fca9439c595098ff3794c7d7e62af5e3139b43e22a0f8864c254a069a083604762d77ea000177a7b908efe27f6e00db7ea25f573734af803044fb2dc333ed9ad589d7677e23614143fa6836d68ed311dcedf0ea031688b2cf8c5248f21be444f1c61b050a0ab7dca04992673afb737bd27526a72dda3b03dad4d3b0bb81f0887ee6f25ec4cf35d58ea5f085e97609cfb6a8e97d84fdf8755b8e81ff29614bf1b03bcd2b8d9ab06dc4d60785f83eb6ee4573859223214ecdad734d114e15e1971a8b82222910fd041a1123a4e792a9239f99252de3e3e8d5bb209e2c9bda506a79853c482546940364a8246392fb5e18e85847458445fe3a970b29db6d3d0e4a806cfec7c8538f24896d2d10669113f2b724161d2007ee75c0b651f4934046142b04b2015212997c609625bbeb81b9fa0249c167196557c08ae9ea2defcf7859eed4d35b2183c628cb82fe01255e558e7c8b13bded6ba43ed8b92f6ba7879b39932468260c5768ca0909aae899ad1252c5dbcd741d971f179bc36e88a0a10981f73202cb25db324da405fdd5ca5331431afe362c5f933b3c1216c3e19140cd27f7c2ef67898887856a46a518a3afec78ee0d9dce778289a38d2df906932c40019afadab12fe7d0695316e5a3c1e38aa630a44bc8cc01a5a8cae060b7de435e54963b9354182d64e340ec9dc3e37f8b2bbaaab23608b86827991df4367839f443c160c1eb77f41159f69592c3eb37c21a521afcd34036a13a145e9cb1039704b8e523359ea5c3a50f705118ea7d8b1063eb85bfcdc941e0235579e97856ee6f6bfe9c4d0d161b5662de26a2fbddc530ce918a98514903c63a3476d6ebe68e2503e6bf255691fbd8a006e9c77f5a4ad9e3e8d21a56bc4f7bc90d61ebb31eaa4dce48eb9a8069a584ae35266a4bc4af970860d2e9a0df7b87e8fc8b597e73a85d8eeb91def6057d7a77e8f859ee9ee07ef2fb2260660e59ee16458eafbd7bab979ed9bf72c1c27cadc9011f563aeed9a4b2f09ce5455857bcf3acbe0e1cf15537594469c777a885ad24ac5a8c894a8257c5212fb46184ad7f280bc25600129b25bf941460fcdd2e45a0216f1f2fec84d4792a15f877d15c649991ef998621a50a04251257ed6ccd803fdbc83ce4b5c4d7e8ee4487848768384e0b3970ec899dfb423560e755c4716deb3e188ca780cc7a97e8fb80076e9c44c6b7575725253ae4605844c3748bf90c14f17ee80a42fea450c05eb1f251eb09855f2ab368047d68cd7f4b8898d0113d52117375eed77707e7abff8554dc7da30cca674ac587198c0f165a47a5db79e9cc1c7bda9cdc36b94f14141a307850062f4d7208b8c612691699e3e3c16c3fc2dc6604e0fcb27acd6b5d326d2663a3f819589176c70423bc4d37d8c2a17e97468cf923a5388eb0d1de61358e9651a7e76b033d32d6c84e7ed5831a990b46e8228b6ef120643049645b82e100a7ed6ddd2ebfe2dcbd8b0e7ac1e5ee021d4279f164acc47875ade2c0acff5dbbf3a6eb0e8601632c926780be1660270420aa02c99fb39af1852b09904791e90cfa1f02aec0ab2de111524394819527819e52d495196ab3aff1e323dfec07af91e18b9a04e37552a23b13177bdcfc64ec7108e5e9b3679ccdf6b1e998e2bbcd5fbbbebf5ad8008e727cae6499cc06aa03809947e298683a4340f51d6eecad38d0a7a5437dd6e72bce6543b81fd3a438d71e232845cdb403f1011295f9aee5e33352b86e92343985884284c9646da13545f37b9d6da7d0cf902c19a5ca1f4f1818a2c2644807fcc54be35c29f96fb4fea5efdc88b270f1c5504bd8ba558834786020cc2f03ab5c56eaea38532b9faf6208f57d970b2e5ff92872713c9e0ad07b26e72dca6f9a9c02bad6c9db4d1d738f306292f14415d2856c2b073c5d8faf89e9713ceb375b6eefabc240bf6c6bf39cafb99993767dbaf5ee5f4b3f93e638e904fb55f443312c145b809fd203b5b3a16bd229b952e100bbfc0e49bbd05d54c3e5fa1a44fe55de16cfa52f3b169e0bfe95b1b8b6367f9309adfe3df079104fd720d46d772def3c0534d73615071fa22a79af875f796478d2f599dbb4c1ed303132333435363738393a3b3c3d3e3f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f" + ], + "measurement": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9 + ], + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "9aa31b9c201fb42526a6b32120318aa9", + "e8551983a3deafaafe0f608295f8d291", + "3eed6ed96f607eb1be6e96868876fc78", + "65b295d3525f0ad74886c2fe7b3b1cd6", + "794a6d37d41457d6f04fd418888cf36e", + "483cc86d052be058d0a1e12384ba0d89", + "9c85cb376b5ebfeffb6c22af3bbd02bf", + "9c0385979cecf00983ba97fc12b2235a", + "c7cbf9f2533dc942eca94540b9a0e745", + "0f418bc80d2926f6ec11e3615a4e0c17" + ], + [ + "675ce463dfe04bdabd594cdedfce7556", + "1aaae67c5c215055e5ef9f7d6a072d6e", + "c5129126909f814e2591697977890387", + "9f4d6a2cada0f5289b793d0184c4e329", + "8cb592c82beba829f3af2be777730c91", + "bec33792fad41fa7135e1edc7b45f276", + "6b7a34c894a14010e892dd50c442fd40", + "6cfc7a6863130ff660456803ed4ddca5", + "4234060dacc236bdf755babf465f18ba", + "fbbe7437f2d6d909f7ed1c9ea5b1f3e8" + ] + ], + "prep_messages": [ + "db085315822777376b4d0f962d8f06d9" + ], + "prep_shares": [ + [ + "e6a4fe7264f95c384446bd51db14f78e7f4133afed0604aeb3b87125fc076c7447d795723adfe93d85f9fe2993c52420e45694fd2ec164a54a7267ee5efc8cb40b6659ac81f2e850786218bcec469ec4f7bb28e875d75ee98d54d566186c61c35448a50cb11e195d886622861a78bbb74325b7972e7b4c47f0e2e10a15d7a33c3daecb2dfc507b1b6676c1e9bfc52a4873f408a5788e7b77ce6943e67f3f457280544d93b81b08e427f699ba54adcbb0ffab83366d9b336846c0c989f0bc25bdd14683f1a85e844b9dbac26daae84cc8d57ef6b0c340798ac5ade63150e8d7a9673b64d798a97cf2715f399fd371e342c1ad50e28431f54180ef63ad7dd21f3e5d8d67159cacfd56f5d99c39d53047c8d7bf11ad83a2e3e569e1393b12d87d01701fa71b50b51e092ca6b797bb97890efb6327f1c4e488663dca5f00675c2af7368a9ab95b3c4e9e1a8dd5430d336833", + "1b5b018d9b06a3c79fb942ae24eb087131693625114418115cb3e54a45056eabd0f6e371501957e00796db78ea8f3388eb8345e938f5fbdd8b24c3a968276d7c457ff43ce93631942f823f5bb9c6d1335b8022e804072711cb8fa5fb3afb209e696c9cac47da44cdc3eb3874eb0d8c89692408b463df12bb2d6e8193d5829cce221e486a579b91cd10a0fb38fec7214a9008d574ba32615f6215aef827a2962a31df892814bd8b8f828d029f07f6acf490e7ecd3377f10b86ec2d5741ba37a3a9522b897e840e315a614a89d0bcf8296bdd45e330eaf3f34b3ce4e1dd41306eae92147fc6676eedff2cb239581f46750df341390e066dabb01ef6362694f923d5a65dbc5252a8da8702a979aca3e211af7124485c1c7dc68f6fb1bdfb9a4d0d993cece17c818d8fb1151f1ef1b32745f09c155f42259af588b424c1dafe8d0e90a3dfc6f55bb428773ce15071cb720fb" + ] + ], + "public_share": "368a9ab95b3c4e9e1a8dd5430d3368330a3dfc6f55bb428773ce15071cb720fb", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + }, + { + "input_shares": [ + "2551f59efba2edc493c9246ff1f0e0a7f8f6f22ee46e662c899e485d7ce288d6becdfee804a39618972fbaa595eeec25423e412cbe51d44a62747de627630302368ec3535a2545a2799e8a0b9a144c811158dda278865d834b34fbe77ad11dbb9fdcf0637c24e10d5ab36d03cdc5f6b95e400a0a81608d96c25733c376376de3927c3570e8ab671358a1686d0c44ac938368d5621cff66585454ef124daa5f18efd7e791a4bcb11caf74b378e2c4feff3e5bad16e7c3fab987eb4d4a0c675bb4f4e70e1373fb00a5dd30a1118355c20e2e4c3700be3d3c1cf25d3e4a729836ba564aa074f99be0d23d4cc0dc9f263c986988e0d16a3d28c262d34f220b1ed127cedea3e2a1bd075c653d4b6f1c3d35e25c2804e7960250dea42dc4a52c9545bedc182ee8391b4c6849366af8e15f30bd06872e5ed651ef7db0b0c442886de32eeeeacc5f2dfe87f9375b4774153fc9e442105b5f8e452e80874c84131400d4d588a1a5d94bac9e68dbf917ef6405b0bc13fa89daf46f84405aedb166ac93f6545256b1da6ac65e01d580bb26eef82c34b8d728fc0c96ff898ed46bd289abbec9917397552ebf6d1eb3f916f69ee9f80e9466512bff70af2d8f3a9ed599f24e33550a09304e1b4f51948e2d8cbf5a1bb14455b1786ae3af4670111bc3983293ad9ae029128efd86d0a05cb3f442b43f466cec5cc9c4989bf5a29eb5c2401bc8bba1d5b7487bc0bf010c968fe76e3a9924459dce6704528d56540081240ed0d2f300a8c9baca5c183b1b5c3a9c03dce5036926d06e1470c2e63d15fdc3a61056154fca9439c595098ff3794c7d7e62af5e3139b43e22a0f8864c254a069a083604762d77ea000177a7b908efe27f6e00db7ea25f573734af803044fb2dc333ed9ad589d7677e23614143fa6836d68ed311dcedf0ea031688b2cf8c5248f21be444f0c61b050a0ab7dca04992673afb737bd27526a72dda3b03dad4d3b0bb81f0887ee6f25ec4cf35d58ea5f085e97609cfb6a8e97d84fdf8755b8e81ff29614bf1b03bcd2b8d9ab06dc4d60785f83eb6ee4573859223214ecdad734d114e15e1971b8b82222910fd041a1123a4e792a9239e99252de3e3e8d5bb209e2c9bda506a78853c482546940364a8246392fb5e18e85847458445fe3a970b29db6d3d0e4a806cfec7c8538f24896d2d10669113f2b724161d2007ee75c0b651f4934046142b04b2015212997c609625bbeb81b9fa0249c167196557c08ae9ea2defcf7859eed4d35b2183c628cb82fe01255e558e7b8b13bded6ba43ed8b92f6ba7879b39922468260c5768ca0909aae899ad1252c5dbcd741d971f179bc36e88a0a10981f73202cb25db324da405fdd5ca5331431afe362c5f933b3c1216c3e19140cd27f7c2ef67898887856a46a518a3afec78ee0d9dce778289a38d2df906932c40019bfadab12fe7d0695316e5a3c1e38aa630a44bc8cc01a5a8cae060b7de435e54963b9354182d64e340ec9dc3e37f8b2bb9aab23608b86827991df4367839f443c160c1eb77f41159f69592c3eb37c21a521afcd34036a13a145e9cb1039704b8e523359ea5c3a50f705118ea7d8b1063eb85bfcdc941e0235579e97856ee6f6bfe9c4d0d161b5662de26a2fbddc530ce918a98514903c63a3476d6ebe68e2503e6bf255691fbd8a006e9c77f5a4ad9e3e7d21a56bc4f7bc90d61ebb31eaa4dce48eb9a8069a584ae35266a4bc4af970860d2e9a0df7b87e8fc8b597e73a85d8eeb91def6057d7a77e8f859ee9ee07ef2fb2260660e59ee16458eafbd7bab979ed9bf72c1c27cadc9011f563aeed9a4b2f09ce5455857bcf3acbe0e1cf15537594469c777a885ad24ac5a8c894a8257c5212fb46184ad7f280bc25600129b25bf941460fcdd2e45a0216f1f2fec84d4792a15f877d15c649991ef998621a50a04251257ed6ccd803fdbc83ce4b5c4d7e8ee4487848768384e0b3970ec899dfb423560e755c4716deb3e188ca780cc7a97e8fb80076e9c44c6b7575725253ae4605844c3748bf90c14f17ee80a42fea450c05eb1f251eb09855f2ab368047d68cd7f4b8898d0113d52117375eed77707e7abff8554dc7da30cca674ac587198c0f165a47a5db79e9cc1c7bda9cdc36b94f14141a307850062f4d7208b8c612691699e3e3c16c3fc2dc6604e0fcb27acd6b5d326d2663a3f819589176c70423bc4d42a21c0a30addfe4b4176740f9a418eca631cda9a8b94d20c7f9b834ed87751464dc5e5b446d920312003e4673b48c6c12b407af1ed90002507883f78d166f0b90bc14ed77d4aec6220cdd51948cdad29ab70513aadccd0e3c8c8108d3b0722602d9612aa6feb323a4ff3e8fe0e3d5701467491acdd3c71c34bc019047647779922216ccd61c47958461e3017adf446c4bd2ab7fbf70e41419679f6a9b3fa4c9aa5e9ef8469ace0d88bc35a3374f462573d2ba24b712359ef36e413006a9883bfa4fad43d89c7f1732725e3cad482d17a9499e1fb0f57d1ca93cafa7fd6d654a70cd7318bd7ace30e981217317105bcfe5e33352b86e92343985884284c9646d966beb8aca87d44e15dcce24aeb08312091cd98e6b52e2525409c58438f00131d33ce09fde0343f84db73369954f2d77a3a559189bc4dfd7e7c043b1364b36550595f624483c4eccccb1c4958a9284e43522dcc72ad9b01162d964605eab990dd1ddd25796f55991e1201f22526117662c0cad518f191effe5608b444b9e8973f5a11a8c154bc501bc47bb4fb5832d67f4c5ea5c221e64ff88ff5d5117aadac704a8b94beb036e87fc9a9462c355231b9bbe8a9122f12390073600f2d7f6262f1758eced79619900cad1910286c5bae553c6525c63c52ca97bf452957c5fc1a7d695dcb5ed0cf0004454c528d63960cd303132333435363738393a3b3c3d3e3f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f" + ], + "measurement": [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1 + ], + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "9ba31b9c201fb42526a6b32120318aa9", + "e8551983a3deafaafe0f608295f8d291", + "3ded6ed96f607eb1be6e96868876fc78", + "63b295d3525f0ad74886c2fe7b3b1cd6", + "764a6d37d41457d6f04fd418888cf36e", + "443cc86d052be058d0a1e12384ba0d89", + "9785cb376b5ebfeffb6c22af3bbd02bf", + "960385979cecf00983ba97fc12b2235a", + "c0cbf9f2533dc942eca94540b9a0e745", + "07418bc80d2926f6ec11e3615a4e0c17" + ], + [ + "675ce463dfe04bdabd594cdedfce7556", + "1aaae67c5c215055e5ef9f7d6a072d6e", + "c5129126909f814e2591697977890387", + "9f4d6a2cada0f5289b793d0184c4e329", + "8cb592c82beba829f3af2be777730c91", + "bec33792fad41fa7135e1edc7b45f276", + "6b7a34c894a14010e892dd50c442fd40", + "6cfc7a6863130ff660456803ed4ddca5", + "4234060dacc236bdf755babf465f18ba", + "fbbe7437f2d6d909f7ed1c9ea5b1f3e8" + ] + ], + "prep_messages": [ + "fdecd6d197e824492dd550fcdd0aa3c7" + ], + "prep_shares": [ + [ + "e6a4fe7264f95c384446bd51db14f78eec654b5beb7d60677c0b8385ca51a5bf896c79a069b04f339fdcb675885a26f533ec0d6e805beef6d7683a8ecdfa46cb87fdbc532efa5b1041347ed94f54fbca15041d013729d5eb78dcb1a1c293e0035448a50cb11e195d886622861a78bbb7bccc68d521fac002ece757cc5d82afb3b0f9f8766a1714047b50417a8e7f63eacf222c675bdf1d3e8362806ef5f9c16c446a1e5e06ce539aa300bc68c837c9207c5dbc7d85f896fb1be725e461ceacd303716a2cd8005e370a8ac1062d966b5c1499813d0dc63d697c4fd44dcfe81b3cfc37952de43649650c52ca9f2044fd3dfc42cd08bf0659e6a7facee2468ad1b07903a933ec3cbd06a5f81461e434449c32ef6678f3c8fd250c0b3318e80235144a96f07af2169dfddb503b38a2039f80f6bdfbbec6b3ad1025a43a90249221d20e627a73e2ed92bf1920574f42775675", + "1b5b018d9b06a3c79fb942ae24eb0871aba638b677efec2418d7921020c44ff8d0f6e371501957e00796db78ea8f338813d41fd058418e6056a93bb2785cec1d457ff43ce93631942f823f5bb9c6d133fbfbf9d10fa7922dccd1e570b0246775696c9cac47da44cdc3eb3874eb0d8c89fcdfd722a57116825749c8972129304c221e486a579b91cd10a0fb38fec7214af0fe63f4a24e9e189e180a8aaed3b22931df892814bd8b8f828d029f07f6acf402eb7c22aea07b55e871ebdf6475ae509522b897e840e315a614a89d0bcf8296c0b0d5b745f277689643e3da46c09a1ce92147fc6676eedff2cb239581f46750fa1947870c81f9d565f51ed7e09d86dc5a65dbc5252a8da8702a979aca3e211ae9def20e5f2fa85d16e3ccb0a917510793cece17c818d8fb1151f1ef1b32745f09c155f42259af588b424c1dafe8d0e90a3dfc6f55bb428773ce15071cb720fb" + ] + ], + "public_share": "0e627a73e2ed92bf1920574f427756750a3dfc6f55bb428773ce15071cb720fb", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + }, + { + "input_shares": [ + "2551f59efba2edc493c9246ff1f0e0a7f9f6f22ee46e662c899e485d7ce288d6bfcdfee804a39618972fbaa595eeec25433e412cbe51d44a62747de627630302378ec3535a2545a2799e8a0b9a144c811258dda278865d834b34fbe77ad11dbba0dcf0637c24e10d5ab36d03cdc5f6b95f400a0a81608d96c25733c376376de3927c3570e8ab671358a1686d0c44ac938468d5621cff66585454ef124daa5f18f0d7e791a4bcb11caf74b378e2c4feff3f5bad16e7c3fab987eb4d4a0c675bb4f5e70e1373fb00a5dd30a1118355c20e2f4c3700be3d3c1cf25d3e4a729836ba574aa074f99be0d23d4cc0dc9f263c986a88e0d16a3d28c262d34f220b1ed127cedea3e2a1bd075c653d4b6f1c3d35e25d2804e7960250dea42dc4a52c9545bedd182ee8391b4c6849366af8e15f30bd07872e5ed651ef7db0b0c442886de32eefeacc5f2dfe87f9375b4774153fc9e443105b5f8e452e80874c84131400d4d589a1a5d94bac9e68dbf917ef6405b0bc14fa89daf46f84405aedb166ac93f6545256b1da6ac65e01d580bb26eef82c34b9d728fc0c96ff898ed46bd289abbec9927397552ebf6d1eb3f916f69ee9f80e9566512bff70af2d8f3a9ed599f24e33560a09304e1b4f51948e2d8cbf5a1bb14555b1786ae3af4670111bc3983293ad9be029128efd86d0a05cb3f442b43f466dec5cc9c4989bf5a29eb5c2401bc8bba1d5b7487bc0bf010c968fe76e3a9924469dce6704528d56540081240ed0d2f301a8c9baca5c183b1b5c3a9c03dce5036a26d06e1470c2e63d15fdc3a610561550ca9439c595098ff3794c7d7e62af5e3239b43e22a0f8864c254a069a083604772d77ea000177a7b908efe27f6e00db7fa25f573734af803044fb2dc333ed9ad589d7677e23614143fa6836d68ed311ddedf0ea031688b2cf8c5248f21be444f1c61b050a0ab7dca04992673afb737bd37526a72dda3b03dad4d3b0bb81f0887fe6f25ec4cf35d58ea5f085e97609cfb7a8e97d84fdf8755b8e81ff29614bf1b13bcd2b8d9ab06dc4d60785f83eb6ee4673859223214ecdad734d114e15e1971b8b82222910fd041a1123a4e792a9239f99252de3e3e8d5bb209e2c9bda506a79853c482546940364a8246392fb5e18e95847458445fe3a970b29db6d3d0e4a816cfec7c8538f24896d2d10669113f2b824161d2007ee75c0b651f4934046142c04b2015212997c609625bbeb81b9fa0349c167196557c08ae9ea2defcf7859eed4d35b2183c628cb82fe01255e558e7c8b13bded6ba43ed8b92f6ba7879b39932468260c5768ca0909aae899ad1252c6dbcd741d971f179bc36e88a0a10981f83202cb25db324da405fdd5ca5331431bfe362c5f933b3c1216c3e19140cd27f8c2ef67898887856a46a518a3afec78ef0d9dce778289a38d2df906932c40019bfadab12fe7d0695316e5a3c1e38aa631a44bc8cc01a5a8cae060b7de435e54973b9354182d64e340ec9dc3e37f8b2bbaaab23608b86827991df4367839f443c260c1eb77f41159f69592c3eb37c21a531afcd34036a13a145e9cb1039704b8e623359ea5c3a50f705118ea7d8b1063ec85bfcdc941e0235579e97856ee6f6bfe9c4d0d161b5662de26a2fbddc530ce928a98514903c63a3476d6ebe68e2503e7bf255691fbd8a006e9c77f5a4ad9e3e8d21a56bc4f7bc90d61ebb31eaa4dce49eb9a8069a584ae35266a4bc4af970861d2e9a0df7b87e8fc8b597e73a85d8eec91def6057d7a77e8f859ee9ee07ef2fc2260660e59ee16458eafbd7bab979ed9bf72c1c27cadc9011f563aeed9a4b2f09ce5455857bcf3acbe0e1cf15537594469c777a885ad24ac5a8c894a8257c5212fb46184ad7f280bc25600129b25bf941460fcdd2e45a0216f1f2fec84d4792a15f877d15c649991ef998621a50a04251257ed6ccd803fdbc83ce4b5c4d7e8ee4487848768384e0b3970ec899dfb423560e755c4716deb3e188ca780cc7a97e8fb80076e9c44c6b7575725253ae4605844c3748bf90c14f17ee80a42fea450c05eb1f251eb09855f2ab368047d68cd7f4b8898d0113d52117375eed77707e7abff8554dc7da30cca674ac587198c0f165a47a5db79e9cc1c7bda9cdc36b94f14141a307850062f4d7208b8c612691699e3e3c16c3fc2dc6604e0fcb27acd6b5d326d2663a3f819589176c70423bc4dc97f402ea5053f2aa068d42c371d327339bd8637472eb9be88e409688a53bff392a8891c96c7804998d761fcf34dbf8da8cb17567f75ab4be29ecb6c33c85bf572b645bb7b226e4b99ccc0d959e6d809ec45b70cefe5ada611ea14962c9788d3e15100872669baf3c07a424cc205db97c0e2f72880a40a48d3820f89a16c7ee9dbc44bea18d787e0b730093a7b1fe4f9fd3e50128c61f6d4179fde88b02629629bf9f56f15abe3860ad2c99eaee9eba2310ca898fee5103f7bcb11ec9f4154007cf7d2b51a173331576526665d9f90879a2122b2bdef3a2cff68e57b146ae6d99a4e247d0daa693ef0a07aadc1a4b25ce5e33352b86e92343985884284c9646d0f8ec766552f75092a8b613870386a8b77901f01cddd76b4761e74519b24b851a570b5de8ca954b2c7df0fb314b6fa550e8e49713a28358e399afb3b9199496b229bc55644ee8e4772f1e00dc53886ade4932acee5cfd079707bd1d204c58360f26434fb158b53c1c4a51b65703f123f8090fe42dc48dbd3469a7d4bf1958203adffe46dd39084b66c789517b4438ed9415946ca552d523fa6c71e3302c3552f140d62d41cf3580e5e8500674cbb7d9ddd849d1ddb1d48ef7fd92f363e5e5b6a95b0c67b37e7e5e6a4dec9d8d56e577562eecec955cb6f9925c81cc165634018ab142c519ddd54f358356cee2ba50840303132333435363738393a3b3c3d3e3f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f" + ], + "measurement": [ + 255, + 255, + 255, + 255, + 255, + 255, + 255, + 255, + 255, + 255 + ], + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "99a41b9c201fb42526a6b32120318aa9", + "e6561983a3deafaafe0f608295f8d291", + "3bee6ed96f607eb1be6e96868876fc78", + "61b395d3525f0ad74886c2fe7b3b1cd6", + "744b6d37d41457d6f04fd418888cf36e", + "423dc86d052be058d0a1e12384ba0d89", + "9586cb376b5ebfeffb6c22af3bbd02bf", + "940485979cecf00983ba97fc12b2235a", + "beccf9f2533dc942eca94540b9a0e745", + "05428bc80d2926f6ec11e3615a4e0c17" + ], + [ + "675ce463dfe04bdabd594cdedfce7556", + "1aaae67c5c215055e5ef9f7d6a072d6e", + "c5129126909f814e2591697977890387", + "9f4d6a2cada0f5289b793d0184c4e329", + "8cb592c82beba829f3af2be777730c91", + "bec33792fad41fa7135e1edc7b45f276", + "6b7a34c894a14010e892dd50c442fd40", + "6cfc7a6863130ff660456803ed4ddca5", + "4234060dacc236bdf755babf465f18ba", + "fbbe7437f2d6d909f7ed1c9ea5b1f3e8" + ] + ], + "prep_messages": [ + "190c0ef07d6f2cbd1bed12d71b5f118d" + ], + "prep_shares": [ + [ + "e6a4fe7264f95c384446bd51db14f78ef1ffdfc96013014baff7bd0684b8ff3360f6a2a23b1d7b71796b845eeb21e7d1682128dc8b87837803fec0e3bd6c548d5cc9c041a482cbfc38120743a2f1a0985053eaeccc339c56c2edf76451e22c9a1bca6ab2181850f44d904964d1f227a970e70b59328028ddafdb1649a8c4f1f7cbeb366e42f8603a7b627f7519f25617a33726ede06ab438714b4bd3cda025dc2bcab64974eb02b9e2a23bf0cab4e5ef1e87bb1a72767098c768a20e1090a712ed38c1e8803fd18181cc0069355b40f5f98ff3cbd2b31022df719d9c660e7d5bab239819730ad9165a38641379444093ba148996166ccf5e2826bddb7a7dc8eda4dd5928b12e1ea715538788804b5ce231dfc98ff45027ff8cb1f92007f2339621201a7dc483c83bb6df082105cb5f5d9eda1f847b3d43b4e16502062d5a4a158f651439d7666c683a6283d308c91b25", + "1b5b018d9b06a3c79fb942ae24eb08714d70b996c916bc83063fda0f9ec82780d0f6e371501957e00796db78ea8f338874640cdcd2ca1496ea55d62aa3709498457ff43ce93631942f823f5bb9c6d133b1f04db02cc745245af6a31e4dfe912a696c9cac47da44cdc3eb3874eb0d8c89cc603d2ceec1678996fe311424b56e65221e486a579b91cd10a0fb38fec7214a46c9ce7f890afed4efa31fb5ca8766ae31df892814bd8b8f828d029f07f6acf47620063e9fa98ab947b376e068b9be689522b897e840e315a614a89d0bcf8296332affc2e6b52204d1d7994260e415c6e92147fc6676eedff2cb239581f4675066213f3b302497abbf10c1729131a0125a65dbc5252a8da8702a979aca3e211afb99a0edb76c92f2bd5941c6071db19d93cece17c818d8fb1151f1ef1b32745f09c155f42259af588b424c1dafe8d0e90a3dfc6f55bb428773ce15071cb720fb" + ] + ], + "public_share": "8f651439d7666c683a6283d308c91b250a3dfc6f55bb428773ce15071cb720fb", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + } + ], + "shares": 2, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_1.json new file mode 100644 index 0000000000..af95aac5a0 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3SumVec_1.json @@ -0,0 +1,146 @@ +{ + "agg_param": null, + "agg_result": [ + 45328, + 76286, + 26980 + ], + "agg_shares": [ + "598a4207618eab7bd108c110106e84cd9498b18b79cace4b383274cef0ab280eae74b07f9a7f087a89a4f51f3adb97ed", + "ea61abdf14b8477f6de1b47a18ac778ad0149cb235e666ccce92a9246a2858403e5903013ab179dcf6719d10fccdf589", + "cfc412198ab90c0589158a74d7e503a89b7cb3c1504fcae7dc3ae20ca52b7fb17a9b4c7f2bcf7da947e96ccfc9567288" + ], + "bits": 16, + "chunk_length": 7, + "length": 3, + "prep": [ + { + "input_shares": [ + "db3fe5357f56f6cfe9a0a34ed08e231765b423d5670741c151f7bc28cb15c4c8678df838b6ef52d74c5bb5e8c8c607b2674c024da705b558163e3127ae09f4a5c061dc6129d71755ec77b5d5a2fc088541fa612bf1273d94718e0d28b654ea9524148bf910a5c6d55b596e323d4f55ab62b0851e0985987ea39cf500bfd4d5faed37f519120b574aead76d706e149c54b6e790f7c35553f830f67553ff3f08db0f5d8a2e157f9c1ef6eedcb845a19fffd3f9c6e705b035c95f3ebe51219a91c6ffff33c52b8fd7e90417636bb443c98c3e6933b1599c09151095fff99d7af43d327fcf2afdc50144e5b8f5b1ff1f8cc8d442613dc76dd48100dd4637b38fc5fa239333c6d1e2bec2132a3e3cf5fd11f7e823fbf275071a71535740ac64de88b83ba002bd4490de5c7540e454331d46ebb825725b964ecfcc5c3bb076bf57fd819c0bd788b4495da456594503c8f5b618c4713e957a7589b151ead3a27ae80153b0fbc219f5524eec22dc91aa56ed87a71d2b6e6b857ee03a700fb5af5b23513fb73023c2764f74b40f00e44e2e7acd463e7f38c783a898015b51863659502645e7d7fdf14abf76c605d42e523a540f40e816db420d599cea502baa131d6b61b8aed86a09b9f7587e107c6fc9043d72644153b54234146535e9ec25223d293f4a2a4481b5649a7625116aef735ecbbdd9e99076985e1c424e91c29ae231c01f2b12007c81a9b3c72a84522014f4cdcb949bb9850b62b1c9c5d35f0fc3dbf4f637ae85d3082ffd15890041a455a5dd38efca02eb415afcd2c93a2b6ab46c88847f380affbc3a9af4745718044c9f23978026af66871e610e1beee21360fc5ae4f05dd32689328216db3c512fddacefb2713187f4f5069170db246ae0cb13b3b88f8c832d11ae6b26fe6f0f89b9a9056c28ac6160134ab1af4be606a2ca7d1256cfc223c823ae4921ba11aab9e4f8104885b2d17ac7826ad06b0fd0e54397e67179755b29c6d724f3e9d9a6934619ef0bfee72cdcf0031672bb6e55be53d3ef0fccf7beba7dea636eff7fdecee99f60a0bc5d5cb11e24683ca8085cf218a3607c3f5a051c33717f39834435c5c542f42bd9c9d9103173cb11436e2a626e228415924037ee4a4078f0d9a5bd0fef1dc9415740c4b32f51d199d912907a103c7e25928acadfb4dc1e8d32c34b4f0100bfe95feb850e0d7b359aae9a332913d1f8eca0d1cd93dee2f4b75076536037ca2b6bec5003ae89e03db9813d0dd2165bf50a836ca492f05750a3b5fff4d089fc54fdf225aa112d72ab021b0cef1b8d31c852133fae501406668db0cb4c3589934bde5ec218876b0dc4087b61d706234635d5d9065787a21ed895fb7c05902a554f879dae44d0905292815c5ae2338be1faa9249c7a0530623084f3a7c8bdddf0cf0123597127f9416dcad20f71bb98bf0078d9d20d4ae09c46fc1882cae650f48bebfcc443dfd7963c9944f0ce0a9bda91fe8b8f307fb3da48e573224324c35e867599262fa281ee6bc537928747dd4a1370440bda92e28eb5ba088481e476f19e2a03a7c7619eb79e2184afff7bb585319fa6a58dbfb8862f5d193ccfe0e4aeb58c1633d9e983861d4976615b11514160e5d77ed9a3e2179893c65d9de03813d27aec3485d96098764ee1e7779d47850e6b96ba064b4b913da8390416afb16719a38b725d5b27db351049bdb322cfc93a905a108d07e49764eb5f3f66fe8f3aec7580606162636465666768696a6b6c6d6e6f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f", + "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f" + ], + "measurement": [ + 10000, + 32000, + 9 + ], + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "7e6f6b02cb848e7e3c5840b05acfd699", + "dda1902ed398efc35ebb269aa58e0d5a", + "cd03902ade7fad281b8cfc5f1349ddf9" + ], + [ + "f9758e4a5c3d6d2a1b4b3c7e5d397d83", + "9bb1de90bc4c2244e630e3b6780dc86a", + "15735600bee57d499ed0890554ef5183" + ], + [ + "9b4106b3d83d0457705c83d147f7abe2", + "89299140701aeef79e13f6aee1632a3b", + "298919d5639ad48d0ea3799a98c7d082" + ] + ], + "prep_messages": [ + "429a80b8b2f73ce00d066d0c0e540cfb" + ], + "prep_shares": [ + [ + "574ce7d79da173f0652f7d1e699d5fe756c90210a093e0c893b2b7510374465c35014cdcef24d6044e63ece20db245a982c5a00f16227c6530ade1e5fd4e925f8505da8add8b97fa3216790b11b46ad76bec9b6b3f0c1e9b8284f81f0c83bb690033a6c4a8e962d34b51e27aeaa33a0d41e852a78af96c6dfba1d36364996df4eaaacc993e5633528fb0d106753112ecbfe7fec74ea75d2d1b16267ee1b3d6145e746f19d0840be93d38baa3a5c7d42c6a05311092b4fe05e38aa9e24fb9c99a5aa4bc485752d049703409ccb6656150eeb34048030f19b5353df1e23b376b1c4be2925034d8b7a847690ca7f0c4f338c30b461d9e8c1a1de094fea67efdc4d3da6e96a43cd51241f828f01028e5b0e9", + "21c1bcbdd38fa9e8d86272f8c66c2b16768d00a79945c6014bb55374174299e4e7967f80f8df630af6f5aac52647e36437d08d0aabfd406b0984557a1a69b28f3706033ed34cde0d6bfd39a81a0fca04117cd7ac37f06815b6dc1a356454d669b0e7871238066b280f7c09e0365214235dcd39dfdf6e46ad632d50101029ff490910f79b754d99088059cf1d5da8a791b1a51d8d96fe428bac8f547dc2e9632c4e00b3a6e95d36d03d83969de2989c120aa3c54040bc7f49f910df6c5e4003721a03a3f1d743c457d0cbadbc0bd25ca7f464f404a1f39f852872314ca44ea35a32c150fc25094cd2147739d1a1e03e8f6e636170867687f1c511dbb936fb521b6639ebc145f03b6d1751fcdf14c5108e", + "89f25b6a8ecee226a56d10e9cff57402c3198e7629d8c4152d9319172f0df07727e57ad2e170a0a32cf748f9dc0fe45597d26743a7254aebceb106f0f172a5f2aecf563787471334096bb202a07696b7e0cf22d7dda51e1cf0fcc3d0d5d0a9f4c7ea8999a3f4edb4bd7d3ecb26f2137fe9d443f34248265c25a12290866130c9ccd2dd263bce8b78c14ddbc0c94ed85da0c401ac614c789711c0e2db27ad6ef59aabfdbb6bf174fff2358c4a8d6b7e6cd2a0d41883f3563506755a920aaa7695cf35ffecb0b195d01d1d86ee1a015df5de2f145b61ef3bc5fff1634d911a69540c03444dcb256fa982b24cf064dd6e4dffe71f8d5648cf336acd9485f9d70e24f2ad6b4aa14511e7b660c0866f73672a" + ] + ], + "public_share": "da6e96a43cd51241f828f01028e5b0e96639ebc145f03b6d1751fcdf14c5108ef2ad6b4aa14511e7b660c0866f73672a", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f" + }, + { + "input_shares": [ + "db3fe5357f56f6cfe9a0a34ed08e231766b423d5670741c151f7bc28cb15c4c8688df838b6ef52d74c5bb5e8c8c607b2684c024da705b558163e3127ae09f4a5bf61dc6129d71755ec77b5d5a2fc088541fa612bf1273d94718e0d28b654ea9524148bf910a5c6d55b596e323d4f55ab63b0851e0985987ea39cf500bfd4d5faed37f519120b574aead76d706e149c54b6e790f7c35553f830f67553ff3f08db0e5d8a2e157f9c1ef6eedcb845a19fffd4f9c6e705b035c95f3ebe51219a91c6ffff33c52b8fd7e90417636bb443c98c3d6933b1599c09151095fff99d7af43d337fcf2afdc50144e5b8f5b1ff1f8cc8d442613dc76dd48100dd4637b38fc5fa249333c6d1e2bec2132a3e3cf5fd11f7e923fbf275071a71535740ac64de88b83ca002bd4490de5c7540e454331d46ebb925725b964ecfcc5c3bb076bf57fd819d0bd788b4495da456594503c8f5b618c4713e957a7589b151ead3a27ae80153b0fbc219f5524eec22dc91aa56ed87a71e2b6e6b857ee03a700fb5af5b23513fb63023c2764f74b40f00e44e2e7acd463e7f38c783a898015b51863659502645e7d7fdf14abf76c605d42e523a540f40e816db420d599cea502baa131d6b61b8add86a09b9f7587e107c6fc9043d72644053b54234146535e9ec25223d293f4a2a4481b5649a7625116aef735ecbbdd9e99076985e1c424e91c29ae231c01f2b12007c81a9b3c72a84522014f4cdcb949bb9850b62b1c9c5d35f0fc3dbf4f637af85d3082ffd15890041a455a5dd38efc902eb415afcd2c93a2b6ab46c88847f390affbc3a9af4745718044c9f23978027af66871e610e1beee21360fc5ae4f05ed32689328216db3c512fddacefb2713287f4f5069170db246ae0cb13b3b88f8d832d11ae6b26fe6f0f89b9a9056c28ad6160134ab1af4be606a2ca7d1256cfc223c823ae4921ba11aab9e4f8104885b3d17ac7826ad06b0fd0e54397e67179755b29c6d724f3e9d9a6934619ef0bfee72cdcf0031672bb6e55be53d3ef0fccf7beba7dea636eff7fdecee99f60a0bc5d5cb11e24683ca8085cf218a3607c3f5a051c33717f39834435c5c542f42bd9c9d9103173cb11436e2a626e228415924037ee4a4078f0d9a5bd0fef1dc9415740c4b32f51d199d912907a103c7e25928acadfb4dc1e8d32c34b4f0100bfe95feb850e0d7b359aae9a332913d1f8eca0d1cd93dee2f4b75076536037ca2b6bec5003ae89e03db9813d0dd2165bf50a836ca492f05750a3b5fff4d089fc54fdf225aa112d72ab021b0cef1b8d31c852133fae501406668db0cb4c3589934bde5ec218876b0dc4087b61d706234635d5d9065787a21ed895fb7c05902a554f879dae44d0905292815c5ae2338be1faa924d95673e4d904cc4215bfda1c9c0ca55fa38af6b696ef4f1a452d6870c1d82fa311cee716b00a11313a0ba7c4038e5e195b5e1d3502a44d9a9ab1636394821eaf157e6c77316f6e26b98ca81220752396890cce205581aead60bfcc3916a6c2d6d360c7e845e70c8c68e4acb5eb2877a9a7c7619eb79e2184afff7bb585319fa669b151040f5b15cab2d8c3a50379e9d9e8bf1ac6319bc32e489f64793f882d0e3e1906ac04d47eaec15c20c503d007d09d6a9b032d0f9a8b3d95447fcb1d4b7334b95d873a171f876dcc2a62a62af58e10802f88742027d3d27b9d72fea73dc84906d3dde03299dc3e033651406229da606162636465666768696a6b6c6d6e6f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f", + "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f" + ], + "measurement": [ + 19342, + 19615, + 3061 + ], + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "fc936b02cb848e7e3c5840b05acfd699", + "7c71902ed398efc35ebb269aa58e0d5a", + "b90f902ade7fad281b8cfc5f1349ddf9" + ], + [ + "f9758e4a5c3d6d2a1b4b3c7e5d397d83", + "9bb1de90bc4c2244e630e3b6780dc86a", + "15735600bee57d499ed0890554ef5183" + ], + [ + "9b4106b3d83d0457705c83d147f7abe2", + "89299140701aeef79e13f6aee1632a3b", + "298919d5639ad48d0ea3799a98c7d082" + ] + ], + "prep_messages": [ + "e129a25d4bb45ae8f5ff9d97d72f6d86" + ], + "prep_shares": [ + [ + "574ce7d79da173f0652f7d1e699d5fe7865e23ddbe20c3646ca1d41f8d0d123a831f3b90d5b917779d25ca7aafd5c903b4b996399d29d3af9fcb1c6549c36a639e45759c53b3bf6a1014fbdf7ea4d5d3726af036968d80aed52a7e48dc0a06ee66a6e34bf38ce1083cb56057cfbe98ae9138b635f2b64b60daf151ce214d33b99062b201463033142fc734e4c163fd8a5d2ea4503567c0a9739027883ab0b37ece24259b5ddcc88b3e272bbad93857a52a48abf5eb9fb135ccaf27fc0379f49f203cc0f5a4f14183ef3b935076893f0bc56dcd9b0ffbf8f711039bfd3fe894c0c007d8702db35ded99e2605fa3d9d4b70f6dfcdea9883c839e2a91aa358a38359e2d3cb20272a3ecb52e1edc17a48940", + "21c1bcbdd38fa9e8d86272f8c66c2b167ab67945b007b46536ece1073c4ac681e7967f80f8df630af6f5aac52647e36454b87d9c2d6aa2d4c2186b47a6bef8153706033ed34cde0d6bfd39a81a0fca04618cc495486eccfbcba29099fb101871b0e7871238066b280f7c09e0365214230a741f21625c730fae5e37a7a3cbe0c00910f79b754d99088059cf1d5da8a7911221c38f8a6f79877dd34cb4fd2990744e00b3a6e95d36d03d83969de2989c12e44eb18d466e8954631aa1dbe84e78111a03a3f1d743c457d0cbadbc0bd25ca7de998de032e28d1a8ca88d584ad7a64032c150fc25094cd2147739d1a1e03e8f6e636170867687f1c511dbb936fb521b6639ebc145f03b6d1751fcdf14c5108e", + "89f25b6a8ecee226a56d10e9cff57402a204a4e83b6b78a5aa32317f68443eac27e57ad2e170a0a32cf748f9dc0fe45524ab48c23361b4b327d498fd702f23b6aecf563787471334096bb202a07696b7eb8d79193777673f93547e14b15adb2ac7ea8999a3f4edb4bd7d3ecb26f2137f5e2881b52fd4d3ec538b978ebba54312ccd2dd263bce8b78c14ddbc0c94ed85d38fdf5fb7ee2eeb21ef066fdde8a94f89aabfdbb6bf174fff2358c4a8d6b7e6cdb080672d15b2ae06d6ca64387b56c00cf35ffecb0b195d01d1d86ee1a015df53e771738ec1faca3fcdfe4f033977dd50c03444dcb256fa982b24cf064dd6e4dffe71f8d5648cf336acd9485f9d70e24f2ad6b4aa14511e7b660c0866f73672a" + ] + ], + "public_share": "9e2d3cb20272a3ecb52e1edc17a489406639ebc145f03b6d1751fcdf14c5108ef2ad6b4aa14511e7b660c0866f73672a", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f" + }, + { + "input_shares": [ + "db3fe5357f56f6cfe9a0a34ed08e231766b423d5670741c151f7bc28cb15c4c8678df838b6ef52d74c5bb5e8c8c607b2674c024da705b558163e3127ae09f4a5c061dc6129d71755ec77b5d5a2fc088542fa612bf1273d94718e0d28b654ea9525148bf910a5c6d55b596e323d4f55ab62b0851e0985987ea39cf500bfd4d5faec37f519120b574aead76d706e149c54b6e790f7c35553f830f67553ff3f08db0f5d8a2e157f9c1ef6eedcb845a19fffd4f9c6e705b035c95f3ebe51219a91c6000034c52b8fd7e90417636bb443c98c3e6933b1599c09151095fff99d7af43d327fcf2afdc50144e5b8f5b1ff1f8cc8d442613dc76dd48100dd4637b38fc5fa249333c6d1e2bec2132a3e3cf5fd11f7e923fbf275071a71535740ac64de88b83ca002bd4490de5c7540e454331d46ebb925725b964ecfcc5c3bb076bf57fd819d0bd788b4495da456594503c8f5b618c4713e957a7589b151ead3a27ae80153b1fbc219f5524eec22dc91aa56ed87a71d2b6e6b857ee03a700fb5af5b23513fb63023c2764f74b40f00e44e2e7acd463e7f38c783a898015b51863659502645e6d7fdf14abf76c605d42e523a540f40e716db420d599cea502baa131d6b61b8add86a09b9f7587e107c6fc9043d72644153b54234146535e9ec25223d293f4a2a4481b5649a7625116aef735ecbbdd9e99076985e1c424e91c29ae231c01f2b11007c81a9b3c72a84522014f4cdcb949cb9850b62b1c9c5d35f0fc3dbf4f637af85d3082ffd15890041a455a5dd38efc902eb415afcd2c93a2b6ab46c88847f380affbc3a9af4745718044c9f23978027af66871e610e1beee21360fc5ae4f05ed32689328216db3c512fddacefb2713187f4f5069170db246ae0cb13b3b88f8d832d11ae6b26fe6f0f89b9a9056c28ac6160134ab1af4be606a2ca7d1256cfc323c823ae4921ba11aab9e4f8104885b3d17ac7826ad06b0fd0e54397e67179765b29c6d724f3e9d9a6934619ef0bfee72cdcf0031672bb6e55be53d3ef0fccf8beba7dea636eff7fdecee99f60a0bc5d5cb11e24683ca8085cf218a3607c3f5a051c33717f39834435c5c542f42bd9c9d9103173cb11436e2a626e228415924037ee4a4078f0d9a5bd0fef1dc9415740c4b32f51d199d912907a103c7e25928acadfb4dc1e8d32c34b4f0100bfe95feb850e0d7b359aae9a332913d1f8eca0d1cd93dee2f4b75076536037ca2b6bec5003ae89e03db9813d0dd2165bf50a836ca492f05750a3b5fff4d089fc54fdf225aa112d72ab021b0cef1b8d31c852133fae501406668db0cb4c3589934bde5ec218876b0dc4087b61d706234635d5d9065787a21ed895fb7c05902a554f879dae44d0905292815c5ae2338be1faa924c67468654e6880c69828616e9d011132b27aa457c6dcbd554ec4374ec882357626d6f323bfadc3b23ac6d390e40f8f2008e462458194a1f9d63e0b977f4be907f34aecb0a12ad13b95bfc858c412abab064106faf149e9b25cb557c12a54e6ae957097b8b11f174094d134cc213bf98da7c7619eb79e2184afff7bb585319fa67b935c839af760464b6f3d5402847d07d9cf6c2502ae55f33e08959b38de273b2911fa9ef530cc2cc1a1f3f8224ed7c8efe455f3ad1e462c1d089d4be054801a56ecdd4dca5bbc7191990a1c028d6d79934bf7aed757eccdd68512ebe9f919f087f6020e75fa8e281316ae3a0a50a7f5606162636465666768696a6b6c6d6e6f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f", + "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f" + ], + "measurement": [ + 15986, + 24671, + 23910 + ], + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "e0866b02cb848e7e3c5840b05acfd699", + "3c85902ed398efc35ebb269aa58e0d5a", + "2a61902ade7fad281b8cfc5f1349ddf9" + ], + [ + "f9758e4a5c3d6d2a1b4b3c7e5d397d83", + "9bb1de90bc4c2244e630e3b6780dc86a", + "15735600bee57d499ed0890554ef5183" + ], + [ + "9b4106b3d83d0457705c83d147f7abe2", + "89299140701aeef79e13f6aee1632a3b", + "298919d5639ad48d0ea3799a98c7d082" + ] + ], + "prep_messages": [ + "62b225fa36c2cbc5896a40b1360f0ce0" + ], + "prep_shares": [ + [ + "574ce7d79da173f0652f7d1e699d5fe74186663960507c6ac01fafa7046bc88a5c4146feb2861bcb9f3ad8bb294028812c631ce293cde53b6fc09c6bdbcb839ede24393da7141c26a1a594ccdab1fba7ed2b2e38faf952343a2871b854699ff9533b4f88f9bc7e935cc62e3dea6efff7e2d396d8276281db4e789dc30d3fe01c397b64d685e87e7dccae1b7636b1a1486293f29e5c3433ee211e1f66ec5a7a0f1d79ed261735f53fdb51cbd51c1a6b057bf2845e34bfa091f0f8fd7565242c8b247064dfac759c1c646d085bca6ad0845b9cc2093fd49d493ebf2d48a90b99815f0bd165f56453c6b27dff04d3a161d9fc0bc4177e423791d4064faccfbbd3218c6815a24e38fb752efc839e59043cbb", + "21c1bcbdd38fa9e8d86272f8c66c2b16ce028e4f460022aed5c3e47011081a9ce7967f80f8df630af6f5aac52647e36481e9f91e23071ea7d98af1e76a9dd5823706033ed34cde0d6bfd39a81a0fca048271f4696baf9fd2e8a37e7b5e1be75eb0e7871238066b280f7c09e036521423abafe75c716165e86a573bab2e46efda0910f79b754d99088059cf1d5da8a79146cedb83d4b00d96b9871a66d73147984e00b3a6e95d36d03d83969de2989c12e222e1015cad12a219308643f8e1c1511a03a3f1d743c457d0cbadbc0bd25ca7cb9908430d0604df4468ab54b36f81ef32c150fc25094cd2147739d1a1e03e8f6e636170867687f1c511dbb936fb521b6639ebc145f03b6d1751fcdf14c5108e", + "89f25b6a8ecee226a56d10e9cff5740294125c8939f32731654a11d38906f94a27e57ad2e170a0a32cf748f9dc0fe45503b71c971d945163a09158bc06295e77aecf563787471334096bb202a07696b7f7c98da9fca5e00ff5ce6eb4b9d52274c7ea8999a3f4edb4bd7d3ecb26f2137fbd6dc71ed450947219eef408d88af7dbccd2dd263bce8b78c14ddbc0c94ed85db3255ece317e82a0f3b4a873e9b72b949aabfdbb6bf174fff2358c4a8d6b7e6c545473085c7c9b5c2aa9d290b01670ffcf35ffecb0b195d01d1d86ee1a015df56465c0991dd00dfda6948958c3ff38510c03444dcb256fa982b24cf064dd6e4dffe71f8d5648cf336acd9485f9d70e24f2ad6b4aa14511e7b660c0866f73672a" + ] + ], + "public_share": "8c6815a24e38fb752efc839e59043cbb6639ebc145f03b6d1751fcdf14c5108ef2ad6b4aa14511e7b660c0866f73672a", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f" + } + ], + "shares": 3, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_0.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_0.json new file mode 100644 index 0000000000..4dd3798668 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_0.json @@ -0,0 +1,40 @@ +{ + "agg_param": null, + "agg_result": 100, + "agg_shares": [ + "0467d4fd6a7ec85386c9f3ef0790dc10", + "61992b02958137ac5d360c10f86f23ef" + ], + "bits": 8, + "prep": [ + { + "input_shares": [ + "1ac8a54e6804d575f85957d45c728487722ad069fc1ed413da970ea56ae482d81057f898a367319a14f402d072c24bb71aa13cf4f9cdcd731e779aaa4a5d561ff40e41797cf2d43308849ff9080f846c2e78e2736a1d3cdaa68f3ec890d633cc13f5bf96466d3f02f93612bc827ff53357d52ae00dd234b9758f2cbb7aa9682d06041da1507aa12446de554a945924f445d3715279ef00f4c55ae987cec4bb9f1316efdc8737b7f924d046a6b8ef222b0dc1205ce9b464076fa80d2dfe37af4d836d597ade7d51e18e9c95d13158942d249efd0a1a226759e4bc1d46d3a41bdb227703fe0a7554cf4769935bc99cd1f35b274ecec240816af4915c7abe3e16b7be5ab5e105f9ae7b2e683191c9400cf99ab0c687e4929f87e6e64f712ca02f07a1b29fcebdbfde7655797f9c1b6b3114420d8a19736ae614116782278b7a71f9ef6928ad44ce588644886523d6fbe0b7bbb47248edbaa0b5ce33f74a07005e2a6842eb2c05778e170112f6e6a5f206d7830aa122e29069dcb4a4c064e63c29b3c6e2b22dfb5ab344ca0f1be8e8ce36d26435413de2dc4f53e158ebb8478b4a98de014a688db9470106fd7e73a65c2e656b5a627b5584ca0594ba10cc39c5612bcef576625c37c5249ad5c04e42c66d6a9653c4ec47e2bcd860870bef64f812974654f17f77c08eaa395803d33bdf31db17d76dbb9d2407d7c4f9efbce274542ff6aa0dcf188803eb586108317db430ad517ce7cb0f56d225c835161eb348949ebe253bedc338c6b939ce837561f01d7f0304963eab2a28b38c36bb169a4ee0637635818bd5e4798a8319152a2678b0aa7b837cb0f24df6148ae2c84b78db8892f4415f90f3804e7a29cdcd32a0a8625fd20aca47ee0ef12ebd6138b3534a1b42303132333435363738393a3b3c3d3e3f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f" + ], + "measurement": 100, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "0467d4fd6a7ec85386c9f3ef0790dc10" + ], + [ + "61992b02958137ac5d360c10f86f23ef" + ] + ], + "prep_messages": [ + "0fd2bb14ac123437f6520fdc4a817934" + ], + "prep_shares": [ + [ + "61428b8d7e326827ac832bc4074ad61652efcfdb8d95b6f06b83dd9f5d55ce9f142d1a1fd437eb8c84581ad15dcd9a57417942e63a1a46e6b0ffc8b6d6300f7d", + "a0bd747281cd97d8377cd43bf8b529e9eb5e4b1153111bd6cd06aa3a5493a6da4470f696b9afff52ec10fc00040e4538470fdb8d3e05e188aba2b16e24c71b69" + ] + ], + "public_share": "417942e63a1a46e6b0ffc8b6d6300f7d470fdb8d3e05e188aba2b16e24c71b69", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f" + } + ], + "shares": 2, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_1.json b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_1.json new file mode 100644 index 0000000000..8e7163ae2a --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/Prio3Sum_1.json @@ -0,0 +1,46 @@ +{ + "agg_param": null, + "agg_result": 100, + "agg_shares": [ + "b3916e4086c52aa0439356b05082885a", + "61992b02958137ac5d360c10f86f23ef", + "52d565bde4b89db326369d3fb70d54b6" + ], + "bits": 8, + "prep": [ + { + "input_shares": [ + "ec863a16981eee70b71dc7d396a85ba412dcb4e86769d6db0c60f668456f5d6fbb844d9503580fb7b662bbb2ed7d92002e6ab4a2d31a6f611cbb5c48ca6df69811d536f74a3ff61eb29bd9b9f1b64c35eddd5c4ac97376057883a317f2989b545a682775f948f28f80f366f36b4eb90f931bca79e229eae377102295d9c46da2e239f74f045084747039c0a955726b4258bc0d14da7474bea6cd136eb5e55e9531e6a68703003a64943a5650b16674c82d9c4b526a7ed3d69f8f13ae83609cf056f3fed8d6593fdad7b367d2d248413072651073ea91b8162d42af168698f0f0928c8238b2df218e26d004d2bdb5f9f20d0a43c0286d08cfc26971f282992f82ff14d51cee3e0f3fc7411869c2176cabc6b1a68e33ff5eb217490de9f0d85cb84e9115bb7e208a190d25bf9cc138485892802a50b790ba6f45804de487a3353e54b5471adb5ab612d9ee6416649e136456215503637e0daab367149bc5cdf02a2dabc2790f84cadec1510263fe6aa27df5df395b7a241777a8ed28da27276b48f599dd895a005746cfd1f3c874e6f52407f4c417934d7091685c0b38b1d76b398ad263ec73f4f811aed38febf67a19a001a2c7ab8071f986939713cccd146c7a049c5129783359fcf86410765028fbfbbe62c2474a6b75de0ba49c037e07946deae971207f4f74b8b1d6a7b225eb0b66ed1f3878bc14d9d7a38b2162247b7ed9ac3df6fd2a98a3e4bf2855c8fb13f39487481fe03f5b5cb5123d11aaef180ff8ae69709322459a01a72e9304295ae5721d6eac6dae140677d0dd60f192f0475bacfd131d4ff3393238caa00fe0847c3a43c97a31f84f58b3c7487c5c0a09e85b39ed4b69fcdfa071da15216fd5f1fad125328e40689acce1a6cb113c2a16f599606162636465666768696a6b6c6d6e6f", + "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f", + "303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f" + ], + "measurement": 100, + "nonce": "000102030405060708090a0b0c0d0e0f", + "out_shares": [ + [ + "b3916e4086c52aa0439356b05082885a" + ], + [ + "61992b02958137ac5d360c10f86f23ef" + ], + [ + "52d565bde4b89db326369d3fb70d54b6" + ] + ], + "prep_messages": [ + "e385da3bc2246be76ff12a7093ecb45e" + ], + "prep_shares": [ + [ + "7b6a9ad01449ec86dc6736dced3ecd24b47ab2a3768908b10696d537f2b02c98cf3314686f94ac37c7d81b14fea51f784e037bbdd56b2ee8486757acad61db1e", + "40a478cd7376c1e9ea339ddcf96ab1a7eb5e4b1153111bd6cd06aa3a5493a6da4470f696b9afff52ec10fc00040e4538470fdb8d3e05e188aba2b16e24c71b69", + "46f1ec617740528f1c642c47185681331adc83aace0cc5ddd256cf295c93c64d207d424f5056a8d59748ba9e423c4cf5f560fb4c6505c9a773629e12f21ee230" + ] + ], + "public_share": "4e037bbdd56b2ee8486757acad61db1e470fdb8d3e05e188aba2b16e24c71b69f560fb4c6505c9a773629e12f21ee230", + "rand": "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f" + } + ], + "shares": 3, + "verify_key": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/XofFixedKeyAes128.json b/third_party/rust/prio/src/vdaf/test_vec/07/XofFixedKeyAes128.json new file mode 100644 index 0000000000..ea76c50ff8 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/XofFixedKeyAes128.json @@ -0,0 +1,8 @@ +{ + "binder": "62696e64657220737472696e67", + "derived_seed": "9cb53deb2feda2f9a7f34fde29a833f4", + "dst": "646f6d61696e2073657061726174696f6e20746167", + "expanded_vec_field128": "9cb53deb2feda2f9a7f34fde29a833f44ade288b2f55f2cd257e5f40595b5069543b40b740dfcf8ab5c863924f4510716b625f633a2f7e55a50b24a5fec9155dec199170f9ebe46768e9d120f7e8f62840441ef53dd5d2ba2d3fd39032e2da99498f4abf815b09c667cef08f0882fa945ba3335d2c7407de1b1650a5f4fe52884caf3ef1f5802eabb8f238c4d9d419bed904dcf79b22da49f68fb547c287a9cd4a38d58017eb2031a6bf1b4defb8905c3b777e9287f62a7fb0f97e4d8a26c4e5b909958bc73a6f7512b5b845488d98a7fcedf711ada6972f4c06818d3c3e7a070a88af60dc0323b59f304935fbbbd3792e590c9b6bce7459deba3599c7f30fe64a638219dde4bde4b1a51df8d85c2f36604c44f5f188148e3ba1dca3fd8073240ee577ef322df19a13d9ffa486a6833f4eb2838a58746707b8bf531cc86098f43809276b5f02914b26cb75938ca16eafa73397920a2f5e607af30e62ff60b83e15699d4d0265affe185b307ed330941a41b2b628e44d9a19412f7d9513cacd7b1fd740b7708e3bc764a0cf2146bca7c94d1901c43f509d7dcc9dfec54476789284e53f3760610a0ac5fce205e9b9aa0355c29702a5c9395bf1de8c974c800e1037a6bf5e0bd2af7d96b7f000ff6ab93299966b6832c493b600f2595a3db99353d2f8889019cd3ec5a73fa457f5442ed5edf349e78c9cf0cbf4f65aea03754c381c3efc206b7f17447cc51ac68eceacab9d92b13b0bc700c99a26ce2b6c3271f7639aa72dc27bbd54984907abb10ef1047ef352d378ddae48bf381804c89aa1847f5027537cf6af1b30aa44cead6495e98ca5b3205d39beb49d2db6752a4e57158e8c83464002b0b4a9838bc381c1dbdc3e9a584554fb76671a15f907c0b395a5", + "length": 40, + "seed": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/07/XofShake128.json b/third_party/rust/prio/src/vdaf/test_vec/07/XofShake128.json new file mode 100644 index 0000000000..edafb1bd4d --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/07/XofShake128.json @@ -0,0 +1,8 @@ +{ + "binder": "62696e64657220737472696e67", + "derived_seed": "87c4d0dd654bf8eec8805c68b5eb0182", + "dst": "646f6d61696e2073657061726174696f6e20746167", + "expanded_vec_field128": "87c4d0dd654bf8eec8805c68b5eb0182b1b8ede598cfb8d8b234038fd0492eb14268bbb2ac15a55d463c8227c2d4fae8607631d13157c4935c5d2d56e4b1e2bdfe0f80b286d82e631704acee29ab6f7acaa316d3623cc3371297604caf57bc2eafe72056143971345901b9fb9f95b6a7384c6a88143124ff693ce9e453675f87a6b6338a1e1c9f72d19e32b51f60a1d7469de1fbe25407cc338d896b34c5fc437d2551297027eeefca9aaccdb78d655a6c220cbc2d76cc4a64b04806ae893c952996abb91f6ec32b6de27fe51be59514352d31af4967c0a85c5823ff73be7f15b9c0769321e4b69cb931a4e88f9da1fde1c5df9d84a7eadb41cf25681fc64a84a1c4accded794c1e6fec1fb26a286712425bfc29521273dcfc76cbab9b3c3c2b840ab6a4f9fd73ea434fc1c22a91943ed38fef0136f0f18f680c191978ab77c750d577c3526a327564da05cfc7bb9ef52c140d9e63b1f39761648772eaa61e2efb15890aed8340e6854b428f16dff5654c8a0852d46e817b49bbe91db3c46620adbd009a0d7d40843c1b6b7786833d3c1ae097b4fa35815dbcfca78e00a34f15936ed6d0f5bf50fc25adbecd3adfa55ba6bc7052f0662595cf7a933dfcc3d0ad5d825ec3bc191586a1c36a037d1c9e73c24777825d6afe59774abdb2918c2147a0436b17bafd967e07c46c3d6240c771f4fd4f9b3fff38b294508b8af5a1b71385f90f407620b7aa636fd2b55435b3688fc26ad3c23b2ad48158c4c475c07eb58569a8d1a906452b82d582397c4c69f5e79d3082d03b4dd85b5277a8b44c933d52d168caae8c602376f5487670a172d138364cb975c569c9c2d79506746090ea8102907c91b66764fd8740ca7bd3acb59173df29b5fa7542e51bce67b97c9ee2", + "length": 40, + "seed": "000102030405060708090a0b0c0d0e0f" +} diff --git a/third_party/rust/prio/src/vdaf/test_vec/prio2/fieldpriov2.json b/third_party/rust/prio/src/vdaf/test_vec/prio2/fieldpriov2.json new file mode 100644 index 0000000000..674e682ac1 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/test_vec/prio2/fieldpriov2.json @@ -0,0 +1,28 @@ +{ + "dimension": 10, + "server_1_decrypted_shares": [ + "x+y6F2RY3Y+toaLjU4a0WDqmNPYgHW1w9Z3svKirr+qcM9eDxhWSUPY4/N3A3PYGVFKa+i867MSiouE7Fq3iykBPKPMuNS4T8e1FA2uJ5PJHzPEobZNQWKG6ax5WYEbpmongi1kgK656OkATqMcXHnmkBC8=", + "u7Z0TCmpNSSrhCf27lo8qfmg7bUAAvcAbnntWm2lcoIa7yE8h1Mi4H7YDLv+t/9pzCHhyXC248VnyLWasFEfZ7wwzI2D+3U0XBjWZIXXufRvsfo8ZlUfbRtwYXYYwjxB5FV0jksJvZxaOFuNGP0IxIR1J1M=", + "iUazYtUedl8JNc8oSU+cLagEfJIBvQC7yUUjF5BhKyvDk54SWsCNbCTM8GigTWj5fobgebFqDlMbWaPceeO1F5S+8AwSpH8zSo20sckmYpeia8daPRDQu27e+ijLtxGkig82No/SniO+PPG96/xJ3e7wO74=", + "WUkmvNOpVZVb08GsZZuN7iqg6eUusF3wZY6U0NA8cAvzDiysSByNROP+cKCuy5l4r4JpkvP2n03WouKfyW4CWqQKC9vB+hCxVyWmNilyOTbACM0mxYU9lqvi43C3XEj4gflaOc7eUi8NTdAzvOThI/pmn3w=", + "Tp5mTet94KPUfzgs/XC8eal0iayw513glL8qUBAjAubTY97A/oO3Mevjm2Gn0uLFq4Add0A+O0Vo3l/Ar+etRUV9h4cEkKk5W7AEpY7JND224plOZ+3gsNiSUFm+hm6uOWXkanNNbPaSC3nyA81PF9fffAY=", + "2awOe4Bc1b7xsPvQxA9oTfadU3ueOw1EVbQ4poETuxmjPsvHbZhJXnYURfAvqfKCfz6eOUPhml7qcwWpht913JN9IQ5kpCM0So31xB99FkAuglwzZY7s4SLmWSrxjnvp9sjg8IDgiJAyCNQlIZtgjqq1TIE=", + "hLF/NXH1yZu4sf6m0jNfsZZS/cxVnoS5c1BtRSxkpeBuiXv+qnO3kko7A43CZ0b8C4kp1/Rfc5bnxn4+liYtsmu0ZGoW2AJdIfQTmkFfQlV/bi6Lcto121CDFY89/jrernWu3urzN1jeeNc1RmkpbmdtOrw=", + "xiufG50lmIHTVsvr6zpMXqTxhfVUlfe15eXfnwMkOEFfjMz/njV0HmqkvKtw9P/Z6gbN1Z0B0XnARknYz5OVNQRIaiu/AHTKWd19pnYu25VtFlHibbKhz32zHbfoHAoZeUXYr7x3j1vbJVNnkbwnz5WkNnc=", + "oEv/wE1+O6dVD/LKV0coeGq46zb8oDMFmL6GEEuZnC6REQdiJhHBa9fyWm1O+NtjUGu08r20R4r7f20dOqmYCOIMaNvejMTcXui9WXjrl8YQzQ86hkHSQyaXR0nkWOlCZsEF1kDX35TIGOWlCpjCjCwifPI=", + "OecAaoAKkmBgkmcVoCiq7NQkfXY/FS9M3zKZtE1pfbIwavWcqr9ucgFSs6zsas5aKVF/yJrNQwlbVE41YlClDuND6jD4NGuYKubzfT6I+saUAacdiDGp1fRA4RWcJFSxULG23XBP2b7D3l5wzeauGWDylkw=" + ], + "server_2_decrypted_shares": [ + "Kge/qYK/UnjrS8Q55v2K5Neg6iHORxEONcsaoLtX7wU=", + "rMt+kStRyxF2V7AWLLnWofVQPf6n7wrkhsTaMcqOZdo=", + "iMFcL/WOs264bO1dDOlyGjicF8/u6M5FdrPOFkkHuAE=", + "B1Dp6k0WFB252R4cAl2vtG7XGmvXmbDYM+iBn+U2QFU=", + "MO0aL0ZRKCfKkgaTmlsfn8Y/K6gRZoM9lOmOwBDxgPA=", + "Flz9wOjYAGLPjo+c89ve8fjaAhm+w1LqngA33ro+Q/M=", + "dfD3DWQfzkjubvNucxMRSbaaAz1O+4VLg2najHfPaxA=", + "ZgUbOVCMvhejv4LzOu3oSi+PgQBHUghfwxp4vJ5Ig9A=", + "Cqig9M1LE6iEl1PLB9Qc8cRfiF81TM4EqMCXVPT0SwE=", + "ZUpeNfk32vpkPoSh5ZNauB1pw5QeDJ5CMmvsDV0F6eA=" + ], + "reference_sum": "BgAAAAcAAAAEAAAABQAAAAkAAAAGAAAAAgAAAAYAAAAEAAAABQAAAA==" +}
\ No newline at end of file diff --git a/third_party/rust/prio/src/vdaf/xof.rs b/third_party/rust/prio/src/vdaf/xof.rs new file mode 100644 index 0000000000..b38d176467 --- /dev/null +++ b/third_party/rust/prio/src/vdaf/xof.rs @@ -0,0 +1,574 @@ +// SPDX-License-Identifier: MPL-2.0 + +//! Implementations of XOFs specified in [[draft-irtf-cfrg-vdaf-07]]. +//! +//! [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ + +use crate::{ + field::FieldElement, + prng::Prng, + vdaf::{CodecError, Decode, Encode}, +}; +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +use aes::{ + cipher::{generic_array::GenericArray, BlockEncrypt, KeyInit}, + Block, +}; +#[cfg(feature = "crypto-dependencies")] +use aes::{ + cipher::{KeyIvInit, StreamCipher}, + Aes128, +}; +#[cfg(feature = "crypto-dependencies")] +use ctr::Ctr64BE; +use rand_core::{ + impls::{next_u32_via_fill, next_u64_via_fill}, + RngCore, SeedableRng, +}; +use sha3::{ + digest::{ExtendableOutput, Update, XofReader}, + Shake128, Shake128Core, Shake128Reader, +}; +#[cfg(feature = "crypto-dependencies")] +use std::fmt::Formatter; +use std::{ + fmt::Debug, + io::{Cursor, Read}, +}; +use subtle::{Choice, ConstantTimeEq}; + +/// Input of [`Xof`]. +#[derive(Clone, Debug)] +pub struct Seed<const SEED_SIZE: usize>(pub(crate) [u8; SEED_SIZE]); + +impl<const SEED_SIZE: usize> Seed<SEED_SIZE> { + /// Generate a uniform random seed. + pub fn generate() -> Result<Self, getrandom::Error> { + let mut seed = [0; SEED_SIZE]; + getrandom::getrandom(&mut seed)?; + Ok(Self::from_bytes(seed)) + } + + /// Construct seed from a byte slice. + pub(crate) fn from_bytes(seed: [u8; SEED_SIZE]) -> Self { + Self(seed) + } +} + +impl<const SEED_SIZE: usize> AsRef<[u8; SEED_SIZE]> for Seed<SEED_SIZE> { + fn as_ref(&self) -> &[u8; SEED_SIZE] { + &self.0 + } +} + +impl<const SEED_SIZE: usize> PartialEq for Seed<SEED_SIZE> { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl<const SEED_SIZE: usize> Eq for Seed<SEED_SIZE> {} + +impl<const SEED_SIZE: usize> ConstantTimeEq for Seed<SEED_SIZE> { + fn ct_eq(&self, other: &Self) -> Choice { + self.0.ct_eq(&other.0) + } +} + +impl<const SEED_SIZE: usize> Encode for Seed<SEED_SIZE> { + fn encode(&self, bytes: &mut Vec<u8>) { + bytes.extend_from_slice(&self.0[..]); + } + + fn encoded_len(&self) -> Option<usize> { + Some(SEED_SIZE) + } +} + +impl<const SEED_SIZE: usize> Decode for Seed<SEED_SIZE> { + fn decode(bytes: &mut Cursor<&[u8]>) -> Result<Self, CodecError> { + let mut seed = [0; SEED_SIZE]; + bytes.read_exact(&mut seed)?; + Ok(Seed(seed)) + } +} + +/// Trait for deriving a vector of field elements. +pub trait IntoFieldVec: RngCore + Sized { + /// Generate a finite field vector from the seed stream. + fn into_field_vec<F: FieldElement>(self, length: usize) -> Vec<F>; +} + +impl<S: RngCore> IntoFieldVec for S { + fn into_field_vec<F: FieldElement>(self, length: usize) -> Vec<F> { + Prng::from_seed_stream(self).take(length).collect() + } +} + +/// An extendable output function (XOF) with the interface specified in [[draft-irtf-cfrg-vdaf-07]]. +/// +/// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ +pub trait Xof<const SEED_SIZE: usize>: Clone + Debug { + /// The type of stream produced by this XOF. + type SeedStream: RngCore + Sized; + + /// Construct an instance of [`Xof`] with the given seed. + fn init(seed_bytes: &[u8; SEED_SIZE], dst: &[u8]) -> Self; + + /// Update the XOF state by passing in the next fragment of the info string. The final info + /// string is assembled from the concatenation of sequence of fragments passed to this method. + fn update(&mut self, data: &[u8]); + + /// Finalize the XOF state, producing a seed stream. + fn into_seed_stream(self) -> Self::SeedStream; + + /// Finalize the XOF state, producing a seed. + fn into_seed(self) -> Seed<SEED_SIZE> { + let mut new_seed = [0; SEED_SIZE]; + let mut seed_stream = self.into_seed_stream(); + seed_stream.fill_bytes(&mut new_seed); + Seed(new_seed) + } + + /// Construct a seed stream from the given seed and info string. + fn seed_stream(seed: &Seed<SEED_SIZE>, dst: &[u8], binder: &[u8]) -> Self::SeedStream { + let mut xof = Self::init(seed.as_ref(), dst); + xof.update(binder); + xof.into_seed_stream() + } +} + +/// The key stream produced by AES128 in CTR-mode. +#[cfg(feature = "crypto-dependencies")] +#[cfg_attr(docsrs, doc(cfg(feature = "crypto-dependencies")))] +pub struct SeedStreamAes128(Ctr64BE<Aes128>); + +#[cfg(feature = "crypto-dependencies")] +impl SeedStreamAes128 { + pub(crate) fn new(key: &[u8], iv: &[u8]) -> Self { + SeedStreamAes128(<Ctr64BE<Aes128> as KeyIvInit>::new(key.into(), iv.into())) + } + + fn fill(&mut self, buf: &mut [u8]) { + buf.fill(0); + self.0.apply_keystream(buf); + } +} + +#[cfg(feature = "crypto-dependencies")] +impl RngCore for SeedStreamAes128 { + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.fill(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill(dest); + Ok(()) + } + + fn next_u32(&mut self) -> u32 { + next_u32_via_fill(self) + } + + fn next_u64(&mut self) -> u64 { + next_u64_via_fill(self) + } +} + +#[cfg(feature = "crypto-dependencies")] +impl Debug for SeedStreamAes128 { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + // Ctr64BE<Aes128> does not implement Debug, but [`ctr::CtrCore`][1] does, and we get that + // with [`cipher::StreamCipherCoreWrapper::get_core`][2]. + // + // [1]: https://docs.rs/ctr/latest/ctr/struct.CtrCore.html + // [2]: https://docs.rs/cipher/latest/cipher/struct.StreamCipherCoreWrapper.html + self.0.get_core().fmt(f) + } +} + +/// The XOF based on SHA-3 as specified in [[draft-irtf-cfrg-vdaf-07]]. +/// +/// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ +#[derive(Clone, Debug)] +pub struct XofShake128(Shake128); + +impl Xof<16> for XofShake128 { + type SeedStream = SeedStreamSha3; + + fn init(seed_bytes: &[u8; 16], dst: &[u8]) -> Self { + let mut xof = Self(Shake128::from_core(Shake128Core::default())); + Update::update( + &mut xof.0, + &[dst.len().try_into().expect("dst must be at most 255 bytes")], + ); + Update::update(&mut xof.0, dst); + Update::update(&mut xof.0, seed_bytes); + xof + } + + fn update(&mut self, data: &[u8]) { + Update::update(&mut self.0, data); + } + + fn into_seed_stream(self) -> SeedStreamSha3 { + SeedStreamSha3::new(self.0.finalize_xof()) + } +} + +/// The seed stream produced by SHAKE128. +pub struct SeedStreamSha3(Shake128Reader); + +impl SeedStreamSha3 { + pub(crate) fn new(reader: Shake128Reader) -> Self { + Self(reader) + } +} + +impl RngCore for SeedStreamSha3 { + fn fill_bytes(&mut self, dest: &mut [u8]) { + XofReader::read(&mut self.0, dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + XofReader::read(&mut self.0, dest); + Ok(()) + } + + fn next_u32(&mut self) -> u32 { + next_u32_via_fill(self) + } + + fn next_u64(&mut self) -> u64 { + next_u64_via_fill(self) + } +} + +/// A `rand`-compatible interface to construct XofShake128 seed streams, with the domain separation tag +/// and binder string both fixed as the empty string. +impl SeedableRng for SeedStreamSha3 { + type Seed = [u8; 16]; + + fn from_seed(seed: Self::Seed) -> Self { + XofShake128::init(&seed, b"").into_seed_stream() + } +} + +/// Factory to produce multiple [`XofFixedKeyAes128`] instances with the same fixed key and +/// different seeds. +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "crypto-dependencies", feature = "experimental"))) +)] +pub struct XofFixedKeyAes128Key { + cipher: Aes128, +} + +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +impl XofFixedKeyAes128Key { + /// Derive the fixed key from the domain separation tag and binder string. + pub fn new(dst: &[u8], binder: &[u8]) -> Self { + let mut fixed_key_deriver = Shake128::from_core(Shake128Core::default()); + Update::update( + &mut fixed_key_deriver, + &[dst.len().try_into().expect("dst must be at most 255 bytes")], + ); + Update::update(&mut fixed_key_deriver, dst); + Update::update(&mut fixed_key_deriver, binder); + let mut key = GenericArray::from([0; 16]); + XofReader::read(&mut fixed_key_deriver.finalize_xof(), key.as_mut()); + Self { + cipher: Aes128::new(&key), + } + } + + /// Combine a fixed key with a seed to produce a new stream of bytes. + pub fn with_seed(&self, seed: &[u8; 16]) -> SeedStreamFixedKeyAes128 { + SeedStreamFixedKeyAes128 { + cipher: self.cipher.clone(), + base_block: (*seed).into(), + length_consumed: 0, + } + } +} + +/// XofFixedKeyAes128 as specified in [[draft-irtf-cfrg-vdaf-07]]. This XOF is NOT RECOMMENDED for +/// general use; see Section 9 ("Security Considerations") for details. +/// +/// This XOF combines SHA-3 and a fixed-key mode of operation for AES-128. The key is "fixed" in +/// the sense that it is derived (using SHAKE128) from the domain separation tag and binder +/// strings, and depending on the application, these strings can be hard-coded. The seed is used to +/// construct each block of input passed to a hash function built from AES-128. +/// +/// [draft-irtf-cfrg-vdaf-07]: https://datatracker.ietf.org/doc/draft-irtf-cfrg-vdaf/07/ +#[derive(Clone, Debug)] +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "crypto-dependencies", feature = "experimental"))) +)] +pub struct XofFixedKeyAes128 { + fixed_key_deriver: Shake128, + base_block: Block, +} + +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +impl Xof<16> for XofFixedKeyAes128 { + type SeedStream = SeedStreamFixedKeyAes128; + + fn init(seed_bytes: &[u8; 16], dst: &[u8]) -> Self { + let mut fixed_key_deriver = Shake128::from_core(Shake128Core::default()); + Update::update( + &mut fixed_key_deriver, + &[dst.len().try_into().expect("dst must be at most 255 bytes")], + ); + Update::update(&mut fixed_key_deriver, dst); + Self { + fixed_key_deriver, + base_block: (*seed_bytes).into(), + } + } + + fn update(&mut self, data: &[u8]) { + Update::update(&mut self.fixed_key_deriver, data); + } + + fn into_seed_stream(self) -> SeedStreamFixedKeyAes128 { + let mut fixed_key = GenericArray::from([0; 16]); + XofReader::read( + &mut self.fixed_key_deriver.finalize_xof(), + fixed_key.as_mut(), + ); + SeedStreamFixedKeyAes128 { + base_block: self.base_block, + cipher: Aes128::new(&fixed_key), + length_consumed: 0, + } + } +} + +/// Seed stream for [`XofFixedKeyAes128`]. +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +#[cfg_attr( + docsrs, + doc(cfg(all(feature = "crypto-dependencies", feature = "experimental"))) +)] +pub struct SeedStreamFixedKeyAes128 { + cipher: Aes128, + base_block: Block, + length_consumed: u64, +} + +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +impl SeedStreamFixedKeyAes128 { + fn hash_block(&self, block: &mut Block) { + let sigma = Block::from([ + // hi + block[8], + block[9], + block[10], + block[11], + block[12], + block[13], + block[14], + block[15], + // xor(hi, lo) + block[8] ^ block[0], + block[9] ^ block[1], + block[10] ^ block[2], + block[11] ^ block[3], + block[12] ^ block[4], + block[13] ^ block[5], + block[14] ^ block[6], + block[15] ^ block[7], + ]); + self.cipher.encrypt_block_b2b(&sigma, block); + for (b, s) in block.iter_mut().zip(sigma.iter()) { + *b ^= s; + } + } + + fn fill(&mut self, buf: &mut [u8]) { + let next_length_consumed = self.length_consumed + u64::try_from(buf.len()).unwrap(); + let mut offset = usize::try_from(self.length_consumed % 16).unwrap(); + let mut index = 0; + let mut block = Block::from([0; 16]); + + // NOTE(cjpatton) We might be able to speed this up by unrolling this loop and encrypting + // multiple blocks at the same time via `self.cipher.encrypt_blocks()`. + for block_counter in self.length_consumed / 16..(next_length_consumed + 15) / 16 { + block.clone_from(&self.base_block); + for (b, i) in block.iter_mut().zip(block_counter.to_le_bytes().iter()) { + *b ^= i; + } + self.hash_block(&mut block); + let read = std::cmp::min(16 - offset, buf.len() - index); + buf[index..index + read].copy_from_slice(&block[offset..offset + read]); + offset = 0; + index += read; + } + + self.length_consumed = next_length_consumed; + } +} + +#[cfg(all(feature = "crypto-dependencies", feature = "experimental"))] +impl RngCore for SeedStreamFixedKeyAes128 { + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.fill(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + self.fill(dest); + Ok(()) + } + + fn next_u32(&mut self) -> u32 { + next_u32_via_fill(self) + } + + fn next_u64(&mut self) -> u64 { + next_u64_via_fill(self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{field::Field128, vdaf::equality_comparison_test}; + use serde::{Deserialize, Serialize}; + use std::{convert::TryInto, io::Cursor}; + + #[derive(Deserialize, Serialize)] + struct XofTestVector { + #[serde(with = "hex")] + seed: Vec<u8>, + #[serde(with = "hex")] + dst: Vec<u8>, + #[serde(with = "hex")] + binder: Vec<u8>, + length: usize, + #[serde(with = "hex")] + derived_seed: Vec<u8>, + #[serde(with = "hex")] + expanded_vec_field128: Vec<u8>, + } + + // Test correctness of dervied methods. + fn test_xof<P, const SEED_SIZE: usize>() + where + P: Xof<SEED_SIZE>, + { + let seed = Seed::generate().unwrap(); + let dst = b"algorithm and usage"; + let binder = b"bind to artifact"; + + let mut xof = P::init(seed.as_ref(), dst); + xof.update(binder); + + let mut want = Seed([0; SEED_SIZE]); + xof.clone().into_seed_stream().fill_bytes(&mut want.0[..]); + let got = xof.clone().into_seed(); + assert_eq!(got, want); + + let mut want = [0; 45]; + xof.clone().into_seed_stream().fill_bytes(&mut want); + let mut got = [0; 45]; + P::seed_stream(&seed, dst, binder).fill_bytes(&mut got); + assert_eq!(got, want); + } + + #[test] + fn xof_shake128() { + let t: XofTestVector = + serde_json::from_str(include_str!("test_vec/07/XofShake128.json")).unwrap(); + let mut xof = XofShake128::init(&t.seed.try_into().unwrap(), &t.dst); + xof.update(&t.binder); + + assert_eq!( + xof.clone().into_seed(), + Seed(t.derived_seed.try_into().unwrap()) + ); + + let mut bytes = Cursor::new(t.expanded_vec_field128.as_slice()); + let mut want = Vec::with_capacity(t.length); + while (bytes.position() as usize) < t.expanded_vec_field128.len() { + want.push(Field128::decode(&mut bytes).unwrap()) + } + let got: Vec<Field128> = xof.clone().into_seed_stream().into_field_vec(t.length); + assert_eq!(got, want); + + test_xof::<XofShake128, 16>(); + } + + #[cfg(feature = "experimental")] + #[test] + fn xof_fixed_key_aes128() { + let t: XofTestVector = + serde_json::from_str(include_str!("test_vec/07/XofFixedKeyAes128.json")).unwrap(); + let mut xof = XofFixedKeyAes128::init(&t.seed.try_into().unwrap(), &t.dst); + xof.update(&t.binder); + + assert_eq!( + xof.clone().into_seed(), + Seed(t.derived_seed.try_into().unwrap()) + ); + + let mut bytes = Cursor::new(t.expanded_vec_field128.as_slice()); + let mut want = Vec::with_capacity(t.length); + while (bytes.position() as usize) < t.expanded_vec_field128.len() { + want.push(Field128::decode(&mut bytes).unwrap()) + } + let got: Vec<Field128> = xof.clone().into_seed_stream().into_field_vec(t.length); + assert_eq!(got, want); + + test_xof::<XofFixedKeyAes128, 16>(); + } + + #[cfg(feature = "experimental")] + #[test] + fn xof_fixed_key_aes128_incomplete_block() { + let seed = Seed::generate().unwrap(); + let mut expected = [0; 32]; + XofFixedKeyAes128::seed_stream(&seed, b"dst", b"binder").fill(&mut expected); + + for len in 0..=32 { + let mut buf = vec![0; len]; + XofFixedKeyAes128::seed_stream(&seed, b"dst", b"binder").fill(&mut buf); + assert_eq!(buf, &expected[..len]); + } + } + + #[cfg(feature = "experimental")] + #[test] + fn xof_fixed_key_aes128_alternate_apis() { + let dst = b"domain separation tag"; + let binder = b"AAAAAAAAAAAAAAAAAAAAAAAA"; + let seed_1 = Seed::generate().unwrap(); + let seed_2 = Seed::generate().unwrap(); + + let mut stream_1_trait_api = XofFixedKeyAes128::seed_stream(&seed_1, dst, binder); + let mut output_1_trait_api = [0u8; 32]; + stream_1_trait_api.fill(&mut output_1_trait_api); + let mut stream_2_trait_api = XofFixedKeyAes128::seed_stream(&seed_2, dst, binder); + let mut output_2_trait_api = [0u8; 32]; + stream_2_trait_api.fill(&mut output_2_trait_api); + + let fixed_key = XofFixedKeyAes128Key::new(dst, binder); + let mut stream_1_alternate_api = fixed_key.with_seed(seed_1.as_ref()); + let mut output_1_alternate_api = [0u8; 32]; + stream_1_alternate_api.fill(&mut output_1_alternate_api); + let mut stream_2_alternate_api = fixed_key.with_seed(seed_2.as_ref()); + let mut output_2_alternate_api = [0u8; 32]; + stream_2_alternate_api.fill(&mut output_2_alternate_api); + + assert_eq!(output_1_trait_api, output_1_alternate_api); + assert_eq!(output_2_trait_api, output_2_alternate_api); + } + + #[test] + fn seed_equality_test() { + equality_comparison_test(&[Seed([1, 2, 3]), Seed([3, 2, 1])]) + } +} diff --git a/third_party/rust/prio/tests/discrete_gauss.rs b/third_party/rust/prio/tests/discrete_gauss.rs new file mode 100644 index 0000000000..5b3ef4c5b3 --- /dev/null +++ b/third_party/rust/prio/tests/discrete_gauss.rs @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MPL-2.0 + +use num_bigint::{BigInt, BigUint}; +use num_rational::Ratio; +use num_traits::FromPrimitive; +use prio::dp::distributions::DiscreteGaussian; +use prio::vdaf::xof::SeedStreamSha3; +use rand::distributions::Distribution; +use rand::SeedableRng; +use serde::Deserialize; + +/// A test vector of discrete Gaussian samples, produced by the python reference +/// implementation for [[CKS20]]. The script used to generate the test vector can +/// be found in this gist: +/// https://gist.github.com/ooovi/529c00fc8a7eafd068cd076b78fc424e +/// The python reference implementation is here: +/// https://github.com/IBM/discrete-gaussian-differential-privacy +/// +/// [CKS20]: https://arxiv.org/pdf/2004.00010.pdf +#[derive(Debug, Eq, PartialEq, Deserialize)] +pub struct DiscreteGaussTestVector { + #[serde(with = "hex")] + seed: [u8; 16], + std_num: u128, + std_denom: u128, + samples: Vec<i128>, +} + +#[test] +fn discrete_gauss_reference() { + let test_vectors: Vec<DiscreteGaussTestVector> = vec![ + serde_json::from_str(include_str!(concat!("test_vectors/discrete_gauss_3.json"))).unwrap(), + serde_json::from_str(include_str!(concat!("test_vectors/discrete_gauss_9.json"))).unwrap(), + serde_json::from_str(include_str!(concat!( + "test_vectors/discrete_gauss_100.json" + ))) + .unwrap(), + serde_json::from_str(include_str!(concat!( + "test_vectors/discrete_gauss_41293847.json" + ))) + .unwrap(), + serde_json::from_str(include_str!(concat!( + "test_vectors/discrete_gauss_9999999999999999999999.json" + ))) + .unwrap(), + ]; + + for test_vector in test_vectors { + let sampler = DiscreteGaussian::new(Ratio::<BigUint>::new( + test_vector.std_num.into(), + test_vector.std_denom.into(), + )) + .unwrap(); + + // check samples are consistent + let mut rng = SeedStreamSha3::from_seed(test_vector.seed); + let samples: Vec<BigInt> = (0..test_vector.samples.len()) + .map(|_| sampler.sample(&mut rng)) + .collect(); + + assert_eq!( + samples, + test_vector + .samples + .iter() + .map(|&s| BigInt::from_i128(s).unwrap()) + .collect::<Vec::<BigInt>>() + ); + } +} diff --git a/third_party/rust/prio/tests/test_vectors/discrete_gauss_100.json b/third_party/rust/prio/tests/test_vectors/discrete_gauss_100.json new file mode 100644 index 0000000000..fe8114c258 --- /dev/null +++ b/third_party/rust/prio/tests/test_vectors/discrete_gauss_100.json @@ -0,0 +1,56 @@ +{ + "samples": [ + -74, + 68, + -15, + 175, + -120, + -2, + -73, + 108, + 40, + 69, + 81, + -135, + 247, + 32, + 107, + -61, + 164, + 22, + 118, + 37, + -58, + 147, + 65, + 53, + 9, + -96, + -130, + 100, + 48, + -30, + -2, + -115, + 56, + 95, + 119, + 28, + -101, + 50, + 39, + 21, + -6, + -70, + 131, + 66, + 81, + -18, + 94, + 55, + 20 + ], + "seed": "000102030405060708090a0b0c0d0e0f", + "std_denom": 1, + "std_num": 100 +} diff --git a/third_party/rust/prio/tests/test_vectors/discrete_gauss_2.342.json b/third_party/rust/prio/tests/test_vectors/discrete_gauss_2.342.json new file mode 100644 index 0000000000..7d9508c44e --- /dev/null +++ b/third_party/rust/prio/tests/test_vectors/discrete_gauss_2.342.json @@ -0,0 +1,56 @@ +{ + "samples": [ + -1, + 4, + 2, + 0, + 0, + -2, + 1, + -5, + 2, + -1, + 0, + 0, + 0, + 3, + -6, + 5, + 2, + 1, + -1, + -3, + 0, + 2, + -3, + -2, + 2, + -3, + 1, + 1, + 2, + -3, + -1, + -1, + 4, + 2, + -2, + 1, + -1, + 0, + -3, + 1, + 2, + 1, + -1, + 3, + 1, + 2, + -3, + -2, + 0 + ], + "seed": "000102030405060708090a0b0c0d0e0f", + "std_denom": 500, + "std_num": 1171 +} diff --git a/third_party/rust/prio/tests/test_vectors/discrete_gauss_3.json b/third_party/rust/prio/tests/test_vectors/discrete_gauss_3.json new file mode 100644 index 0000000000..d4a3486db0 --- /dev/null +++ b/third_party/rust/prio/tests/test_vectors/discrete_gauss_3.json @@ -0,0 +1,56 @@ +{ + "samples": [ + 1, + -1, + -2, + 1, + -1, + 0, + 1, + -1, + 0, + 2, + 3, + 4, + 1, + 1, + -2, + 4, + 6, + -5, + -3, + -1, + 4, + 0, + 6, + 2, + 2, + -4, + -2, + -5, + -3, + 2, + 1, + -3, + -2, + 1, + -2, + 1, + 0, + 3, + -4, + -4, + 1, + 3, + 2, + 1, + 0, + -1, + 1, + 4, + 1 + ], + "seed": "000102030405060708090a0b0c0d0e0f", + "std_denom": 1, + "std_num": 3 +} diff --git a/third_party/rust/prio/tests/test_vectors/discrete_gauss_41293847.json b/third_party/rust/prio/tests/test_vectors/discrete_gauss_41293847.json new file mode 100644 index 0000000000..213e919c6c --- /dev/null +++ b/third_party/rust/prio/tests/test_vectors/discrete_gauss_41293847.json @@ -0,0 +1,56 @@ +{ + "samples": [ + -10157810, + -7944688, + 80361481, + -2601121, + -2098394, + -61204295, + 250399, + 62702361, + -35117486, + 14804891, + -5043613, + 34131059, + -34448923, + -24176095, + 106518772, + 44056972, + 15910928, + 63338376, + 12839729, + 11701052, + -54254959, + -11306071, + -6005727, + 29738939, + -30284246, + -47672638, + 11549070, + -17580447, + -2973754, + -298465, + -15349002, + 56970396, + 35612502, + -78720214, + -6082493, + -2759887, + -11374460, + 177253, + -35234082, + 42256563, + 44219644, + 86984413, + -43711428, + -16031438, + 42572889, + 39464625, + -14433332, + -7735634, + 4403776 + ], + "seed": "000102030405060708090a0b0c0d0e0f", + "std_denom": 1, + "std_num": 41293847 +} diff --git a/third_party/rust/prio/tests/test_vectors/discrete_gauss_9.json b/third_party/rust/prio/tests/test_vectors/discrete_gauss_9.json new file mode 100644 index 0000000000..408c7489a7 --- /dev/null +++ b/third_party/rust/prio/tests/test_vectors/discrete_gauss_9.json @@ -0,0 +1,56 @@ +{ + "samples": [ + 2, + 14, + -12, + 10, + -13, + 1, + 6, + -10, + -1, + 4, + -5, + 1, + -20, + 1, + -14, + 2, + -5, + 10, + 3, + 9, + 10, + 12, + -4, + 12, + -4, + -13, + 7, + -10, + -6, + 2, + 6, + 18, + -7, + -11, + 20, + 1, + -4, + 14, + 11, + 5, + 6, + -1, + 2, + 5, + 3, + -8, + -1, + 6, + -8 + ], + "seed": "000102030405060708090a0b0c0d0e0f", + "std_denom": 1, + "std_num": 9 +} diff --git a/third_party/rust/prio/tests/test_vectors/discrete_gauss_9999999999999999999999.json b/third_party/rust/prio/tests/test_vectors/discrete_gauss_9999999999999999999999.json new file mode 100644 index 0000000000..396f5a4cf0 --- /dev/null +++ b/third_party/rust/prio/tests/test_vectors/discrete_gauss_9999999999999999999999.json @@ -0,0 +1,56 @@ +{ + "samples": [ + 7826646418794481373730, + 4044429794334089153683, + -13887062284591122240746, + 4816851335312673293131, + 1899078604295677453383, + -7819872990828024151405, + 1164017821807881486579, + -25360379570365624087817, + 5906637163630390455939, + 3730592807287609262846, + 1737147847266613603450, + 18593764766679058926154, + 22724295990478919946193, + -4396717652132313983045, + -7933138830987043774425, + 12204708418993708917398, + 10716232788607487693156, + -7423575920998904747964, + -274262846168968742506, + -24595460253341309777726, + 1880635641243101726137, + 10823060437484007979521, + -2525077352000184270857, + -1421364839539048815904, + -2648842672480402351562, + -7783156811031000955203, + -1831198454606609539077, + 905920470298728568753, + -8805882598094077859729, + 2949974625887521817722, + 13071000629486423981714, + 1311702736683393895126, + -14044034250430823347919, + 1421736709768854180193, + 14824744520414922652958, + 10752031849750698732804, + -522118577625103067952, + 2006618532306506057615, + -7573105805904097888275, + -14482966128638641002042, + -11408400475022385123481, + -17555433966245180572099, + -6120185353438187140929, + -4778266138627264471521, + -19325342657405318133711, + 3725950229250476135126, + -7977400544074383347686, + 2738166449787592433931, + -1321521809406566447071 + ], + "seed": "000102030405060708090a0b0c0d0e0f", + "std_denom": 1, + "std_num": 9999999999999999999999 +} |