diff --git a/Cargo.toml b/Cargo.toml index d1e7c492..6557b9c9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,8 @@ members = [ "crates/clmul", "crates/mpz-ole-core", "crates/mpz-ole", + "crates/mpz-zk-core", + "crates/mpz-zk", ] resolver = "2" @@ -43,6 +45,8 @@ mpz-ole = { path = "crates/mpz-ole" } mpz-ole-core = { path = "crates/mpz-ole-core" } clmul = { path = "crates/clmul" } matrix-transpose = { path = "crates/matrix-transpose" } +mpz-zk-core = { path = "crates/mpz-zk-core" } +mpz-zk = { path = "crates/mpz-zk" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" } tlsn-utils-aio = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6e0be94" } diff --git a/crates/mpz-common/src/ideal.rs b/crates/mpz-common/src/ideal.rs index 804472ef..1b6b3181 100644 --- a/crates/mpz-common/src/ideal.rs +++ b/crates/mpz-common/src/ideal.rs @@ -18,7 +18,7 @@ struct Buffer { } /// The ideal functionality from the perspective of Alice. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Alice { f: Arc>, buffer: Arc>, @@ -35,7 +35,7 @@ impl Clone for Alice { impl Alice { /// Returns a lock to the ideal functionality. - pub fn get_mut(&mut self) -> MutexGuard<'_, F> { + pub fn lock(&self) -> MutexGuard<'_, F> { self.f.lock().unwrap() } @@ -79,7 +79,7 @@ impl Alice { } /// The ideal functionality from the perspective of Bob. -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Bob { f: Arc>, buffer: Arc>, @@ -96,7 +96,7 @@ impl Clone for Bob { impl Bob { /// Returns a lock to the ideal functionality. - pub fn get_mut(&mut self) -> MutexGuard<'_, F> { + pub fn lock(&self) -> MutexGuard<'_, F> { self.f.lock().unwrap() } diff --git a/crates/mpz-core/src/block.rs b/crates/mpz-core/src/block.rs index 2f7a0105..b191eea9 100644 --- a/crates/mpz-core/src/block.rs +++ b/crates/mpz-core/src/block.rs @@ -7,6 +7,7 @@ use generic_array::{typenum::consts::U16, GenericArray}; use itybity::{BitIterable, BitLength, GetBit, Lsb0, Msb0}; use rand::{distributions::Standard, prelude::Distribution, CryptoRng, Rng}; use serde::{Deserialize, Serialize}; +use std::iter::successors; /// A block of 128 bits #[repr(transparent)] @@ -22,6 +23,11 @@ impl Block { pub const ONE: Self = Self([1u8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); /// A block with all bits set to 1 pub const ONES: Self = Self([0xff; 16]); + /// A block with all 1 bits excect the lsb. + pub const MINIS_ONE: Block = Self([ + 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, + ]); /// A length 2 array of zero and one blocks pub const SELECT_MASK: [Self; 2] = [Self::ZERO, Self::ONES]; @@ -123,6 +129,15 @@ impl Block { bytemuck::cast([x[1], x[0]]) } + /// Generate the powers of the seed. + /// Starting with seed. + #[inline(always)] + pub fn powers(seed: Self, size: usize) -> Vec { + successors(Some(seed), |pow| Some(pow.gfmul(seed))) + .take(size) + .collect() + } + /// Converts a block to a [`GenericArray`](cipher::generic_array::GenericArray) /// from the [`generic-array`](https://docs.rs/generic-array/latest/generic_array/) crate. #[allow(dead_code)] diff --git a/crates/mpz-core/src/ggm_tree.rs b/crates/mpz-core/src/ggm_tree.rs index 913fffb6..840efcc6 100644 --- a/crates/mpz-core/src/ggm_tree.rs +++ b/crates/mpz-core/src/ggm_tree.rs @@ -32,33 +32,35 @@ impl GgmTree { assert_eq!(k0.len(), self.depth); assert_eq!(k1.len(), self.depth); let mut buf = [Block::ZERO; 8]; - self.tkprp.expand_1to2(tree, seed); - k0[0] = tree[0]; - k1[0] = tree[1]; + if self.depth > 1 { + self.tkprp.expand_1to2(tree, seed); + k0[0] = tree[0]; + k1[0] = tree[1]; - self.tkprp.expand_2to4(&mut buf, tree); - k0[1] = buf[0] ^ buf[2]; - k1[1] = buf[1] ^ buf[3]; - tree[0..4].copy_from_slice(&buf[0..4]); - - for h in 2..self.depth { - k0[h] = Block::ZERO; - k1[h] = Block::ZERO; - - // How many nodes there are in this layer - let sz = 1 << h; - for i in (0..=sz - 4).rev().step_by(4) { - self.tkprp.expand_4to8(&mut buf, &tree[i..]); - k0[h] ^= buf[0]; - k0[h] ^= buf[2]; - k0[h] ^= buf[4]; - k0[h] ^= buf[6]; - k1[h] ^= buf[1]; - k1[h] ^= buf[3]; - k1[h] ^= buf[5]; - k1[h] ^= buf[7]; + self.tkprp.expand_2to4(&mut buf, tree); + k0[1] = buf[0] ^ buf[2]; + k1[1] = buf[1] ^ buf[3]; + tree[0..4].copy_from_slice(&buf[0..4]); - tree[2 * i..2 * i + 8].copy_from_slice(&buf); + for h in 2..self.depth { + k0[h] = Block::ZERO; + k1[h] = Block::ZERO; + + // How many nodes there are in this layer + let sz = 1 << h; + for i in (0..=sz - 4).rev().step_by(4) { + self.tkprp.expand_4to8(&mut buf, &tree[i..]); + k0[h] ^= buf[0]; + k0[h] ^= buf[2]; + k0[h] ^= buf[4]; + k0[h] ^= buf[6]; + k1[h] ^= buf[1]; + k1[h] ^= buf[3]; + k1[h] ^= buf[5]; + k1[h] ^= buf[7]; + + tree[2 * i..2 * i + 8].copy_from_slice(&buf); + } } } } diff --git a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs index 403802f9..d9638951 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/receiver.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/receiver.rs @@ -153,7 +153,7 @@ impl Receiver { let SenderPayload { id, payload } = payload; // Check that the transfer id matches - let expected_id = current_id.next(); + let expected_id = current_id.next_id(); if id != expected_id { return Err(ReceiverError::IdMismatch(expected_id, id)); } diff --git a/crates/mpz-ot-core/src/chou_orlandi/sender.rs b/crates/mpz-ot-core/src/chou_orlandi/sender.rs index 09a8b5a6..328354eb 100644 --- a/crates/mpz-ot-core/src/chou_orlandi/sender.rs +++ b/crates/mpz-ot-core/src/chou_orlandi/sender.rs @@ -139,7 +139,7 @@ impl Sender { } = receiver_payload; // Check that the transfer id matches - let expected_id = current_id.next(); + let expected_id = current_id.next_id(); if id != expected_id { return Err(SenderError::IdMismatch(expected_id, id)); } diff --git a/crates/mpz-ot-core/src/ferret/mod.rs b/crates/mpz-ot-core/src/ferret/mod.rs index 3ad7701e..ac73c005 100644 --- a/crates/mpz-ot-core/src/ferret/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mod.rs @@ -1,7 +1,4 @@ //! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. - -use mpz_core::lpn::LpnParameters; - pub mod cuckoo; pub mod error; pub mod mpcot; @@ -19,28 +16,13 @@ pub const CUCKOO_HASH_NUM: usize = 3; /// Trial numbers in Cuckoo hash insertion. pub const CUCKOO_TRIAL_NUM: usize = 100; -/// LPN parameters with regular noise. -/// Derived from https://github.com/emp-toolkit/emp-ot/blob/master/emp-ot/ferret/constants.h -pub const LPN_PARAMETERS_REGULAR: LpnParameters = LpnParameters { - n: 10180608, - k: 124000, - t: 4971, -}; - -/// LPN parameters with uniform noise. -/// Derived from Table 2. -pub const LPN_PARAMETERS_UNIFORM: LpnParameters = LpnParameters { - n: 10616092, - k: 588160, - t: 1324, -}; - /// The type of Lpn parameters. -#[derive(Debug)] +#[derive(Debug, Clone, Copy, Default)] pub enum LpnType { /// Uniform error distribution. Uniform, /// Regular error distribution. + #[default] Regular, } @@ -48,15 +30,15 @@ pub enum LpnType { mod tests { use super::*; - use msgs::LpnMatrixSeed; use receiver::Receiver; use sender::Sender; - use crate::ideal::{cot::IdealCOT, mpcot::IdealMpcot}; - use crate::test::assert_cot; - use crate::{MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput}; + use crate::{ + ideal::{cot::IdealCOT, mpcot::IdealMpcot}, + test::assert_cot, + MPCOTReceiverOutput, MPCOTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput, + }; use mpz_core::{lpn::LpnParameters, prg::Prg}; - use rand::SeedableRng; const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { n: 9600, @@ -66,7 +48,7 @@ mod tests { #[test] fn ferret_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_cot = IdealCOT::default(); let mut ideal_mpcot = IdealMpcot::default(); @@ -101,18 +83,8 @@ mod tests { ) .unwrap(); - let LpnMatrixSeed { - seed: lpn_matrix_seed, - } = seed; - let mut sender = sender - .setup( - delta, - LPN_PARAMETERS_TEST, - LpnType::Regular, - lpn_matrix_seed, - &v, - ) + .setup(delta, LPN_PARAMETERS_TEST, LpnType::Regular, seed, &v) .unwrap(); // extend once @@ -122,8 +94,15 @@ mod tests { let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = ideal_mpcot.extend(&query.0, query.1); - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); + sender.extend(s).unwrap(); + receiver.extend(r).unwrap(); + + let RCOTSenderOutput { msgs, .. } = sender.consume(2).unwrap(); + let RCOTReceiverOutput { + choices, + msgs: received, + .. + } = receiver.consume(2).unwrap(); assert_cot(delta, &choices, &msgs, &received); @@ -134,8 +113,15 @@ mod tests { let (MPCOTSenderOutput { s, .. }, MPCOTReceiverOutput { r, .. }) = ideal_mpcot.extend(&query.0, query.1); - let msgs = sender.extend(&s).unwrap(); - let (choices, received) = receiver.extend(&r).unwrap(); + sender.extend(s).unwrap(); + receiver.extend(r).unwrap(); + + let RCOTSenderOutput { msgs, .. } = sender.consume(sender.remaining()).unwrap(); + let RCOTReceiverOutput { + choices, + msgs: received, + .. + } = receiver.consume(receiver.remaining()).unwrap(); assert_cot(delta, &choices, &msgs, &received); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs index e74dc38a..047780d4 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/mod.rs @@ -16,11 +16,10 @@ mod tests { use crate::ideal::spcot::IdealSpcot; use crate::{SPCOTReceiverOutput, SPCOTSenderOutput}; use mpz_core::prg::Prg; - use rand::SeedableRng; #[test] fn mpcot_general_test() { - let mut prg = Prg::from_seed([1u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); @@ -96,7 +95,7 @@ mod tests { #[test] fn mpcot_regular_test() { - let mut prg = Prg::from_seed([2u8; 16].into()); + let mut prg = Prg::new(); let delta = prg.random_block(); let mut ideal_spcot = IdealSpcot::new_with_delta(delta); diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs index 0f8613af..e4d362da 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver.rs @@ -32,11 +32,11 @@ impl Receiver { /// # Argument /// /// * `hash_seed` - Random seed to generate hashes, will be sent to the sender. - pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { + pub fn setup(self, hash_seed: Block) -> (Receiver, HashSeed) { let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); let recv = Receiver { - state: state::PreExtension { + state: state::Extension { counter: 0, hashes: Arc::new(hashes), }, @@ -48,7 +48,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -63,7 +63,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { if alphas.len() as u32 > n { return Err(ReceiverError::InvalidInput( "length of alphas should not exceed n".to_string(), @@ -104,7 +104,7 @@ impl Receiver { } let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, m, n, @@ -117,7 +117,7 @@ impl Receiver { Ok((receiver, p)) } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -128,7 +128,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt.len() != self.state.m { return Err(ReceiverError::InvalidInput( "the length rt should be m".to_string(), @@ -165,7 +165,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, hashes: self.state.hashes, }, @@ -182,8 +182,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -200,20 +200,20 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, /// The hashes to generate Cuckoo hash table. pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state of extension. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter pub(super) counter: usize, /// Current length of Cuckoo hash table, will possibly be changed in each extension. @@ -228,7 +228,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs index 2b226108..e1e7edfe 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/receiver_regular.rs @@ -19,13 +19,13 @@ impl Receiver { } /// Completes the setup phase of the protocol. - pub fn setup(self) -> Receiver { + pub fn setup(self) -> Receiver { Receiver { - state: state::PreExtension { counter: 0 }, + state: state::Extension { counter: 0 }, } } } -impl Receiver { +impl Receiver { /// Performs the prepare procedure in MPCOT extension. /// Outputs the indices for SPCOT. /// @@ -38,7 +38,7 @@ impl Receiver { self, alphas: &[u32], n: u32, - ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { + ) -> Result<(Receiver, Vec<(usize, u32)>), ReceiverError> { let t = alphas.len() as u32; if t > n { return Err(ReceiverError::InvalidInput( @@ -91,7 +91,7 @@ impl Receiver { .collect(); let receiver = Receiver { - state: state::Extension { + state: state::ExtensionInternal { counter: self.state.counter, n, queries_length, @@ -103,7 +103,7 @@ impl Receiver { } } -impl Receiver { +impl Receiver { /// Performs MPCOT extension. /// /// # Arguments. @@ -112,7 +112,7 @@ impl Receiver { pub fn extend( self, rt: &[Vec], - ) -> Result<(Receiver, Vec), ReceiverError> { + ) -> Result<(Receiver, Vec), ReceiverError> { if rt .iter() .zip(self.state.queries_depth.iter()) @@ -130,7 +130,7 @@ impl Receiver { } let receiver = Receiver { - state: state::PreExtension { + state: state::Extension { counter: self.state.counter + 1, }, }; @@ -145,8 +145,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The receiver's state. @@ -162,19 +162,19 @@ pub mod state { /// The receiver's state before extending. /// /// In this state the receiver performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} + impl State for Extension {} - opaque_debug::implement!(PreExtension); + opaque_debug::implement!(Extension); /// The receiver's state after the setup phase. /// /// In this state the receiver performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Current MPCOT counter #[allow(dead_code)] pub(super) counter: usize, @@ -186,7 +186,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs index f1e49105..ad025574 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender.rs @@ -31,12 +31,12 @@ impl Sender { /// /// * `delta` - The sender's global secret. /// * `hash_seed` - The seed for Cuckoo hash sent by the receiver. - pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { + pub fn setup(self, delta: Block, hash_seed: HashSeed) -> Sender { let HashSeed { seed: hash_seed } = hash_seed; let mut prg = Prg::from_seed(hash_seed); let hashes = std::array::from_fn(|_| AesEncryptor::new(prg.random_block())); Sender { - state: state::PreExtension { + state: state::Extension { delta, counter: 0, hashes: Arc::new(hashes), @@ -45,7 +45,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs the hash procedure in MPCOT extension. /// Outputs the length of each bucket plus 1. /// @@ -59,7 +59,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -86,7 +86,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, m, @@ -101,7 +101,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// See Step 5 in Figure 7. @@ -112,7 +112,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st.len() != self.state.m { return Err(SenderError::InvalidInput( "the length st should be m".to_string(), @@ -147,7 +147,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, hashes: self.state.hashes, @@ -166,8 +166,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -184,7 +184,7 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -193,13 +193,13 @@ pub mod state { pub(super) hashes: Arc<[AesEncryptor; CUCKOO_HASH_NUM]>, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state of extension. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -217,7 +217,7 @@ pub mod state { pub(super) buckets_length: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs index db0646b6..7afa5106 100644 --- a/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs +++ b/crates/mpz-ot-core/src/ferret/mpcot/sender_regular.rs @@ -23,14 +23,14 @@ impl Sender { /// # Argument. /// /// * `delta` - The sender's global secret. - pub fn setup(self, delta: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { - state: state::PreExtension { delta, counter: 0 }, + state: state::Extension { delta, counter: 0 }, } } } -impl Sender { +impl Sender { /// Performs the prepare procedure in MPCOT extension. /// Outputs the information for SPCOT. /// @@ -42,7 +42,7 @@ impl Sender { self, t: u32, n: u32, - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if t > n { return Err(SenderError::InvalidInput( "t should not exceed n".to_string(), @@ -78,7 +78,7 @@ impl Sender { } let sender = Sender { - state: state::Extension { + state: state::ExtensionInternal { delta: self.state.delta, counter: self.state.counter, n, @@ -91,7 +91,7 @@ impl Sender { } } -impl Sender { +impl Sender { /// Performs MPCOT extension. /// /// # Arguments. @@ -100,7 +100,7 @@ impl Sender { pub fn extend( self, st: &[Vec], - ) -> Result<(Sender, Vec), SenderError> { + ) -> Result<(Sender, Vec), SenderError> { if st .iter() .zip(self.state.queries_depth.iter()) @@ -117,7 +117,7 @@ impl Sender { } let sender = Sender { - state: state::PreExtension { + state: state::Extension { delta: self.state.delta, counter: self.state.counter + 1, }, @@ -135,8 +135,8 @@ pub mod state { pub trait Sealed {} impl Sealed for super::Initialized {} - impl Sealed for super::PreExtension {} impl Sealed for super::Extension {} + impl Sealed for super::ExtensionInternal {} } /// The sender's state. @@ -153,20 +153,20 @@ pub mod state { /// The sender's state before extending. /// /// In this state the sender performs pre extension in MPCOT (potentially multiple times). - pub struct PreExtension { + pub struct Extension { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter pub(super) counter: usize, } - impl State for PreExtension {} - opaque_debug::implement!(PreExtension); + impl State for Extension {} + opaque_debug::implement!(Extension); /// The sender's state after the setup phase. /// /// In this state the sender performs MPCOT extension (potentially multiple times). - pub struct Extension { + pub struct ExtensionInternal { /// Sender's global secret. pub(super) delta: Block, /// Current MPCOT counter @@ -179,7 +179,7 @@ pub mod state { pub(super) queries_depth: Vec, } - impl State for Extension {} + impl State for ExtensionInternal {} - opaque_debug::implement!(Extension); + opaque_debug::implement!(ExtensionInternal); } diff --git a/crates/mpz-ot-core/src/ferret/receiver.rs b/crates/mpz-ot-core/src/ferret/receiver.rs index 4d08c69b..782d2b9e 100644 --- a/crates/mpz-ot-core/src/ferret/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/receiver.rs @@ -1,10 +1,15 @@ //! Ferret receiver +use std::collections::VecDeque; + use mpz_core::{ lpn::{LpnEncoder, LpnParameters}, Block, }; -use crate::ferret::{error::ReceiverError, LpnType}; +use crate::{ + ferret::{error::ReceiverError, LpnType}, + RCOTReceiverOutput, TransferId, +}; use super::msgs::LpnMatrixSeed; @@ -59,6 +64,9 @@ impl Receiver { u: u.to_vec(), w: w.to_vec(), e: Vec::default(), + id: TransferId::default(), + choices_buffer: VecDeque::new(), + msgs_buffer: VecDeque::new(), }, }, LpnMatrixSeed { seed }, @@ -67,12 +75,18 @@ impl Receiver { } impl Receiver { + /// Returns the current transfer id. + pub fn id(&self) -> TransferId { + self.state.id + } + + /// Returns the number of remaining COTs. + pub fn remaining(&self) -> usize { + self.state.choices_buffer.len() + } + /// The prepare precedure of extension, sample error vectors and outputs information for MPCOT. /// See step 3 and 4. - /// - /// # Arguments. - /// - /// * `lpn_type` - The type of LPN parameters. pub fn get_mpcot_query(&mut self) -> (Vec, usize) { match self.state.lpn_type { LpnType::Uniform => { @@ -100,13 +114,15 @@ impl Receiver { /// # Arguments. /// /// * `r` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, r: &[Block]) -> Result<(Vec, Vec), ReceiverError> { + pub fn extend(&mut self, r: Vec) -> Result<(), ReceiverError> { if r.len() != self.state.lpn_parameters.n { return Err(ReceiverError("the length of r should be n".to_string())); } + self.state.id.next_id(); + // Compute z = A * w + r. - let mut z = r.to_vec(); + let mut z = r; self.state.lpn_encoder.compute(&mut z, &self.state.w); // Compute x = A * u + e. @@ -131,7 +147,32 @@ impl Receiver { // Update counter self.state.counter += 1; - Ok((x_, z_)) + self.state.choices_buffer.extend(x_); + self.state.msgs_buffer.extend(z_); + + Ok(()) + } + + /// Consumes `count` COTs. + pub fn consume( + &mut self, + count: usize, + ) -> Result, ReceiverError> { + if count > self.state.choices_buffer.len() { + return Err(ReceiverError(format!( + "insufficient OTs: {} < {count}", + self.state.choices_buffer.len() + ))); + } + + let choices = self.state.choices_buffer.drain(0..count).collect(); + let msgs = self.state.msgs_buffer.drain(0..count).collect(); + + Ok(RCOTReceiverOutput { + id: self.state.id.next_id(), + choices, + msgs, + }) } } @@ -176,6 +217,12 @@ pub mod state { /// Receiver's lpn error vector. pub(super) e: Vec, + + /// TransferID + pub(super) id: TransferId, + /// Extended OTs buffers. + pub(super) choices_buffer: VecDeque, + pub(super) msgs_buffer: VecDeque, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/sender.rs b/crates/mpz-ot-core/src/ferret/sender.rs index 9e8db180..e6af6452 100644 --- a/crates/mpz-ot-core/src/ferret/sender.rs +++ b/crates/mpz-ot-core/src/ferret/sender.rs @@ -1,10 +1,17 @@ //! Ferret sender. +use std::collections::VecDeque; + use mpz_core::{ lpn::{LpnEncoder, LpnParameters}, Block, }; -use crate::ferret::{error::SenderError, LpnType}; +use crate::{ + ferret::{error::SenderError, LpnType}, + RCOTSenderOutput, TransferId, +}; + +use super::msgs::LpnMatrixSeed; /// Ferret sender. #[derive(Debug, Default)] @@ -36,7 +43,7 @@ impl Sender { delta: Block, lpn_parameters: LpnParameters, lpn_type: LpnType, - seed: Block, + seed: LpnMatrixSeed, v: &[Block], ) -> Result, SenderError> { if v.len() != lpn_parameters.k { @@ -44,6 +51,7 @@ impl Sender { "the length of v should be equal to k".to_string(), )); } + let LpnMatrixSeed { seed } = seed; let lpn_encoder = LpnEncoder::<10>::new(seed, lpn_parameters.k as u32); Ok(Sender { @@ -54,15 +62,33 @@ impl Sender { lpn_type, lpn_encoder, v: v.to_vec(), + id: TransferId::default(), + msgs_buffer: VecDeque::new(), }, }) } } impl Sender { + /// Returns the current transfer id. + pub fn id(&self) -> TransferId { + self.state.id + } + + /// Returns the number of remaining COTs. + pub fn remaining(&self) -> usize { + self.state.msgs_buffer.len() + } + + /// Returns the delta correlation. + pub fn delta(&self) -> Block { + self.state.delta + } + /// Outputs the information for MPCOT. /// /// See step 3 and 4. + #[inline] pub fn get_mpcot_query(&self) -> (u32, u32) { ( self.state.lpn_parameters.t as u32, @@ -78,13 +104,15 @@ impl Sender { /// # Arguments. /// /// * `s` - The vector received from the MPCOT protocol. - pub fn extend(&mut self, s: &[Block]) -> Result, SenderError> { + pub fn extend(&mut self, s: Vec) -> Result<(), SenderError> { if s.len() != self.state.lpn_parameters.n { return Err(SenderError("the length of s should be n".to_string())); } + self.state.id.next_id(); + // Compute y = A * v + s - let mut y = s.to_vec(); + let mut y = s; self.state.lpn_encoder.compute(&mut y, &self.state.v); let y_ = y.split_off(self.state.lpn_parameters.k); @@ -94,13 +122,33 @@ impl Sender { // Update counter self.state.counter += 1; + self.state.msgs_buffer.extend(y_); - Ok(y_) + Ok(()) + } + + /// Consumes `count` COTs. + pub fn consume(&mut self, count: usize) -> Result, SenderError> { + if count > self.state.msgs_buffer.len() { + return Err(SenderError(format!( + "insufficient OTs: {} < {count}", + self.state.msgs_buffer.len() + ))); + } + + let msgs = self.state.msgs_buffer.drain(0..count).collect(); + + Ok(RCOTSenderOutput { + id: self.state.id.next_id(), + msgs, + }) } } /// The sender's state. pub mod state { + use crate::TransferId; + use super::*; mod sealed { @@ -141,6 +189,11 @@ pub mod state { /// Sender's COT message in the setup phase. pub(super) v: Vec, + + /// Transfer ID. + pub(crate) id: TransferId, + /// COT messages buffer. + pub(super) msgs_buffer: VecDeque, } impl State for Extension {} diff --git a/crates/mpz-ot-core/src/ferret/spcot/mod.rs b/crates/mpz-ot-core/src/ferret/spcot/mod.rs index 802efb66..63ebea15 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/mod.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/mod.rs @@ -7,8 +7,6 @@ pub mod sender; #[cfg(test)] mod tests { - use mpz_core::prg::Prg; - use super::{receiver::Receiver as SpcotReceiver, sender::Sender as SpcotSender}; use crate::{ferret::CSP, ideal::cot::IdealCOT, RCOTReceiverOutput, RCOTSenderOutput}; @@ -18,49 +16,82 @@ mod tests { let sender = SpcotSender::new(); let receiver = SpcotReceiver::new(); - let mut prg = Prg::new(); - let sender_seed = prg.random_block(); let delta = ideal_cot.delta(); - let mut sender = sender.setup(delta, sender_seed); + let mut sender = sender.setup(delta); let mut receiver = receiver.setup(); - let h1 = 8; - let alpha1 = 3; + let hs = [8, 4, 10]; + let alphas = [3, 2, 4]; - // Extend once - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h1); + let h_sum = hs.iter().sum(); + // batch extension + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h1, alpha1, &rs).unwrap(); - let msg_from_sender = sender.extend(h1, &qs, maskbits).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); + + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); + + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); + + // Check + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); + + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = msg_for_receiver; + + let RCOTSenderOutput { msgs: y_star, .. } = msg_for_sender; + + let check_from_receiver = receiver.check_pre(&x_star).unwrap(); - receiver.extend(h1, alpha1, &ts, msg_from_sender).unwrap(); + let (mut output_sender, check) = sender.check(&y_star, check_from_receiver).unwrap(); - // Extend twice - let h2 = 4; - let alpha2 = 2; + let output_receiver = receiver.check(&z_star, check).unwrap(); - let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h2); + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + + // extend twice + let hs = [6, 9, 8]; + let alphas = [2, 1, 3]; + + let h_sum = hs.iter().sum(); + + let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(h_sum); let RCOTReceiverOutput { - choices: rs, - msgs: ts, + choices: rss, + msgs: tss, .. } = msg_for_receiver; - let RCOTSenderOutput { msgs: qs, .. } = msg_for_sender; - let maskbits = receiver.extend_mask_bits(h2, alpha2, &rs).unwrap(); + let RCOTSenderOutput { msgs: qss, .. } = msg_for_sender; + + let maskbits = receiver.extend_mask_bits(&hs, &alphas, &rss).unwrap(); - let msg_from_sender = sender.extend(h2, &qs, maskbits).unwrap(); + let msg_from_sender = sender.extend(&hs, &qss, &maskbits).unwrap(); - receiver.extend(h2, alpha2, &ts, msg_from_sender).unwrap(); + receiver + .extend(&hs, &alphas, &tss, &msg_from_sender) + .unwrap(); // Check let (msg_for_sender, msg_for_receiver) = ideal_cot.random_correlated(CSP); diff --git a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs index 5e860f31..baf10ae2 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/receiver.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/receiver.rs @@ -6,6 +6,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -43,71 +47,101 @@ impl Receiver { } impl Receiver { - /// Performs the mask bit step in extension. + /// Performs the mask bit step in batch in extension. /// /// See step 4 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `rs` - The message from COT ideal functionality for the receiver. Only the random bits are used. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `rss` - The message from COT ideal functionality for the receiver for all the tress. Only the random bits are used. pub fn extend_mask_bits( &mut self, - h: usize, - alpha: u32, - rs: &[bool], - ) -> Result { + hs: &[usize], + alphas: &[u32], + rss: &[bool], + ) -> Result, ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( "extension is not allowed".to_string(), )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - if rs.len() != h { + let h_sum = hs.iter().sum(); + + if rss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of r should be h".to_string(), + "the length of r should be the sum of h".to_string(), )); } - // Step 4 in Figure 6 + let mut rs_s = vec![Vec::::new(); hs.len()]; + let mut rss_vec = rss.to_vec(); + for (index, h) in hs.iter().enumerate() { + rs_s[index] = rss_vec.drain(0..*h).collect(); + } - let bs: Vec = alpha - .iter_msb0() - .skip(32 - h) - // Computes alpha_i XOR r_i XOR 1. - .zip(rs.iter()) - .map(|(alpha, &r)| alpha == r) - .collect(); + // Step 4 in Figure 6 + let mut bss = vec![Vec::::new(); hs.len()]; + + let iter = bss + .iter_mut() + .zip(alphas.iter()) + .zip(hs.iter()) + .zip(rs_s.iter()) + .map(|(((bs, alpha), h), rs)| (bs, alpha, h, rs)); + + for (bs, alpha, h, rs) in iter { + *bs = alpha + .iter_msb0() + .skip(32 - h) + // Computes alpha_i XOR r_i XOR 1. + .zip(rs.iter()) + .map(|(alpha, &r)| alpha == r) + .collect(); + } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); + + let res: Vec = bss.into_iter().map(|bs| MaskBits { bs }).collect(); - Ok(MaskBits { bs }) + Ok(res) } - /// Performs the GGM reconstruction step in extension. This function can be called multiple times before checking. + /// Performs the GGM reconstruction step in batch in extension. This function can be called multiple times before checking. /// /// See step 5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `alpha` - The chosen position. - /// * `ts` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. - /// * `extendfs` - The message sent by the sender. + /// * `hs` - The depths of the GGM trees. + /// * `alphas` - The vector of chosen positions. + /// * `tss` - The message from COT ideal functionality for the receiver. Only the chosen blocks are used. + /// * `extendfss` - The vector of messages sent by the sender. pub fn extend( &mut self, - h: usize, - alpha: u32, - ts: &[Block], - extendfs: ExtendFromSender, + hs: &[usize], + alphas: &[u32], + tss: &[Block], + extendfss: &[ExtendFromSender], ) -> Result<(), ReceiverError> { if self.state.extended { return Err(ReceiverError::InvalidState( @@ -115,61 +149,122 @@ impl Receiver { )); } - if alpha >= (1 << h) { + if alphas.len() != hs.len() { + return Err(ReceiverError::InvalidLength( + "the length of alphas should be the length of hs".to_string(), + )); + } + + if alphas + .iter() + .zip(hs.iter()) + .any(|(alpha, h)| *alpha >= (1 << h)) + { return Err(ReceiverError::InvalidInput( "the input pos should be no more than 2^h-1".to_string(), )); } - let ExtendFromSender { ms, sum } = extendfs; - if ts.len() != h { + let h_sum = hs.iter().sum(); + + if tss.len() != h_sum { return Err(ReceiverError::InvalidLength( - "the length of t should be h".to_string(), + "the length of tss should be the sum of h".to_string(), )); } - if ms.len() != h { + let mut ts_s = vec![Vec::::new(); hs.len()]; + let mut tss_vec = tss.to_vec(); + for (index, h) in hs.iter().enumerate() { + ts_s[index] = tss_vec.drain(0..*h).collect(); + } + + if extendfss.len() != hs.len() { return Err(ReceiverError::InvalidLength( - "the length of M should be h".to_string(), + "the length of extendfss should be the length of hs".to_string(), )); } - // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); - - let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); - - // Step 5 in Figure 6. - let k: Vec = ms - .into_iter() - .zip(ts) - .zip(alpha_bar_vec.iter()) - .enumerate() - .map(|(i, (([m0, m1], &t), &b))| { - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - if !b { - // H(t, i|ell) ^ M0 - FIXED_KEY_AES.tccr(tweak, t) ^ m0 - } else { - // H(t, i|ell) ^ M1 - FIXED_KEY_AES.tccr(tweak, t) ^ m1 - } - }) - .collect(); + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; - // Reconstructs GGM tree except `ws[alpha]`. - let ggm_tree = GgmTree::new(h); - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.reconstruct(&mut tree, &k, &alpha_bar_vec); + for (index, extendfs) in extendfss.iter().enumerate() { + ms_s[index].clone_from(&extendfs.ms); + sum_s[index] = extendfs.sum; + } + + if ms_s.iter().zip(hs.iter()).any(|(ms, h)| ms.len() != *h) { + return Err(ReceiverError::InvalidLength( + "the length of ms should be h".to_string(), + )); + } + // Updates hasher + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); + + let mut trees = vec![Vec::::new(); hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = alphas + .par_iter() + .zip(ms_s.par_iter()) + .zip(sum_s.par_iter()) + .zip(hs.par_iter()) + .zip(ts_s.par_iter()) + .zip(trees.par_iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + }else{ + let iter = alphas + .iter() + .zip(ms_s.iter()) + .zip(sum_s.iter()) + .zip(hs.iter()) + .zip(ts_s.iter()) + .zip(trees.iter_mut()) + .map(|(((((alpha, ms), sum), h), ts), tree)| (alpha, ms, sum, h, ts, tree)); + } + } - // Sets `tree[alpha]`, which is `ws[alpha]`. - tree[alpha as usize] = tree.iter().fold(sum, |acc, &x| acc ^ x); + iter.for_each(|(alpha, ms, sum, h, ts, tree)| { + let alpha_bar_vec: Vec = alpha.iter_msb0().skip(32 - h).map(|a| !a).collect(); + + // Step 5 in Figure 6. + let k: Vec = ms + .iter() + .zip(ts) + .zip(alpha_bar_vec.iter()) + .enumerate() + .map(|(i, (([m0, m1], &t), &b))| { + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + if !b { + // H(t, i|ell) ^ M0 + FIXED_KEY_AES.tccr(tweak, t) ^ *m0 + } else { + // H(t, i|ell) ^ M1 + FIXED_KEY_AES.tccr(tweak, t) ^ *m1 + } + }) + .collect(); + + // Reconstructs GGM tree except `ws[alpha]`. + let ggm_tree = GgmTree::new(*h); + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.reconstruct(tree, &k, &alpha_bar_vec); + + // Sets `tree[alpha]`, which is `ws[alpha]`. + tree[(*alpha) as usize] = tree.iter().fold(*sum, |acc, &x| acc ^ x); + }); + + for tree in trees { + self.state.unchecked_ws.extend_from_slice(&tree); + } - self.state.unchecked_ws.extend_from_slice(&tree); - self.state.alphas_and_length.push((alpha, 1 << h)); + for (alpha, h) in alphas.iter().zip(hs.iter()) { + self.state.alphas_and_length.push((*alpha, 1 << h)); + } - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); Ok(()) } @@ -248,7 +343,6 @@ impl Receiver { } self.state.cot_counter += self.state.unchecked_ws.len(); - self.state.extended = true; let mut res = Vec::new(); for (alpha, n) in &self.state.alphas_and_length { @@ -256,8 +350,19 @@ impl Receiver { res.push((tmp, *alpha)); } + self.state.hasher = blake3::Hasher::new(); + self.state.alphas_and_length.clear(); + self.state.chis.clear(); + self.state.unchecked_ws.clear(); + Ok(res) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The receiver's state. diff --git a/crates/mpz-ot-core/src/ferret/spcot/sender.rs b/crates/mpz-ot-core/src/ferret/spcot/sender.rs index fef1327e..a62ad3bb 100644 --- a/crates/mpz-ot-core/src/ferret/spcot/sender.rs +++ b/crates/mpz-ot-core/src/ferret/spcot/sender.rs @@ -5,6 +5,10 @@ use mpz_core::{ utils::blake3, Block, }; use rand_core::SeedableRng; +#[cfg(feature = "rayon")] +use rayon::iter::{ + IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, +}; use super::msgs::{CheckFromReceiver, CheckFromSender, ExtendFromSender, MaskBits}; @@ -29,8 +33,7 @@ impl Sender { /// # Arguments /// /// * `delta` - The sender's global secret. - /// * `seed` - The random seed to generate PRG. - pub fn setup(self, delta: Block, seed: Block) -> Sender { + pub fn setup(self, delta: Block) -> Sender { Sender { state: state::Extension { delta, @@ -39,7 +42,6 @@ impl Sender { cot_counter: 0, exec_counter: 0, extended: false, - prg: Prg::from_seed(seed), hasher: blake3::Hasher::new(), }, } @@ -47,85 +49,137 @@ impl Sender { } impl Sender { - /// Performs the SPCOT extension. + /// Performs batch SPCOT extension. /// /// See Step 1-5 in Figure 6. /// /// # Arguments /// - /// * `h` - The depth of the GGM tree. - /// * `qs`- The blocks received by calling the COT functionality. - /// * `mask`- The mask bits sent by the receiver. + /// * `hs` - The depths of the GGM trees. + /// * `qss`- The blocks received by calling the COT functionality for hs trees. + /// * `masks`- The vector of mask bits sent by the receiver. pub fn extend( &mut self, - h: usize, - qs: &[Block], - mask: MaskBits, - ) -> Result { + hs: &[usize], + qss: &[Block], + masks: &[MaskBits], + ) -> Result, SenderError> { if self.state.extended { return Err(SenderError::InvalidState( "extension is not allowed".to_string(), )); } - if qs.len() != h { + let h_sum = hs.iter().sum(); + + if qss.len() != h_sum { return Err(SenderError::InvalidLength( - "the length of q should be h".to_string(), + "the length of qss should be the sum of h".to_string(), )); } - let MaskBits { bs } = mask; + let mut qs_s = vec![Vec::::new(); hs.len()]; + let mut qss_vec = qss.to_vec(); + for (index, h) in hs.iter().enumerate() { + qs_s[index] = qss_vec.drain(0..*h).collect(); + } - if bs.len() != h { + if masks.len() != hs.len() { + return Err(SenderError::InvalidLength( + "the length of masks should be the length of hs".to_string(), + )); + } + + let bss: Vec> = masks.iter().map(|m| m.clone().bs).collect(); + + if bss.iter().zip(hs.iter()).any(|(b, h)| b.len() != *h) { return Err(SenderError::InvalidLength( "the length of b should be h".to_string(), )); } // Updates hasher. - self.state.hasher.update(&bs.to_bytes()); + self.state.hasher.update(&bss.to_bytes()); // Step 3-4, Figure 6. // Generates a GGM tree with depth h and seed s. - let s = self.state.prg.random_block(); - let ggm_tree = GgmTree::new(h); - let mut k0 = vec![Block::ZERO; h]; - let mut k1 = vec![Block::ZERO; h]; - let mut tree = vec![Block::ZERO; 1 << h]; - ggm_tree.gen(s, &mut tree, &mut k0, &mut k1); + let mut trees = vec![Vec::::new(); hs.len()]; + let mut ms_s = vec![Vec::<[Block; 2]>::new(); hs.len()]; + let mut sum_s = vec![Block::ZERO; hs.len()]; + + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = trees + .par_iter_mut().zip(hs.par_iter()) + .zip(qs_s.par_iter()) + .zip(bss.par_iter()) + .zip(ms_s.par_iter_mut()) + .zip(sum_s.par_iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + }else{ + let iter = trees + .iter_mut() + .zip(hs.iter()) + .zip(qs_s.iter()) + .zip(bss.iter()) + .zip(ms_s.iter_mut()) + .zip(sum_s.iter_mut()) + .map(|(((((tree, h), qs), bs), ms), sum)| (tree, h, qs, bs, ms, sum)); + } + } + + iter.for_each(|(tree, h, qs, bs, ms, sum)| { + let s = Prg::new().random_block(); + let ggm_tree = GgmTree::new(*h); + let mut k0 = vec![Block::ZERO; *h]; + let mut k1 = vec![Block::ZERO; *h]; + *tree = vec![Block::ZERO; 1 << h]; + ggm_tree.gen(s, tree, &mut k0, &mut k1); + + // Computes the sum of the leaves and delta. + *sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); + + // Computes M0 and M1. + for (((i, &q), b), (k0, k1)) in + qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) + { + let mut m = if *b { + [q ^ self.state.delta, q] + } else { + [q, q ^ self.state.delta] + }; + let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); + FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); + m[0] ^= k0; + m[1] ^= k1; + ms.push(m); + } + }); // Stores the tree, i.e., the possible output of sender. - self.state.unchecked_vs.extend_from_slice(&tree); + for tree in trees { + self.state.unchecked_vs.extend_from_slice(&tree); + } // Stores the length of this extension. - self.state.vs_length.push(1 << h); - - // Computes the sum of the leaves and delta. - let sum = tree.iter().fold(self.state.delta, |acc, &x| acc ^ x); - - // Computes M0 and M1. - let mut ms: Vec<[Block; 2]> = Vec::with_capacity(qs.len()); - for (((i, &q), b), (k0, k1)) in qs.iter().enumerate().zip(bs).zip(k0.into_iter().zip(k1)) { - let mut m = if b { - [q ^ self.state.delta, q] - } else { - [q, q ^ self.state.delta] - }; - let tweak: Block = bytemuck::cast([i, self.state.exec_counter]); - FIXED_KEY_AES.tccr_many(&[tweak, tweak], &mut m); - m[0] ^= k0; - m[1] ^= k1; - ms.push(m); + for h in hs { + self.state.vs_length.push(1 << h); } // Updates hasher - self.state.hasher.update(&ms.to_bytes()); - self.state.hasher.update(&sum.to_bytes()); + self.state.hasher.update(&ms_s.to_bytes()); + self.state.hasher.update(&sum_s.to_bytes()); - self.state.exec_counter += 1; + self.state.exec_counter += hs.len(); + + let res: Vec = ms_s + .into_iter() + .zip(sum_s.iter()) + .map(|(ms, &sum)| ExtendFromSender { ms, sum }) + .collect(); - Ok(ExtendFromSender { ms, sum }) + Ok(res) } /// Performs the consistency check for the resulting COTs. @@ -193,10 +247,18 @@ impl Sender { res.push(tmp); } - self.state.extended = true; + self.state.hasher = blake3::Hasher::new(); + self.state.unchecked_vs.clear(); + self.state.vs_length.clear(); Ok((res, CheckFromSender { hashed_v })) } + + /// Complete extension. + #[inline] + pub fn finalize(&mut self) { + self.state.extended = true; + } } /// The sender's state. @@ -239,8 +301,6 @@ pub mod state { /// This is to prevent the receiver from extending twice pub(super) extended: bool, - /// A PRG to generate random strings. - pub(super) prg: Prg, /// A hasher to generate chi seed. pub(super) hasher: blake3::Hasher, } diff --git a/crates/mpz-ot-core/src/ideal/cot.rs b/crates/mpz-ot-core/src/ideal/cot.rs index a28abef8..a842129d 100644 --- a/crates/mpz-ot-core/src/ideal/cot.rs +++ b/crates/mpz-ot-core/src/ideal/cot.rs @@ -76,7 +76,7 @@ impl IdealCOT { .collect(); self.counter += count; - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( RCOTSenderOutput { id, msgs }, diff --git a/crates/mpz-ot-core/src/ideal/mpcot.rs b/crates/mpz-ot-core/src/ideal/mpcot.rs index 44a5595f..c038331b 100644 --- a/crates/mpz-ot-core/src/ideal/mpcot.rs +++ b/crates/mpz-ot-core/src/ideal/mpcot.rs @@ -60,7 +60,7 @@ impl IdealMpcot { self.counter += 1; } - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (MPCOTSenderOutput { id, s }, MPCOTReceiverOutput { id, r }) } diff --git a/crates/mpz-ot-core/src/ideal/ot.rs b/crates/mpz-ot-core/src/ideal/ot.rs index e389066e..76ebe630 100644 --- a/crates/mpz-ot-core/src/ideal/ot.rs +++ b/crates/mpz-ot-core/src/ideal/ot.rs @@ -55,7 +55,7 @@ impl IdealOT { self.counter += choices.len(); self.choices.extend(choices); - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (OTSenderOutput { id }, OTReceiverOutput { id, msgs: chosen }) } diff --git a/crates/mpz-ot-core/src/ideal/rot.rs b/crates/mpz-ot-core/src/ideal/rot.rs index 8a8b5d68..e29b9204 100644 --- a/crates/mpz-ot-core/src/ideal/rot.rs +++ b/crates/mpz-ot-core/src/ideal/rot.rs @@ -68,7 +68,7 @@ impl IdealROT { .collect(); self.counter += count; - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( ROTSenderOutput { id, msgs }, @@ -103,7 +103,7 @@ impl IdealROT { .collect(); self.counter += choices.len(); - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); ( ROTSenderOutput { id, msgs }, diff --git a/crates/mpz-ot-core/src/ideal/spcot.rs b/crates/mpz-ot-core/src/ideal/spcot.rs index 12c5f829..93b3c720 100644 --- a/crates/mpz-ot-core/src/ideal/spcot.rs +++ b/crates/mpz-ot-core/src/ideal/spcot.rs @@ -61,7 +61,7 @@ impl IdealSpcot { self.counter += n; } - let id = self.transfer_id.next(); + let id = self.transfer_id.next_id(); (SPCOTSenderOutput { id, v }, SPCOTReceiverOutput { id, w }) } diff --git a/crates/mpz-ot-core/src/kos/receiver.rs b/crates/mpz-ot-core/src/kos/receiver.rs index fdcad328..127c4f1d 100644 --- a/crates/mpz-ot-core/src/kos/receiver.rs +++ b/crates/mpz-ot-core/src/kos/receiver.rs @@ -330,7 +330,7 @@ impl Receiver { )); } - let id = self.state.transfer_id.next(); + let id = self.state.transfer_id.next_id(); let index = self.state.index - self.state.keys.len(); Ok(ReceiverKeys { diff --git a/crates/mpz-ot-core/src/kos/sender.rs b/crates/mpz-ot-core/src/kos/sender.rs index 24917940..23edff5c 100644 --- a/crates/mpz-ot-core/src/kos/sender.rs +++ b/crates/mpz-ot-core/src/kos/sender.rs @@ -294,7 +294,7 @@ impl Sender { return Err(SenderError::InsufficientSetup(count, self.state.keys.len())); } - let id = self.state.transfer_id.next(); + let id = self.state.transfer_id.next_id(); Ok(SenderKeys { id, diff --git a/crates/mpz-ot-core/src/lib.rs b/crates/mpz-ot-core/src/lib.rs index 8dd77287..b0b69260 100644 --- a/crates/mpz-ot-core/src/lib.rs +++ b/crates/mpz-ot-core/src/lib.rs @@ -45,7 +45,7 @@ impl std::fmt::Display for TransferId { impl TransferId { /// Returns the current transfer ID, incrementing `self` in-place. - pub(crate) fn next(&mut self) -> Self { + pub fn next_id(&mut self) -> Self { let id = *self; self.0 += 1; id diff --git a/crates/mpz-ot/examples/ferret.rs b/crates/mpz-ot/examples/ferret.rs new file mode 100644 index 00000000..f328e4d9 --- /dev/null +++ b/crates/mpz-ot/examples/ferret.rs @@ -0,0 +1 @@ +fn main() {} diff --git a/crates/mpz-ot/src/ferret/error.rs b/crates/mpz-ot/src/ferret/error.rs new file mode 100644 index 00000000..4e428a4b --- /dev/null +++ b/crates/mpz-ot/src/ferret/error.rs @@ -0,0 +1,342 @@ +use std::fmt::Display; + +/// Ferret sender error. +#[derive(Debug, thiserror::Error)] +pub struct SenderError { + kind: SenderErrorKind, + #[source] + source: Option>, +} + +impl SenderError { + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: SenderErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn io(msg: impl Into) -> Self { + Self { + kind: SenderErrorKind::Io, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +enum SenderErrorKind { + Io, + State, + Core, + Rcot, + Mpcot, +} + +impl Display for SenderError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + SenderErrorKind::Io => f.write_str("io error")?, + SenderErrorKind::State => f.write_str("state error")?, + SenderErrorKind::Core => f.write_str("core error")?, + SenderErrorKind::Rcot => f.write_str("rcot error")?, + SenderErrorKind::Mpcot => f.write_str("mpcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } +} + +impl From for SenderError { + fn from(err: std::io::Error) -> Self { + Self { + kind: SenderErrorKind::Io, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: mpz_ot_core::ferret::error::SenderError) -> Self { + Self { + kind: SenderErrorKind::Core, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: crate::OTError) -> Self { + Self { + kind: SenderErrorKind::Rcot, + source: Some(Box::new(err)), + } + } +} + +impl From for SenderError { + fn from(err: MPCOTError) -> Self { + Self { + kind: SenderErrorKind::Mpcot, + source: Some(Box::new(err)), + } + } +} + +impl From for crate::OTError { + fn from(err: SenderError) -> Self { + crate::OTError::SenderError(Box::new(err)) + } +} + +/// Ferret receiver error. +#[derive(Debug, thiserror::Error)] +pub struct ReceiverError { + kind: ReceiverErrorKind, + #[source] + source: Option>, +} + +impl ReceiverError { + pub(crate) fn state(msg: impl Into) -> Self { + Self { + kind: ReceiverErrorKind::State, + source: Some(msg.into().into()), + } + } + + pub(crate) fn io(msg: impl Into) -> Self { + Self { + kind: ReceiverErrorKind::Io, + source: Some(msg.into().into()), + } + } +} + +#[derive(Debug)] +enum ReceiverErrorKind { + Io, + State, + Core, + Rcot, + Mpcot, +} + +impl Display for ReceiverError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ReceiverErrorKind::Io => f.write_str("io error")?, + ReceiverErrorKind::State => f.write_str("state error")?, + ReceiverErrorKind::Core => f.write_str("core error")?, + ReceiverErrorKind::Rcot => f.write_str("rcot error")?, + ReceiverErrorKind::Mpcot => f.write_str("mpcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } +} + +impl From for ReceiverError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ReceiverErrorKind::Io, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: mpz_ot_core::ferret::error::ReceiverError) -> Self { + Self { + kind: ReceiverErrorKind::Core, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ReceiverErrorKind::Rcot, + source: Some(Box::new(err)), + } + } +} + +impl From for ReceiverError { + fn from(err: MPCOTError) -> Self { + Self { + kind: ReceiverErrorKind::Mpcot, + source: Some(Box::new(err)), + } + } +} + +impl From for crate::OTError { + fn from(err: ReceiverError) -> Self { + crate::OTError::ReceiverError(Box::new(err)) + } +} + +mod mpcot { + use super::*; + + /// MPCOT error. + #[derive(Debug, thiserror::Error)] + pub(crate) struct MPCOTError { + kind: ErrorKind, + #[source] + source: Option>, + } + + #[derive(Debug)] + enum ErrorKind { + Io, + Core, + Rcot, + Spcot, + } + + impl Display for MPCOTError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Core => f.write_str("core error")?, + ErrorKind::Rcot => f.write_str("rcot error")?, + ErrorKind::Spcot => f.write_str("spcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } + } + + impl From for MPCOTError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ErrorKind::Io, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: mpz_ot_core::ferret::mpcot::error::SenderError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: mpz_ot_core::ferret::mpcot::error::ReceiverError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: SPCOTError) -> Self { + Self { + kind: ErrorKind::Spcot, + source: Some(Box::new(err)), + } + } + } + + impl From for MPCOTError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ErrorKind::Rcot, + source: Some(Box::new(err)), + } + } + } +} +pub(crate) use mpcot::MPCOTError; + +mod spcot { + use super::*; + + /// SPCOT error. + #[derive(Debug, thiserror::Error)] + pub(crate) struct SPCOTError { + kind: ErrorKind, + #[source] + source: Option>, + } + + #[derive(Debug)] + enum ErrorKind { + Io, + Core, + Rcot, + } + + impl Display for SPCOTError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.kind { + ErrorKind::Io => f.write_str("io error")?, + ErrorKind::Core => f.write_str("core error")?, + ErrorKind::Rcot => f.write_str("rcot error")?, + } + + if let Some(source) = &self.source { + write!(f, " caused by: {}", source) + } else { + Ok(()) + } + } + } + + impl From for SPCOTError { + fn from(err: std::io::Error) -> Self { + Self { + kind: ErrorKind::Io, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: mpz_ot_core::ferret::spcot::error::SenderError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: mpz_ot_core::ferret::spcot::error::ReceiverError) -> Self { + Self { + kind: ErrorKind::Core, + source: Some(Box::new(err)), + } + } + } + + impl From for SPCOTError { + fn from(err: crate::OTError) -> Self { + Self { + kind: ErrorKind::Rcot, + source: Some(Box::new(err)), + } + } + } +} +pub(crate) use spcot::SPCOTError; diff --git a/crates/mpz-ot/src/ferret/mod.rs b/crates/mpz-ot/src/ferret/mod.rs new file mode 100644 index 00000000..9d421885 --- /dev/null +++ b/crates/mpz-ot/src/ferret/mod.rs @@ -0,0 +1,256 @@ +//! An implementation of the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) protocol. +mod error; +mod mpcot; +mod receiver; +mod sender; +mod spcot; + +pub use error::{ReceiverError, SenderError}; +pub use receiver::Receiver; +pub use sender::Sender; + +use mpz_core::lpn::LpnParameters; +use mpz_ot_core::ferret::LpnType; + +/// Configuration of Ferret. +#[derive(Debug, Clone)] +pub struct FerretConfig { + lpn_parameters: LpnParameters, + lpn_type: LpnType, +} + +impl FerretConfig { + /// Create a new instance. + /// + /// # Arguments. + /// + /// * `lpn_parameters` - The parameters of LPN. + /// * `lpn_type` - The type of LPN. + pub fn new(lpn_parameters: LpnParameters, lpn_type: LpnType) -> Self { + Self { + lpn_parameters, + lpn_type, + } + } + + /// Get the lpn type + pub fn lpn_type(&self) -> LpnType { + self.lpn_type + } + + /// Get the lpn parameters + pub fn lpn_parameters(&self) -> LpnParameters { + self.lpn_parameters + } +} + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with small extension output. +pub const FERRET_REGULAR_SETUP_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 102_400, + k: 6_750, + t: 1_600, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with small extension output. +pub const FERRET_REGULAR_EXTENSION_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 1_740_800, + k: 66_400, + t: 1700, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with medium extension output. +pub const FERRET_REGULAR_SETUP_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 283_648, + k: 18_584, + t: 1_108, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with medium extension output. +pub const FERRET_REGULAR_EXTENSION_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for setup with large extension output. +pub const FERRET_REGULAR_SETUP_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 518_656, + k: 34_643, + t: 1_013, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with regular LPN parameters. +/// Parameters for extension with large extension output. +pub const FERRET_REGULAR_EXTENSION_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 10_485_760, + k: 458_000, + t: 1280, + }, + lpn_type: LpnType::Regular, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with small extension output. +pub const FERRET_UNIFORM_SETUP_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 98_000, + k: 4_450, + t: 1_600, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with small extension output. +pub const FERRET_UNIFORM_EXTENSION_SMALL: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 1_071_888, + k: 40_800, + t: 1720, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with medium extension output. +pub const FERRET_UNIFORM_SETUP_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 283_648, + k: 18_584, + t: 1_108, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with medium extension output. +pub const FERRET_UNIFORM_EXTENSION_MEDIUM: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 5_324_800, + k: 240_000, + t: 1_300, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for setup with large extension output. +pub const FERRET_UNIFORM_SETUP_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 545_656, + k: 34_643, + t: 1_050, + }, + lpn_type: LpnType::Uniform, +}; + +/// Ferret config with uniform LPN parameters. +/// Parameters for extension with large extension output. +pub const FERRET_UNIFORM_EXTENSION_LARGE: FerretConfig = FerretConfig { + lpn_parameters: LpnParameters { + n: 10_488_928, + k: 458_000, + t: 1_280, + }, + lpn_type: LpnType::Uniform, +}; + +#[cfg(test)] +mod tests { + use super::*; + use futures::TryFutureExt as _; + use mpz_common::executor::test_st_executor; + use mpz_core::lpn::LpnParameters; + use mpz_ot_core::{ferret::LpnType, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput}; + use rstest::*; + + use crate::{ideal::cot::ideal_rcot, Correlation, OTError, RandomCOTReceiver, RandomCOTSender}; + + // l = n - k = 8380 + const LPN_PARAMETERS_TEST: LpnParameters = LpnParameters { + n: 9600, + k: 1220, + t: 600, + }; + + #[rstest] + #[case::uniform(LpnType::Uniform)] + #[case::regular(LpnType::Regular)] + #[tokio::test] + async fn test_ferret(#[case] lpn_type: LpnType) { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (rcot_sender, rcot_receiver) = ideal_rcot(); + + let config = FerretConfig::new(LPN_PARAMETERS_TEST, lpn_type); + + let mut sender = Sender::new(config.clone(), rcot_sender); + let mut receiver = Receiver::new(config, rcot_receiver); + + tokio::try_join!( + sender.setup(&mut ctx_sender).map_err(OTError::from), + receiver.setup(&mut ctx_receiver).map_err(OTError::from) + ) + .unwrap(); + + // extend once. + let count = LPN_PARAMETERS_TEST.k; + tokio::try_join!( + sender.extend(&mut ctx_sender, count).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, count) + .map_err(OTError::from) + ) + .unwrap(); + + // extend twice + let count = 10000; + tokio::try_join!( + sender.extend(&mut ctx_sender, count).map_err(OTError::from), + receiver + .extend(&mut ctx_receiver, count) + .map_err(OTError::from) + ) + .unwrap(); + + let ( + RCOTSenderOutput { + id: sender_id, + msgs: u, + }, + RCOTReceiverOutput { + id: receiver_id, + choices: b, + msgs: w, + }, + ) = tokio::try_join!( + sender.send_random_correlated(&mut ctx_sender, count), + receiver.receive_random_correlated(&mut ctx_receiver, count) + ) + .unwrap(); + + assert_eq!(sender_id, receiver_id); + assert_cot(sender.delta(), &b, &u, &w); + } +} diff --git a/crates/mpz-ot/src/ferret/mpcot.rs b/crates/mpz-ot/src/ferret/mpcot.rs new file mode 100644 index 00000000..be7de33a --- /dev/null +++ b/crates/mpz-ot/src/ferret/mpcot.rs @@ -0,0 +1,185 @@ +//! Implementation of the Multiple-Point COT (mpcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::ferret::{ + mpcot::{ + msgs::HashSeed, receiver::Receiver as UniformReceiverCore, + receiver_regular::Receiver as RegularReceiverCore, sender::Sender as UniformSender, + sender_regular::Sender as RegularSender, + }, + LpnType, +}; +use serio::{stream::IoStreamExt as _, SinkExt as _}; + +use crate::{ + ferret::{error::MPCOTError as Error, spcot}, + RandomCOTReceiver, RandomCOTSender, +}; + +/// MPCOT send. +/// +/// # Arguments. +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT sender. +/// * `delta` - Delta correlation. +/// * `lpn_type` - The type of LPN. +/// * `t` - The number of queried indices. +/// * `n` - The total number of indices. +pub(crate) async fn send>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + delta: Block, + lpn_type: LpnType, + t: u32, + n: u32, +) -> Result, Error> { + match lpn_type { + LpnType::Uniform => { + let hash_seed: HashSeed = ctx.io_mut().expect_next().await?; + + let (sender, hs) = CpuBackend::blocking(move || { + UniformSender::new() + .setup(delta, hash_seed) + .pre_extend(t, n) + }) + .await?; + + let st = spcot::send(ctx, rcot, delta, &hs).await?; + + let (_, output) = CpuBackend::blocking(move || sender.extend(&st)).await?; + + Ok(output) + } + LpnType::Regular => { + let (sender, hs) = + CpuBackend::blocking(move || RegularSender::new().setup(delta).pre_extend(t, n)) + .await?; + + let st = spcot::send(ctx, rcot, delta, &hs).await?; + + let (_, output) = CpuBackend::blocking(move || sender.extend(&st)).await?; + + Ok(output) + } + } +} + +/// MPCOT receive. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT receiver. +/// * `lpn_type` - The type of LPN. +/// * `alphas` - The queried indices. +/// * `n` - The total number of indices. +pub(crate) async fn receive>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + lpn_type: LpnType, + alphas: Vec, + n: u32, +) -> Result, Error> { + match lpn_type { + LpnType::Uniform => { + let hash_seed = Prg::new().random_block(); + + let (receiver, hash_seed) = UniformReceiverCore::new().setup(hash_seed); + + ctx.io_mut().send(hash_seed).await?; + + let (receiver, h_and_pos) = + CpuBackend::blocking(move || receiver.pre_extend(&alphas, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + let rt = spcot::receive(ctx, rcot, &pos, &hs).await?; + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (_, output) = CpuBackend::blocking(move || receiver.extend(&rt)).await?; + + Ok(output) + } + LpnType::Regular => { + let receiver = RegularReceiverCore::new().setup(); + + let (receiver, h_and_pos) = + CpuBackend::blocking(move || receiver.pre_extend(&alphas, n)).await?; + + let mut hs = vec![0usize; h_and_pos.len()]; + + let mut pos = vec![0u32; h_and_pos.len()]; + for (index, (h, p)) in h_and_pos.iter().enumerate() { + hs[index] = *h; + pos[index] = *p; + } + + let rt = spcot::receive(ctx, rcot, &pos, &hs).await?; + let rt: Vec> = rt.into_iter().map(|(elem, _)| elem).collect(); + let (_, output) = CpuBackend::blocking(move || receiver.extend(&rt)).await?; + + Ok(output) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ideal::cot::ideal_rcot; + use mpz_common::executor::test_st_executor; + use mpz_ot_core::ferret::LpnType; + use rstest::*; + + #[rstest] + #[case(LpnType::Uniform)] + #[case(LpnType::Regular)] + #[tokio::test] + async fn test_mpcot(#[case] lpn_type: LpnType) { + use crate::Correlation; + + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let alphas = match lpn_type { + LpnType::Uniform => vec![0, 1, 3, 4, 2], + LpnType::Regular => vec![0, 3, 4, 7, 9], + }; + + let t = alphas.len(); + let n = 10; + let delta = rcot_sender.delta(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + send( + &mut ctx_sender, + &mut rcot_sender, + delta, + lpn_type, + t as u32, + n + ), + receive( + &mut ctx_receiver, + &mut rcot_receiver, + lpn_type, + alphas.clone(), + n + ) + ) + .unwrap(); + + for i in alphas { + output_sender[i as usize] ^= delta; + } + + assert_eq!(output_sender, output_receiver); + } +} diff --git a/crates/mpz-ot/src/ferret/receiver.rs b/crates/mpz-ot/src/ferret/receiver.rs new file mode 100644 index 00000000..fbbb38eb --- /dev/null +++ b/crates/mpz-ot/src/ferret/receiver.rs @@ -0,0 +1,253 @@ +use std::mem; + +use async_trait::async_trait; +use mpz_common::{cpu::CpuBackend, Allocate, Context, Preprocess}; +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::{ + ferret::{ + receiver::{state, Receiver as ReceiverCore}, + LpnType, CSP, CUCKOO_HASH_NUM, + }, + RCOTReceiverOutput, +}; +use serio::SinkExt; + +use crate::{ + ferret::{mpcot, FerretConfig, ReceiverError}, + OTError, RandomCOTReceiver, +}; + +#[derive(Debug)] +pub(crate) enum State { + Initialized(Box>), + Extension(Box>), + Error, +} + +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + +/// Ferret Receiver. +#[derive(Debug)] +pub struct Receiver { + state: State, + config: FerretConfig, + rcot: RandomCOT, + alloc: usize, + buffer: ReceiverBuffer, + buffer_len: usize, +} + +impl Receiver { + /// Creates a new Receiver. + /// + /// # Arguments. + /// + /// * `config` - The Ferret config. + /// * `rcot` - The random COT in setup. + pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { + Self { + state: State::Initialized(Box::new(ReceiverCore::new())), + config, + rcot, + alloc: 0, + buffer: Default::default(), + buffer_len: 0, + } + } + + /// Setup for receiver. + /// + /// # Arguments. + /// + /// * `ctx` - The channel context. + pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), ReceiverError> + where + Ctx: Context, + RandomCOT: RandomCOTReceiver, + { + let State::Initialized(receiver) = self.state.take() else { + return Err(ReceiverError::state("receiver not in initialized state")); + }; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Compute the number of buffered OTs. + self.buffer_len = match lpn_type { + // The number here is a rough estimation to ensure sufficient buffer. + // It is hard to precisely compute the number because of the Cuckoo hashes. + LpnType::Uniform => { + let m = (1.5 * (params.t as f32)).ceil() as usize; + m * ((2 * CUCKOO_HASH_NUM * params.n / m) + .checked_next_power_of_two() + .expect("The length should be less than usize::MAX / 2 - 1") + .ilog2() as usize) + + CSP + } + // In our chosen paramters, we always set n is divided by t and n/t is a power of 2. + LpnType::Regular => { + assert!(params.n % params.t == 0 && (params.n / params.t).is_power_of_two()); + params.t * ((params.n / params.t).ilog2() as usize) + CSP + } + }; + + // Get random blocks from ideal Random COT. + let RCOTReceiverOutput { + choices: mut u, + msgs: mut w, + id, + } = self + .rcot + .receive_random_correlated(ctx, params.k + self.buffer_len) + .await?; + + // Initiate buffer. + let buffer = RCOTReceiverOutput { + id, + choices: u.drain(0..self.buffer_len).collect(), + msgs: w.drain(0..self.buffer_len).collect(), + }; + self.buffer = ReceiverBuffer::new(buffer); + + let seed = Prg::new().random_block(); + + let (receiver, seed) = receiver.setup(params, lpn_type, seed, &u, &w)?; + + ctx.io_mut().send(seed).await?; + + self.state = State::Extension(Box::new(receiver)); + + Ok(()) + } + + /// Performs extension. + /// + /// # Arguments + /// + /// * `ctx` - Thread context. + /// * `count` - The number of OTs to extend. + pub async fn extend(&mut self, ctx: &mut Ctx, count: usize) -> Result<(), ReceiverError> + where + Ctx: Context, + RandomCOT: RandomCOTReceiver + Send, + { + let State::Extension(mut receiver) = self.state.take() else { + return Err(ReceiverError::state("receiver not in extension state")); + }; + + let lpn_type = self.config.lpn_type(); + let target = receiver.remaining() + count; + while receiver.remaining() < target { + let (alphas, n) = receiver.get_mpcot_query(); + + let r = mpcot::receive(ctx, &mut self.buffer, lpn_type, alphas, n as u32).await?; + + receiver = CpuBackend::blocking(move || receiver.extend(r).map(|()| receiver)).await?; + + // Update receiver buffer. + let buffer = receiver + .consume(self.buffer_len) + .map_err(ReceiverError::from) + .map_err(OTError::from)?; + + self.buffer = ReceiverBuffer::new(buffer); + } + + self.state = State::Extension(receiver); + + Ok(()) + } +} + +#[async_trait] +impl RandomCOTReceiver for Receiver +where + RandomCOT: Send, +{ + async fn receive_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let State::Extension(receiver) = &mut self.state else { + return Err(ReceiverError::state("receiver not in extension state").into()); + }; + + receiver + .consume(count) + .map_err(ReceiverError::from) + .map_err(OTError::from) + } +} + +impl Allocate for Receiver { + fn alloc(&mut self, count: usize) { + self.alloc += count; + } +} + +#[async_trait] +impl Preprocess for Receiver +where + Ctx: Context, + RandomCOT: RandomCOTReceiver + Send, +{ + type Error = ReceiverError; + + async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let count = mem::take(&mut self.alloc); + self.extend(ctx, count).await + } +} + +#[derive(Debug)] +struct ReceiverBuffer { + buffer: RCOTReceiverOutput, +} + +impl ReceiverBuffer { + fn new(buffer: RCOTReceiverOutput) -> Self { + Self { buffer } + } +} + +impl Default for ReceiverBuffer { + fn default() -> Self { + ReceiverBuffer { + buffer: RCOTReceiverOutput { + id: Default::default(), + choices: Vec::new(), + msgs: Vec::new(), + }, + } + } +} + +#[async_trait] +impl RandomCOTReceiver for ReceiverBuffer { + async fn receive_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + if count > self.buffer.choices.len() { + return Err(ReceiverError::io(format!( + "insufficient OTs: {} < {count}", + self.buffer.choices.len() + )) + .into()); + } + + let choices = self.buffer.choices.drain(0..count).collect(); + let msgs = self.buffer.msgs.drain(0..count).collect(); + Ok(RCOTReceiverOutput { + id: self.buffer.id.next_id(), + choices, + msgs, + }) + } +} diff --git a/crates/mpz-ot/src/ferret/sender.rs b/crates/mpz-ot/src/ferret/sender.rs new file mode 100644 index 00000000..02884b2c --- /dev/null +++ b/crates/mpz-ot/src/ferret/sender.rs @@ -0,0 +1,294 @@ +use std::mem; + +use crate::{ferret::mpcot, Correlation, RandomCOTSender}; +use async_trait::async_trait; +use mpz_common::{cpu::CpuBackend, Allocate, Context, Preprocess}; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + sender::{state, Sender as SenderCore}, + LpnType, CSP, CUCKOO_HASH_NUM, + }, + RCOTSenderOutput, +}; +use serio::stream::IoStreamExt; + +use super::{FerretConfig, SenderError}; +use crate::OTError; + +#[derive(Debug)] +pub(crate) enum State { + Initialized(SenderCore), + Extension(SenderCore), + Error, +} + +impl State { + fn take(&mut self) -> Self { + std::mem::replace(self, State::Error) + } +} + +/// Ferret Sender. +#[derive(Debug)] +pub struct Sender { + state: State, + config: FerretConfig, + rcot: RandomCOT, + alloc: usize, + buffer: SenderBuffer, + buffer_len: usize, +} + +impl Sender { + /// Creates a new Sender. + /// + /// # Argument + /// + /// `config` - The Ferret config. + /// `rcot` - The random COT in setup. + pub fn new(config: FerretConfig, rcot: RandomCOT) -> Self { + Self { + state: State::Initialized(SenderCore::new()), + config, + rcot, + alloc: 0, + buffer: Default::default(), + buffer_len: 0, + } + } + + /// Setup with provided delta. + /// + /// # Argument + /// + /// * `ctx` - The channel context. + pub async fn setup(&mut self, ctx: &mut Ctx) -> Result<(), SenderError> + where + Ctx: Context, + RandomCOT: RandomCOTSender + Correlation, + { + let State::Initialized(sender) = self.state.take() else { + return Err(SenderError::state("sender not in initialized state")); + }; + + let params = self.config.lpn_parameters(); + let lpn_type = self.config.lpn_type(); + + // Compute the number of buffered OTs. + self.buffer_len = match lpn_type { + // The number here is a rough estimation to ensure sufficient buffer. + // It is hard to precisely compute the number because of the Cuckoo hashes. + LpnType::Uniform => { + let m = (1.5 * (params.t as f32)).ceil() as usize; + m * ((2 * CUCKOO_HASH_NUM * params.n / m) + .checked_next_power_of_two() + .expect("The length should be less than usize::MAX / 2 - 1") + .ilog2() as usize) + + CSP + } + // In our chosen paramters, we always set n is divided by t and n/t is a power of 2. + LpnType::Regular => { + assert!(params.n % params.t == 0 && (params.n / params.t).is_power_of_two()); + params.t * ((params.n / params.t).ilog2() as usize) + CSP + } + }; + + // Get random blocks from ideal Random COT. + let RCOTSenderOutput { msgs: mut v, id } = self + .rcot + .send_random_correlated(ctx, params.k + self.buffer_len) + .await?; + + // Initiate buffer. + let buffer = RCOTSenderOutput { + id, + msgs: v.drain(0..self.buffer_len).collect(), + }; + self.buffer = SenderBuffer::new(self.rcot.delta(), buffer); + + // Get seed for LPN matrix from receiver. + let seed = ctx.io_mut().expect_next().await?; + + // Ferret core setup. + let sender = sender.setup(self.rcot.delta(), params, lpn_type, seed, &v)?; + + self.state = State::Extension(sender); + + Ok(()) + } + + /// Performs extension. + /// + /// # Argument + /// + /// * `ctx` - Thread context. + /// * `count` - The number of OTs to extend. + pub async fn extend( + &mut self, + ctx: &mut Ctx, + count: usize, + ) -> Result<(), SenderError> + where + RandomCOT: RandomCOTSender + Send, + { + let State::Extension(mut sender) = self.state.take() else { + return Err(SenderError::state("sender not in extension state")); + }; + + let lpn_type = self.config.lpn_type(); + let delta = sender.delta(); + let target = sender.remaining() + count; + while sender.remaining() < target { + let (t, n) = sender.get_mpcot_query(); + + let s = mpcot::send(ctx, &mut self.buffer, delta, lpn_type, t, n).await?; + + sender = CpuBackend::blocking(move || sender.extend(s).map(|()| sender)).await?; + + // Update sender buffer. + let buffer = sender + .consume(self.buffer_len) + .map_err(SenderError::from) + .map_err(OTError::from)?; + + self.buffer = SenderBuffer::new(delta, buffer); + } + + self.state = State::Extension(sender); + + Ok(()) + } +} + +impl Correlation for Sender +where + RandomCOT: Correlation, +{ + type Correlation = Block; + + fn delta(&self) -> Self::Correlation { + self.rcot.delta() + } +} + +#[async_trait] +impl RandomCOTSender for Sender +where + RandomCOT: Correlation + Send, +{ + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + let State::Extension(sender) = &mut self.state else { + return Err(SenderError::state("sender not in extension state").into()); + }; + + sender + .consume(count) + .map_err(SenderError::from) + .map_err(OTError::from) + } +} + +impl Allocate for Sender { + fn alloc(&mut self, count: usize) { + self.alloc += count; + } +} + +#[async_trait] +impl Preprocess for Sender +where + Ctx: Context, + RandomCOT: RandomCOTSender + Send, +{ + type Error = SenderError; + + async fn preprocess(&mut self, ctx: &mut Ctx) -> Result<(), Self::Error> { + let count = mem::take(&mut self.alloc); + self.extend(ctx, count).await + } +} + +#[derive(Debug)] +struct SenderBuffer { + delta: Block, + buffer: RCOTSenderOutput, +} + +impl SenderBuffer { + fn new(delta: Block, buffer: RCOTSenderOutput) -> Self { + Self { delta, buffer } + } +} + +impl Default for SenderBuffer { + fn default() -> Self { + let buffer = RCOTSenderOutput { + id: Default::default(), + msgs: Vec::new(), + }; + Self { + delta: Block::ZERO, + buffer, + } + } +} +impl Correlation for SenderBuffer { + type Correlation = Block; + + fn delta(&self) -> Self::Correlation { + self.delta + } +} + +#[async_trait] +impl RandomCOTSender for SenderBuffer { + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + if count > self.buffer.msgs.len() { + return Err(SenderError::io(format!( + "insufficient OTs: {} < {count}", + self.buffer.msgs.len() + )) + .into()); + } + + let msgs = self.buffer.msgs.drain(0..count).collect(); + Ok(RCOTSenderOutput { + id: self.buffer.id.next_id(), + msgs, + }) + } +} + +#[derive(Debug)] +struct BootstrappedSender<'a>(&'a mut SenderCore); + +impl Correlation for BootstrappedSender<'_> { + type Correlation = Block; + + fn delta(&self) -> Block { + self.0.delta() + } +} + +#[async_trait] +impl RandomCOTSender for BootstrappedSender<'_> { + async fn send_random_correlated( + &mut self, + _ctx: &mut Ctx, + count: usize, + ) -> Result, OTError> { + self.0 + .consume(count) + .map_err(SenderError::from) + .map_err(OTError::from) + } +} diff --git a/crates/mpz-ot/src/ferret/spcot.rs b/crates/mpz-ot/src/ferret/spcot.rs new file mode 100644 index 00000000..e63a1aa9 --- /dev/null +++ b/crates/mpz-ot/src/ferret/spcot.rs @@ -0,0 +1,161 @@ +//! Implementation of the Single-Point COT (spcot) protocol in the [`Ferret`](https://eprint.iacr.org/2020/924.pdf) paper. + +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::Block; +use mpz_ot_core::{ + ferret::{ + spcot::{ + msgs::{ExtendFromSender, MaskBits}, + receiver::Receiver as ReceiverCore, + sender::Sender as SenderCore, + }, + CSP, + }, + RCOTReceiverOutput, RCOTSenderOutput, +}; +use serio::{stream::IoStreamExt as _, SinkExt as _}; + +use crate::{ferret::error::SPCOTError as Error, RandomCOTReceiver, RandomCOTSender}; + +/// SPCOT send. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT sender. +/// * `delta` - Delta correlation. +/// * `hs` - The depth of the GGM trees. +pub(crate) async fn send>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + delta: Block, + hs: &[usize], +) -> Result>, Error> { + let mut sender = SenderCore::new().setup(delta); + + let h = hs.iter().sum(); + let RCOTSenderOutput { msgs: qss, .. } = rcot.send_random_correlated(ctx, h).await?; + + let masks: Vec = ctx.io_mut().expect_next().await?; + + // extend + let h_in = hs.to_vec(); + let (mut sender, extend_msg) = CpuBackend::blocking(move || { + sender + .extend(&h_in, &qss, &masks) + .map(|extend_msg| (sender, extend_msg)) + }) + .await?; + + ctx.io_mut().send(extend_msg).await?; + + // batch check + let RCOTSenderOutput { msgs: y_star, .. } = rcot.send_random_correlated(ctx, CSP).await?; + + let checkfr = ctx.io_mut().expect_next().await?; + + let (output, check_msg) = CpuBackend::blocking(move || sender.check(&y_star, checkfr)).await?; + + ctx.io_mut().send(check_msg).await?; + + Ok(output) +} + +/// SPCOT receive. +/// +/// # Arguments +/// +/// * `ctx` - Thread context. +/// * `rcot` - Random COT receiver. +/// * `alphas` - Vector of chosen positions. +/// * `hs` - The depth of the GGM trees. +pub(crate) async fn receive>( + ctx: &mut Ctx, + rcot: &mut RandomCOT, + alphas: &[u32], + hs: &[usize], +) -> Result, u32)>, Error> { + let mut receiver = ReceiverCore::new().setup(); + + let h = hs.iter().sum(); + let RCOTReceiverOutput { + choices: rss, + msgs: tss, + .. + } = rcot.receive_random_correlated(ctx, h).await?; + + // extend + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let (mut receiver, masks) = CpuBackend::blocking(move || { + receiver + .extend_mask_bits(&h_in, &alphas_in, &rss) + .map(|mask| (receiver, mask)) + }) + .await?; + + ctx.io_mut().send(masks).await?; + + let extendfss: Vec = ctx.io_mut().expect_next().await?; + + let h_in = hs.to_vec(); + let alphas_in = alphas.to_vec(); + let mut receiver = CpuBackend::blocking(move || { + receiver + .extend(&h_in, &alphas_in, &tss, &extendfss) + .map(|_| receiver) + }) + .await?; + + // batch check + let RCOTReceiverOutput { + choices: x_star, + msgs: z_star, + .. + } = rcot.receive_random_correlated(ctx, CSP).await?; + + let (mut receiver, checkfr) = CpuBackend::blocking(move || { + receiver + .check_pre(&x_star) + .map(|checkfr| (receiver, checkfr)) + }) + .await?; + + ctx.io_mut().send(checkfr).await?; + let check = ctx.io_mut().expect_next().await?; + + let output = CpuBackend::blocking(move || receiver.check(&z_star, check)).await?; + + Ok(output) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ideal::cot::ideal_rcot, Correlation}; + use mpz_common::executor::test_st_executor; + + #[tokio::test] + async fn test_spcot() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let hs = [8usize, 4]; + let alphas = [4u32, 2]; + let delta = rcot_sender.delta(); + + let (mut output_sender, output_receiver) = tokio::try_join!( + send(&mut ctx_sender, &mut rcot_sender, delta, &hs), + receive(&mut ctx_receiver, &mut rcot_receiver, &alphas, &hs) + ) + .unwrap(); + + assert!(output_sender + .iter_mut() + .zip(output_receiver.iter()) + .all(|(vs, (ws, alpha))| { + vs[*alpha as usize] ^= delta; + vs == ws + })); + } +} diff --git a/crates/mpz-ot/src/ideal/cot.rs b/crates/mpz-ot/src/ideal/cot.rs index b0084957..aa441c60 100644 --- a/crates/mpz-ot/src/ideal/cot.rs +++ b/crates/mpz-ot/src/ideal/cot.rs @@ -11,7 +11,9 @@ use mpz_ot_core::{ ideal::cot::IdealCOT, COTReceiverOutput, COTSenderOutput, RCOTReceiverOutput, RCOTSenderOutput, }; -use crate::{COTReceiver, COTSender, OTError, OTSetup, RandomCOTReceiver, RandomCOTSender}; +use crate::{ + COTReceiver, COTSender, Correlation, OTError, OTSetup, RandomCOTReceiver, RandomCOTSender, +}; fn cot( f: &mut IdealCOT, @@ -45,10 +47,26 @@ pub fn ideal_rcot() -> (IdealCOTSender, IdealCOTReceiver) { (IdealCOTSender(alice), IdealCOTReceiver(bob)) } +/// Returns an ideal random COT sender and receiver with a given delta. +pub fn ideal_rcot_with_delta(delta: Block) -> (IdealCOTSender, IdealCOTReceiver) { + let (alice, bob) = ideal_f2p(IdealCOT::new( + mpz_core::prg::Prg::new().random_block(), + delta, + )); + (IdealCOTSender(alice), IdealCOTReceiver(bob)) +} + /// Ideal COT sender. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTSender(Alice); +impl IdealCOTSender { + /// Returns Alice. + pub fn alice(&mut self) -> &mut Alice { + &mut self.0 + } +} + #[async_trait] impl OTSetup for IdealCOTSender where @@ -75,6 +93,14 @@ where } } +impl Correlation for IdealCOTSender { + type Correlation = Block; + + fn delta(&self) -> Block { + self.0.lock().delta() + } +} + #[async_trait] impl COTSender for IdealCOTSender { async fn send_correlated( @@ -98,7 +124,7 @@ impl RandomCOTSender for IdealCOTSender { } /// Ideal COT receiver. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct IdealCOTReceiver(Bob); #[async_trait] @@ -163,7 +189,7 @@ mod tests { let (mut ctx_a, mut ctx_b) = test_st_executor(8); let (mut alice, mut bob) = ideal_cot(); - let delta = alice.0.get_mut().delta(); + let delta = alice.delta(); let count = 10; let choices = (0..count).map(|_| rng.gen()).collect::>(); @@ -194,7 +220,7 @@ mod tests { let (mut ctx_a, mut ctx_b) = test_st_executor(8); let (mut alice, mut bob) = ideal_rcot(); - let delta = alice.0.get_mut().delta(); + let delta = alice.delta(); let count = 10; diff --git a/crates/mpz-ot/src/lib.rs b/crates/mpz-ot/src/lib.rs index b9871eab..667fe97b 100644 --- a/crates/mpz-ot/src/lib.rs +++ b/crates/mpz-ot/src/lib.rs @@ -10,7 +10,7 @@ )] pub mod chou_orlandi; -#[cfg(any(test, feature = "ideal"))] +pub mod ferret; pub mod ideal; pub mod kos; @@ -60,9 +60,18 @@ pub trait OTSender { async fn send(&mut self, ctx: &mut Ctx, msgs: &[T]) -> Result; } +/// Correlation of COT messages. +pub trait Correlation { + /// The type of the correlation. + type Correlation; + + /// Returns the correlation. + fn delta(&self) -> Self::Correlation; +} + /// A correlated oblivious transfer sender. #[async_trait] -pub trait COTSender { +pub trait COTSender: Correlation { /// Obliviously transfers the correlated messages to the receiver. /// /// Returns the `0`-bit messages that were obliviously transferred. @@ -96,7 +105,7 @@ pub trait RandomOTSender { /// A random correlated oblivious transfer sender. #[async_trait] -pub trait RandomCOTSender { +pub trait RandomCOTSender: Correlation { /// Obliviously transfers the correlated messages to the receiver. /// /// Returns the `0`-bit messages that were obliviously transferred. diff --git a/crates/mpz-zk-core/Cargo.toml b/crates/mpz-zk-core/Cargo.toml new file mode 100644 index 00000000..936f9dc8 --- /dev/null +++ b/crates/mpz-zk-core/Cargo.toml @@ -0,0 +1,39 @@ +[package] +name = "mpz-zk-core" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lints] +workspace = true + +[lib] +name = "mpz_zk_core" + +[features] +default = ["rayon", "test-utils"] +rayon = ["dep:rayon", "itybity/rayon"] +test-utils = [] + +[dependencies] +mpz-core.workspace = true +mpz-ot-core.workspace = true +mpz-circuits.workspace = true +clmul.workspace = true +matrix-transpose.workspace = true + +tlsn-utils.workspace = true + +serde_arrays.workspace = true +rayon = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"] } +thiserror.workspace = true +derive_builder.workspace = true +itybity.workspace = true +opaque-debug.workspace = true +cfg-if.workspace = true +bytemuck = { workspace = true, features = ["derive"] } +enum-try-as-inner.workspace = true +blake3.workspace = true +rand_core.workspace = true diff --git a/crates/mpz-zk-core/src/ideal/mod.rs b/crates/mpz-zk-core/src/ideal/mod.rs new file mode 100644 index 00000000..b8436c2b --- /dev/null +++ b/crates/mpz-zk-core/src/ideal/mod.rs @@ -0,0 +1,3 @@ +//! Ideal functionalities. + +pub mod vope; diff --git a/crates/mpz-zk-core/src/ideal/vope.rs b/crates/mpz-zk-core/src/ideal/vope.rs new file mode 100644 index 00000000..14cc2a79 --- /dev/null +++ b/crates/mpz-zk-core/src/ideal/vope.rs @@ -0,0 +1,106 @@ +//! Ideal VOPE functionality. + +use std::iter::successors; + +use mpz_core::{prg::Prg, Block}; +use mpz_ot_core::TransferId; +use rand_core::SeedableRng; + +use crate::{VOPEReceiverOutput, VOPESenderOutput}; + +/// The ideal VOPE functionality. +#[derive(Debug)] +pub struct IdealVOPE { + delta: Block, + transfer_id: TransferId, + counter: usize, + prg: Prg, +} + +impl IdealVOPE { + /// Creates a new ideal VOPE functionality. + /// + /// # Arguments + /// + /// * `seed` - The seed for the PRG. + /// * `delta` - The correlation. + pub fn new(seed: Block, delta: Block) -> Self { + Self { + delta, + transfer_id: TransferId::default(), + counter: 0, + prg: Prg::from_seed(seed), + } + } + + /// Returns the correlation, delta. + pub fn delta(&self) -> Block { + self.delta + } + + /// Sets the correlation, delta. + pub fn set_delta(&mut self, delta: Block) { + self.delta = delta; + } + + /// Returns the current transfer id. + pub fn transfer_id(&self) -> TransferId { + self.transfer_id + } + + /// Returns the number of VOPE executed. + pub fn count(&self) -> usize { + self.counter + } + + /// Executes the VOPE. + /// + /// # Arguments + /// + /// * `degree` - The degree of the polynomnial. + pub fn random_correlated( + &mut self, + degree: usize, + ) -> (VOPESenderOutput, VOPEReceiverOutput) { + let mut coeff = vec![Block::ZERO; degree + 1]; + self.prg.random_blocks(&mut coeff); + + let powers: Vec = successors(Some(Block::ONE), |pow| Some(pow.gfmul(self.delta))) + .take(degree + 1) + .collect(); + + let eval = Block::inn_prdt_red(&coeff, &powers); + + self.counter += 1; + let id = self.transfer_id.next_id(); + + ( + VOPESenderOutput { id, eval }, + VOPEReceiverOutput { id, coeff }, + ) + } +} + +impl Default for IdealVOPE { + fn default() -> Self { + let mut rng = Prg::from_seed(Block::ZERO); + Self::new(rng.random_block(), rng.random_block()) + } +} + +#[cfg(test)] +mod tests { + use crate::{test::poly_check, VOPEReceiverOutput, VOPESenderOutput}; + + use super::IdealVOPE; + + #[test] + fn test_ideal_vope() { + let mut ideal = IdealVOPE::default(); + + let (VOPESenderOutput { eval, .. }, VOPEReceiverOutput { coeff, .. }) = + ideal.random_correlated(10); + + assert!(poly_check(&coeff, eval, ideal.delta())); + } +} diff --git a/crates/mpz-zk-core/src/lib.rs b/crates/mpz-zk-core/src/lib.rs new file mode 100644 index 00000000..de438fb6 --- /dev/null +++ b/crates/mpz-zk-core/src/lib.rs @@ -0,0 +1,45 @@ +//! Low-level crate containing core functionalities for zero-knowledge protocols. +//! +//! This crate is not intended to be used directly. Instead, use the higher-level APIs provided by +//! the `mpz-zk` crate. +//! +//! # ⚠️ Warning ⚠️ +//! +//! Some implementations make assumptions about invariants which may not be checked if using these +//! low-level APIs naively. Failing to uphold these invariants may result in security vulnerabilities. +//! +//! USE AT YOUR OWN RISK. + +#![deny( + unsafe_code, + missing_docs, + unused_imports, + unused_must_use, + unreachable_pub, + clippy::all +)] + +use mpz_ot_core::TransferId; + +pub mod ideal; +pub mod quicksilver; +pub mod test; +pub mod vope; + +/// The output the receiver receives from the VOPE functionality. +#[derive(Debug)] +pub struct VOPEReceiverOutput { + /// The transfer id. + pub id: TransferId, + /// The coefficients. + pub coeff: Vec, +} + +/// The output the sender receives from the VOPE functinality. +#[derive(Debug)] +pub struct VOPESenderOutput { + /// The transfer id. + pub id: TransferId, + /// The evaluation value. + pub eval: T, +} diff --git a/crates/mpz-zk-core/src/quicksilver/error.rs b/crates/mpz-zk-core/src/quicksilver/error.rs new file mode 100644 index 00000000..468a7d34 --- /dev/null +++ b/crates/mpz-zk-core/src/quicksilver/error.rs @@ -0,0 +1,11 @@ +//! Errors in QuickSilver. + +/// Errors that can occur during proving +#[derive(Debug, thiserror::Error)] +#[error("invalid inputs: expect {0}")] +pub struct QsProverError(pub String); + +/// Errors that can occur during verifying +#[derive(Debug, thiserror::Error)] +#[error("invalid inputs: expect {0}")] +pub struct QsVerifierError(pub String); diff --git a/crates/mpz-zk-core/src/quicksilver/mod.rs b/crates/mpz-zk-core/src/quicksilver/mod.rs new file mode 100644 index 00000000..89783a9d --- /dev/null +++ b/crates/mpz-zk-core/src/quicksilver/mod.rs @@ -0,0 +1,110 @@ +//! This is the implementation of QuickSilver (https://eprint.iacr.org/2021/076.pdf). + +mod error; +mod prover; +mod verifier; + +pub use error::*; +pub use prover::Prover; +pub use verifier::Verifier; + +/// Buffer size of each check. +pub(crate) const CHECK_BUFFER_SIZE: usize = 1024 * 1024; + +/// Convert bool vector to byte vector. +#[inline] +pub fn bools_to_bytes(bv: &[bool]) -> Vec { + let offset = if bv.len() % 8 == 0 { 0 } else { 1 }; + let mut v = vec![0u8; bv.len() / 8 + offset]; + for (i, b) in bv.iter().enumerate() { + v[i / 8] |= (*b as u8) << (7 - (i % 8)); + } + v +} + +/// Convert byte vector to bool vector. +#[inline] +pub fn bytes_to_bools(v: &[u8]) -> Vec { + let mut bv = Vec::with_capacity(v.len() * 8); + for byte in v.iter() { + for i in 0..8 { + bv.push(((byte >> (7 - i)) & 1) != 0); + } + } + bv +} + +#[cfg(test)] +mod tests { + use mpz_core::{prg::Prg, Block}; + use mpz_ot_core::{ + ideal::cot::IdealCOT, test::assert_cot, RCOTReceiverOutput, RCOTSenderOutput, + }; + + use crate::ideal::vope::IdealVOPE; + + use super::{Prover, Verifier}; + + #[test] + fn test_qs_core() { + const N: usize = 200; + let mut prg = Prg::new(); + let mut input = vec![false; N]; + prg.random_bools(&mut input); + let mut delta = prg.random_block(); + delta.set_lsb(); + + let mut ideal_cot = IdealCOT::new(Block::ZERO, delta); + let mut ideal_vope = IdealVOPE::new(Block::ZERO, delta); + + let mut prover = Prover::new(); + let mut verifier = Verifier::new(delta); + + let (cot_sender, cot_receiver) = ideal_cot.random_correlated(input.len()); + + let (masks, prover_labels) = prover.auth_input_bits(&input, cot_receiver).unwrap(); + + let verifier_labels = verifier.auth_input_bits(&masks, cot_sender).unwrap(); + let input_exp: Vec = prover_labels.iter().map(|x| x.lsb() == 1).collect(); + assert_eq!(input, input_exp); + + assert_cot(delta, &input, &prover_labels, &verifier_labels); + + let mut output_macs = vec![Block::ZERO; N]; + let mut output_keys = vec![Block::ZERO; N]; + prover_labels + .iter() + .zip(verifier_labels.iter()) + .zip(output_macs.iter_mut()) + .zip(output_keys.iter_mut()) + .for_each(|(((&mac, &key), output_mac), output_key)| { + let (cot_sender, cot_receiver) = ideal_cot.random_correlated(1); + + let RCOTReceiverOutput { + choices: s, + msgs: blks, + .. + } = cot_receiver; + + let (mask, tmp) = prover.auth_and_gate(mac, mac, (s[0], blks[0])); + *output_mac = tmp; + + let RCOTSenderOutput { msgs: blks, .. } = cot_sender; + + *output_key = verifier.auth_and_gate(key, key, mask, blks[0]); + }); + + assert_cot(delta, &input, &output_macs, &output_keys); + + let (vope_sender, vope_receiver) = ideal_vope.random_correlated(1); + + let (u, v) = prover.check_and_gates(vope_receiver); + + verifier.check_and_gates(vope_sender, u, v); + + let hash = prover.finish(&output_macs); + verifier.finish(hash, &output_keys, &input).unwrap(); + + assert!(verifier.checked()); + } +} diff --git a/crates/mpz-zk-core/src/quicksilver/prover.rs b/crates/mpz-zk-core/src/quicksilver/prover.rs new file mode 100644 index 00000000..0a095b9a --- /dev/null +++ b/crates/mpz-zk-core/src/quicksilver/prover.rs @@ -0,0 +1,186 @@ +use mpz_core::{hash::Hash, serialize::CanonicalSerialize, utils::blake3, Block}; +use mpz_ot_core::RCOTReceiverOutput; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; + +use crate::VOPEReceiverOutput; + +use super::{bools_to_bytes, QsProverError, CHECK_BUFFER_SIZE}; + +/// Internal QuickSilver Prover. +#[derive(Debug, Default)] +pub struct Prover { + /// Buffer for left wire label. + buf_left: Vec, + /// Buffer for right wire label. + buf_right: Vec, + /// Buffer for output wire label. + buf_out: Vec, + /// Counter for check. + check_counter: usize, + /// Hasher. + hasher: blake3::Hasher, + /// Hash buffer for the bools. + buf_hash: Vec, +} + +impl Prover { + /// Create a new instance + pub fn new() -> Self { + Self { + buf_left: vec![Block::ZERO; CHECK_BUFFER_SIZE], + buf_right: vec![Block::ZERO; CHECK_BUFFER_SIZE], + buf_out: vec![Block::ZERO; CHECK_BUFFER_SIZE], + check_counter: 0, + hasher: blake3::Hasher::new(), + buf_hash: vec![false; CHECK_BUFFER_SIZE], + } + } + + /// Compute authenticated bits for inputs. + /// See step 4 in Figure 5 + /// + /// # Arguments + /// + /// * `inputs` - The input bits. + /// * `cot` - The COT mask received from Ideal COT as the receiver. + pub fn auth_input_bits( + &mut self, + inputs: &[bool], + cot: RCOTReceiverOutput, + ) -> Result<(Vec, Vec), QsProverError> { + if cot.choices.len() != inputs.len() { + return Err(QsProverError("lengths not match".to_string())); + } + + let RCOTReceiverOutput { + choices: bits, + msgs: blks, + .. + } = cot; + + let res: (Vec, Vec) = bits + .iter() + .zip(inputs.iter()) + .zip(blks.iter()) + .map(|((mask, b), blk)| (b ^ mask, Self::set_value(*blk, *b))) + .unzip(); + + // Hash the bools. + self.hasher.update(&bools_to_bytes(&res.0)); + + Ok(res) + } + + /// Compute authenticated and gate. + /// See step 6 in Figure 5. + /// + /// # Arguments. + /// + /// * `ma` - The MAC of wire a. + /// * `mb` - The MAC of wire b. + /// * `cot` - The COT mask received from Ideal COT as the receiver. + pub fn auth_and_gate(&mut self, ma: Block, mb: Block, cot: (bool, Block)) -> (bool, Block) { + assert!(self.check_counter < CHECK_BUFFER_SIZE); + + self.buf_left[self.check_counter] = ma; + self.buf_right[self.check_counter] = mb; + + let s = cot.0; + let blk = cot.1; + + // Compute wa * wb + let v = ma.lsb() & mb.lsb() == 1; + // Compute the mask of v with s. + let d = v ^ s; + + let mc = Self::set_value(blk, v); + self.buf_out[self.check_counter] = mc; + self.buf_hash[self.check_counter] = d; + self.check_counter += 1; + + (d, mc) + } + + /// Check and gate. + /// See step 6, 7 in Figure 5. + /// + /// # Arguments. + /// + /// * `vope` - The mask blocks received from ideal VOPE. + pub fn check_and_gates(&mut self, vope: VOPEReceiverOutput) -> (Block, Block) { + assert!(self.check_counter <= CHECK_BUFFER_SIZE); + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = self.buf_left[..self.check_counter] + .par_iter() + .zip(self.buf_right[..self.check_counter].par_iter()) + .zip(self.buf_out[..self.check_counter].par_iter()); + } else{ + let iter = self.buf_left[..self.check_counter] + .iter() + .zip(self.buf_right[..self.check_counter].iter()) + .zip(self.buf_out[..self.check_counter].iter()) + } + } + + // Compute A0 and A1. + let blocks: (Vec, Vec) = iter + .map(|((a, b), c)| { + let tmp0 = if a.lsb() == 1 { *b } else { Block::ZERO }; + let tmp1 = if b.lsb() == 1 { *a } else { Block::ZERO }; + + (a.gfmul(*b), tmp0 ^ tmp1 ^ *c) + }) + .unzip(); + + // Compute chi and powers. + self.hasher + .update(&bools_to_bytes(&self.buf_hash[..self.check_counter])); + let seed = *self.hasher.finalize().as_bytes(); + let seed = Block::try_from(&seed[0..16]).unwrap(); + let chis = Block::powers(seed, self.check_counter); + + // Compute the inner product. + let u = Block::inn_prdt_red(&blocks.0, &chis); + let v = Block::inn_prdt_red(&blocks.1, &chis); + + // Mask the results. + let u = u ^ vope.coeff[0]; + let v = v ^ vope.coeff[1]; + + // Update the hasher + self.hasher.update(&u.to_bytes()); + self.hasher.update(&v.to_bytes()); + self.check_counter = 0; + + (u, v) + } + + /// Enable and check or not. + /// If check_counter is set into the default number, + /// we enable the check protocol. + #[inline] + pub fn enable_check(&self) -> bool { + self.check_counter == CHECK_BUFFER_SIZE + } + + /// Enable the final check or not. + /// if check_counter is zero, then no need to check. + #[inline] + pub fn enable_final_check(&self) -> bool { + self.check_counter != 0 + } + + /// Hash the output macs + #[inline] + pub fn finish(&self, macs: &[Block]) -> Hash { + Hash::from(blake3(&macs.to_bytes())) + } + + // Set the LSB of the block to as the bit. + // This assumes the lsb of delta is 1. + #[inline] + fn set_value(block: Block, b: bool) -> Block { + (block & Block::MINIS_ONE) ^ (if b { Block::ONE } else { Block::ZERO }) + } +} diff --git a/crates/mpz-zk-core/src/quicksilver/verifier.rs b/crates/mpz-zk-core/src/quicksilver/verifier.rs new file mode 100644 index 00000000..7120c78d --- /dev/null +++ b/crates/mpz-zk-core/src/quicksilver/verifier.rs @@ -0,0 +1,205 @@ +use mpz_core::{hash::Hash, serialize::CanonicalSerialize, utils::blake3, Block}; +use mpz_ot_core::RCOTSenderOutput; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; + +use crate::VOPESenderOutput; + +use super::{bools_to_bytes, QsVerifierError, CHECK_BUFFER_SIZE}; + +/// QuickSilver Verifier. +#[derive(Debug, Default)] +pub struct Verifier { + /// Global secret. + delta: Block, + /// Buffer for left wire KEY. + buf_left: Vec, + /// Buffer for right wire KEY. + buf_right: Vec, + /// Buffer for output wire KEY. + buf_out: Vec, + /// Counter for check. + check_counter: usize, + /// Hasher. + hasher: blake3::Hasher, + /// Hash buffer for the bools. + buf_hash: Vec, + /// Indicate the checks pass or not. + checked: bool, +} + +impl Verifier { + /// Create a new instance + /// + /// # Arguments. + /// + /// * `delta` - The global secret. + pub fn new(delta: Block) -> Self { + Self { + delta, + buf_left: vec![Block::ZERO; CHECK_BUFFER_SIZE], + buf_right: vec![Block::ZERO; CHECK_BUFFER_SIZE], + buf_out: vec![Block::ZERO; CHECK_BUFFER_SIZE], + check_counter: 0, + hasher: blake3::Hasher::new(), + buf_hash: vec![false; CHECK_BUFFER_SIZE], + checked: true, + } + } + /// Compute authenticated bits for inputs. + /// See step 4 in Figure 5 + /// # Arguments + /// + /// * `masks` - The mask bits sent from the prover. + /// * `cot` - The COT mask received from Ideal COT as the sender. + pub fn auth_input_bits( + &mut self, + masks: &[bool], + cot: RCOTSenderOutput, + // The mask bits sent by prover. + ) -> Result, QsVerifierError> { + if masks.len() != cot.msgs.len() { + return Err(QsVerifierError("lengths not match".to_string())); + } + + // Hash the bools. + self.hasher.update(&bools_to_bytes(masks)); + + let RCOTSenderOutput { msgs: blks, .. } = cot; + + let res = blks + .iter() + .zip(masks.iter()) + .map(|(blk, mask)| { + let block = *blk ^ (if *mask { self.delta } else { Block::ZERO }); + Self::set_zero(block) + }) + .collect(); + + Ok(res) + } + + /// Compute authenticated and gate. + /// See step 6 in Figure 5. + /// + /// # Arguments. + /// + /// * `ka` - The KEY of wire a. + /// * `kb` - The KEY of wire b. + /// * `mask` - The mask sent by the prover. + /// * `cot` - The COT mask received from Ideal COT as the sender. + pub fn auth_and_gate(&mut self, ka: Block, kb: Block, mask: bool, cot: Block) -> Block { + assert!(self.check_counter < CHECK_BUFFER_SIZE); + + self.buf_left[self.check_counter] = ka; + self.buf_right[self.check_counter] = kb; + self.buf_hash[self.check_counter] = mask; + + let block = cot ^ if mask { self.delta } else { Block::ZERO }; + let kc = Self::set_zero(block); + self.buf_out[self.check_counter] = kc; + self.check_counter += 1; + + kc + } + + /// Check and gate. + /// See step 6, 7 in Figure 5. + /// + /// # Arguments. + /// + /// * `vope` - The mask block received from ideal VOPE. + /// * `u` - The block sent by the prover. + /// * `v` - The block sent by the prover. + pub fn check_and_gates(&mut self, vope: VOPESenderOutput, u: Block, v: Block) { + assert!(self.check_counter <= CHECK_BUFFER_SIZE); + cfg_if::cfg_if! { + if #[cfg(feature = "rayon")]{ + let iter = self.buf_left[..self.check_counter] + .par_iter() + .zip(self.buf_right[..self.check_counter].par_iter()) + .zip(self.buf_out[..self.check_counter].par_iter()); + } else{ + let iter = self.buf_left[..self.counter] + .iter() + .zip(self.buf_right[..self.counter].iter()) + .zip(self.buf_out[..self.counter].iter()) + } + } + + // Compute B. + let block: Vec = iter + .map(|((a, b), c)| a.gfmul(*b) ^ c.gfmul(self.delta)) + .collect(); + + // Compute chi and powers. + self.hasher + .update(&bools_to_bytes(&self.buf_hash[..self.check_counter])); + let seed = *self.hasher.finalize().as_bytes(); + let seed = Block::try_from(&seed[0..16]).unwrap(); + let chis = Block::powers(seed, self.check_counter); + + // Compute the inner product. + let w = Block::inn_prdt_red(&block, &chis); + self.checked &= (w ^ vope.eval) == u ^ v.gfmul(self.delta); + + self.hasher.update(&u.to_bytes()); + self.hasher.update(&v.to_bytes()); + self.check_counter = 0; + } + + /// Enable and check or not. + /// If check_counter is set to the default buffer size, + /// we enable the check protocol. + #[inline] + pub fn enable_check(&self) -> bool { + self.check_counter == CHECK_BUFFER_SIZE + } + + /// Enable the final check or not. + /// if check_counter is zero, then no need to check. + #[inline] + pub fn enable_final_check(&self) -> bool { + self.check_counter != 0 + } + + /// Hash the output keys with the outputs. + pub fn finish( + &mut self, + hash: Hash, + keys: &[Block], + outputs: &[bool], + ) -> Result<(), QsVerifierError> { + if keys.len() != outputs.len() { + return Err(QsVerifierError("lengths not match".to_string())); + } + + let pre_hash: Vec = keys + .iter() + .zip(outputs.iter()) + .map(|(&k, &o)| if o { k ^ self.delta } else { k }) + .collect(); + + let expected_hash = Hash::from(blake3(&pre_hash.to_bytes())); + self.checked &= hash == expected_hash; + + Ok(()) + } + + /// Returns the and_check results. + #[inline] + pub fn checked(&self) -> bool { + self.checked + } + + /// Returns delta. + pub fn delta(&self) -> Block { + self.delta + } + + // Set the lsb of the block to zero. + // This assumes the lsb of delta is 1. + #[inline] + fn set_zero(block: Block) -> Block { + block & Block::MINIS_ONE + } +} diff --git a/crates/mpz-zk-core/src/test.rs b/crates/mpz-zk-core/src/test.rs new file mode 100644 index 00000000..c02a81ee --- /dev/null +++ b/crates/mpz-zk-core/src/test.rs @@ -0,0 +1,24 @@ +//! test functions. + +use mpz_core::Block; + +use crate::{VOPEReceiverOutput, VOPESenderOutput}; + +/// Check polynomial relation. +pub fn poly_check(a: &[Block], b: Block, delta: Block) -> bool { + b == a + .iter() + .rev() + .fold(Block::ZERO, |acc, &x| x ^ (delta.gfmul(acc))) +} + +/// Assert VOPE relation. +pub fn assert_vope( + send: VOPESenderOutput, + recv: VOPEReceiverOutput, + delta: Block, +) -> bool { + let send = send.eval; + let recv = recv.coeff; + poly_check(&recv, send, delta) +} diff --git a/crates/mpz-zk-core/src/vope/error.rs b/crates/mpz-zk-core/src/vope/error.rs new file mode 100644 index 00000000..101a11b3 --- /dev/null +++ b/crates/mpz-zk-core/src/vope/error.rs @@ -0,0 +1,21 @@ +//! Errors that can occur when using VOPE. + +/// Errors that can occur when using VOPE sender (verifier). +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error("invalid input: expected {0}")] + InvalidInput(String), + #[error("invalid length: expected {0}")] + InvalidLength(String), +} + +/// Errors that can occur when using VOPE receiver (prover). +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error("invalid input: expected {0}")] + InvalidInput(String), + #[error("invalid length: expected {0}")] + InvalidLength(String), +} diff --git a/crates/mpz-zk-core/src/vope/mod.rs b/crates/mpz-zk-core/src/vope/mod.rs new file mode 100644 index 00000000..3da511be --- /dev/null +++ b/crates/mpz-zk-core/src/vope/mod.rs @@ -0,0 +1,65 @@ +//! This is the implementation of vector oblivious polynomial evaluation (VOPE) based on Figure 4 in https://eprint.iacr.org/2021/076.pdf + +pub mod error; +pub mod receiver; +pub mod sender; + +/// Security parameter +pub const CSP: usize = 128; + +#[cfg(test)] +mod tests { + use mpz_core::prg::Prg; + use mpz_ot_core::{ideal::cot::IdealCOT, RCOTReceiverOutput, RCOTSenderOutput}; + + use crate::test::poly_check; + + use super::{receiver::Receiver, sender::Sender, CSP}; + + #[test] + fn vope_test() { + let mut prg = Prg::new(); + let delta = prg.random_block(); + + let mut ideal_cot = IdealCOT::default(); + ideal_cot.set_delta(delta); + + let sender = Sender::new(); + let receiver = Receiver::new(); + + let mut sender = sender.setup(delta); + let mut receiver = receiver.setup(); + + let d = 1; + + let (sender_cot, receiver_cot) = ideal_cot.random_correlated((2 * d - 1) * CSP); + + let RCOTSenderOutput { msgs: ks, .. } = sender_cot; + let RCOTReceiverOutput { + msgs: ms, + choices: us, + .. + } = receiver_cot; + + let sender_out = sender.extend(&ks, d).unwrap(); + let receiver_out = receiver.extend(&ms, &us, d).unwrap(); + + assert!(poly_check(&receiver_out, sender_out, delta)); + + let d = 5; + + let (sender_cot, receiver_cot) = ideal_cot.random_correlated((2 * d - 1) * CSP); + + let RCOTSenderOutput { msgs: ks, .. } = sender_cot; + let RCOTReceiverOutput { + msgs: ms, + choices: us, + .. + } = receiver_cot; + + let sender_out = sender.extend(&ks, d).unwrap(); + let receiver_out = receiver.extend(&ms, &us, d).unwrap(); + + assert!(poly_check(&receiver_out, sender_out, delta)); + } +} diff --git a/crates/mpz-zk-core/src/vope/receiver.rs b/crates/mpz-zk-core/src/vope/receiver.rs new file mode 100644 index 00000000..9a0e4fce --- /dev/null +++ b/crates/mpz-zk-core/src/vope/receiver.rs @@ -0,0 +1,162 @@ +//! VOPE receiver. +use mpz_core::Block; + +use crate::vope::CSP; + +use super::error::ReceiverError; + +/// VOPE receiver +/// This is the prover in Figure 4. +#[derive(Debug, Default)] +pub struct Receiver { + state: T, +} + +impl Receiver { + /// Create a new receiver. + pub fn new() -> Self { + Receiver { + state: state::Initialized::default(), + } + } + + /// Completes the setup phase of the protocol. + /// + /// See Initialize in Figure 4. + pub fn setup(self) -> Receiver { + Receiver { + state: state::Extension { + vope_counter: 0, + exec_counter: 0, + }, + } + } +} + +impl Receiver { + /// Performs VOPE extension. + /// + /// See step 1-3 in Figure 4. + /// + /// # Arguments + /// + /// * `ms` - The blocks received by calling the COT ideal functionality. + /// * `us` - The bits received by calling the COT ideal functionality. + /// * `d` - The degree of the polynomial. + /// + /// Note that this functionality is only suitable for small d. + pub fn extend( + &mut self, + ms: &[Block], + us: &[bool], + d: usize, + ) -> Result, ReceiverError> { + if d == 0 { + return Err(ReceiverError::InvalidInput( + "the degree d should not be 0".to_string(), + )); + } + + if ms.len() != us.len() { + return Err(ReceiverError::InvalidLength( + "the length of ms and us should be equal".to_string(), + )); + } + + if ms.len() != (2 * d - 1) * CSP { + return Err(ReceiverError::InvalidLength( + "the length of ms and us should be (2 * d -1) * CSP".to_string(), + )); + } + + let mut h_ms = ms.to_vec(); + let mut h_us = us.to_vec(); + + let mut mi = vec![Block::ZERO; 2 * d - 1]; + let mut ui = vec![Block::ZERO; 2 * d - 1]; + + let base: Vec = (0..CSP) + .map(|x| bytemuck::cast((1_u128) << (CSP - 1 - x))) + .collect(); + + for i in 0..(2 * d - 1) { + let m = h_ms.split_off(CSP); + let u = h_us.split_off(CSP); + + mi[i] = Block::inn_prdt_red(&h_ms, &base); + + ui[i] = + h_us.iter().zip(base.iter()).fold( + Block::ZERO, + |acc, (b, base)| { + if *b { + acc ^ *base + } else { + acc + } + }, + ); + h_ms = m; + h_us = u; + } + + let mut gi = vec![Block::ZERO; d + 1]; + gi[0] = mi[0]; + gi[1] = ui[0]; + + for i in 0..d - 1 { + poly_update(&mut gi, mi[i + 1], ui[i + 1], i + 2); + gi[0] ^= mi[d + i]; + gi[1] ^= ui[d + i]; + } + + self.state.exec_counter += 1; + self.state.vope_counter += 1; + + Ok(gi) + } +} + +fn poly_update(g: &mut [Block], m: Block, u: Block, length: usize) { + let mut buffer = vec![Block::ZERO; length + 1]; + for i in 0..length { + buffer[i + 1] = g[i].gfmul(u); + g[i] = g[i].gfmul(m); + + g[i] ^= buffer[i]; + } + g[length] = buffer[length]; +} + +/// The receiver's state. +pub mod state { + mod sealed { + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::Extension {} + } + + /// The receiver's state. + pub trait State: sealed::Sealed {} + + /// The receiver's initial state. + #[derive(Default)] + pub struct Initialized {} + + impl State for Initialized {} + opaque_debug::implement!(Initialized); + + /// The receiver's state after the setup phase. + /// + /// In this state the sender performs VOPE extension. + pub struct Extension { + /// Current VOPE counter + pub(super) vope_counter: usize, + /// Current execution counter + pub(super) exec_counter: usize, + } + + impl State for Extension {} + + opaque_debug::implement!(Extension); +} diff --git a/crates/mpz-zk-core/src/vope/sender.rs b/crates/mpz-zk-core/src/vope/sender.rs new file mode 100644 index 00000000..2bca3fd6 --- /dev/null +++ b/crates/mpz-zk-core/src/vope/sender.rs @@ -0,0 +1,128 @@ +//! VOPE sender. +use mpz_core::Block; + +use crate::vope::CSP; + +use super::error::SenderError; + +/// VOPE sender +/// This is the verifier in Figure 4. +#[derive(Debug, Default)] +pub struct Sender { + state: T, +} + +impl Sender { + /// Creates a new sender. + pub fn new() -> Self { + Sender { + state: state::Initialized::default(), + } + } + + /// Completes the setup phase of the protocol. + /// + /// See Initialize in Figure 4. + /// + /// # Arguments. + /// + /// * `delta` - The sender's global secret. + pub fn setup(self, delta: Block) -> Sender { + Sender { + state: state::Extension { + delta, + vope_counter: 0, + exec_counter: 0, + }, + } + } +} + +impl Sender { + /// Performs VOPE extension. + /// + /// See step 1-3 in Figure 4. + /// + /// # Arguments + /// + /// * `ks` - The blocks received by calling the COT ideal functionality. + /// * `d` - The degree of the polynomial. + /// + /// Note that this functionality is only suitable for small d. + pub fn extend(&mut self, ks: &[Block], d: usize) -> Result { + if d == 0 { + return Err(SenderError::InvalidInput( + "the degree d should not be 0".to_string(), + )); + } + + if ks.len() != (2 * d - 1) * CSP { + return Err(SenderError::InvalidLength( + "the length of ks should be (2 * d -1) * CSP".to_string(), + )); + } + + let mut ki = vec![Block::ZERO; 2 * d - 1]; + + let base: Vec = (0..CSP) + .map(|x| bytemuck::cast((1_u128) << (CSP - 1 - x))) + .collect(); + + let mut h_ks = ks.to_vec(); + + for k in ki.iter_mut().take(2 * d - 1) { + let buf = h_ks.split_off(CSP); + *k = Block::inn_prdt_red(&h_ks, &base); + h_ks = buf; + } + + let mut b = ki[0]; + + for i in 0..d - 1 { + b = b.gfmul(ki[i + 1]) ^ ki[d + i] + } + + self.state.exec_counter += 1; + self.state.vope_counter += 1; + + Ok(b) + } +} +/// The sender's state. +pub mod state { + use super::*; + + mod sealed { + pub trait Sealed {} + impl Sealed for super::Initialized {} + impl Sealed for super::Extension {} + } + + /// The sender's state. + pub trait State: sealed::Sealed {} + + /// The sender's initial state. + #[derive(Default)] + pub struct Initialized {} + + impl State for Initialized {} + opaque_debug::implement!(Initialized); + + /// The sender's state after the setup phase. + /// + /// In this state the sender performs VOPE extension. + pub struct Extension { + /// Sender's global secret. + #[allow(dead_code)] + pub(crate) delta: Block, + + /// Current VOPE counter + pub(super) vope_counter: usize, + /// Current execution counter + pub(super) exec_counter: usize, + } + + impl State for Extension {} + + opaque_debug::implement!(Extension); +} diff --git a/crates/mpz-zk/Cargo.toml b/crates/mpz-zk/Cargo.toml new file mode 100644 index 00000000..cc147603 --- /dev/null +++ b/crates/mpz-zk/Cargo.toml @@ -0,0 +1,55 @@ +[package] +name = "mpz-zk" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[lints] +workspace = true + +[lib] +name = "mpz_zk" + +[features] +default = ["rayon"] +rayon = ["mpz-ot-core/rayon"] +ideal = ["mpz-common/ideal"] + +[dependencies] +mpz-core.workspace = true +mpz-zk-core.workspace = true +mpz-common.workspace = true +mpz-cointoss.workspace = true +mpz-ot-core.workspace = true +mpz-ot.workspace = true +mpz-circuits.workspace = true + +tlsn-utils-aio.workspace = true + +async-trait.workspace = true +futures.workspace = true +rand.workspace = true +rand_core.workspace = true +rand_chacha.workspace = true +thiserror.workspace = true +rayon = { workspace = true } +itybity.workspace = true +enum-try-as-inner.workspace = true +opaque-debug.workspace = true +serde = { workspace = true, optional = true } +serio.workspace = true +cfg-if.workspace = true + +[dev-dependencies] +mpz-common = { workspace = true, features = ["test-utils", "ideal"] } +mpz-ot-core = { workspace = true, features = ["test-utils"] } +rstest = { workspace = true } +criterion = { workspace = true, features = ["async_tokio"] } +tokio = { workspace = true, features = [ + "net", + "macros", + "rt", + "rt-multi-thread", +] } +aes = { workspace = true, features = [] } diff --git a/crates/mpz-zk/src/lib.rs b/crates/mpz-zk/src/lib.rs new file mode 100644 index 00000000..6a336b8e --- /dev/null +++ b/crates/mpz-zk/src/lib.rs @@ -0,0 +1,37 @@ +//! Implementations of zero-knowledge protocols. + +#![deny( + unsafe_code, + missing_docs, + unused_imports, + unused_must_use, + unreachable_pub, + clippy::all +)] + +pub mod quicksilver; +pub mod vope; + +/// A vope error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum VOPEError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error("sender error: {0}")] + SenderError(Box), + #[error("receiver error: {0}")] + ReceiverError(Box), +} + +/// A zk error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ZKError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error("prover error: {0}")] + ProverError(Box), + #[error("verifier error: {0}")] + VerifierError(Box), +} diff --git a/crates/mpz-zk/src/quicksilver/error.rs b/crates/mpz-zk/src/quicksilver/error.rs new file mode 100644 index 00000000..294e6b6c --- /dev/null +++ b/crates/mpz-zk/src/quicksilver/error.rs @@ -0,0 +1,53 @@ +use mpz_circuits::CircuitError; + +use crate::ZKError; + +/// Prover error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ProverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::quicksilver::QsProverError), + #[error(transparent)] + OTError(#[from] mpz_ot::OTError), + #[error(transparent)] + CircuitError(#[from] CircuitError), + #[error(transparent)] + VOPEError(#[from] crate::vope::error::ReceiverError), +} + +/// Verifier error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum VerifierError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::quicksilver::QsVerifierError), + #[error(transparent)] + OTError(#[from] mpz_ot::OTError), + #[error(transparent)] + CircuitError(#[from] CircuitError), + #[error(transparent)] + VOPEError(#[from] crate::vope::error::SenderError), +} + +impl From for ZKError { + fn from(err: ProverError) -> Self { + match err { + ProverError::IOError(e) => e.into(), + e => ZKError::ProverError(Box::new(e)), + } + } +} + +impl From for ZKError { + fn from(err: VerifierError) -> Self { + match err { + VerifierError::IOError(e) => e.into(), + e => ZKError::VerifierError(Box::new(e)), + } + } +} diff --git a/crates/mpz-zk/src/quicksilver/mod.rs b/crates/mpz-zk/src/quicksilver/mod.rs new file mode 100644 index 00000000..eb26621d --- /dev/null +++ b/crates/mpz-zk/src/quicksilver/mod.rs @@ -0,0 +1,66 @@ +//! Implementation of QuickSilver (https://eprint.iacr.org/2021/076.pdf). + +mod error; +mod prover; +mod verifier; + +pub use error::{ProverError, VerifierError}; +pub use prover::Prover; +pub use verifier::Verifier; + +#[cfg(test)] +mod tests { + use crate::{ + quicksilver::{Prover, Verifier}, + ZKError, + }; + use aes::{ + cipher::{BlockEncrypt, KeyInit}, + Aes128, + }; + use futures::TryFutureExt; + use mpz_circuits::{circuits::AES128, evaluate}; + use mpz_common::executor::test_st_executor; + use mpz_core::prg::Prg; + use mpz_ot::ideal::cot::ideal_rcot_with_delta; + + #[tokio::test] + async fn test_qs() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let mut delta = Prg::new().random_block(); + delta.set_lsb(); + + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot_with_delta(delta); + + let mut prover = Prover::new(); + let mut verifier = Verifier::new(delta); + + let key = [69u8; 16]; + let msg = [42u8; 16]; + + let output = evaluate!(AES128, fn(key, msg) -> [u8; 16]).unwrap(); + + let expected: [u8; 16] = { + let cipher = Aes128::new_from_slice(&key).unwrap(); + let mut out = msg.into(); + cipher.encrypt_block(&mut out); + out.into() + }; + + assert_eq!(output, expected); + + let input_value = [key, msg].concat(); + tokio::try_join!( + prover + .prove(&mut ctx_sender, &AES128, input_value, &mut rcot_receiver) + .map_err(ZKError::from), + verifier + .verify(&mut ctx_receiver, &AES128, output, &mut rcot_sender) + .map_err(ZKError::from) + ) + .unwrap(); + + assert!(verifier.checked()) + } +} diff --git a/crates/mpz-zk/src/quicksilver/prover.rs b/crates/mpz-zk/src/quicksilver/prover.rs new file mode 100644 index 00000000..e4a9dc86 --- /dev/null +++ b/crates/mpz-zk/src/quicksilver/prover.rs @@ -0,0 +1,188 @@ +use itybity::IntoBits; +use mpz_circuits::{types::Value, Circuit, CircuitError, Gate}; +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::Block; +use mpz_ot::{RCOTReceiverOutput, RandomCOTReceiver}; +use mpz_zk_core::quicksilver::Prover as ProverCore; +use serio::SinkExt; + +use super::error::ProverError; + +/// QuickSilver Prover. +pub struct Prover { + macs: Vec, + prover_core: ProverCore, +} + +impl Prover { + /// Create a new instance. + pub fn new() -> Self { + Self { + macs: Vec::default(), + prover_core: ProverCore::new(), + } + } + + // Authenticate inputs. + async fn auth_inputs( + &mut self, + ctx: &mut Ctx, + inputs: &[bool], + rcot: &mut RCOT, + ) -> Result, ProverError> + where + Ctx: Context, + RCOT: RandomCOTReceiver, + { + let cot = rcot.receive_random_correlated(ctx, inputs.len()).await?; + + let (bits, macs) = self.prover_core.auth_input_bits(inputs, cot)?; + + // TODO: optimize sending bools. + ctx.io_mut().send(bits).await?; + + Ok(macs) + } + + /// Prove a circuit. + /// + /// # Arguments. + /// + /// * `ctx` - The context. + /// * `circ` - The circuit. + /// * `input_value` - The witness hold by the prover. + /// * `rcot` - The ideal RCOT functionality. + pub async fn prove( + &mut self, + ctx: &mut Ctx, + circ: &Circuit, + input_value: impl Into, + rcot: &mut RCOT, + ) -> Result<(), ProverError> + where + Ctx: Context, + RCOT: RandomCOTReceiver, + { + let len: usize = circ.inputs().iter().map(|v| v.len()).sum(); + + let input_value = input_value.into().into_lsb0_vec(); + if input_value.len() != len { + return Err(CircuitError::InvalidInputCount(len, input_value.len()))?; + } + + if circ.feed_count() > self.macs.len() { + self.macs.resize(circ.feed_count(), Default::default()); + } + + // Handle inputs. + let input_macs = self.auth_inputs(ctx, &input_value, rcot).await?; + + for (mac, node) in input_macs + .iter() + .zip(circ.inputs().iter().flat_map(|v| v.iter())) + { + self.macs[node.id()] = *mac; + } + + // Authenticate the circuit. + for gate in circ.gates() { + match gate { + Gate::Xor { + x: node_x, + y: node_y, + z: node_z, + } => { + let x_0 = self.macs[node_x.id()]; + let y_0 = self.macs[node_y.id()]; + self.macs[node_z.id()] = x_0 ^ y_0; + } + Gate::And { + x: node_x, + y: node_y, + z: node_z, + } => { + // Check the batched authenticated and gates. + if self.prover_core.enable_check() { + self.check_and_gates(ctx, rcot).await?; + } + + let x_0 = self.macs[node_x.id()]; + let y_0 = self.macs[node_y.id()]; + + let RCOTReceiverOutput { + choices: bit, + msgs: blk, + .. + } = rcot.receive_random_correlated(ctx, 1).await?; + + let (d, z_0) = self.prover_core.auth_and_gate(x_0, y_0, (bit[0], blk[0])); + + // TODO: optimize sending bool. + ctx.io_mut().send(d).await?; + + self.macs[node_z.id()] = z_0; + } + Gate::Inv { + x: node_x, + z: node_z, + } => { + let x_0 = self.macs[node_x.id()]; + self.macs[node_z.id()] = x_0 ^ Block::ONE; + } + } + } + + // Handle final check. + if self.prover_core.enable_final_check() { + self.check_and_gates(ctx, rcot).await?; + } + + // Handle outputs. + let output_macs: Vec = circ + .outputs() + .iter() + .flat_map(|v| v.iter()) + .map(|node| self.macs[node.id()]) + .collect(); + + // Send the hash of the output macs. + let hash = self.prover_core.finish(&output_macs); + ctx.io_mut().send(hash).await?; + + Ok(()) + } + + // Check the and gates. + async fn check_and_gates( + &mut self, + ctx: &mut Ctx, + rcot: &mut RCOT, + ) -> Result<(), ProverError> + where + Ctx: Context, + RCOT: RandomCOTReceiver, + { + let mut vope = crate::vope::receiver::Receiver::new(); + vope.setup()?; + + let v = vope.receive(ctx, rcot, 1).await?; + + let mut prover_core = std::mem::take(&mut self.prover_core); + + let (u, prover_core) = + CpuBackend::blocking(move || (prover_core.check_and_gates(v), prover_core)).await; + + // Send (U, V) + ctx.io_mut().send(u).await?; + + self.prover_core = prover_core; + Ok(()) + } +} + +impl Default for Prover { + #[inline] + fn default() -> Self { + Self::new() + } +} diff --git a/crates/mpz-zk/src/quicksilver/verifier.rs b/crates/mpz-zk/src/quicksilver/verifier.rs new file mode 100644 index 00000000..84f0ab65 --- /dev/null +++ b/crates/mpz-zk/src/quicksilver/verifier.rs @@ -0,0 +1,186 @@ +use itybity::IntoBits; +use mpz_circuits::{types::Value, Circuit, CircuitError, Gate}; +use mpz_common::{cpu::CpuBackend, Context}; +use mpz_core::Block; +use mpz_ot::{RCOTSenderOutput, RandomCOTSender}; +use mpz_zk_core::quicksilver::Verifier as VerifierCore; +use serio::stream::IoStreamExt; + +use super::error::VerifierError; + +/// QuickSilver Verifier. +pub struct Verifier { + keys: Vec, + verifier_core: VerifierCore, +} + +impl Verifier { + /// Create a new instance. + pub fn new(delta: Block) -> Self { + Self { + keys: Vec::default(), + verifier_core: VerifierCore::new(delta), + } + } + + // Authenticate inputs. + async fn auth_inputs( + &mut self, + ctx: &mut Ctx, + len: usize, + rcot: &mut RCOT, + ) -> Result, VerifierError> + where + Ctx: Context, + RCOT: RandomCOTSender, + { + let cot = rcot.send_random_correlated(ctx, len).await?; + + let masks: Vec = ctx.io_mut().expect_next().await?; + + assert_eq!(masks.len(), len); + + let keys = self.verifier_core.auth_input_bits(&masks, cot)?; + + Ok(keys) + } + + /// Verify a circuit. + /// + /// # Arguments. + /// + /// * `ctx` - The context. + /// * `circ` - The circuit. + /// * `output_value` - The public output value hold by the verifier and prover. + /// * `rcot` - The ideal RCOT functionality. + pub async fn verify( + &mut self, + ctx: &mut Ctx, + circ: &Circuit, + output_value: impl Into, + rcot: &mut RCOT, + ) -> Result<(), VerifierError> + where + Ctx: Context, + RCOT: RandomCOTSender, + { + let len: usize = circ.outputs().iter().map(|v| v.len()).sum(); + + let output_value = output_value.into().into_lsb0_vec(); + if output_value.len() != len { + return Err(CircuitError::InvalidOutputCount(len, output_value.len()))?; + } + + if circ.feed_count() > self.keys.len() { + self.keys.resize(circ.feed_count(), Default::default()); + } + + // Handle inputs. + let input_len: usize = circ.inputs().iter().map(|v| v.len()).sum(); + let input_keys = self.auth_inputs(ctx, input_len, rcot).await?; + + for (key, node) in input_keys + .iter() + .zip(circ.inputs().iter().flat_map(|v| v.iter())) + { + self.keys[node.id()] = *key; + } + + // Authenticate the circuit. + for gate in circ.gates() { + match gate { + Gate::Xor { + x: node_x, + y: node_y, + z: node_z, + } => { + let x_0 = self.keys[node_x.id()]; + let y_0 = self.keys[node_y.id()]; + self.keys[node_z.id()] = x_0 ^ y_0; + } + Gate::And { + x: node_x, + y: node_y, + z: node_z, + } => { + // Check the batched authenticated and gates. + if self.verifier_core.enable_check() { + self.check_and_gates(ctx, rcot).await?; + } + + let x_0 = self.keys[node_x.id()]; + let y_0 = self.keys[node_y.id()]; + + let RCOTSenderOutput { msgs: blk, .. } = + rcot.send_random_correlated(ctx, 1).await?; + + let mask = ctx.io_mut().expect_next().await?; + let z_0 = self.verifier_core.auth_and_gate(x_0, y_0, mask, blk[0]); + + self.keys[node_z.id()] = z_0; + } + Gate::Inv { + x: node_x, + z: node_z, + } => { + let x_0 = self.keys[node_x.id()]; + self.keys[node_z.id()] = x_0 ^ self.verifier_core.delta() ^ Block::ONE; + } + } + } + + // Handle final check. + if self.verifier_core.enable_final_check() { + self.check_and_gates(ctx, rcot).await?; + } + + // Handle outputs. + let output_keys: Vec = circ + .outputs() + .iter() + .flat_map(|v| v.iter()) + .map(|node| self.keys[node.id()]) + .collect(); + + // Receive the hash of output macs and verify. + let hash = ctx.io_mut().expect_next().await?; + self.verifier_core + .finish(hash, &output_keys, &output_value)?; + + Ok(()) + } + + // Check the and gates. + async fn check_and_gates( + &mut self, + ctx: &mut Ctx, + rcot: &mut RCOT, + ) -> Result<(), VerifierError> + where + Ctx: Context, + RCOT: RandomCOTSender, + { + let mut vope = crate::vope::sender::Sender::new(); + vope.setup(self.verifier_core.delta())?; + + let v = vope.send(ctx, rcot, 1).await?; + + let u: (Block, Block) = ctx.io_mut().expect_next().await?; + + let mut verifier_core = std::mem::take(&mut self.verifier_core); + + let (_, verifier_core) = CpuBackend::blocking(move || { + (verifier_core.check_and_gates(v, u.0, u.1), verifier_core) + }) + .await; + + self.verifier_core = verifier_core; + Ok(()) + } + + /// Returns checked or not. + #[inline] + pub fn checked(&self) -> bool { + self.verifier_core.checked() + } +} diff --git a/crates/mpz-zk/src/vope/error.rs b/crates/mpz-zk/src/vope/error.rs new file mode 100644 index 00000000..912efda1 --- /dev/null +++ b/crates/mpz-zk/src/vope/error.rs @@ -0,0 +1,61 @@ +//! Errors in VOPE + +use crate::VOPEError; + +/// A VOPE Sender error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum SenderError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::vope::error::SenderError), + #[error(transparent)] + RandomCOTError(#[from] mpz_ot::OTError), + #[error("{0}")] + StateError(String), +} + +/// A VOPE Receiver error. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum ReceiverError { + #[error(transparent)] + IOError(#[from] std::io::Error), + #[error(transparent)] + CoreError(#[from] mpz_zk_core::vope::error::ReceiverError), + #[error(transparent)] + RandomCOTError(#[from] mpz_ot::OTError), + #[error("{0}")] + StateError(String), +} + +impl From for VOPEError { + fn from(err: SenderError) -> Self { + match err { + SenderError::IOError(e) => e.into(), + e => VOPEError::SenderError(Box::new(e)), + } + } +} + +impl From for SenderError { + fn from(err: crate::vope::sender::StateError) -> Self { + SenderError::StateError(err.to_string()) + } +} + +impl From for VOPEError { + fn from(err: ReceiverError) -> Self { + match err { + ReceiverError::IOError(e) => e.into(), + e => VOPEError::ReceiverError(Box::new(e)), + } + } +} + +impl From for ReceiverError { + fn from(err: crate::vope::receiver::StateError) -> Self { + ReceiverError::StateError(err.to_string()) + } +} diff --git a/crates/mpz-zk/src/vope/mod.rs b/crates/mpz-zk/src/vope/mod.rs new file mode 100644 index 00000000..e395fa0a --- /dev/null +++ b/crates/mpz-zk/src/vope/mod.rs @@ -0,0 +1,60 @@ +//! This is the implementation of vector oblivious polynomial evaluation (VOPE) based on Figure 4 in https://eprint.iacr.org/2021/076.pdf + +pub mod error; +pub mod receiver; +pub mod sender; + +#[cfg(test)] +mod tests { + use crate::{ + vope::{receiver::Receiver, sender::Sender}, + VOPEError, + }; + use futures::TryFutureExt; + use mpz_common::executor::test_st_executor; + use mpz_ot::{ideal::cot::ideal_rcot, Correlation}; + use mpz_zk_core::test::assert_vope; + + #[tokio::test] + async fn test_vope() { + let (mut ctx_sender, mut ctx_receiver) = test_st_executor(8); + + let (mut rcot_sender, mut rcot_receiver) = ideal_rcot(); + + let mut sender = Sender::new(); + let mut receiver = Receiver::new(); + + let delta = rcot_sender.delta(); + + sender.setup(delta).unwrap(); + receiver.setup().unwrap(); + + let d = 1; + + let (output_sender, output_receiver) = tokio::try_join!( + sender + .send(&mut ctx_sender, &mut rcot_sender, d) + .map_err(VOPEError::from), + receiver + .receive(&mut ctx_receiver, &mut rcot_receiver, d) + .map_err(VOPEError::from) + ) + .unwrap(); + + assert!(assert_vope(output_sender, output_receiver, delta)); + + let d = 5; + + let (output_sender, output_receiver) = tokio::try_join!( + sender + .send(&mut ctx_sender, &mut rcot_sender, d) + .map_err(VOPEError::from), + receiver + .receive(&mut ctx_receiver, &mut rcot_receiver, d) + .map_err(VOPEError::from) + ) + .unwrap(); + + assert!(assert_vope(output_sender, output_receiver, delta)); + } +} diff --git a/crates/mpz-zk/src/vope/receiver.rs b/crates/mpz-zk/src/vope/receiver.rs new file mode 100644 index 00000000..f34958a1 --- /dev/null +++ b/crates/mpz-zk/src/vope/receiver.rs @@ -0,0 +1,106 @@ +//! Implementation of VOPE receiver. + +use crate::vope::error::ReceiverError; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot::{RCOTReceiverOutput, RandomCOTReceiver, TransferId}; +use mpz_zk_core::{ + vope::{ + receiver::{state, Receiver as ReceiverCore}, + CSP, + }, + VOPEReceiverOutput, +}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +#[allow(missing_docs)] +pub enum State { + Initialized(ReceiverCore), + Extension(ReceiverCore), + Error, +} + +/// VOPE receiver (prover) +#[derive(Debug)] +pub struct Receiver { + state: State, + id: TransferId, +} + +impl Receiver { + /// Creates a new receiver. + pub fn new() -> Self { + Self { + state: State::Initialized(ReceiverCore::new()), + id: TransferId::default(), + } + } + + /// Performs setup for receiver. + pub fn setup(&mut self) -> Result<(), ReceiverError> { + let ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_receiver = ext_receiver.setup(); + + self.state = State::Extension(ext_receiver); + + Ok(()) + } + + /// Performs VOPE extension for receiver. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `rcot` - The ideal random COT. + /// * `d` - The polynomial degree. + pub async fn receive( + &mut self, + ctx: &mut Ctx, + rcot: &mut RCOT, + d: usize, + ) -> Result, ReceiverError> + where + Ctx: Context, + RCOT: RandomCOTReceiver, + { + let mut ext_receiver = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + assert!(d > 0); + + let RCOTReceiverOutput { + msgs: ms, + choices: us, + .. + } = rcot + .receive_random_correlated(ctx, (2 * d - 1) * CSP) + .await?; + + // extend + let (ext_receiver, res) = Backend::spawn(move || { + ext_receiver + .extend(&ms, &us, d) + .map(|res| (ext_receiver, res)) + }) + .await?; + + self.state = State::Extension(ext_receiver); + + Ok(VOPEReceiverOutput { + id: self.id.next_id(), + coeff: res, + }) + } +} + +impl Default for Receiver { + #[inline] + fn default() -> Self { + Self::new() + } +} diff --git a/crates/mpz-zk/src/vope/sender.rs b/crates/mpz-zk/src/vope/sender.rs new file mode 100644 index 00000000..2643ad5a --- /dev/null +++ b/crates/mpz-zk/src/vope/sender.rs @@ -0,0 +1,99 @@ +//! Implementation of VOPE sender + +use crate::vope::error::SenderError; +use enum_try_as_inner::EnumTryAsInner; +use mpz_common::Context; +use mpz_core::Block; +use mpz_ot::{RCOTSenderOutput, RandomCOTSender, TransferId}; +use mpz_zk_core::{ + vope::{ + sender::{state, Sender as SenderCore}, + CSP, + }, + VOPESenderOutput, +}; +use utils_aio::non_blocking_backend::{Backend, NonBlockingBackend}; + +#[derive(Debug, EnumTryAsInner)] +#[derive_err(Debug)] +#[allow(missing_docs)] +pub enum State { + Initialized(SenderCore), + Extension(SenderCore), + Error, +} + +/// VOPE sender (verifier) +#[derive(Debug)] +pub struct Sender { + state: State, + id: TransferId, +} + +impl Sender { + /// Creates a new Sender. + pub fn new() -> Self { + Self { + state: State::Initialized(SenderCore::new()), + id: TransferId::default(), + } + } + + /// Performs setup with the provided delta. + /// + /// # Arguments + /// + /// * `delta` - The delta value to use for VOPE extension. + pub fn setup(&mut self, delta: Block) -> Result<(), SenderError> { + let ext_sender = std::mem::replace(&mut self.state, State::Error).try_into_initialized()?; + + let ext_sender = ext_sender.setup(delta); + + self.state = State::Extension(ext_sender); + + Ok(()) + } + + /// Performs VOPE extension for sender. + /// + /// # Arguments + /// + /// * `ctx` - The context. + /// * `rcot` - The ideal random COT. + /// * `d` - The polynomial degree. + pub async fn send( + &mut self, + ctx: &mut Ctx, + rcot: &mut RCOT, + d: usize, + ) -> Result, SenderError> + where + Ctx: Context, + RCOT: RandomCOTSender, + { + let mut ext_sender = + std::mem::replace(&mut self.state, State::Error).try_into_extension()?; + + assert!(d > 0); + + let RCOTSenderOutput { msgs: ks, .. } = + rcot.send_random_correlated(ctx, (2 * d - 1) * CSP).await?; + + let (ext_sender, res) = + Backend::spawn(move || ext_sender.extend(&ks, d).map(|res| (ext_sender, res))).await?; + + self.state = State::Extension(ext_sender); + + Ok(VOPESenderOutput { + id: self.id.next_id(), + eval: res, + }) + } +} + +impl Default for Sender { + #[inline] + fn default() -> Self { + Self::new() + } +}