diff --git a/mpc-core/src/protocols/rep3/arithmetic.rs b/mpc-core/src/protocols/rep3/arithmetic.rs index 36101adf..99d64c52 100644 --- a/mpc-core/src/protocols/rep3/arithmetic.rs +++ b/mpc-core/src/protocols/rep3/arithmetic.rs @@ -683,3 +683,69 @@ pub(crate) fn arithmetic_xor_many( .collect(); Ok(res) } + +/// Reshares the shared valuse from two parties to one other +/// Assumes seeds are set up correctly already +pub fn reshare_from_2_to_3_parties( + input: Option>>, + len: usize, + recipient: PartyID, + io_context: &mut IoContext, +) -> IoResult>> { + if io_context.id == recipient { + let mut result = Vec::with_capacity(len); + for _ in 0..len { + let (a, b) = io_context.random_fes::(); + result.push(Rep3PrimeFieldShare::new(a, b)); + } + return Ok(result); + } + + if input.is_none() { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "During execution of reshare_from_2_to_3_parties in MPC: input is None", + )); + } + + let input = input.unwrap(); + if input.len() != len { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "During execution of reshare_from_2_to_3_parties in MPC: input length does not match", + )); + } + + let mut rand = Vec::with_capacity(len); + let mut result = Vec::with_capacity(len); + if io_context.id == recipient.next_id() { + for inp in input { + let beta = inp.a + inp.b; + let b = io_context.rngs.rand.random_field_element_rng2(); + let r = beta - b; + rand.push(r); + result.push(Rep3PrimeFieldShare::new(r, b)); + } + let comm_id = io_context.id.next_id(); + io_context.network.send_many(comm_id, &rand)?; + let rcv = io_context.network.recv_many::(comm_id)?; + for (res, r) in result.iter_mut().zip(rcv) { + res.a += r; + } + } else { + for inp in input { + let beta = inp.a; + let a = io_context.rngs.rand.random_field_element_rng1(); + let r = beta - a; + rand.push(r); + result.push(Rep3PrimeFieldShare::new(a, r)); + } + let comm_id = io_context.id.prev_id(); + io_context.network.send_many(comm_id, &rand)?; + let rcv = io_context.network.recv_many::(comm_id)?; + for (res, r) in result.iter_mut().zip(rcv) { + res.b += r; + } + } + Ok(result) +} diff --git a/tests/tests/mpc/rep3.rs b/tests/tests/mpc/rep3.rs index cfd821e7..5b5a9868 100644 --- a/tests/tests/mpc/rep3.rs +++ b/tests/tests/mpc/rep3.rs @@ -1488,6 +1488,53 @@ mod field_share { let is_result = rep3::combine_field_elements(&result1, &result2, &result3); assert_eq!(is_result, should_result); } + + fn reshare_from_2_to_3_parties_test_internal(recipient: PartyID) { + const VEC_SIZE: usize = 10; + + let test_network = Rep3TestNetwork::default(); + let mut rng = thread_rng(); + let x = (0..VEC_SIZE) + .map(|_| ark_bn254::Fr::rand(&mut rng)) + .collect_vec(); + let x_shares = rep3::share_field_elements(&x, &mut rng); + + let (tx1, rx1) = mpsc::channel(); + let (tx2, rx2) = mpsc::channel(); + let (tx3, rx3) = mpsc::channel(); + + for (net, tx, x) in izip!( + test_network.get_party_networks().into_iter(), + [tx1, tx2, tx3], + x_shares.into_iter() + ) { + thread::spawn(move || { + let mut rep3 = IoContext::init(net).unwrap(); + + let decomposed = arithmetic::reshare_from_2_to_3_parties( + Some(x), + VEC_SIZE, + recipient, + &mut rep3, + ) + .unwrap(); + tx.send(decomposed) + }); + } + + let result1 = rx1.recv().unwrap(); + let result2 = rx2.recv().unwrap(); + let result3 = rx3.recv().unwrap(); + let is_result = rep3::combine_field_elements(&result1, &result2, &result3); + assert_eq!(is_result, x); + } + + #[test] + fn reshare_from_2_to_3_parties_test() { + reshare_from_2_to_3_parties_test_internal(PartyID::ID0); + reshare_from_2_to_3_parties_test_internal(PartyID::ID1); + reshare_from_2_to_3_parties_test_internal(PartyID::ID2); + } } mod curve_share {