Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/batch decompose #299

Merged
merged 17 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion co-noir/co-acvm/src/mpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,18 @@ pub trait NoirWitnessExtensionProtocol<F: PrimeField> {
fn decompose_arithmetic(
&mut self,
input: Self::ArithmeticShare,
// io_context: &mut IoContext<N>,
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> std::io::Result<Vec<Self::ArithmeticShare>>;

/// Decompose a shared value into a vector of shared values: \[a\] = a_1 + a_2 + ... + a_n. Each value a_i has at most decompose_bit_size bits, whereas the total bit size of the shares is total_bit_size_per_field. Thus, a_n, might have a smaller bitsize than the other chunks
fn decompose_arithmetic_many(
&mut self,
input: &[Self::ArithmeticShare],
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> std::io::Result<Vec<Vec<Self::ArithmeticShare>>>;

/// Sorts a vector of shared values in ascending order, only considering the first bitsize bits.
fn sort(
&mut self,
Expand Down
14 changes: 14 additions & 0 deletions co-noir/co-acvm/src/mpc/plain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ impl<F: PrimeField> NoirWitnessExtensionProtocol<F> for PlainAcvmSolver<F> {
Ok(result)
}

fn decompose_arithmetic_many(
&mut self,
input: &[Self::ArithmeticShare],
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> std::io::Result<Vec<Vec<Self::ArithmeticShare>>> {
input
.iter()
.map(|&inp| {
Self::decompose_arithmetic(self, inp, total_bit_size_per_field, decompose_bit_size)
})
.collect()
}

fn sort(
&mut self,
inputs: &[Self::ArithmeticShare],
Expand Down
26 changes: 26 additions & 0 deletions co-noir/co-acvm/src/mpc/rep3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,32 @@ impl<F: PrimeField, N: Rep3Network> NoirWitnessExtensionProtocol<F> for Rep3Acvm
)
}

fn decompose_arithmetic_many(
&mut self,
input: &[Self::ArithmeticShare],
total_bit_size_per_field: usize,
decompose_bit_size: usize,
) -> std::io::Result<Vec<Vec<Self::ArithmeticShare>>> {
// Defines an upper bound on the size of the input vector to keep the GC at a reasonable size (for RAM)
const BATCH_SIZE: usize = 512; // TODO adapt this if it requires too much RAM

let num_decomps_per_field = total_bit_size_per_field.div_ceil(decompose_bit_size);
let mut results = Vec::with_capacity(input.len());

for inp_chunk in input.chunks(BATCH_SIZE) {
let result = yao::decompose_arithmetic_many(
inp_chunk,
&mut self.io_context,
total_bit_size_per_field,
decompose_bit_size,
)?;
for chunk in result.chunks(num_decomps_per_field) {
results.push(chunk.to_vec());
}
}
Ok(results)
}

fn sort(
&mut self,
inputs: &[Self::ArithmeticShare],
Expand Down
8 changes: 8 additions & 0 deletions co-noir/co-acvm/src/mpc/shamir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,14 @@ impl<F: PrimeField, N: ShamirNetwork> NoirWitnessExtensionProtocol<F> for Shamir
) -> std::io::Result<Vec<Self::ArithmeticShare>> {
panic!("functionality decompose_arithmetic not feasible for Shamir")
}
fn decompose_arithmetic_many(
&mut self,
_input: &[Self::ArithmeticShare],
_total_bit_size_per_field: usize,
_decompose_bit_size: usize,
) -> std::io::Result<Vec<Vec<Self::ArithmeticShare>>> {
panic!("functionality decompose_arithmetic_many not feasible for Shamir")
}

fn sort(
&mut self,
Expand Down
238 changes: 215 additions & 23 deletions co-noir/co-brillig/src/mpc/rep3.rs

Large diffs are not rendered by default.

123 changes: 99 additions & 24 deletions co-noir/co-builder/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use crate::{
AddQuad, AddTriple, AggregationObjectIndices, AggregationObjectPubInputIndices,
AuxSelectors, BlockConstraint, BlockType, CachedPartialNonNativeFieldMultiplication,
ColumnIdx, FieldCT, GateCounter, MulQuad, PlookupBasicTable, PolyTriple, RamTranscript,
RangeList, ReadData, RomRecord, RomTable, RomTranscript, UltraTraceBlock,
UltraTraceBlocks, NUM_WIRES,
RangeConstraint, RangeList, ReadData, RomRecord, RomTable, RomTranscript,
UltraTraceBlock, UltraTraceBlocks, NUM_WIRES,
},
},
utils::Utils,
Expand All @@ -23,7 +23,7 @@ use ark_ec::pairing::Pairing;
use ark_ff::{One, Zero};
use co_acvm::{mpc::NoirWitnessExtensionProtocol, PlainAcvmSolver};
use num_bigint::BigUint;
use std::collections::BTreeMap;
use std::collections::{BTreeMap, HashMap};

type GateBlocks<F> = UltraTraceBlocks<UltraTraceBlock<F>>;

Expand Down Expand Up @@ -632,6 +632,59 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi
}
}

// decomposes the shared values in batches, separated into the corresponding number of bits the values have
#[expect(clippy::type_complexity)]
fn prepare_for_range_decompose(
&mut self,
driver: &mut T,
range_constraints: &[RangeConstraint],
) -> std::io::Result<(
HashMap<u32, usize>,
Vec<Vec<Vec<T::ArithmeticShare>>>,
Vec<(bool, usize)>,
)> {
let mut to_decompose: Vec<Vec<T::ArithmeticShare>> = vec![];
let mut decompose_indices: Vec<(bool, usize)> = vec![];
let mut bits_locations: HashMap<u32, usize> = HashMap::new();

for constraint in range_constraints.iter() {
let val = &self.get_variable(constraint.witness as usize);

let num_bits = constraint.num_bits;
if num_bits > Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u32 && T::is_shared(val) {
let share_val = T::get_shared(val).expect("Already checked it is shared");
if let Some(&idx) = bits_locations.get(&num_bits) {
to_decompose[idx].push(share_val);
decompose_indices.push((true, to_decompose[idx].len() - 1));
} else {
let new_idx = to_decompose.len();
to_decompose.push(vec![share_val]);
decompose_indices.push((true, 0));
bits_locations.insert(num_bits, new_idx);
}
} else {
decompose_indices.push((false, 0));
}
}

let mut decomposed = Vec::with_capacity(to_decompose.len());

for (i, inp) in to_decompose.into_iter().enumerate() {
let num_bits = bits_locations
.iter()
.find_map(|(&key, &value)| if value == i { Some(key) } else { None })
.expect("Index not found in bitsloc");

decomposed.push(T::decompose_arithmetic_many(
driver,
&inp,
num_bits as usize,
Self::DEFAULT_PLOOKUP_RANGE_BITNUM,
)?);
}
Ok((bits_locations, decomposed, decompose_indices))
}

fn build_constraints(
&mut self,
driver: &mut T,
Expand Down Expand Up @@ -719,8 +772,27 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi
// todo!("Logic gates");
// }

// We want to decompose all shared elements in parallel
let (bits_locations, decomposed, decompose_indices) =
self.prepare_for_range_decompose(driver, &constraint_system.range_constraints)?;

for (i, constraint) in constraint_system.range_constraints.iter().enumerate() {
self.create_range_constraint(driver, constraint.witness, constraint.num_bits)?;
let idx_option = bits_locations.get(&constraint.num_bits);
if idx_option.is_some() && decompose_indices[i].0 {
// Already decomposed
let idx = idx_option.unwrap().to_owned();
self.decompose_into_default_range(
driver,
constraint.witness,
constraint.num_bits as u64,
Some(&decomposed[idx][decompose_indices[i].1]),
Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u64,
)?;
} else {
// Either we do not have to decompose or the value is public
self.create_range_constraint(driver, constraint.witness, constraint.num_bits)?;
}

gate_counter.track_diff(
self,
&mut constraint_system.gates_per_opcode,
Expand Down Expand Up @@ -2345,10 +2417,12 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi

self.create_new_range_constraint(variable_index, (1u64 << num_bits) - 1);
} else {
// The value must be public, otherwise it would have been batch decomposed already
self.decompose_into_default_range(
driver,
variable_index,
num_bits as u64,
None,
Self::DEFAULT_PLOOKUP_RANGE_BITNUM as u64,
)?;
}
Expand Down Expand Up @@ -2420,12 +2494,14 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi
driver: &mut T,
variable_index: u32,
num_bits: u64,
decompose: Option<&[T::ArithmeticShare]>,
target_range_bitnum: u64,
) -> std::io::Result<Vec<u32>> {
assert!(self.is_valid_variable(variable_index as usize));

assert!(num_bits > 0);
let val = self.get_variable(variable_index as usize);

// We cannot check that easily in MPC:
// If the value is out of range, set the composer error to the given msg.
// if val.msb() >= num_bits && !self.failed() {
Expand All @@ -2451,28 +2527,27 @@ impl<P: Pairing, T: NoirWitnessExtensionProtocol<P::ScalarField>> GenericUltraCi
let last_limb_range = (1u64 << last_limb_size) - 1;

let mut sublimb_indices: Vec<u32> = Vec::with_capacity(num_limbs as usize);
let sublimbs = if T::is_shared(&val) {
let decomp = T::decompose_arithmetic(
driver,
T::get_shared(&val).expect("Already checked it is shared"),
num_bits as usize,
target_range_bitnum as usize,
)?;
decomp.into_iter().map(T::AcvmType::from).collect()
} else {
let mut sublimbs = Vec::with_capacity(num_limbs as usize);
let mut accumulator: BigUint = T::get_public(&val)
.expect("Already checked it is public")
.into();
for _ in 0..num_limbs {
let sublimb_value = P::ScalarField::from(&accumulator & &sublimb_mask.into());
sublimbs.push(T::AcvmType::from(sublimb_value));
accumulator >>= target_range_bitnum;
let sublimbs: Vec<T::AcvmType> = match decompose {
// Already decomposed, i.e., we just take the values
Some(decomposed) => decomposed
.iter()
.map(|item| T::AcvmType::from(item.clone()))
.collect(),
None => {
// Not yet decomposed, i.e., it was a public value
let mut accumulator: BigUint = T::get_public(&val)
.expect("Already checked it is public")
.into();
let sublimb_mask: BigUint = sublimb_mask.into();
(0..num_limbs)
.map(|_| {
let sublimb_value = P::ScalarField::from(&accumulator & &sublimb_mask);
accumulator >>= target_range_bitnum;
T::AcvmType::from(sublimb_value)
})
.collect()
}

sublimbs
};

for (i, sublimb) in sublimbs.iter().enumerate() {
let limb_idx = self.add_variable(sublimb.clone());

Expand Down
16 changes: 16 additions & 0 deletions co-noir/co-noir/examples/run_full_ranges.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# split input into shares
cargo run --release --bin co-noir -- split-input --circuit test_vectors/ranges/ranges.json --input test_vectors/ranges/Prover.toml --protocol REP3 --out-dir test_vectors/ranges
# run witness extension in MPC
cargo run --release --bin co-noir -- generate-witness --input test_vectors/ranges/Prover.toml.0.shared --circuit test_vectors/ranges/ranges.json --protocol REP3 --config configs/party1.toml --out test_vectors/ranges/ranges.gz.0.shared &
cargo run --release --bin co-noir -- generate-witness --input test_vectors/ranges/Prover.toml.1.shared --circuit test_vectors/ranges/ranges.json --protocol REP3 --config configs/party2.toml --out test_vectors/ranges/ranges.gz.1.shared &
cargo run --release --bin co-noir -- generate-witness --input test_vectors/ranges/Prover.toml.2.shared --circuit test_vectors/ranges/ranges.json --protocol REP3 --config configs/party3.toml --out test_vectors/ranges/ranges.gz.2.shared
wait $(jobs -p)
# run proving in MPC
cargo run --release --bin co-noir -- build-and-generate-proof --witness test_vectors/ranges/ranges.gz.0.shared --circuit test_vectors/ranges/ranges.json --crs test_vectors/bn254_g1.dat --protocol REP3 --hasher KECCAK --config configs/party1.toml --out proof.0.proof --public-input public_input.json &
cargo run --release --bin co-noir -- build-and-generate-proof --witness test_vectors/ranges/ranges.gz.1.shared --circuit test_vectors/ranges/ranges.json --crs test_vectors/bn254_g1.dat --protocol REP3 --hasher KECCAK --config configs/party2.toml --out proof.1.proof &
cargo run --release --bin co-noir -- build-and-generate-proof --witness test_vectors/ranges/ranges.gz.2.shared --circuit test_vectors/ranges/ranges.json --crs test_vectors/bn254_g1.dat --protocol REP3 --hasher KECCAK --config configs/party3.toml --out proof.2.proof
wait $(jobs -p)
# Create verification key
cargo run --release --bin co-noir -- create-vk --circuit test_vectors/ranges/ranges.json --crs test_vectors/bn254_g1.dat --hasher KECCAK --vk test_vectors/ranges/verification_key
# verify proof
cargo run --release --bin co-noir -- verify --proof proof.0.proof --vk test_vectors/ranges/verification_key --hasher KECCAK --crs test_vectors/bn254_g2.dat
7 changes: 7 additions & 0 deletions co-noir/co-noir/examples/test_vectors/ranges/Nargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "ranges"
type = "bin"
authors = [""]
compiler_version = ">=1.0.0"

[dependencies]
7 changes: 7 additions & 0 deletions co-noir/co-noir/examples/test_vectors/ranges/Prover.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
x = "1"
w = "30"
u = "32"
y = "2"
z = "3"
s = "12"
t = "121"
Binary file not shown.
1 change: 1 addition & 0 deletions co-noir/co-noir/examples/test_vectors/ranges/ranges.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"noir_version":"1.0.0-beta.0+7311d8ca566c3b3e0744389fc5e4163741927767","hash":13046917000996441616,"abi":{"parameters":[{"name":"x","type":{"kind":"integer","sign":"unsigned","width":64},"visibility":"private"},{"name":"w","type":{"kind":"integer","sign":"unsigned","width":16},"visibility":"private"},{"name":"u","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"y","type":{"kind":"integer","sign":"unsigned","width":64},"visibility":"public"},{"name":"z","type":{"kind":"integer","sign":"unsigned","width":64},"visibility":"public"},{"name":"s","type":{"kind":"integer","sign":"unsigned","width":32},"visibility":"private"},{"name":"t","type":{"kind":"integer","sign":"unsigned","width":16},"visibility":"private"}],"return_type":{"abi_type":{"kind":"tuple","fields":[{"kind":"integer","sign":"unsigned","width":64},{"kind":"integer","sign":"unsigned","width":32},{"kind":"integer","sign":"unsigned","width":16}]},"visibility":"public"},"error_types":{"5019202896831570965":{"error_kind":"string","string":"attempt to add with overflow"},"7233212735005103307":{"error_kind":"string","string":"attempt to multiply with overflow"}}},"bytecode":"H4sIAAAAAAAA/81WTc7CIBC1pfp9/saYuHDnEaCAhZ0uPIiNde8tPIB30JO40IWX8CB2dJpMSHUjJH3JhOZNeYVXBohabwzLmOBzVAbDFrB0OIixw8VlzB2O1fRNarh2Td8OfoOCkfEA+G8QPrUY0ZJ8oVSRpYWQYsNTmxvNlc4XRhihjd6mRsrCKJPZ3GbcCiULsdNW7lCs529cnHpKdUN5mnjU6nn0tB/IU6obefYyxtpwdX/1YhDIC9CdB/Ki2hN8ezEM5AXouvtX7NmTv4bWR8g5/ze0DkLOudvg9Z6gVpvMv7oPANchfjB8vzpzYP3C/+wSnRZygOljNdtfT2uSep1hn3IDbG/H++V82GxpbvQlB3gCOAMWOQgJAAA=","debug_symbols":"ldFNCoMwEAXgu8zahYlttV6lFIkaJRCSkJ9CCd69iTRFglCymzePbzbjYaajWwcmFmmgf3jgciKWSRGSh3ZfGUVETMYSbaFHNa6AijlO7VbBwjiF/tZtzwq6UnD/AxDKAKqLBToVqPmJLhe4WDSnAtdJ4GsuLmUihFEzztk6HF8U1i+iGRk5/cbFienQ2rdKTfJKy4nOTtN4ae/C+Q8=","file_map":{"68":{"source":"fn main(x: u64,w: u16,u: u32, y: pub u64, z: pub u64, s: u32, t: u16) -> pub (u64, u32, u16) {\n (x + y + z, u * s, w * t)\n}\n","path":"/home/fabsits/co-snarks/co-noir/co-noir/examples/test_vectors/ranges/src/main.nr"}},"names":["main"],"brillig_names":[]}
3 changes: 3 additions & 0 deletions co-noir/co-noir/examples/test_vectors/ranges/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
fn main(x: u64, w: u16, u: u32, y: pub u64, z: pub u64, s: u32, t: u16) -> pub (u64, u32, u16) {
(x + y + z, u * s, w * t)
}
Loading
Loading