diff --git a/co-noir/co-acvm/src/mpc.rs b/co-noir/co-acvm/src/mpc.rs index 63f359f4..6f880a43 100644 --- a/co-noir/co-acvm/src/mpc.rs +++ b/co-noir/co-acvm/src/mpc.rs @@ -142,15 +142,14 @@ pub trait NoirWitnessExtensionProtocol { fn decompose_arithmetic( &mut self, input: Self::ArithmeticShare, - // io_context: &mut IoContext, total_bit_size_per_field: usize, decompose_bit_size: usize, ) -> std::io::Result>; - /// Decompose a shared value into a vector of shared values: \[a\] = a_1 + a_2 + ... + a_n. Each value a_i has at most decompose_bit_size bits, whereas the total bit size of the shares is total_bit_size_per_field. Thus, a_n, might have a smaller bitsize than the other chunks - fn decompose_arithmetic_many( + + /// Decompose a shared value into a vector of shared values: \[a\] = a_1 + a_2 + ... + a_n. Each value a_i has at most decompose_bit_size bits, whereas the total bit size of the shares is total_bit_size_per_field. Thus, a_n, might have a smaller bitsize than the other chunks + fn decompose_arithmetic_many( &mut self, input: &[Self::ArithmeticShare], - // io_context: &mut IoContext, total_bit_size_per_field: usize, decompose_bit_size: usize, ) -> std::io::Result>>; diff --git a/co-noir/co-acvm/src/mpc/plain.rs b/co-noir/co-acvm/src/mpc/plain.rs index 75c5d8c4..05276e80 100644 --- a/co-noir/co-acvm/src/mpc/plain.rs +++ b/co-noir/co-acvm/src/mpc/plain.rs @@ -188,6 +188,20 @@ impl NoirWitnessExtensionProtocol for PlainAcvmSolver { Ok(result) } + fn decompose_arithmetic_many( + &mut self, + input: &[Self::ArithmeticShare], + total_bit_size_per_field: usize, + decompose_bit_size: usize, + ) -> std::io::Result>> { + input + .iter() + .map(|&inp| { + Self::decompose_arithmetic(self, inp, total_bit_size_per_field, decompose_bit_size) + }) + .collect() + } + fn sort( &mut self, inputs: &[Self::ArithmeticShare], @@ -203,19 +217,4 @@ impl NoirWitnessExtensionProtocol for PlainAcvmSolver { result.sort(); Ok(result) } - - fn decompose_arithmetic_many( - &mut self, - input: &[Self::ArithmeticShare], - // io_context: &mut IoContext, - total_bit_size_per_field: usize, - decompose_bit_size: usize, - ) -> std::io::Result>> { - input - .iter() - .map(|&inp| { - Self::decompose_arithmetic(self, inp, total_bit_size_per_field, decompose_bit_size) - }) - .collect() - } } diff --git a/co-noir/co-acvm/src/mpc/rep3.rs b/co-noir/co-acvm/src/mpc/rep3.rs index d92a80c9..2249556a 100644 --- a/co-noir/co-acvm/src/mpc/rep3.rs +++ b/co-noir/co-acvm/src/mpc/rep3.rs @@ -491,32 +491,37 @@ impl NoirWitnessExtensionProtocol for Rep3Acvm ) } - fn sort( - &mut self, - inputs: &[Self::ArithmeticShare], - bitsize: usize, - ) -> std::io::Result> { - radix_sort_fields(inputs, &mut self.io_context, bitsize) - } - fn decompose_arithmetic_many( &mut self, input: &[Self::ArithmeticShare], - // io_context: &mut IoContext, total_bit_size_per_field: usize, decompose_bit_size: usize, ) -> std::io::Result>> { + // Defines an upper bound on the size of the input vector to keep the GC at a reasonable size (for RAM) + const BATCH_SIZE: usize = 512; // TODO adapt this if it requires too much RAM + let num_decomps_per_field = total_bit_size_per_field.div_ceil(decompose_bit_size); - let result = yao::decompose_arithmetic_many( - input, - &mut self.io_context, - total_bit_size_per_field, - decompose_bit_size, - )?; - let results = result - .chunks(num_decomps_per_field) - .map(|chunk| chunk.to_vec()) - .collect(); + let mut results = Vec::with_capacity(input.len()); + + for inp_chunk in input.chunks(BATCH_SIZE) { + let result = yao::decompose_arithmetic_many( + inp_chunk, + &mut self.io_context, + total_bit_size_per_field, + decompose_bit_size, + )?; + for chunk in result.chunks(num_decomps_per_field) { + results.push(chunk.to_vec()); + } + } Ok(results) } + + fn sort( + &mut self, + inputs: &[Self::ArithmeticShare], + bitsize: usize, + ) -> std::io::Result> { + radix_sort_fields(inputs, &mut self.io_context, bitsize) + } } diff --git a/co-noir/co-acvm/src/mpc/shamir.rs b/co-noir/co-acvm/src/mpc/shamir.rs index c4531e02..aee4768b 100644 --- a/co-noir/co-acvm/src/mpc/shamir.rs +++ b/co-noir/co-acvm/src/mpc/shamir.rs @@ -407,22 +407,20 @@ impl NoirWitnessExtensionProtocol for Shamir ) -> std::io::Result> { panic!("functionality decompose_arithmetic not feasible for Shamir") } - - fn sort( - &mut self, - _inputs: &[Self::ArithmeticShare], - _bitsize: usize, - ) -> std::io::Result> { - panic!("functionality sort not feasible for Shamir") - } - fn decompose_arithmetic_many( &mut self, _input: &[Self::ArithmeticShare], - // io_context: &mut IoContext, _total_bit_size_per_field: usize, _decompose_bit_size: usize, ) -> std::io::Result>> { panic!("functionality decompose_arithmetic_many not feasible for Shamir") } + + fn sort( + &mut self, + _inputs: &[Self::ArithmeticShare], + _bitsize: usize, + ) -> std::io::Result> { + panic!("functionality sort not feasible for Shamir") + } } diff --git a/co-noir/co-brillig/src/mpc/rep3.rs b/co-noir/co-brillig/src/mpc/rep3.rs index 42ccab2a..4e285f23 100644 --- a/co-noir/co-brillig/src/mpc/rep3.rs +++ b/co-noir/co-brillig/src/mpc/rep3.rs @@ -1,10 +1,10 @@ use super::{BrilligDriver, PlainBrilligDriver}; use ark_ff::{One as _, PrimeField}; use brillig::{BitSize, IntegerBitSize}; -use mpc_core::protocols::rep3_ring::conversion::a2b; use core::panic; use mpc_core::protocols::rep3::network::{IoContext, Rep3Network}; use mpc_core::protocols::rep3::{self, Rep3PrimeFieldShare}; +use mpc_core::protocols::rep3_ring::conversion::a2b; use mpc_core::protocols::rep3_ring::ring::bit::Bit; use mpc_core::protocols::rep3_ring::ring::int_ring::IntRing2k; use mpc_core::protocols::rep3_ring::ring::ring_impl::RingElement; @@ -1414,8 +1414,6 @@ impl BrilligDriver for Rep3BrilligDriver (Rep3BrilligType::Shared(val), Rep3BrilligType::Public(radix)) => { if let (Shared::Field(val), Public::Int(radix, IntegerBitSize::U32)) = (val, radix) { - - let radix = u32::try_from(radix).expect("must be u32"); let mut input = val; assert!(radix <= 256, "radix is at most 256"); @@ -1447,14 +1445,15 @@ impl BrilligDriver for Rep3BrilligDriver &mut self.io_context, )?; //radix is at most 256, so should fit into u8, but is this necessary? if bits { - let limb_2b = a2b(limb[0], &mut self.io_context)?; - let limb_bit = rep3_ring::conversion::bit_inject(&limb_2b, &mut self.io_context)?; - limbs[i] = Rep3BrilligType::Shared(Shared::::Ring8(limb_bit)); - } - else { - - limbs[i] = Rep3BrilligType::Shared(Shared::::Ring8(limb[0])); - }; + let limb_2b = a2b(limb[0], &mut self.io_context)?; + let limb_bit = rep3_ring::conversion::bit_inject( + &limb_2b, + &mut self.io_context, + )?; + limbs[i] = Rep3BrilligType::Shared(Shared::::Ring8(limb_bit)); + } else { + limbs[i] = Rep3BrilligType::Shared(Shared::::Ring8(limb[0])); + }; input = div; } limbs diff --git a/co-noir/co-builder/src/builder.rs b/co-noir/co-builder/src/builder.rs index 6e5d9ade..f0309a50 100644 --- a/co-noir/co-builder/src/builder.rs +++ b/co-noir/co-builder/src/builder.rs @@ -631,6 +631,7 @@ impl> GenericUltraCi } } } + // decomposes the shared values in batches, separated into the corresponding number of bits the values have #[expect(clippy::type_complexity)] fn prepare_for_range_decompose( @@ -649,19 +650,15 @@ impl> GenericUltraCi for constraint in range_constraints.iter() { let val = &self.get_variable(constraint.witness as usize); - if constraint.num_bits > Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u32 && T::is_shared(val) - { - let num_bits = constraint.num_bits; - + let num_bits = constraint.num_bits; + if num_bits > Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u32 && T::is_shared(val) { + let share_val = T::get_shared(val).expect("Already checked it is shared"); if let Some(&idx) = bits_locations.get(&num_bits) { - to_decompose[idx] - .push(T::get_shared(val).expect("Already checked it is shared")); + to_decompose[idx].push(share_val); decompose_indices.push((true, to_decompose[idx].len() - 1)); } else { let new_idx = to_decompose.len(); - to_decompose.push(vec![ - T::get_shared(val).expect("Already checked it is shared") - ]); + to_decompose.push(vec![share_val]); decompose_indices.push((true, 0)); bits_locations.insert(num_bits, new_idx); } @@ -670,9 +667,9 @@ impl> GenericUltraCi } } - let mut decomposed: Vec>> = Vec::with_capacity(to_decompose.len()); + let mut decomposed = Vec::with_capacity(to_decompose.len()); - for (i, inp) in to_decompose.iter().enumerate() { + for (i, inp) in to_decompose.into_iter().enumerate() { let num_bits = bits_locations .iter() .find_map(|(&key, &value)| if value == i { Some(key) } else { None }) @@ -680,7 +677,7 @@ impl> GenericUltraCi decomposed.push(T::decompose_arithmetic_many( driver, - inp, + &inp, num_bits as usize, Self::DEFAULT_PLOOKUP_RANGE_BITNUM, )?); @@ -775,23 +772,24 @@ impl> GenericUltraCi // todo!("Logic gates"); // } + // We want to decompose all shared elements in parallel let (bits_locations, decomposed, decompose_indices) = self.prepare_for_range_decompose(driver, &constraint_system.range_constraints)?; for (i, constraint) in constraint_system.range_constraints.iter().enumerate() { - if let Some(&idx) = bits_locations.get(&constraint.num_bits) { - if decompose_indices[i].0 { - self.decompose_into_default_range( - driver, - constraint.witness, - constraint.num_bits as u64, - Some(&decomposed[idx][decompose_indices[i].1]), - Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u64, - )?; - } else { - self.create_range_constraint(driver, constraint.witness, constraint.num_bits)?; - } + let idx_option = bits_locations.get(&constraint.num_bits); + if idx_option.is_some() && decompose_indices[i].0 { + // Already decomposed + let idx = idx_option.unwrap().to_owned(); + self.decompose_into_default_range( + driver, + constraint.witness, + constraint.num_bits as u64, + Some(&decomposed[idx][decompose_indices[i].1]), + Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u64, + )?; } else { + // Either we do not have to decompose or the value is public self.create_range_constraint(driver, constraint.witness, constraint.num_bits)?; } @@ -2419,6 +2417,7 @@ impl> GenericUltraCi self.create_new_range_constraint(variable_index, (1u64 << num_bits) - 1); } else { + // The value must be public, otherwise it would have been batch decomposed already self.decompose_into_default_range( driver, variable_index, @@ -2529,11 +2528,13 @@ impl> GenericUltraCi let mut sublimb_indices: Vec = Vec::with_capacity(num_limbs as usize); let sublimbs: Vec = match decompose { + // Already decomposed, i.e., we just take the values Some(decomposed) => decomposed .iter() .map(|item| T::AcvmType::from(item.clone())) .collect(), None => { + // Not yet decomposed, i.e., it was a public value let mut accumulator: BigUint = T::get_public(&val) .expect("Already checked it is public") .into(); diff --git a/co-noir/co-noir/examples/test_vectors/ranges/Nargo.toml b/co-noir/co-noir/examples/test_vectors/ranges/Nargo.toml index 684d7044..0eca26e1 100644 --- a/co-noir/co-noir/examples/test_vectors/ranges/Nargo.toml +++ b/co-noir/co-noir/examples/test_vectors/ranges/Nargo.toml @@ -2,6 +2,6 @@ name = "ranges" type = "bin" authors = [""] -compiler_version = ">=0.38.0" +compiler_version = ">=1.0.0" [dependencies] diff --git a/co-noir/co-noir/examples/test_vectors/ranges/src/main.nr b/co-noir/co-noir/examples/test_vectors/ranges/src/main.nr index a1eeb26f..7bbc3fc8 100644 --- a/co-noir/co-noir/examples/test_vectors/ranges/src/main.nr +++ b/co-noir/co-noir/examples/test_vectors/ranges/src/main.nr @@ -1,3 +1,3 @@ -fn main(x: u64,w: u16,u: u32, y: pub u64, z: pub u64, s: u32, t: u16) -> pub (u64, u32, u16) { +fn main(x: u64, w: u16, u: u32, y: pub u64, z: pub u64, s: u32, t: u16) -> pub (u64, u32, u16) { (x + y + z, u * s, w * t) } diff --git a/mpc-core/src/protocols/rep3/yao.rs b/mpc-core/src/protocols/rep3/yao.rs index 4774a9ab..a0e4563b 100644 --- a/mpc-core/src/protocols/rep3/yao.rs +++ b/mpc-core/src/protocols/rep3/yao.rs @@ -235,7 +235,7 @@ impl GCUtils { Ok(F::from(res)) } - fn biguint_to_bits(input: BigUint, n_bits: usize) -> Vec { + fn biguint_to_bits(input: &BigUint, n_bits: usize) -> Vec { let mut res = Vec::with_capacity(n_bits); let mut bits = 0; for mut el in input.to_u64_digits() { @@ -252,6 +252,13 @@ impl GCUtils { res } + fn field_to_bits(field: F) -> Vec { + let n_bits = F::MODULUS_BIT_SIZE as usize; + let bigint: BigUint = field.into(); + + Self::biguint_to_bits(&bigint, n_bits) + } + fn field_to_bits_as_u16(field: F) -> Vec { let n_bits = F::MODULUS_BIT_SIZE as usize; let bigint: BigUint = field.into(); @@ -674,6 +681,7 @@ pub fn decompose_arithmetic( decompose_bit_size, ) } + /// Divides a vector of field elements by a power of 2, rounding down. pub fn field_int_div_power_2_many( inputs: &[Rep3PrimeFieldShare], @@ -715,13 +723,8 @@ pub fn field_int_div_many( io_context: &mut IoContext, ) -> IoResult>> { let num_inputs = input1.len(); + debug_assert_eq!(input1.len(), input2.len()); - // if divisor_bit == 0 { - // return Ok(inputs.to_owned()); - // } - // if divisor_bit >= F::MODULUS_BIT_SIZE as usize { - // return Ok(vec![Rep3PrimeFieldShare::zero_share(); num_inputs]); - // } let mut combined_inputs = Vec::with_capacity(input1.len() + input2.len()); combined_inputs.extend_from_slice(input1); combined_inputs.extend_from_slice(input2); @@ -734,7 +737,7 @@ pub fn field_int_div_many( ) } -/// Divides a field element by a power of 2, rounding down. +/// Divides a field element by another, rounding down. pub fn field_int_div( input1: Rep3PrimeFieldShare, input2: Rep3PrimeFieldShare, @@ -743,6 +746,7 @@ pub fn field_int_div( let res = field_int_div_many(&[input1], &[input2], io_context)?; Ok(res[0]) } + /// Divides a vector of field elements by another, rounding down. pub fn field_int_div_by_public_many( input: &[Rep3PrimeFieldShare], @@ -750,18 +754,13 @@ pub fn field_int_div_by_public_many( io_context: &mut IoContext, ) -> IoResult>> { let num_inputs = input.len(); + debug_assert_eq!(input.len(), divisors.len()); - // if divisor_bit == 0 { - // return Ok(inputs.to_owned()); - // } - // if divisor_bit >= F::MODULUS_BIT_SIZE as usize { - // return Ok(vec![Rep3PrimeFieldShare::zero_share(); num_inputs]); - // } field_to_bits_as_u16 let mut divisors_as_bits = Vec::with_capacity(F::MODULUS_BIT_SIZE as usize * num_inputs); divisors .iter() - .for_each(|y| divisors_as_bits.extend(GCUtils::field_to_bits_as_u16::(*y))); - let divisors_as_bits = divisors_as_bits.iter().map(|&x| x != 0).collect(); // rfield_to_bits_as_u16 returns a 0-1 vec + .for_each(|y| divisors_as_bits.extend(GCUtils::field_to_bits::(*y))); + decompose_circuit_compose_blueprint!( &input, io_context, @@ -771,7 +770,7 @@ pub fn field_int_div_by_public_many( ) } -/// Divides a field element by a power of 2, rounding down. +/// Divides a field element by another, rounding down. pub fn field_int_div_by_public( input: Rep3PrimeFieldShare, divisor: F, @@ -780,6 +779,7 @@ pub fn field_int_div_by_public( let res = field_int_div_by_public_many(&[input], &[divisor], io_context)?; Ok(res[0]) } + /// Divides a vector of field elements by another, rounding down. pub fn field_int_div_by_shared_many( input: &[F], @@ -787,28 +787,23 @@ pub fn field_int_div_by_shared_many( io_context: &mut IoContext, ) -> IoResult>> { let num_inputs = input.len(); + debug_assert_eq!(input.len(), divisors.len()); - // if divisor_bit == 0 { - // return Ok(inputs.to_owned()); - // } - // if divisor_bit >= F::MODULUS_BIT_SIZE as usize { - // return Ok(vec![Rep3PrimeFieldShare::zero_share(); num_inputs]); - // } field_to_bits_as_u16 let mut inputs_as_bits = Vec::with_capacity(F::MODULUS_BIT_SIZE as usize * num_inputs); input .iter() - .for_each(|y| inputs_as_bits.extend(GCUtils::field_to_bits_as_u16::(*y))); - let divisors_as_bits = inputs_as_bits.iter().map(|&x| x != 0).collect(); // rfield_to_bits_as_u16 returns a 0-1 vec + .for_each(|y| inputs_as_bits.extend(GCUtils::field_to_bits::(*y))); + decompose_circuit_compose_blueprint!( &divisors, io_context, num_inputs, GarbledCircuits::field_int_div_by_shared_many::<_, F>, - (divisors_as_bits) + (inputs_as_bits) ) } -/// Divides a field element by a power of 2, rounding down. +/// Divides a field element by another, rounding down. pub fn field_int_div_by_shared( input: F, divisor: Rep3PrimeFieldShare, diff --git a/mpc-core/src/protocols/rep3/yao/circuits.rs b/mpc-core/src/protocols/rep3/yao/circuits.rs index d0b2c8bf..eac2ee03 100644 --- a/mpc-core/src/protocols/rep3/yao/circuits.rs +++ b/mpc-core/src/protocols/rep3/yao/circuits.rs @@ -483,7 +483,7 @@ impl GarbledCircuits { // Prepare p for subtraction let new_bitlen = bitlen + 1; let p_ = (BigUint::from(1u64) << new_bitlen) - F::MODULUS.into(); - let p_bits = GCUtils::biguint_to_bits(p_, new_bitlen); + let p_bits = GCUtils::biguint_to_bits(&p_, new_bitlen); // manual_rca: let mut subtracted = Vec::with_capacity(bitlen); diff --git a/mpc-core/src/protocols/rep3_ring/yao.rs b/mpc-core/src/protocols/rep3_ring/yao.rs index 5f000991..3c4aa03a 100644 --- a/mpc-core/src/protocols/rep3_ring/yao.rs +++ b/mpc-core/src/protocols/rep3_ring/yao.rs @@ -63,6 +63,16 @@ impl GCUtils { res } + fn ring_to_bits(input: RingElement) -> Vec { + let mut res = Vec::with_capacity(T::K); + let mut el = input; + for _ in 0..T::K { + res.push((el & RingElement::one()) == RingElement::one()); + el >>= 1; + } + res + } + /// This puts the X_0 values into garbler_wires and X_c values into evaluator_wires pub fn encode_ring( ring: RingElement, @@ -740,7 +750,7 @@ where Standard: Distribution, { let num_inputs = input1.len(); - assert_eq!(input1.len(), input2.len()); + debug_assert_eq!(input1.len(), input2.len()); let mut combined_inputs = Vec::with_capacity(input1.len() + input2.len()); combined_inputs.extend_from_slice(input1); @@ -769,6 +779,7 @@ where let res = ring_div_many(&[input1], &[input2], io_context)?; Ok(res[0]) } + /// Divides a vector of ring elements by another public. pub fn ring_div_by_public_many( input: &[Rep3RingShare], @@ -779,12 +790,11 @@ where Standard: Distribution, { let num_inputs = input.len(); - assert_eq!(input.len(), divisors.len()); + debug_assert_eq!(input.len(), divisors.len()); let mut divisors_as_bits = Vec::with_capacity(T::K * num_inputs); divisors .iter() - .for_each(|y| divisors_as_bits.extend(GCUtils::ring_to_bits_as_u16::(*y))); - let divisors_as_bits = divisors_as_bits.iter().map(|&x| x != 0).collect(); // ring_to_bits_as_u16 returns a 0-1 vec + .for_each(|y| divisors_as_bits.extend(GCUtils::ring_to_bits::(*y))); decompose_circuit_compose_blueprint!( &input, io_context, @@ -808,7 +818,8 @@ where let res = ring_div_by_public_many(&[input], &[divisor], io_context)?; Ok(res[0]) } -/// Divides a vector of ring elements by another public. + +/// Divides a public vector of ring elements by another. pub fn ring_div_by_shared_many( input: &[RingElement], divisors: &[Rep3RingShare], @@ -822,8 +833,7 @@ where let mut input_as_bits = Vec::with_capacity(T::K * num_inputs); input .iter() - .for_each(|y| input_as_bits.extend(GCUtils::ring_to_bits_as_u16::(*y))); - let input_as_bits = input_as_bits.iter().map(|&x| x != 0).collect(); // ring_to_bits_as_u16 returns a 0-1 vec + .for_each(|y| input_as_bits.extend(GCUtils::ring_to_bits::(*y))); decompose_circuit_compose_blueprint!( &divisors, io_context, @@ -834,7 +844,7 @@ where ) } -/// Divides a ring element by another public. +/// Divides a public ring element by another. pub fn ring_div_by_shared( input: RingElement, divisor: Rep3RingShare, diff --git a/tests/tests/mpc/rep3.rs b/tests/tests/mpc/rep3.rs index 4d9b663e..61a7e192 100644 --- a/tests/tests/mpc/rep3.rs +++ b/tests/tests/mpc/rep3.rs @@ -1550,12 +1550,8 @@ mod field_share { let y = (0..VEC_SIZE) .map(|_| ark_bn254::Fr::rand(&mut rng)) .collect_vec(); - let y_1 = y.clone(); - let y_2 = y.clone(); - let y_3 = y.clone(); - let ys = [y_1, y_2, y_3]; let mut should_result = Vec::with_capacity(VEC_SIZE); - for (x, y) in x.into_iter().zip(y.into_iter()) { + for (x, y) in x.into_iter().zip(y.iter().cloned()) { let x: BigUint = x.into(); let y: BigUint = y.into(); @@ -1566,16 +1562,15 @@ mod field_share { let (tx2, rx2) = mpsc::channel(); let (tx3, rx3) = mpsc::channel(); - for (net, tx, x, y_c) in izip!( + for (net, tx, x) in izip!( test_network.get_party_networks().into_iter(), [tx1, tx2, tx3], x_shares.into_iter(), - ys.into_iter() ) { + let y_ = y.to_owned(); thread::spawn(move || { let mut rep3 = IoContext::init(net).unwrap(); - let decomposed = - yao::field_int_div_by_public_many(&x, y_c.as_ref(), &mut rep3).unwrap(); + let decomposed = yao::field_int_div_by_public_many(&x, &y_, &mut rep3).unwrap(); tx.send(decomposed) }); } @@ -1600,11 +1595,9 @@ mod field_share { .map(|_| ark_bn254::Fr::rand(&mut rng)) .collect_vec(); let y_shares = rep3::share_field_elements(&y, &mut rng); - let x_1 = x.clone(); - let x_2 = x.clone(); - let x_3 = x.clone(); + let mut should_result = Vec::with_capacity(VEC_SIZE); - for (x, y) in x.into_iter().zip(y.into_iter()) { + for (x, y) in x.iter().cloned().zip(y.into_iter()) { let x: BigUint = x.into(); let y: BigUint = y.into(); @@ -1615,16 +1608,16 @@ mod field_share { let (tx2, rx2) = mpsc::channel(); let (tx3, rx3) = mpsc::channel(); - for (net, tx, y_c, x_c) in izip!( + for (net, tx, y_c) in izip!( test_network.get_party_networks().into_iter(), [tx1, tx2, tx3], y_shares.into_iter(), - [x_1, x_2, x_3] ) { + let x_ = x.to_owned(); thread::spawn(move || { let mut rep3 = IoContext::init(net).unwrap(); - let div = yao::field_int_div_by_shared_many(x_c.as_ref(), &y_c, &mut rep3).unwrap(); + let div = yao::field_int_div_by_shared_many(&x_, &y_c, &mut rep3).unwrap(); tx.send(div) }); } diff --git a/tests/tests/mpc/rep3_ring.rs b/tests/tests/mpc/rep3_ring.rs index 4aff6cfe..34156942 100644 --- a/tests/tests/mpc/rep3_ring.rs +++ b/tests/tests/mpc/rep3_ring.rs @@ -51,6 +51,18 @@ mod ring_share { }; } + fn gen_non_zero(rng: &mut R) -> RingElement + where + Standard: Distribution, + { + loop { + let el = rng.gen::>(); + if !el.is_zero() { + return el; + } + } + } + // TODO we dont need channels, we can just join fn rep3_add_t() @@ -60,7 +72,6 @@ mod ring_share { let mut rng = thread_rng(); let x = rng.gen::>(); let y = rng.gen::>(); - println!("x {x} y {x}"); let x_shares = rep3_ring::share_ring_element(x, &mut rng); let y_shares = rep3_ring::share_ring_element(y, &mut rng); let should_result = x + y; @@ -1728,6 +1739,12 @@ mod ring_share { let is_result = rep3_ring::combine_ring_elements(&result1, &result2, &result3); assert_eq!(is_result, should_result); } + + #[test] + fn rep3_div_power_2_via_yao() { + apply_to_all!(rep3_div_power_2_via_yao_t, [Bit, u8, u16, u32, u64, u128]); + } + fn rep3_bin_div_via_yao_t() where Standard: Distribution, @@ -1740,7 +1757,7 @@ mod ring_share { .map(|_| rng.gen::>()) .collect_vec(); let y = (0..VEC_SIZE) - .map(|_| rng.gen::>()) + .map(|_| gen_non_zero::(&mut rng)) .collect_vec(); let x_shares = rep3_ring::share_ring_elements(&x, &mut rng); let y_shares = rep3_ring::share_ring_elements(&y, &mut rng); @@ -1774,6 +1791,12 @@ mod ring_share { let is_result = rep3_ring::combine_ring_elements(&result1, &result2, &result3); assert_eq!(is_result, should_result); } + + #[test] + fn rep3_bin_div_via_yao() { + apply_to_all!(rep3_bin_div_via_yao_t, [u8, u16, u32, u64, u128]); + } + fn rep3_bin_div_by_public_via_yao_t() where Standard: Distribution, @@ -1786,14 +1809,11 @@ mod ring_share { .map(|_| rng.gen::>()) .collect_vec(); let y = (0..VEC_SIZE) - .map(|_| rng.gen::>()) + .map(|_| gen_non_zero::(&mut rng)) .collect_vec(); - let y_1 = y.clone(); - let y_2 = y.clone(); - let y_3 = y.clone(); let x_shares = rep3_ring::share_ring_elements(&x, &mut rng); let mut should_result: Vec> = Vec::with_capacity(VEC_SIZE); - for (x, y) in x.into_iter().zip(y.into_iter()) { + for (x, y) in x.into_iter().zip(y.iter()) { should_result.push(RingElement(T::cast_from_biguint( &(x.0.cast_to_biguint() / y.0.cast_to_biguint()), ))); @@ -1802,16 +1822,16 @@ mod ring_share { let (tx2, rx2) = mpsc::channel(); let (tx3, rx3) = mpsc::channel(); - for (net, tx, x, y_c) in izip!( + for (net, tx, x) in izip!( test_network.get_party_networks().into_iter(), [tx1, tx2, tx3], x_shares.into_iter(), - [y_1, y_2, y_3] ) { + let y_ = y.to_owned(); thread::spawn(move || { let mut rep3 = IoContext::init(net).unwrap(); - let div = yao::ring_div_by_public_many(&x, &y_c, &mut rep3).unwrap(); + let div = yao::ring_div_by_public_many(&x, &y_, &mut rep3).unwrap(); tx.send(div) }); } @@ -1822,6 +1842,12 @@ mod ring_share { let is_result = rep3_ring::combine_ring_elements(&result1, &result2, &result3); assert_eq!(is_result, should_result); } + + #[test] + fn rep3_bin_div_by_public_via_yao() { + apply_to_all!(rep3_bin_div_by_public_via_yao_t, [u8, u16, u32, u64, u128]); + } + fn rep3_bin_div_by_shared_via_yao_t() where Standard: Distribution, @@ -1834,14 +1860,11 @@ mod ring_share { .map(|_| rng.gen::>()) .collect_vec(); let y = (0..VEC_SIZE) - .map(|_| rng.gen::>()) + .map(|_| gen_non_zero::(&mut rng)) .collect_vec(); - let x_1 = x.clone(); - let x_2 = x.clone(); - let x_3 = x.clone(); let y_shares = rep3_ring::share_ring_elements(&y, &mut rng); let mut should_result: Vec> = Vec::with_capacity(VEC_SIZE); - for (x, y) in x.into_iter().zip(y.into_iter()) { + for (x, y) in x.iter().zip(y.into_iter()) { should_result.push(RingElement(T::cast_from_biguint( &(x.0.cast_to_biguint() / y.0.cast_to_biguint()), ))); @@ -1850,16 +1873,16 @@ mod ring_share { let (tx2, rx2) = mpsc::channel(); let (tx3, rx3) = mpsc::channel(); - for (net, tx, y_c, x_c) in izip!( + for (net, tx, y_c) in izip!( test_network.get_party_networks().into_iter(), [tx1, tx2, tx3], y_shares.into_iter(), - [x_1, x_2, x_3] ) { + let x_ = x.to_owned(); thread::spawn(move || { let mut rep3 = IoContext::init(net).unwrap(); - let div = yao::ring_div_by_shared_many(&x_c, &y_c, &mut rep3).unwrap(); + let div = yao::ring_div_by_shared_many(&x_, &y_c, &mut rep3).unwrap(); tx.send(div) }); } @@ -1870,21 +1893,9 @@ mod ring_share { let is_result = rep3_ring::combine_ring_elements(&result1, &result2, &result3); assert_eq!(is_result, should_result); } + #[test] fn rep3_bin_div_by_shared_via_yao() { apply_to_all!(rep3_bin_div_by_shared_via_yao_t, [u8, u16, u32, u64, u128]); } - #[test] - fn rep3_bin_div_by_public_via_yao() { - apply_to_all!(rep3_bin_div_by_public_via_yao_t, [u8, u16, u32, u64, u128]); - } - #[test] - fn rep3_bin_div_via_yao() { - apply_to_all!(rep3_bin_div_via_yao_t, [u8, u16, u32, u64, u128]); - } - - #[test] - fn rep3_div_power_2_via_yao() { - apply_to_all!(rep3_div_power_2_via_yao_t, [Bit, u8, u16, u32, u64, u128]); - } }