Skip to content

Commit

Permalink
chore: Define upper bound for decompose arithmetic_many, clean up som…
Browse files Browse the repository at this point in the history
…e code and prevent division by 0 in testcases
  • Loading branch information
rw0x0 committed Jan 8, 2025
1 parent a6e2fe7 commit 8608e24
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 168 deletions.
7 changes: 3 additions & 4 deletions co-noir/co-acvm/src/mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,14 @@ pub trait NoirWitnessExtensionProtocol<F: PrimeField> {
fn decompose_arithmetic(
&mut self,
input: Self::ArithmeticShare,
// io_context: &mut IoContext<N>,
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> std::io::Result<Vec<Self::ArithmeticShare>>;
/// Decompose a shared value into a vector of shared values: \[a\] = a_1 + a_2 + ... + a_n. Each value a_i has at most decompose_bit_size bits, whereas the total bit size of the shares is total_bit_size_per_field. Thus, a_n, might have a smaller bitsize than the other chunks
fn decompose_arithmetic_many(

/// Decompose a shared value into a vector of shared values: \[a\] = a_1 + a_2 + ... + a_n. Each value a_i has at most decompose_bit_size bits, whereas the total bit size of the shares is total_bit_size_per_field. Thus, a_n, might have a smaller bitsize than the other chunks
fn decompose_arithmetic_many(
&mut self,
input: &[Self::ArithmeticShare],
// io_context: &mut IoContext<N>,
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> std::io::Result<Vec<Vec<Self::ArithmeticShare>>>;
Expand Down
29 changes: 14 additions & 15 deletions co-noir/co-acvm/src/mpc/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ impl<F: PrimeField> NoirWitnessExtensionProtocol<F> for PlainAcvmSolver<F> {
Ok(result)
}

fn decompose_arithmetic_many(
&mut self,
input: &[Self::ArithmeticShare],
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> std::io::Result<Vec<Vec<Self::ArithmeticShare>>> {
input
.iter()
.map(|&inp| {
Self::decompose_arithmetic(self, inp, total_bit_size_per_field, decompose_bit_size)
})
.collect()
}

fn sort(
&mut self,
inputs: &[Self::ArithmeticShare],
Expand All @@ -203,19 +217,4 @@ impl<F: PrimeField> NoirWitnessExtensionProtocol<F> for PlainAcvmSolver<F> {
result.sort();
Ok(result)
}

fn decompose_arithmetic_many(
&mut self,
input: &[Self::ArithmeticShare],
// io_context: &mut IoContext<N>,
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> std::io::Result<Vec<Vec<Self::ArithmeticShare>>> {
input
.iter()
.map(|&inp| {
Self::decompose_arithmetic(self, inp, total_bit_size_per_field, decompose_bit_size)
})
.collect()
}
}
43 changes: 24 additions & 19 deletions co-noir/co-acvm/src/mpc/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,32 +491,37 @@ impl<F: PrimeField, N: Rep3Network> NoirWitnessExtensionProtocol<F> for Rep3Acvm
)
}

fn sort(
&mut self,
inputs: &[Self::ArithmeticShare],
bitsize: usize,
) -> std::io::Result<Vec<Self::ArithmeticShare>> {
radix_sort_fields(inputs, &mut self.io_context, bitsize)
}

fn decompose_arithmetic_many(
&mut self,
input: &[Self::ArithmeticShare],
// io_context: &mut IoContext<N>,
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> std::io::Result<Vec<Vec<Self::ArithmeticShare>>> {
// Defines an upper bound on the size of the input vector to keep the GC at a reasonable size (for RAM)
const BATCH_SIZE: usize = 512; // TODO adapt this if it requires too much RAM

let num_decomps_per_field = total_bit_size_per_field.div_ceil(decompose_bit_size);
let result = yao::decompose_arithmetic_many(
input,
&mut self.io_context,
total_bit_size_per_field,
decompose_bit_size,
)?;
let results = result
.chunks(num_decomps_per_field)
.map(|chunk| chunk.to_vec())
.collect();
let mut results = Vec::with_capacity(input.len());

for inp_chunk in input.chunks(BATCH_SIZE) {
let result = yao::decompose_arithmetic_many(
inp_chunk,
&mut self.io_context,
total_bit_size_per_field,
decompose_bit_size,
)?;
for chunk in result.chunks(num_decomps_per_field) {
results.push(chunk.to_vec());
}
}
Ok(results)
}

fn sort(
&mut self,
inputs: &[Self::ArithmeticShare],
bitsize: usize,
) -> std::io::Result<Vec<Self::ArithmeticShare>> {
radix_sort_fields(inputs, &mut self.io_context, bitsize)
}
}
18 changes: 8 additions & 10 deletions co-noir/co-acvm/src/mpc/shamir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,22 +407,20 @@ impl<F: PrimeField, N: ShamirNetwork> NoirWitnessExtensionProtocol<F> for Shamir
) -> std::io::Result<Vec<Self::ArithmeticShare>> {
panic!("functionality decompose_arithmetic not feasible for Shamir")
}

fn sort(
&mut self,
_inputs: &[Self::ArithmeticShare],
_bitsize: usize,
) -> std::io::Result<Vec<Self::ArithmeticShare>> {
panic!("functionality sort not feasible for Shamir")
}

fn decompose_arithmetic_many(
&mut self,
_input: &[Self::ArithmeticShare],
// io_context: &mut IoContext<N>,
_total_bit_size_per_field: usize,
_decompose_bit_size: usize,
) -> std::io::Result<Vec<Vec<Self::ArithmeticShare>>> {
panic!("functionality decompose_arithmetic_many not feasible for Shamir")
}

fn sort(
&mut self,
_inputs: &[Self::ArithmeticShare],
_bitsize: usize,
) -> std::io::Result<Vec<Self::ArithmeticShare>> {
panic!("functionality sort not feasible for Shamir")
}
}
21 changes: 10 additions & 11 deletions co-noir/co-brillig/src/mpc/rep3.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::{BrilligDriver, PlainBrilligDriver};
use ark_ff::{One as _, PrimeField};
use brillig::{BitSize, IntegerBitSize};
use mpc_core::protocols::rep3_ring::conversion::a2b;
use core::panic;
use mpc_core::protocols::rep3::network::{IoContext, Rep3Network};
use mpc_core::protocols::rep3::{self, Rep3PrimeFieldShare};
use mpc_core::protocols::rep3_ring::conversion::a2b;
use mpc_core::protocols::rep3_ring::ring::bit::Bit;
use mpc_core::protocols::rep3_ring::ring::int_ring::IntRing2k;
use mpc_core::protocols::rep3_ring::ring::ring_impl::RingElement;
Expand Down Expand Up @@ -1414,8 +1414,6 @@ impl<F: PrimeField, N: Rep3Network> BrilligDriver<F> for Rep3BrilligDriver<F, N>
(Rep3BrilligType::Shared(val), Rep3BrilligType::Public(radix)) => {
if let (Shared::Field(val), Public::Int(radix, IntegerBitSize::U32)) = (val, radix)
{


let radix = u32::try_from(radix).expect("must be u32");
let mut input = val;
assert!(radix <= 256, "radix is at most 256");
Expand Down Expand Up @@ -1447,14 +1445,15 @@ impl<F: PrimeField, N: Rep3Network> BrilligDriver<F> for Rep3BrilligDriver<F, N>
&mut self.io_context,
)?; //radix is at most 256, so should fit into u8, but is this necessary?
if bits {
let limb_2b = a2b(limb[0], &mut self.io_context)?;
let limb_bit = rep3_ring::conversion::bit_inject(&limb_2b, &mut self.io_context)?;
limbs[i] = Rep3BrilligType::Shared(Shared::<F>::Ring8(limb_bit));
}
else {

limbs[i] = Rep3BrilligType::Shared(Shared::<F>::Ring8(limb[0]));
};
let limb_2b = a2b(limb[0], &mut self.io_context)?;
let limb_bit = rep3_ring::conversion::bit_inject(
&limb_2b,
&mut self.io_context,
)?;
limbs[i] = Rep3BrilligType::Shared(Shared::<F>::Ring8(limb_bit));
} else {
limbs[i] = Rep3BrilligType::Shared(Shared::<F>::Ring8(limb[0]));
};
input = div;
}
limbs
Expand Down
49 changes: 25 additions & 24 deletions co-noir/co-builder/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi
}
}
}

// decomposes the shared values in batches, separated into the corresponding number of bits the values have
#[expect(clippy::type_complexity)]
fn prepare_for_range_decompose(
Expand All @@ -649,19 +650,15 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi
for constraint in range_constraints.iter() {
let val = &self.get_variable(constraint.witness as usize);

if constraint.num_bits > Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u32 && T::is_shared(val)
{
let num_bits = constraint.num_bits;

let num_bits = constraint.num_bits;
if num_bits > Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u32 && T::is_shared(val) {
let share_val = T::get_shared(val).expect("Already checked it is shared");
if let Some(&idx) = bits_locations.get(&num_bits) {
to_decompose[idx]
.push(T::get_shared(val).expect("Already checked it is shared"));
to_decompose[idx].push(share_val);
decompose_indices.push((true, to_decompose[idx].len() - 1));
} else {
let new_idx = to_decompose.len();
to_decompose.push(vec![
T::get_shared(val).expect("Already checked it is shared")
]);
to_decompose.push(vec![share_val]);
decompose_indices.push((true, 0));
bits_locations.insert(num_bits, new_idx);
}
Expand All @@ -670,17 +667,17 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi
}
}

let mut decomposed: Vec<Vec<Vec<_>>> = Vec::with_capacity(to_decompose.len());
let mut decomposed = Vec::with_capacity(to_decompose.len());

for (i, inp) in to_decompose.iter().enumerate() {
for (i, inp) in to_decompose.into_iter().enumerate() {
let num_bits = bits_locations
.iter()
.find_map(|(&key, &value)| if value == i { Some(key) } else { None })
.expect("Index not found in bitsloc");

decomposed.push(T::decompose_arithmetic_many(
driver,
inp,
&inp,
num_bits as usize,
Self::DEFAULT_PLOOKUP_RANGE_BITNUM,
)?);
Expand Down Expand Up @@ -775,23 +772,24 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi
// todo!("Logic gates");
// }

// We want to decompose all shared elements in parallel
let (bits_locations, decomposed, decompose_indices) =
self.prepare_for_range_decompose(driver, &constraint_system.range_constraints)?;

for (i, constraint) in constraint_system.range_constraints.iter().enumerate() {
if let Some(&idx) = bits_locations.get(&constraint.num_bits) {
if decompose_indices[i].0 {
self.decompose_into_default_range(
driver,
constraint.witness,
constraint.num_bits as u64,
Some(&decomposed[idx][decompose_indices[i].1]),
Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u64,
)?;
} else {
self.create_range_constraint(driver, constraint.witness, constraint.num_bits)?;
}
let idx_option = bits_locations.get(&constraint.num_bits);
if idx_option.is_some() && decompose_indices[i].0 {
// Already decomposed
let idx = idx_option.unwrap().to_owned();
self.decompose_into_default_range(
driver,
constraint.witness,
constraint.num_bits as u64,
Some(&decomposed[idx][decompose_indices[i].1]),
Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u64,
)?;
} else {
// Either we do not have to decompose or the value is public
self.create_range_constraint(driver, constraint.witness, constraint.num_bits)?;
}

Expand Down Expand Up @@ -2419,6 +2417,7 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi

self.create_new_range_constraint(variable_index, (1u64 << num_bits) - 1);
} else {
// The value must be public, otherwise it would have been batch decomposed already
self.decompose_into_default_range(
driver,
variable_index,
Expand Down Expand Up @@ -2529,11 +2528,13 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi

let mut sublimb_indices: Vec<u32> = Vec::with_capacity(num_limbs as usize);
let sublimbs: Vec<T::AcvmType> = match decompose {
// Already decomposed, i.e., we just take the values
Some(decomposed) => decomposed
.iter()
.map(|item| T::AcvmType::from(item.clone()))
.collect(),
None => {
// Not yet decomposed, i.e., it was a public value
let mut accumulator: BigUint = T::get_public(&val)
.expect("Already checked it is public")
.into();
Expand Down
2 changes: 1 addition & 1 deletion co-noir/co-noir/examples/test_vectors/ranges/Nargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
name = "ranges"
type = "bin"
authors = [""]
compiler_version = ">=0.38.0"
compiler_version = ">=1.0.0"

[dependencies]
2 changes: 1 addition & 1 deletion co-noir/co-noir/examples/test_vectors/ranges/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
fn main(x: u64,w: u16,u: u32, y: pub u64, z: pub u64, s: u32, t: u16) -> pub (u64, u32, u16) {
fn main(x: u64, w: u16, u: u32, y: pub u64, z: pub u64, s: u32, t: u16) -> pub (u64, u32, u16) {
(x + y + z, u * s, w * t)
}
Loading

0 comments on commit 8608e24

Please sign in to comment.