diff --git a/plonk-hashing/Cargo.toml b/plonk-hashing/Cargo.toml index d3532515..c8fcb942 100644 --- a/plonk-hashing/Cargo.toml +++ b/plonk-hashing/Cargo.toml @@ -22,3 +22,17 @@ std = [] [dependencies] plonk-core = { path = "../plonk-core" } +ark-ec = { version = "0.3", features = ["std"] } +ark-ff = { version = "0.3", features = ["std"] } +ark-serialize = { version = "0.3", features = ["derive"] } +ark-poly = "0.3" +ark-poly-commit = "0.3" +ark-crypto-primitives = { version = "^0.3.0", features = ["r1cs"], default-features = false } +ark-std = { version = "^0.3.0", default-features = false } +itertools = { version = "0.10.1", default-features = false } +num-traits = "0.2.14" +derivative = { version = "2.2.0", default-features = false, features = ["use_core"] } +hashbrown = { version = "0.11.2", default-features = false, features = ["ahash"] } +ark-relations = "0.3.0" +ark-r1cs-std = "0.3.1" +thiserror = "1.0.30" diff --git a/plonk-hashing/src/lib.rs b/plonk-hashing/src/lib.rs index 3dabf609..b30d1a61 100644 --- a/plonk-hashing/src/lib.rs +++ b/plonk-hashing/src/lib.rs @@ -6,7 +6,11 @@ //! PLONK Hashing Library -#![cfg_attr(not(any(feature = "std", test)), no_std)] +// #![cfg_attr(not(any(feature = "std", test)), no_std)] #![cfg_attr(doc_cfg, feature(doc_cfg))] #![forbid(rustdoc::broken_intra_doc_links)] -#![forbid(missing_docs)] +// #![forbid(missing_docs)] + +pub extern crate alloc; + +pub mod poseidon; diff --git a/plonk-hashing/src/poseidon/constants.rs b/plonk-hashing/src/poseidon/constants.rs new file mode 100644 index 00000000..5ae428f6 --- /dev/null +++ b/plonk-hashing/src/poseidon/constants.rs @@ -0,0 +1,79 @@ +use crate::poseidon::{ + matrix::Matrix, + mds::{factor_to_sparse_matrixes, MdsMatrices, SparseMatrix}, + preprocessing::compress_round_constants, + round_constant::generate_constants, + round_numbers::calc_round_numbers, +}; +use ark_ff::PrimeField; + +#[derive(Clone, Debug, PartialEq)] +pub struct PoseidonConstants { + pub mds_matrices: MdsMatrices, + pub round_constants: Vec, + pub compressed_round_constants: Vec, + pub pre_sparse_matrix: Matrix, + pub sparse_matrixes: Vec>, + pub domain_tag: F, + pub full_rounds: usize, + pub half_full_rounds: usize, + pub partial_rounds: usize, +} + +impl PoseidonConstants { + /// Generate all constants needed for poseidon hash of specified + /// width. Note that WIDTH = ARITY + 1 + pub fn generate() -> Self { + let arity = WIDTH - 1; + let mds_matrices = MdsMatrices::new(WIDTH); + let (num_full_rounds, num_partial_rounds) = + calc_round_numbers(WIDTH, true); + + debug_assert_eq!(num_full_rounds % 2, 0); + let num_half_full_rounds = num_full_rounds / 2; + let round_constants = generate_constants( + 1, // prime field + 1, // sbox + F::size_in_bits() as u16, + WIDTH.try_into().expect("WIDTH is too large"), + num_full_rounds + .try_into() + .expect("num_full_rounds is too large"), + num_partial_rounds + .try_into() + .expect("num_partial_rounds is too large"), + ); + let domain_tag = F::from(((1 << arity) - 1) as u64); + + let compressed_round_constants = compress_round_constants( + WIDTH, + num_full_rounds, + num_partial_rounds, + &round_constants, + &mds_matrices, + ); + + let (pre_sparse_matrix, sparse_matrixes) = factor_to_sparse_matrixes( + mds_matrices.m.clone(), + num_partial_rounds, + ); + + assert!( + WIDTH * (num_full_rounds + num_partial_rounds) + <= round_constants.len(), + "Not enough round constants" + ); + + PoseidonConstants { + mds_matrices, + round_constants, + domain_tag, + full_rounds: num_full_rounds, + half_full_rounds: num_half_full_rounds, + partial_rounds: num_partial_rounds, + compressed_round_constants, + pre_sparse_matrix, + sparse_matrixes, + } + } +} diff --git a/plonk-hashing/src/poseidon/constraints.rs b/plonk-hashing/src/poseidon/constraints.rs new file mode 100644 index 00000000..43decc9f --- /dev/null +++ b/plonk-hashing/src/poseidon/constraints.rs @@ -0,0 +1,473 @@ +//! Library independent specification for field trait. +//! `COM` is `()` when the field is in native, and is constraint synthesizer +//! when the field is a variable. +//! +//! This file is adapted from manta-rs, but is tweaked to better support PLONK proving system. +//! Some essential features (like allocating unknown variable) are missing. +//! +//! The initial goal is to simultaneous support PLONK, R1CS-std, and native. + +/// Compiler constant. +pub trait Constant +where COM: ?Sized +{ + type Type; + fn new_constant(c: &mut COM, value: &Self::Type) -> Self; +} +/// Constant Type Alias +pub type Const = >::Type; + +/// Compiler variable. +pub trait Variable +where COM: ?Sized { + type Type; + // we can only make it unknown (the witness), because PLONK does not support arbitrary public input. + fn new_variable(c: &mut COM, value: &Self::Type) -> Self; +} +/// Variable Type Alias +pub type Var = >::Type; + +/// Value-source auto-trait +pub trait ValueSource +where COM: ?Sized +{ + /// Allocates `self` as a constant in `compiler`. + fn as_constant(&self, compiler: &mut COM) -> C + where + C: Constant, + { + C::new_constant(compiler, self) + } + + /// Allocates `self` as a known value in `compiler`. + fn as_variable(&self, compiler: &mut COM) -> V + where + V: Variable, + { + V::new_variable(compiler, self) + } +} + +impl ValueSource for T where T: ?Sized {} + +/// Allocator Auto-Trait +pub trait Allocator { + /// Allocates a constant with the given `value` into `self`. + #[inline] + fn allocate_constant(&mut self, value: &C::Type) -> C + where + C: Constant, + { + C::new_constant(self, value) + } + + /// Allocates a known variable with the given `value` into `self`. + #[inline] + fn allocate_variable(&mut self, value: &V::Type) -> V + where + V: Variable, + { + V::new_variable(self, value) + } +} + +impl Allocator for COM where COM: ?Sized {} + +/// Addition trait for variable +pub trait COMAdd: Variable + where + COM: ?Sized, +{ + type Constant: Constant; + /// Adds `self` and `rhs` inside of `compiler`. + fn com_add(&self, rhs: &Self, compiler: &mut COM) -> Self; +} + +/// Subtraction +pub trait COMSub: Sized + where + COM: ?Sized, +{ + type Output; + /// Subtracts `rhs` from `self` inside of `compiler`. + fn com_sub(&self, rhs: &Self, compiler: &mut COM) -> Self; +} + +/// Constraint System trait for compiler +pub trait ConstraintSystem { + /// Boolean Variable type + type Bool; + + /// Asserts that `b == 1`. + fn assert(&mut self, b: Self::Bool); + + /// Asserts that all the booleans in `iter` are equal to `1`. + #[inline] + fn assert_all(&mut self, iter: I) + where + I: IntoIterator, + { + iter.into_iter().for_each(move |b| self.assert(b)); + } + + /// Generates a boolean that represents the fact that `lhs` and `rhs` may be equal. + #[inline] + fn eq(&mut self, lhs: &V, rhs: &V) -> Self::Bool + where + V: Equal, + { + V::eq(lhs, rhs, self) + } + + /// Asserts that `lhs` and `rhs` are equal. + #[inline] + fn assert_eq(&mut self, lhs: &V, rhs: &V) + where + V: Equal, + { + V::assert_eq(lhs, rhs, self); + } +} + +/// Equality Trait +pub trait Equal + where + COM: ConstraintSystem + ?Sized, +{ + /// Generates a boolean that represents the fact that `lhs` and `rhs` may be equal. + fn eq(lhs: &Self, rhs: &Self, compiler: &mut COM) -> COM::Bool; + + /// Asserts that `lhs` and `rhs` are equal. + #[inline] + fn assert_eq(lhs: &Self, rhs: &Self, compiler: &mut COM) { + let boolean = Self::eq(lhs, rhs, compiler); + compiler.assert(boolean); + } + + /// Asserts that all the elements in `iter` are equal to some `base` element. + #[inline] + fn assert_all_eq_to_base<'t, I>(base: &'t Self, iter: I, compiler: &mut COM) + where + I: IntoIterator, + { + for item in iter { + Self::assert_eq(base, item, compiler); + } + } + + /// Asserts that all the elements in `iter` are equal. + #[inline] + fn assert_all_eq<'t, I>(iter: I, compiler: &mut COM) + where + Self: 't, + I: IntoIterator, + { + let mut iter = iter.into_iter(); + if let Some(base) = iter.next() { + Self::assert_all_eq_to_base(base, iter, compiler); + } + } + +} + + + + +// /// Basic Arithmetic Operations. Both constants and variables should implement this trait. +// pub trait COMArith: Sized + Clone + Debug { +// // I added `com_` prefix here to avoid conflict with num_traits. Any suggestion +// // is welcome! +// // TODO: do we want to split `add`, `neg`, `mul` to different traits? +// /// additive identity +// fn com_zero(c: &mut COM) -> Self; +// fn zeros(c: &mut COM) -> [Self; SIZE]; +// /// add two field elements +// fn com_add(&self, c: &mut COM, b: &Self) -> Self; +// /// the additive inverse of a field element +// fn com_neg(&self, c: &mut COM) -> Self; +// /// multiply two field elements +// fn com_mul(&self, c: &mut COM, other: &Self) -> Self; +// fn com_square(&self, c: &mut COM) -> Self { +// self.com_mul(c, self) +// } +// fn com_add_assign(&mut self, c: &mut COM, other: &Self) { +// *self = self.com_add(c, other); +// } +// fn com_mul_assign(&mut self, c: &mut COM, other: &Self) { +// *self = self.com_mul(c, other); +// } +// } +// +// pub trait COMFromInt: COMArith { +// /// multiplicative identity +// fn com_one(c: &mut COM) -> Self; +// fn com_from_const_int(c: &mut COM, v: u64) -> Self; +// } +// +// pub trait COMPower: COMArith { +// type Scalar; +// fn com_pow(&self, c: &mut COM, exp: &Self::Scalar) -> Self; +// } +// +// pub trait COMArithExt: COMArith + Sized { +// fn __make_arith_gate(c: &mut COM, config: ArithExtBuilder) -> Self; +// /// `(w_l * w_r) * q_m + a * q_l + b * q_r + w_4 * q_4 + q_c + PI + q_o * c = 0` +// /// where output is `c` +// fn com_arith(c: &mut COM) -> ArithExtBuilder { +// ArithExtBuilder::new(c) +// } +// } +// +// pub type NativeField = >::Native; +// +// /// `(w_l * w_r) * q_m + a * q_l + b * q_r + w_4 * q_4 + q_c + PI + q_o * c = 0` +// /// where output is `c` +// pub struct ArithExtBuilder, COM = ()> { +// w_l: F, +// w_r: F, +// q_m: F::Native, +// q_l: F::Native, +// q_r: F::Native, +// q_c: F::Native, +// q_o: F::Native, +// q_4_w_4: Option<(F::Native, F)>, +// pi: Option, +// _compiler: PhantomData, +// } +// +// impl, COM> ArithExtBuilder { +// pub(crate) fn new(c: &mut COM) -> Self { +// Self { +// w_l: F::com_zero(c), +// w_r: F::com_zero(c), +// q_m: F::Native::com_zero(&mut ()), +// q_l: F::Native::com_zero(&mut ()), +// q_r: F::Native::com_zero(&mut ()), +// q_c: F::Native::com_zero(&mut ()), +// q_o: F::Native::com_one(&mut ()).com_neg(&mut ()), +// q_4_w_4: None, +// pi: None, +// _compiler: PhantomData, +// } +// } +// +// pub fn w_l(mut self, w_l: F) -> Self { +// self.w_l = w_l; +// self +// } +// +// pub fn w_r(mut self, w_r: F) -> Self { +// self.w_r = w_r; +// self +// } +// +// pub fn witness(mut self, w_l: F, w_r: F) -> Self { +// self.w_l = w_l; +// self.w_r = w_r; +// self +// } +// +// pub fn q_m(mut self, q_m: F::Native) -> Self { +// self.q_m = q_m; +// self +// } +// +// pub fn mul(mut self) -> Self { +// self.q_m(F::Native::com_one(&mut ())) +// } +// +// pub fn q_l(mut self, q_l: F::Native) -> Self { +// self.q_l = q_l; +// self +// } +// +// pub fn q_r(mut self, q_r: F::Native) -> Self { +// self.q_r = q_r; +// self +// } +// +// pub fn q_c(mut self, q_c: F::Native) -> Self { +// self.q_c = q_c; +// self +// } +// +// pub fn q4w4(mut self, q_4_w_4: (F::Native, F)) -> Self { +// self.q_4_w_4 = Some(q_4_w_4); +// self +// } +// +// pub fn pi(mut self, pi: F::PublicInput) -> Self { +// self.pi = Some(pi); +// self +// } +// +// pub fn q_o(mut self, q_o: F::Native) -> Self { +// self.q_o = q_o; +// self +// } +// +// pub fn build(self, c: &mut COM) -> F { +// F::__make_arith_gate(c, self) +// } +// } +// +// impl COMArith<()> for F { +// type Native = F; +// type PublicInput = F; +// fn com_zero(_c: &mut ()) -> Self { +// F::zero() +// } +// +// fn com_alloc(_c: &mut (), v: Self::Native) -> Self { +// v +// } +// +// fn zeros(_c: &mut ()) -> [Self; SIZE] { +// [F::zero(); SIZE] +// } +// +// fn com_add(&self, _c: &mut (), b: &Self) -> Self { +// *self + *b +// } +// +// fn com_addi(&self, _c: &mut (), b: &Self::Native) -> Self { +// *self + *b +// } +// +// fn com_neg(&self, _c: &mut ()) -> Self { +// -*self +// } +// +// fn com_mul(&self, _c: &mut (), other: &Self) -> Self { +// *self * *other +// } +// +// fn com_muli(&self, _c: &mut (), other: &Self::Native) -> Self { +// *self * *other +// } +// } +// +// impl COMFromInt<()> for F { +// fn com_one(_c: &mut ()) -> Self { +// F::one() +// } +// +// fn com_from_const_int(c: &mut (), v: u64) -> Self { +// F::from(v) +// } +// } +// +// impl COMArithExt<()> for F { +// fn __make_arith_gate(_c: &mut (), config: ArithExtBuilder) -> Self { +// let mut result = F::zero(); +// result += (config.w_l * config.w_r) * config.q_m; +// result += config.q_l * config.w_l; +// result += config.q_r * config.w_r; +// result += config.q_4_w_4.map_or(F::zero(), |(q_4, w_4)| q_4 * w_4); +// result += config.q_c; +// result += config.pi.unwrap_or(F::zero()); +// +// // now result = - q_o * c, we want c = (-result) / q_o +// let q_o_inv = F::inverse(&config.q_o).unwrap(); +// (-result) * q_o_inv +// } +// } +// +// impl COMArith> for Variable +// where +// E: PairingEngine, +// P: TEModelParameters, +// { +// type Native = E::Fr; +// type PublicInput = E::Fr; +// +// fn com_zero(c: &mut StandardComposer) -> Self { +// c.zero_var() +// } +// +// fn com_alloc(c: &mut StandardComposer, v: Self::Native) -> Self { +// c.add_input(v) +// } +// +// fn zeros(c: &mut StandardComposer) -> [Self; SIZE] { +// [c.zero_var(); SIZE] +// } +// +// fn com_add(&self, c: &mut StandardComposer, b: &Self) -> Self { +// c.arithmetic_gate(|g| g.witness(*self, *b, None).add(E::Fr::one(), E::Fr::one())) +// } +// +// fn com_addi(&self, c: &mut StandardComposer, b: &Self::Native) -> Self { +// let zero = c.zero_var(); +// c.arithmetic_gate(|g| { +// g.witness(*self, zero, None) +// .add(E::Fr::one(), E::Fr::zero()) +// .constant(*b) +// }) +// } +// +// fn com_neg(&self, c: &mut StandardComposer) -> Self { +// let zero = c.zero_var(); +// c.arithmetic_gate(|g| g.witness(*self, zero, None).out(E::Fr::one())) +// } +// +// fn com_mul(&self, c: &mut StandardComposer, other: &Self) -> Self { +// c.arithmetic_gate(|g| g.witness(*self, *other, None).mul(E::Fr::one())) +// } +// +// fn com_muli(&self, c: &mut StandardComposer, other: &Self::Native) -> Self { +// let zero = c.zero_var(); +// c.arithmetic_gate(|g| g.witness(*self, zero, None).add(*other, E::Fr::zero())) +// } +// } +// +// impl COMArithExt> for Variable +// where +// E: PairingEngine, +// P: TEModelParameters, +// { +// fn __make_arith_gate( +// c: &mut StandardComposer, +// config: ArithExtBuilder>, +// ) -> Self { +// c.arithmetic_gate(|g| { +// g.witness(config.w_l, config.w_r, None).mul(config.q_m); +// g.add(config.q_l, config.q_r); +// if let Some((q_4, w_4)) = config.q_4_w_4 { +// g.fan_in_3(q_4, w_4); +// }; +// g.constant(config.q_c); +// g.out(config.q_o); +// if let Some(pi) = config.pi { +// g.constant(pi); +// } +// g +// }) +// } +// } +// +// #[cfg(test)] +// mod tests { +// use super::*; +// use ark_bls12_381::Fr; +// use ark_std::{test_rng, UniformRand}; +// +// #[test] +// fn sanity_check_on_native() { +// // calculate 3xy + 2x + y + 1 +// let mut rng = test_rng(); +// let x = Fr::rand(&mut rng); +// let y = Fr::rand(&mut rng); +// let expected = (Fr::from(3u64) * x * y) + (Fr::from(2u64) * x) + y + Fr::one(); +// let actual = Fr::com_arith(&mut ()) +// .w_l(x) +// .w_r(y) +// .q_m(3u64.into()) +// .q_l(2u64.into()) +// .q_r(Fr::one()) +// .q_c(Fr::one()) +// .build(&mut ()); +// +// assert_eq!(expected, actual); +// } +// } diff --git a/plonk-hashing/src/poseidon/matrix.rs b/plonk-hashing/src/poseidon/matrix.rs new file mode 100644 index 00000000..34d4a720 --- /dev/null +++ b/plonk-hashing/src/poseidon/matrix.rs @@ -0,0 +1,693 @@ +//! acknowledgement: adapted from FileCoin Project: https://github.com/filecoin-project/neptune/blob/master/src/matrix.rs + +use ark_ff::PrimeField; +use core::ops::{Index, IndexMut}; + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct Matrix(pub Vec>); + +impl From>> for Matrix { + fn from(v: Vec>) -> Self { + Matrix(v) + } +} + +impl Matrix { + pub fn num_rows(&self) -> usize { + self.0.len() + } + + pub fn num_columns(&self) -> usize { + if self.0.is_empty() { + 0 + } else { + let column_length = self.0[0].len(); + for row in &self.0 { + if row.len() != column_length { + panic!("not a matrix"); + } + } + column_length + } + } + + pub fn iter_rows<'a>(&'a self) -> impl Iterator> { + self.0.iter() + } + + pub fn column(&self, column: usize) -> impl Iterator { + self.0.iter().map(move |row| &row[column]) + } + + pub fn is_square(&self) -> bool { + self.num_rows() == self.num_columns() + } + + pub fn transpose(&self) -> Matrix { + let size = self.num_rows(); + let mut new = Vec::with_capacity(size); + for j in 0..size { + let mut row = Vec::with_capacity(size); + for i in 0..size { + row.push(self.0[i][j].clone()) + } + new.push(row); + } + Matrix(new) + } +} + +impl Index for Matrix { + type Output = Vec; + + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } +} + +impl IndexMut for Matrix { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index] + } +} + +// from iterator rows +impl FromIterator> for Matrix { + fn from_iter>>(iter: T) -> Self { + let rows = iter.into_iter().collect::>(); + Self(rows) + } +} + +impl Matrix { + /// return an identity matrix of size `n*n` + pub fn identity(n: usize) -> Matrix { + let mut m = Matrix(vec![vec![F::zero(); n]; n]); + for i in 0..n { + m.0[i][i] = F::one(); + } + m + } + + pub fn is_identity(&self) -> bool { + if !self.is_square() { + return false; + } + for i in 0..self.num_rows() { + for j in 0..self.num_columns() { + if self.0[i][j] != kronecker_delta(i, j) { + return false; + } + } + } + true + } + + /// check if `self` is square and `self[1..][1..]` is identity + pub fn is_sparse(&self) -> bool { + self.is_square() && self.minor(0, 0).is_identity() + } + + pub fn mul_by_scalar(&self, scalar: F) -> Self { + let res = self + .0 + .iter() + .map(|row| { + row.iter() + .map(|val| { + let mut prod = scalar; + prod.mul_assign(val); + prod + }) + .collect::>() + }) + .collect::>(); + Matrix(res) + } + + /// return `self @ vec`, treating `vec` as a column vector. + pub fn mul_col_vec(&self, v: &[F]) -> Vec { + assert!( + self.is_square(), + "Only square matrix can be applied to vector." + ); + assert_eq!( + self.num_rows(), + v.len(), + "Matrix can only be applied to vector of same size." + ); + + let mut result = vec![F::zero(); v.len()]; + + for (result, row) in result.iter_mut().zip(self.0.iter()) { + for (mat_val, vec_val) in row.iter().zip(v) { + let mut tmp = *mat_val; + tmp.mul_assign(vec_val); + result.add_assign(&tmp); + } + } + result + } + + /// return `vec @ self`, treat `vec` as a row vector. + pub fn right_apply(&self, v: &[F]) -> Vec { + self.mul_row_vec_at_left(v) + } + + /// return `self @ vec`, treating `vec` as a column vector. + pub fn left_apply(&self, v: &[F]) -> Vec { + self.mul_col_vec(v) + } + + /// return `vec @ self`, treating `vec` as a row vector. + pub fn mul_row_vec_at_left(&self, v: &[F]) -> Vec { + assert!( + self.is_square(), + "Only square matrix can be applied to vector." + ); + assert_eq!( + self.num_rows(), + v.len(), + "Matrix can only be applied to vector of same size." + ); + + let mut result = vec![F::zero(); v.len()]; + for (j, val) in result.iter_mut().enumerate() { + for (i, row) in self.0.iter().enumerate() { + let mut tmp = row[j]; + tmp.mul_assign(&v[i]); + val.add_assign(&tmp); + } + } + result + } + + /// return `self @ other` + pub fn matmul(&self, other: &Self) -> Option { + if self.num_rows() != other.num_columns() { + return None; + }; + + let other_t = other.transpose(); + + let res = self + .0 + .iter() + .map(|input_row| { + other_t + .iter_rows() + .map(|transposed_column| { + inner_product(&input_row, &transposed_column) + }) + .collect() + }) + .collect(); + Some(Matrix(res)) + } + + pub fn invert(&self) -> Option { + let mut shadow = Self::identity(self.num_columns()); + let ut = self.upper_triangular(&mut shadow); + + ut.and_then(|x| x.reduce_to_identity(&mut shadow)) + .and(Some(shadow)) + } + + pub fn is_invertible(&self) -> bool { + self.is_square() && self.invert().is_some() + } + + pub fn minor(&self, i: usize, j: usize) -> Self { + assert!(self.is_square()); + let size = self.num_rows(); + assert!(size > 0); + let new: Vec> = self + .0 + .iter() + .enumerate() + .filter_map(|(ii, row)| { + if ii == i { + None + } else { + let mut new_row = row.clone(); + new_row.remove(j); + Some(new_row) + } + }) + .collect(); + let res = Matrix(new); + assert!(res.is_square()); + res + } + + /// Assumes matrix is partially reduced to upper triangular. `column` is the + /// column to eliminate from all rows. Returns `None` if either: + /// - no non-zero pivot can be found for `column` + /// - `column` is not the first + pub fn eliminate(&self, column: usize, shadow: &mut Self) -> Option { + let zero = F::zero(); + let pivot_index = (0..self.num_rows()).find(|&i| { + self[i][column] != zero && (0..column).all(|j| self[i][j] == zero) + })?; + + let pivot = &self[pivot_index]; + let pivot_val = pivot[column]; + + // This should never fail since we have a non-zero `pivot_val` if we got + // here. + let inv_pivot = pivot_val.inverse()?; + let mut result = Vec::with_capacity(self.num_rows()); + result.push(pivot.clone()); + + for (i, row) in self.iter_rows().enumerate() { + if i == pivot_index { + continue; + }; + + let val = row[column]; + if val == zero { + result.push(row.to_vec()); + } else { + let factor = val * inv_pivot; + let scaled_pivot = scalar_vec_mul(factor, &pivot); + let eliminated = vec_sub(row, &scaled_pivot); + result.push(eliminated); + + let shadow_pivot = &shadow[pivot_index]; + let scaled_shadow_pivot = scalar_vec_mul(factor, shadow_pivot); + let shadow_row = &shadow[i]; + shadow[i] = vec_sub(shadow_row, &scaled_shadow_pivot); + } + } + + let pivot_row = shadow.0.remove(pivot_index); + shadow.0.insert(0, pivot_row); + + Some(result.into()) + } + + /// Performs row operations to put a matrix in upper triangular form. + /// Each row operation is performed on `shadow` as well to keep track + /// of their cumulative effect. In other words, row operations are + /// performed on the augmented matrix [self | shadow ]. + pub fn upper_triangular(&self, shadow: &mut Self) -> Option { + assert!(self.is_square()); + let mut result = Vec::with_capacity(self.num_rows()); + let mut shadow_result = Vec::with_capacity(self.num_rows()); + + let mut curr = self.clone(); + let mut column = 0; + while curr.num_rows() > 1 { + let initial_rows = curr.num_rows(); + + curr = curr.eliminate(column, shadow)?; + result.push(curr[0].clone()); + shadow_result.push(shadow[0].clone()); + column += 1; + + curr = Matrix(curr.0[1..].to_vec()); + *shadow = Matrix(shadow.0[1..].to_vec()); + assert_eq!(curr.num_rows(), initial_rows - 1); + } + result.push(curr[0].clone()); + shadow_result.push(shadow[0].clone()); + + *shadow = Matrix(shadow_result); + + Some(Matrix(result)) + } + + /// Perform row operations to reduce `self` to the + /// identity matrix. `self` must be upper triangular. + /// All operations are performed also on `shadow` to track + /// their cumulative effect. + pub fn reduce_to_identity(&self, shadow: &mut Self) -> Option { + let size = self.num_rows(); + let mut result: Vec> = Vec::new(); + let mut shadow_result: Vec> = Vec::new(); + + for i in 0..size { + let idx = size - i - 1; + let row = &self.0[idx]; + let shadow_row = &shadow[idx]; + + let val = row[idx]; + let inv = val.inverse()?; + + let mut normalized = scalar_vec_mul(inv, &row); + let mut shadow_normalized = scalar_vec_mul(inv, &shadow_row); + + for j in 0..i { + let idx = size - j - 1; + let val = normalized[idx]; + let subtracted = scalar_vec_mul(val, &result[j]); + let result_subtracted = scalar_vec_mul(val, &shadow_result[j]); + + normalized = vec_sub(&normalized, &subtracted); + shadow_normalized = + vec_sub(&shadow_normalized, &result_subtracted); + } + + result.push(normalized); + shadow_result.push(shadow_normalized); + } + + result.reverse(); + shadow_result.reverse(); + + *shadow = Matrix(shadow_result); + Some(Matrix(result)) + } +} + +pub fn inner_product(a: &[F], b: &[F]) -> F { + a.iter().zip(b).fold(F::zero(), |mut acc, (v1, v2)| { + let mut tmp = *v1; + tmp.mul_assign(v2); + acc.add_assign(&tmp); + acc + }) +} + +pub fn vec_add(a: &[F], b: &[F]) -> Vec { + a.iter() + .zip(b.iter()) + .map(|(a, b)| { + let mut res = *a; + res.add_assign(b); + res + }) + .collect::>() +} + +pub fn vec_sub(a: &[F], b: &[F]) -> Vec { + a.iter() + .zip(b.iter()) + .map(|(a, b)| { + let mut res = *a; + res.sub_assign(b); + res + }) + .collect::>() +} + +fn scalar_vec_mul(scalar: F, v: &[F]) -> Vec { + v.iter() + .map(|val| { + let mut prod = scalar; + prod.mul_assign(val); + prod + }) + .collect::>() +} + +pub fn kronecker_delta(i: usize, j: usize) -> F { + if i == j { + F::one() + } else { + F::zero() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ark_ff::Zero; + + type Fr = ark_bls12_381::Fr; + + #[test] + fn test_minor() { + let one = Fr::from(1u64); + let two = Fr::from(2u64); + let three = Fr::from(3u64); + let four = Fr::from(4u64); + let five = Fr::from(5u64); + let six = Fr::from(6u64); + let seven = Fr::from(7u64); + let eight = Fr::from(8u64); + let nine = Fr::from(9u64); + + let m: Matrix<_> = vec![ + vec![one, two, three], + vec![four, five, six], + vec![seven, eight, nine], + ] + .into(); + + let cases = [ + (0, 0, Matrix(vec![vec![five, six], vec![eight, nine]])), + (0, 1, Matrix(vec![vec![four, six], vec![seven, nine]])), + (0, 2, Matrix(vec![vec![four, five], vec![seven, eight]])), + (1, 0, Matrix(vec![vec![two, three], vec![eight, nine]])), + (1, 1, Matrix(vec![vec![one, three], vec![seven, nine]])), + (1, 2, Matrix(vec![vec![one, two], vec![seven, eight]])), + (2, 0, Matrix(vec![vec![two, three], vec![five, six]])), + (2, 1, Matrix(vec![vec![one, three], vec![four, six]])), + (2, 2, Matrix(vec![vec![one, two], vec![four, five]])), + ]; + for (i, j, expected) in &cases { + let result = m.minor(*i, *j); + + assert_eq!(*expected, result); + } + } + + #[test] + fn test_scalar_mul() { + let zero = Fr::from(0u64); + let one = Fr::from(1u64); + let two = Fr::from(2u64); + let three = Fr::from(3u64); + let four = Fr::from(4u64); + let six = Fr::from(6u64); + + let m = Matrix(vec![vec![zero, one], vec![two, three]]); + let res = m.mul_by_scalar(two); + + let expected = Matrix(vec![vec![zero, two], vec![four, six]]); + + assert_eq!(expected, res); + } + + #[test] + fn test_vec_mul() { + let one = Fr::from(1u64); + let two = Fr::from(2u64); + let three = Fr::from(3u64); + let four = Fr::from(4u64); + let five = Fr::from(5u64); + let six = Fr::from(6u64); + + let a = vec![one, two, three]; + let b = vec![four, five, six]; + let res = inner_product(&a, &b); + + let expected = Fr::from(32u64); + + assert_eq!(expected, res); + } + + #[test] + fn test_transpose() { + let one = Fr::from(1u64); + let two = Fr::from(2u64); + let three = Fr::from(3u64); + let four = Fr::from(4u64); + let five = Fr::from(5u64); + let six = Fr::from(6u64); + let seven = Fr::from(7u64); + let eight = Fr::from(8u64); + let nine = Fr::from(9u64); + + let m: Matrix<_> = vec![ + vec![one, two, three], + vec![four, five, six], + vec![seven, eight, nine], + ] + .into(); + + let expected: Matrix<_> = vec![ + vec![one, four, seven], + vec![two, five, eight], + vec![three, six, nine], + ] + .into(); + + let res = m.transpose(); + assert_eq!(expected, res); + } + + #[test] + fn test_inverse() { + let zero = Fr::from(0u64); + let one = Fr::from(1u64); + let two = Fr::from(2u64); + let three = Fr::from(3u64); + let four = Fr::from(4u64); + let five = Fr::from(5u64); + let six = Fr::from(6u64); + let seven = Fr::from(7u64); + let eight = Fr::from(8u64); + let nine = Fr::from(9u64); + + let m = Matrix(vec![ + vec![one, two, three], + vec![four, three, six], + vec![five, eight, seven], + ]); + + let m1 = Matrix(vec![ + vec![one, two, three], + vec![four, five, six], + vec![seven, eight, nine], + ]); + + assert!(!m1.is_invertible()); + assert!(m.is_invertible()); + + let m_inv = m.invert().unwrap(); + + let computed_identity = m.matmul(&m_inv).unwrap(); + assert!(computed_identity.is_identity()); + + // S + let some_vec = vec![six, five, four]; + + // M^-1(S) + let inverse_applied = m_inv.right_apply(&some_vec); + + // M(M^-1(S)) + let m_applied_after_inverse = m.right_apply(&inverse_applied); + + // S = M(M^-1(S)) + assert_eq!( + some_vec, m_applied_after_inverse, + "M(M^-1(V))) = V did not hold" + ); + + // panic!(); + // B + let base_vec = vec![eight, two, five]; + + // S + M(B) + let add_after_apply = vec_add(&some_vec, &m.right_apply(&base_vec)); + + // M(B + M^-1(S)) + let apply_after_add = + m.right_apply(&vec_add(&base_vec, &inverse_applied)); + + // S + M(B) = M(B + M^-1(S)) + assert_eq!(add_after_apply, apply_after_add, "breakin' the law"); + + let m = Matrix(vec![vec![zero, one], vec![one, zero]]); + let m_inv = m.invert().unwrap(); + let computed_identity = m.matmul(&m_inv).unwrap(); + assert!(computed_identity.is_identity()); + let computed_identity = m_inv.matmul(&m).unwrap(); + assert!(computed_identity.is_identity()); + } + + #[test] + fn test_eliminate() { + let two = Fr::from(2u64); + let three = Fr::from(3u64); + let four = Fr::from(4u64); + let five = Fr::from(5u64); + let six = Fr::from(6u64); + let seven = Fr::from(7u64); + let eight = Fr::from(8u64); + let m = Matrix(vec![ + vec![two, three, four], + vec![four, five, six], + vec![seven, eight, eight], + ]); + + for i in 0..m.num_rows() { + let mut shadow = Matrix::identity(m.num_columns()); + let res = m.eliminate(i, &mut shadow); + if i > 0 { + assert!(res.is_none()); + continue; + } else { + assert!(res.is_some()); + } + + assert_eq!( + 1, + res.unwrap() + .iter_rows() + .filter(|&row| row[i] != Fr::zero()) + .count() + ); + } + } + + #[test] + fn test_upper_triangular() { + let zero = Fr::zero(); + let two = Fr::from(2u64); + let three = Fr::from(3u64); + let four = Fr::from(4u64); + let five = Fr::from(5u64); + let six = Fr::from(6u64); + let seven = Fr::from(7u64); + let eight = Fr::from(8u64); + + let m = Matrix(vec![ + vec![two, three, four], + vec![four, five, six], + vec![seven, eight, eight], + ]); + + // let expected = Matrix(vec![ + // vec![two, three, four], + // vec![zero, five, six], + // vec![zero, zero, eight], + // ]); + + // let mut shadow = make_identity(columns(&m)); + // let _res = upper_triangular(&m, &mut shadow); + let mut shadow = Matrix::identity(m.num_columns()); + let res = m.upper_triangular(&mut shadow).unwrap(); + + // Actually assert things. + assert!(res[0][0] != zero); + assert!(res[0][1] != zero); + assert!(res[0][2] != zero); + assert!(res[1][0] == zero); + assert!(res[1][1] != zero); + assert!(res[1][2] != zero); + assert!(res[2][0] == zero); + assert!(res[2][1] == zero); + assert!(res[2][2] != zero); + } + + #[test] + fn test_reduce_to_identity() { + let two = Fr::from(2u64); + let three = Fr::from(3u64); + let four = Fr::from(4u64); + let five = Fr::from(5u64); + let six = Fr::from(6u64); + let seven = Fr::from(7u64); + let eight = Fr::from(8u64); + + let m = Matrix(vec![ + vec![two, three, four], + vec![four, five, six], + vec![seven, eight, eight], + ]); + + let mut shadow = Matrix::identity(m.num_columns()); + let ut = m.upper_triangular(&mut shadow); + + let res = ut.and_then(|x| x.reduce_to_identity(&mut shadow)).unwrap(); + + assert!(res.is_identity()); + + let prod = m.matmul(&shadow).unwrap(); + + assert!(prod.is_identity()); + } +} diff --git a/plonk-hashing/src/poseidon/mds.rs b/plonk-hashing/src/poseidon/mds.rs new file mode 100644 index 00000000..ba3e6fe4 --- /dev/null +++ b/plonk-hashing/src/poseidon/mds.rs @@ -0,0 +1,257 @@ +// adapted from https://github.com/filecoin-project/neptune/blob/master/src/mds.rs +use crate::poseidon::matrix::Matrix; +use ark_ff::vec::*; +use ark_ff::PrimeField; + +#[derive(Clone, Debug, PartialEq)] +pub struct MdsMatrices { + pub m: Matrix, + pub m_inv: Matrix, + pub m_hat: Matrix, + pub m_hat_inv: Matrix, + pub m_prime: Matrix, + pub m_double_prime: Matrix, +} + +impl MdsMatrices { + /// Derive MDS matrix of size `dim*dim` and relevant things + pub fn new(dim: usize) -> Self { + let m = Self::generate_mds(dim); + Self::derive_mds_matrices(m) + } + + /// Given an MDS matrix `m`, compute all its associated matrices. + pub(crate) fn derive_mds_matrices(m: Matrix) -> Self { + let m_inv = m.invert().expect("Derived MDS matrix is not invertible"); + let m_hat = m.minor(0, 0); + let m_hat_inv = + m_hat.invert().expect("Derived MDS matrix is not correct"); + let m_prime = Self::make_prime(&m); + let m_double_prime = Self::make_double_prime(&m, &m_hat_inv); + MdsMatrices { + m, + m_inv, + m_hat, + m_hat_inv, + m_prime, + m_double_prime, + } + } + + fn generate_mds(t: usize) -> Matrix { + let xs: Vec = (0..t as u64).map(F::from).collect(); + let ys: Vec = (t as u64..2 * t as u64).map(F::from).collect(); + + let matrix = xs + .iter() + .map(|xs_item| { + ys.iter() + .map(|ys_item| { + // Generate the entry at (i,j) + let mut tmp = *xs_item; + tmp.add_assign(ys_item); + tmp.inverse().unwrap() + }) + .collect() + }) + .collect::>(); + + debug_assert!(matrix.is_invertible()); + debug_assert_eq!(matrix, matrix.transpose()); + matrix + } + + /// Returns a matrix associated to `m` in the optimization of + /// MDS matrices. + fn make_prime(m: &Matrix) -> Matrix { + m.iter_rows() + .enumerate() + .map(|(i, row)| match i { + 0 => { + let mut new_row = vec![F::zero(); row.len()]; + new_row[0] = F::one(); + new_row + } + _ => { + let mut new_row = vec![F::zero(); row.len()]; + new_row[1..].copy_from_slice(&row[1..]); + new_row + } + }) + .collect() + } + + /// Returns a matrix associated to `m` in the optimization of + /// MDS matrices. + fn make_double_prime(m: &Matrix, m_hat_inv: &Matrix) -> Matrix { + let (v, w) = Self::make_v_w(m); + let w_hat = m_hat_inv.right_apply(&w); + + m.iter_rows() + .enumerate() + .map(|(i, row)| match i { + 0 => { + let mut new_row = Vec::with_capacity(row.len()); + new_row.push(row[0]); + new_row.extend(&v); + new_row + } + _ => { + let mut new_row = vec![F::zero(); row.len()]; + new_row[0] = w_hat[i - 1]; + new_row[i] = F::one(); + new_row + } + }) + .collect() + } + + /// Returns two vectors associated to `m` in the optimization of + /// MDS matrices. + fn make_v_w(m: &Matrix) -> (Vec, Vec) { + let v = m[0][1..].to_vec(); + let w = m.iter_rows().skip(1).map(|column| column[0]).collect(); + (v, w) + } +} + +/// A `SparseMatrix` is specifically one of the form of M''. +/// This means its first row and column are each dense, and the interior matrix +/// (minor to the element in both the row and column) is the identity. +/// We will pluralize this compact structure `sparse_matrixes` to distinguish +/// from `sparse_matrices` from which they are created. +#[derive(Debug, Clone, PartialEq)] +pub struct SparseMatrix { + /// `w_hat` is the first column of the M'' matrix. It will be directly + /// multiplied (scalar product) with a row of state elements. + pub w_hat: Vec, + /// `v_rest` contains all but the first (already included in `w_hat`). + pub v_rest: Vec, +} + +impl SparseMatrix { + pub fn new(m_double_prime: &Matrix) -> Self { + assert!(m_double_prime.is_sparse()); + + let w_hat = m_double_prime.iter_rows().map(|r| r[0]).collect(); + let v_rest = m_double_prime[0][1..].to_vec(); + Self { w_hat, v_rest } + } + + pub fn size(&self) -> usize { + self.w_hat.len() + } + + pub fn to_matrix(&self) -> Matrix { + let mut m = Matrix::identity(self.size()); + for (j, elt) in self.w_hat.iter().enumerate() { + m[j][0] = *elt; + } + for (i, elt) in self.v_rest.iter().enumerate() { + m[0][i + 1] = *elt; + } + m + } +} + +// TODO: naming is from https://github.com/filecoin-project/neptune/blob/master/src/mds.rs +// TODO: which is little difficult to follow... We need to change it + +pub fn factor_to_sparse_matrixes( + base_matrix: Matrix, + n: usize, +) -> (Matrix, Vec>) { + let (pre_sparse, mut sparse_matrices) = + (0..n).fold((base_matrix.clone(), Vec::new()), |(curr, mut acc), _| { + let derived = MdsMatrices::derive_mds_matrices(curr); + acc.push(derived.m_double_prime); + let new = base_matrix.matmul(&derived.m_prime).unwrap(); + (new, acc) + }); + sparse_matrices.reverse(); + let sparse_matrixes = sparse_matrices + .iter() + .map(|m| SparseMatrix::::new(m)) + .collect::>(); + + (pre_sparse, sparse_matrixes) +} + +#[cfg(test)] +mod tests { + use crate::poseidon::mds::MdsMatrices; + use ark_bls12_381::Fr; + use ark_std::{test_rng, UniformRand}; + + #[test] + fn test_mds_matrices_creation() { + for i in 2..5 { + test_mds_matrices_creation_aux(i); + } + } + + fn test_mds_matrices_creation_aux(width: usize) { + let MdsMatrices { + m, + m_inv, + m_hat, + m_hat_inv: _, + m_prime, + m_double_prime, + } = MdsMatrices::::new(width); + + for i in 0..m_hat.num_rows() { + for j in 0..m_hat.num_columns() { + assert_eq!( + m[i + 1][j + 1], + m_hat[i][j], + "MDS minor has wrong value." + ); + } + } + + // M^-1 x M = I + assert!(m_inv.matmul(&m).unwrap().is_identity()); + + // M' x M'' = M + assert_eq!(m, m_prime.matmul(&m_double_prime).unwrap()); + } + + #[test] + fn test_swapping() { + test_swapping_aux(3) + } + + fn test_swapping_aux(width: usize) { + let mut rng = test_rng(); + let mds = MdsMatrices::::new(width); + + let base = (0..width).map(|_| Fr::rand(&mut rng)).collect::>(); + let x = { + let mut x = base.clone(); + x[0] = Fr::rand(&mut rng); + x + }; + let y = { + let mut y = base.clone(); + y[0] = Fr::rand(&mut rng); + y + }; + + let qx = mds.m_prime.right_apply(&x); + let qy = mds.m_prime.right_apply(&y); + assert_eq!(qx[0], x[0]); + assert_eq!(qy[0], y[0]); + assert_eq!(qx[1..], qy[1..]); + + let mx = mds.m.left_apply(&x); + let m1_m2_x = + mds.m_prime.left_apply(&mds.m_double_prime.left_apply(&x)); + assert_eq!(mx, m1_m2_x); + + let xm = mds.m.right_apply(&x); + let x_m1_m2 = + mds.m_double_prime.right_apply(&mds.m_prime.right_apply(&x)); + assert_eq!(xm, x_m1_m2); + } +} diff --git a/plonk-hashing/src/poseidon/mod.rs b/plonk-hashing/src/poseidon/mod.rs new file mode 100644 index 00000000..b95c9751 --- /dev/null +++ b/plonk-hashing/src/poseidon/mod.rs @@ -0,0 +1,17 @@ +pub mod constants; +pub mod matrix; +pub mod mds; +pub mod poseidon; +pub mod poseidon_ref; +pub mod preprocessing; +pub mod round_constant; +pub mod round_numbers; +// pub mod constraints; + +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum PoseidonError { + #[error("Buffer is full")] + FullBuffer, +} diff --git a/plonk-hashing/src/poseidon/poseidon.rs b/plonk-hashing/src/poseidon/poseidon.rs new file mode 100644 index 00000000..18e6725f --- /dev/null +++ b/plonk-hashing/src/poseidon/poseidon.rs @@ -0,0 +1,582 @@ +//! optimized poseidon + +use crate::poseidon::constants::PoseidonConstants; +use crate::poseidon::matrix::Matrix; +use crate::poseidon::mds::SparseMatrix; +use crate::poseidon::PoseidonError; +use ark_ec::TEModelParameters; +use ark_ff::PrimeField; +use core::{fmt::Debug, marker::PhantomData}; +use derivative::Derivative; +use plonk_core::constraint_system::{StandardComposer, Variable}; +use plonk_core::prelude as plonk; + +// TODO: reduce duplicate code with `poseidon_ref` +pub trait PoseidonSpec { + type Field: Debug + Clone; + type ParameterField: PrimeField; + + fn output_hash( + c: &mut COM, + elements: &[Self::Field; WIDTH], + constants: &PoseidonConstants, + ) -> Self::Field { + // Counters + let mut constants_offset = 0usize; + let mut current_round = 0usize; + // State vector to modify + let mut state = elements.clone(); + Self::add_round_constants( + c, + &mut state, + constants, + &mut constants_offset, + ); + + for _ in 0..constants.half_full_rounds { + Self::full_round( + c, + constants, + &mut current_round, + &mut constants_offset, + false, + &mut state, + ) + } + + for _ in 0..constants.partial_rounds { + Self::partial_round( + c, + constants, + &mut current_round, + &mut constants_offset, + &mut state, + ); + } + + // All but last full round + for _ in 1..constants.half_full_rounds { + Self::full_round( + c, + constants, + &mut current_round, + &mut constants_offset, + false, + &mut state, + ); + } + Self::full_round( + c, + constants, + &mut current_round, + &mut constants_offset, + true, + &mut state, + ); + + assert_eq!( + constants_offset, + constants.compressed_round_constants.len(), + "Constants consumed ({}) must equal preprocessed constants provided ({}).", + constants_offset, + constants.compressed_round_constants.len() + ); + + state[1].clone() + } + + fn full_round( + c: &mut COM, + constants: &PoseidonConstants, + current_round: &mut usize, + const_offset: &mut usize, + last_round: bool, + state: &mut [Self::Field; WIDTH], + ) { + let to_take = WIDTH; + let post_round_keys = constants + .compressed_round_constants + .iter() + .skip(*const_offset) + .take(to_take); + + if !last_round { + let needed = *const_offset + to_take; + assert!( + needed <= constants.compressed_round_constants.len(), + "Not enough preprocessed round constants ({}), need {}.", + constants.compressed_round_constants.len(), + needed + ); + } + + state.iter_mut().zip(post_round_keys).for_each(|(l, post)| { + // Be explicit that no round key is added after last round of S-boxes. + let post_key = if last_round { + panic!( + "Trying to skip last full round, but there is a key here! ({:?})", + post + ); + } else { + Some(post.clone()) + }; + *l = Self::quintic_s_box(c, l.clone(), None, post_key); + }); + + if last_round { + state.iter_mut().for_each(|l| { + *l = Self::quintic_s_box(c, l.clone(), None, None) + }) + } else { + *const_offset += to_take; + } + Self::round_product_mds(c, constants, current_round, state); + } + + fn partial_round( + c: &mut COM, + constants: &PoseidonConstants, + current_round: &mut usize, + const_offset: &mut usize, + state: &mut [Self::Field; WIDTH], + ) { + let post_round_key = + constants.compressed_round_constants[*const_offset]; + + state[0] = Self::quintic_s_box( + c, + state[0].clone(), + None, + Some(post_round_key), + ); + *const_offset += 1; + + Self::round_product_mds(c, constants, current_round, state); + } + + fn add_round_constants( + c: &mut COM, + state: &mut [Self::Field; WIDTH], + constants: &PoseidonConstants, + const_offset: &mut usize, + ) { + for (element, round_constant) in state.iter_mut().zip( + constants + .compressed_round_constants + .iter() + .skip(*const_offset), + ) { + *element = Self::addi(c, element, round_constant); + } + *const_offset += WIDTH; + } + + fn round_product_mds( + c: &mut COM, + constants: &PoseidonConstants, + current_round: &mut usize, + state: &mut [Self::Field; WIDTH], + ) { + let full_half = constants.half_full_rounds; + let sparse_offset = full_half - 1; + if *current_round == sparse_offset { + Self::product_mds_with_matrix( + c, + state, + &constants.pre_sparse_matrix, + ) + } else { + if (*current_round > sparse_offset) + && (*current_round < full_half + constants.partial_rounds) + { + let index = *current_round - sparse_offset - 1; + let sparse_matrix = &constants.sparse_matrixes[index]; + + Self::product_mds_with_sparse_matrix(c, state, sparse_matrix) + } else { + Self::product_mds(c, constants, state) + } + }; + + *current_round += 1; + } + + fn product_mds( + c: &mut COM, + constants: &PoseidonConstants, + state: &mut [Self::Field; WIDTH], + ) { + Self::product_mds_with_matrix(c, state, &constants.mds_matrices.m) + } + + fn linear_combination( + c: &mut COM, + state: &[Self::Field; WIDTH], + coeff: impl IntoIterator, + ) -> Self::Field { + state.iter().zip(coeff).fold(Self::zero(c), |acc, (x, y)| { + let tmp = Self::muli(c, x, &y); + Self::add(c, &tmp, &acc) + }) + } + + /// compute state @ Mat where `state` is a row vector + fn product_mds_with_matrix( + c: &mut COM, + state: &mut [Self::Field; WIDTH], + matrix: &Matrix, + ) { + let mut result = Self::zeros::(c); + for (col_index, val) in result.iter_mut().enumerate() { + // for (i, row) in matrix.iter_rows().enumerate() { + // // *val += row[j] * state[i]; + // let tmp = Self::muli(c, &state[i], &row[j]); + // *val = Self::add(c, val, &tmp); + // } + *val = Self::linear_combination( + c, + state, + matrix.column(col_index).cloned(), + ); + } + + *state = result; + } + + fn product_mds_with_sparse_matrix( + c: &mut COM, + state: &mut [Self::Field; WIDTH], + matrix: &SparseMatrix, + ) { + let mut result = Self::zeros::(c); + + // First column is dense. + // for (i, val) in matrix.w_hat.iter().enumerate() { + // // result[0] += w_hat[i] * state[i]; + // let tmp = Self::muli(c, &state[i], &val); + // result[0] = Self::add(c, &result[0], &tmp); + // } + result[0] = + Self::linear_combination(c, state, matrix.w_hat.iter().cloned()); + + for (j, val) in result.iter_mut().enumerate().skip(1) { + // for each j, result[j] = state[j] + state[0] * v_rest[j-1] + + // Except for first row/column, diagonals are one. + *val = Self::add(c, val, &state[j]); + // + // // First row is dense. + let tmp = Self::muli(c, &state[0], &matrix.v_rest[j - 1]); + *val = Self::add(c, val, &tmp); + } + *state = result; + } + + /// return (x + pre_add)^5 + post_add + fn quintic_s_box( + c: &mut COM, + x: Self::Field, + pre_add: Option, + post_add: Option, + ) -> Self::Field { + let mut tmp = match pre_add { + Some(a) => Self::addi(c, &x, &a), + None => x.clone(), + }; + tmp = Self::power_of_5(c, &tmp); + match post_add { + Some(a) => Self::addi(c, &tmp, &a), + None => tmp, + } + } + + fn power_of_5(c: &mut COM, x: &Self::Field) -> Self::Field { + let mut tmp = Self::mul(c, x, x); // x^2 + tmp = Self::mul(c, &tmp, &tmp); // x^4 + Self::mul(c, &tmp, x) // x^5 + } + + fn alloc(c: &mut COM, v: Self::ParameterField) -> Self::Field; + fn zeros(c: &mut COM) -> [Self::Field; W]; + fn zero(c: &mut COM) -> Self::Field { + Self::zeros::<1>(c)[0].clone() + } + fn add(c: &mut COM, x: &Self::Field, y: &Self::Field) -> Self::Field; + fn addi( + c: &mut COM, + a: &Self::Field, + b: &Self::ParameterField, + ) -> Self::Field; + fn mul(c: &mut COM, x: &Self::Field, y: &Self::Field) -> Self::Field; + fn muli( + c: &mut COM, + x: &Self::Field, + y: &Self::ParameterField, + ) -> Self::Field; +} + +#[derive(Derivative)] +#[derivative(Debug(bound = ""))] +pub struct Poseidon, const WIDTH: usize> +where + S: ?Sized, +{ + pub(crate) constants: PoseidonConstants, +} + +impl, const WIDTH: usize> + Poseidon +where + S: ?Sized, +{ + pub fn new(constants: PoseidonConstants) -> Self { + Poseidon { constants } + } + + pub fn arity(&self) -> usize { + WIDTH - 1 + } + + /// Hash an array of ARITY-many elements. The size of elements could be + /// specified as WIDTH - 1 when const generic expressions are allowed. + /// Function will panic if elements does not have length ARITY. + pub fn output_hash(&self, elements: &[S::Field], c: &mut COM) -> S::Field { + // Inputs should have domain_tag as its leading entry + let mut inputs = S::zeros(c); + // clone_from_slice will panic unless we provided ARITY-many elements + inputs[1..WIDTH].clone_from_slice(&elements[..(WIDTH - 1)]); + + S::output_hash(c, &inputs, &self.constants) + } +} + +pub struct NativeSpec { + _field: PhantomData, +} + +impl PoseidonSpec<(), WIDTH> + for NativeSpec +{ + type Field = F; + type ParameterField = F; + + fn alloc(_c: &mut (), v: Self::ParameterField) -> Self::Field { + v + } + + fn zeros(_c: &mut ()) -> [Self::Field; W] { + [F::zero(); W] + } + + fn add(_c: &mut (), x: &Self::Field, y: &Self::Field) -> Self::Field { + *x + *y + } + + fn addi( + _c: &mut (), + a: &Self::Field, + b: &Self::ParameterField, + ) -> Self::Field { + *a + *b + } + + fn mul(_c: &mut (), x: &Self::Field, y: &Self::Field) -> Self::Field { + *x * *y + } + + fn muli( + _c: &mut (), + x: &Self::Field, + y: &Self::ParameterField, + ) -> Self::Field { + *x * *y + } +} + +pub struct PlonkSpec; + +impl + PoseidonSpec, WIDTH> for PlonkSpec +where + F: PrimeField, + P: TEModelParameters, +{ + type Field = plonk::Variable; + type ParameterField = F; + + fn alloc( + c: &mut StandardComposer, + v: Self::ParameterField, + ) -> Self::Field { + c.add_input(v) + } + + fn zeros( + c: &mut StandardComposer, + ) -> [Self::Field; W] { + [c.zero_var(); W] + } + + fn add( + c: &mut StandardComposer, + x: &Self::Field, + y: &Self::Field, + ) -> Self::Field { + c.arithmetic_gate(|g| g.witness(*x, *y, None).add(F::one(), F::one())) + } + + fn addi( + c: &mut StandardComposer, + a: &Self::Field, + b: &Self::ParameterField, + ) -> Self::Field { + let zero = c.zero_var(); + c.arithmetic_gate(|g| { + g.witness(*a, zero, None) + .add(F::one(), F::zero()) + .constant(*b) + }) + } + + fn mul( + c: &mut StandardComposer, + x: &Self::Field, + y: &Self::Field, + ) -> Self::Field { + c.arithmetic_gate(|q| q.witness(*x, *y, None).mul(F::one())) + } + + fn muli( + c: &mut StandardComposer, + x: &Self::Field, + y: &Self::ParameterField, + ) -> Self::Field { + let zero = c.zero_var(); + c.arithmetic_gate(|g| g.witness(*x, zero, None).add(*y, F::zero())) + } + + #[cfg(not(feature = "no-optimize"))] + fn quintic_s_box( + c: &mut StandardComposer, + x: Self::Field, + pre_add: Option, + post_add: Option, + ) -> Self::Field { + match (pre_add, post_add) { + (None, None) => Self::power_of_5(c, &x), + (Some(_), None) => { + unreachable!("currently no one is using this") + } + (None, Some(post_add)) => { + let x_2 = Self::mul(c, &x, &x); + let x_4 = Self::mul(c, &x_2, &x_2); + c.arithmetic_gate(|g| { + g.witness(x_4, x, None).mul(F::one()).constant(post_add) + }) + } + (Some(_), Some(_)) => { + unreachable!("currently no one is using this") + } + } + } + + #[cfg(not(feature = "no-optimize"))] + fn linear_combination( + c: &mut StandardComposer, + state: &[Self::Field; WIDTH], + coeff: impl IntoIterator, + ) -> Self::Field { + let coeffs = coeff.into_iter().collect::>(); + let mut remaining = WIDTH; + let mut index = 0; + let mut result: Self::Field; + // the first time you have no accumulated result yet, so you can take 3 + // inputs + if remaining < 3 { + // this is unlikely, WIDTH is usually at least 3 + result = c.arithmetic_gate(|g| { + g.witness(state[0], state[1], None) + .add(coeffs[0], coeffs[1]) + }); + remaining -= 2; + } else { + result = c.arithmetic_gate(|g| { + g.witness(state[index], state[index + 1], None) + .add(coeffs[index], coeffs[index + 1]) + .fan_in_3(coeffs[index + 2], state[index + 2]) + }); + index += 3; + remaining -= 3; + } + + // Now you have an accumulated result to carry, so can only take 2 + // inputs at a time + while remaining > 0 { + if remaining < 2 { + // Accumulate remaining one + result = c.arithmetic_gate(|g| { + g.witness(state[index], result, None) + .add(coeffs[index], Self::ParameterField::one()) + }); + remaining -= 1; + } else { + // Accumulate next two + result = c.arithmetic_gate(|g| { + g.witness(state[index], state[index + 1], None) + .add(coeffs[index], coeffs[index + 1]) + .fan_in_3(Self::ParameterField::one(), result) + }); + index += 2; + remaining -= 2; + } + } + result + } + + #[cfg(not(feature = "no-optimize"))] + fn product_mds_with_sparse_matrix( + c: &mut StandardComposer, + state: &mut [Self::Field; WIDTH], + matrix: &SparseMatrix, + ) { + let mut result = Self::zeros::(c); + + result[0] = + Self::linear_combination(c, state, matrix.w_hat.iter().cloned()); + for (j, val) in result.iter_mut().enumerate().skip(1) { + // for each j, result[j] = state[j] + state[0] * v_rest[j-1] + *val = c.arithmetic_gate(|g| { + g.witness(state[0], state[j], None) + .add(matrix.v_rest[j - 1], F::one()) + }); + } + *state = result; + } +} + +trait HashFunction { + type Input; + type Output; + + // The input ought to have size ARITY, but const generic expressions aren't + // allowed yet + fn hash(&self, input: &[Self::Input], compiler: &mut COM) -> Self::Output; +} + +impl HashFunction> + for Poseidon, PlonkSpec, WIDTH> +where + F: PrimeField, + P: TEModelParameters, +{ + type Input = Variable; + type Output = Variable; + + fn hash( + &self, + input: &[Self::Input], + compiler: &mut StandardComposer, + ) -> Self::Output { + self.output_hash(input, compiler) + } +} diff --git a/plonk-hashing/src/poseidon/poseidon_ref.rs b/plonk-hashing/src/poseidon/poseidon_ref.rs new file mode 100644 index 00000000..5a2a6497 --- /dev/null +++ b/plonk-hashing/src/poseidon/poseidon_ref.rs @@ -0,0 +1,525 @@ +//! Correct, Naive, reference implementation of Poseidon hash function. + +use crate::poseidon::PoseidonError; + +use crate::poseidon::constants::PoseidonConstants; +use ark_ec::TEModelParameters; +use ark_ff::PrimeField; +use core::{fmt::Debug, marker::PhantomData}; +use derivative::Derivative; +use plonk_core::{constraint_system::StandardComposer, prelude as plonk}; + +pub trait PoseidonRefSpec { + /// Field used as state + type Field: Debug + Clone; + /// Field used as constant paramater + type ParameterField: PrimeField; // TODO: for now, only prime field is supported. Can be used for arkplonk + // and arkworks which uses the same + // PrimeField. For other field, we are not + // supporting yet. + + fn full_round( + c: &mut COM, + constants: &PoseidonConstants, + constants_offset: &mut usize, + state: &mut [Self::Field; WIDTH], + ) { + let pre_round_keys = constants + .round_constants + .iter() + .skip(*constants_offset) + .map(Some); + + state.iter_mut().zip(pre_round_keys).for_each(|(l, pre)| { + *l = Self::quintic_s_box(c, l.clone(), pre.map(|x| *x), None); + }); + + *constants_offset += WIDTH; + + Self::product_mds(c, constants, state); + } + + fn partial_round( + c: &mut COM, + constants: &PoseidonConstants, + constants_offset: &mut usize, + state: &mut [Self::Field; WIDTH], + ) { + // TODO: we can combine add_round_constants and s_box using fewer + // constraints + Self::add_round_constants(c, state, constants, constants_offset); + + // apply quintic s-box to the first element + state[0] = Self::quintic_s_box(c, state[0].clone(), None, None); + + // Multiply by MDS + Self::product_mds(c, constants, state); + } + + fn add_round_constants( + c: &mut COM, + state: &mut [Self::Field; WIDTH], + constants: &PoseidonConstants, + constants_offset: &mut usize, + ) { + for (element, round_constant) in state + .iter_mut() + .zip(constants.round_constants.iter().skip(*constants_offset)) + { + // element.com_addi(c, round_constant); + *element = Self::addi(c, element, round_constant) + } + + *constants_offset += WIDTH; + } + + fn product_mds( + c: &mut COM, + constants: &PoseidonConstants, + state: &mut [Self::Field; WIDTH], + ) { + let matrix = &constants.mds_matrices.m; + let mut result = Self::zeros::(c); + for (j, val) in result.iter_mut().enumerate() { + for (i, row) in matrix.iter_rows().enumerate() { + // *val += row[j] * state[i]; + let tmp = Self::muli(c, &state[i], &row[j]); + *val = Self::add(c, val, &tmp); + } + } + *state = result; + } + + /// return (x + pre_add)^5 + post_add + fn quintic_s_box( + c: &mut COM, + x: Self::Field, + pre_add: Option, + post_add: Option, + ) -> Self::Field { + let mut tmp = match pre_add { + Some(a) => Self::addi(c, &x, &a), + None => x.clone(), + }; + tmp = Self::power_of_5(c, &tmp); + match post_add { + Some(a) => Self::addi(c, &tmp, &a), + None => tmp, + } + } + + fn power_of_5(c: &mut COM, x: &Self::Field) -> Self::Field { + let mut tmp = Self::mul(c, x, x); // x^2 + tmp = Self::mul(c, &tmp, &tmp); // x^4 + Self::mul(c, &tmp, x) // x^5 + } + + fn alloc(c: &mut COM, v: Self::ParameterField) -> Self::Field; + fn zeros(c: &mut COM) -> [Self::Field; W]; + fn zero(c: &mut COM) -> Self::Field { + Self::zeros::<1>(c)[0].clone() + } + fn add(c: &mut COM, x: &Self::Field, y: &Self::Field) -> Self::Field; + fn addi( + c: &mut COM, + a: &Self::Field, + b: &Self::ParameterField, + ) -> Self::Field; + fn mul(c: &mut COM, x: &Self::Field, y: &Self::Field) -> Self::Field; + fn muli( + c: &mut COM, + x: &Self::Field, + y: &Self::ParameterField, + ) -> Self::Field; +} + +#[derive(Derivative)] +#[derivative(Debug(bound = ""))] +pub struct PoseidonRef, const WIDTH: usize> +where + S: ?Sized, +{ + pub(crate) constants_offset: usize, + pub(crate) current_round: usize, + pub elements: [S::Field; WIDTH], + pos: usize, + pub(crate) constants: PoseidonConstants, +} + +impl, const WIDTH: usize> + PoseidonRef +{ + pub fn new( + c: &mut COM, + constants: PoseidonConstants, + ) -> Self { + let mut elements = S::zeros(c); + elements[0] = S::alloc(c, constants.domain_tag); + PoseidonRef { + constants_offset: 0, + current_round: 0, + elements, + pos: 1, + constants, + } + } + + pub fn arity(&self) -> usize { + WIDTH - 1 + } + + pub fn reset(&mut self, c: &mut COM) { + self.constants_offset = 0; + self.current_round = 0; + self.elements[1..].iter_mut().for_each(|l| *l = S::zero(c)); + self.elements[0] = S::alloc(c, self.constants.domain_tag); + self.pos = 1; + } + + /// input one field element to Poseidon. Return the position of the element + /// in state. + pub fn input(&mut self, input: S::Field) -> Result { + // Cannot input more elements than the defined constant width + if self.pos >= WIDTH { + return Err(PoseidonError::FullBuffer); + } + + // Set current element, and increase the pointer + self.elements[self.pos] = input; + self.pos += 1; + + Ok(self.pos - 1) + } + + /// Output the hash + pub fn output_hash(&mut self, c: &mut COM) -> S::Field { + S::full_round( + c, + &self.constants, + &mut self.constants_offset, + &mut self.elements, + ); + + for _ in 1..self.constants.half_full_rounds { + S::full_round( + c, + &self.constants, + &mut self.constants_offset, + &mut self.elements, + ); + } + + S::partial_round( + c, + &self.constants, + &mut self.constants_offset, + &mut self.elements, + ); + + for _ in 1..self.constants.partial_rounds { + S::partial_round( + c, + &self.constants, + &mut self.constants_offset, + &mut self.elements, + ); + } + + for _ in 0..self.constants.half_full_rounds { + S::full_round( + c, + &self.constants, + &mut self.constants_offset, + &mut self.elements, + ) + } + + self.elements[1].clone() + } +} + +pub struct NativeSpecRef { + _field: PhantomData, +} + +impl PoseidonRefSpec<(), WIDTH> + for NativeSpecRef +{ + type Field = F; + type ParameterField = F; + + fn alloc(_c: &mut (), v: Self::ParameterField) -> Self::Field { + v + } + + fn zeros(_c: &mut ()) -> [Self::Field; W] { + [F::zero(); W] + } + + fn add(_c: &mut (), x: &Self::Field, y: &Self::Field) -> Self::Field { + *x + *y + } + + fn addi( + _c: &mut (), + a: &Self::Field, + b: &Self::ParameterField, + ) -> Self::Field { + *a + *b + } + + fn mul(_c: &mut (), x: &Self::Field, y: &Self::Field) -> Self::Field { + *x * *y + } + + fn muli( + _c: &mut (), + x: &Self::Field, + y: &Self::ParameterField, + ) -> Self::Field { + *x * *y + } +} + +pub struct PlonkSpecRef; + +impl + PoseidonRefSpec, WIDTH> for PlonkSpecRef +where + F: PrimeField, + P: TEModelParameters, +{ + type Field = plonk::Variable; + type ParameterField = F; + + fn alloc( + c: &mut StandardComposer, + v: Self::ParameterField, + ) -> Self::Field { + c.add_input(v) + } + + fn zeros( + c: &mut StandardComposer, + ) -> [Self::Field; W] { + [c.zero_var(); W] + } + + fn add( + c: &mut StandardComposer, + x: &Self::Field, + y: &Self::Field, + ) -> Self::Field { + c.arithmetic_gate(|g| g.witness(*x, *y, None).add(F::one(), F::one())) + } + + fn addi( + c: &mut StandardComposer, + a: &Self::Field, + b: &Self::ParameterField, + ) -> Self::Field { + let zero = c.zero_var(); + c.arithmetic_gate(|g| { + g.witness(*a, zero, None) + .add(F::one(), F::zero()) + .constant(*b) + }) + } + + fn mul( + c: &mut StandardComposer, + x: &Self::Field, + y: &Self::Field, + ) -> Self::Field { + c.arithmetic_gate(|q| q.witness(*x, *y, None).mul(F::one())) + } + + fn muli( + c: &mut StandardComposer, + x: &Self::Field, + y: &Self::ParameterField, + ) -> Self::Field { + let zero = c.zero_var(); + c.arithmetic_gate(|g| g.witness(*x, zero, None).add(*y, F::zero())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ark_ec::PairingEngine; + type E = ark_bls12_381::Bls12_381; + type P = ark_ed_on_bls12_381::EdwardsParameters; + type Fr = ::Fr; + use ark_ff::Field; + use ark_std::{test_rng, UniformRand}; + + #[test] + // poseidon should output something if num_inputs = arity + fn sanity_test() { + const ARITY: usize = 4; + const WIDTH: usize = ARITY + 1; + let mut rng = test_rng(); + + let param = PoseidonConstants::generate::(); + let mut poseidon = PoseidonRef::<(), NativeSpecRef, WIDTH>::new( + &mut (), + param.clone(), + ); + let inputs = (0..ARITY).map(|_| Fr::rand(&mut rng)).collect::>(); + + inputs.iter().for_each(|x| { + let _ = poseidon.input(*x).unwrap(); + }); + let native_hash: Fr = poseidon.output_hash(&mut ()); + + let mut c = StandardComposer::::new(); + let inputs_var = + inputs.iter().map(|x| c.add_input(*x)).collect::>(); + let mut poseidon_circuit = + PoseidonRef::<_, PlonkSpecRef, WIDTH>::new(&mut c, param); + inputs_var.iter().for_each(|x| { + let _ = poseidon_circuit.input(*x).unwrap(); + }); + let plonk_hash = poseidon_circuit.output_hash(&mut c); + + c.check_circuit_satisfied(); + + let expected = c.add_input(native_hash); + c.assert_equal(expected, plonk_hash); + + c.check_circuit_satisfied(); + println!( + "circuit size for WIDTH {} poseidon: {}", + WIDTH, + c.circuit_size() + ) + } + + #[test] + // poseidon should output something if num_inputs = arity + fn sanity_test_r1cs() { + const ARITY: usize = 2; + const WIDTH: usize = ARITY + 1; + let mut rng = test_rng(); + + let param = PoseidonConstants::generate::(); + let mut poseidon = PoseidonRef::<(), NativeSpecRef, WIDTH>::new( + &mut (), + param.clone(), + ); + let inputs = (0..ARITY).map(|_| Fr::rand(&mut rng)).collect::>(); + + inputs.iter().for_each(|x| { + let _ = poseidon.input(*x).unwrap(); + }); + let native_hash: Fr = poseidon.output_hash(&mut ()); + + let mut cs = ConstraintSystem::new_ref(); + let mut poseidon_var = + PoseidonRef::<_, R1csSpecRef, WIDTH>::new( + &mut cs, + param.clone(), + ); + let inputs_var = inputs + .iter() + .map(|x| R1csSpecRef::<_, WIDTH>::alloc(&mut cs, *x)) + .collect::>(); + inputs_var.iter().for_each(|x| { + let _ = poseidon_var.input(x.clone()).unwrap(); + }); + + let hash_var = poseidon_var.output_hash(&mut cs); + + assert!(cs.is_satisfied().unwrap()); + assert_eq!(hash_var.value().unwrap(), native_hash); + println!( + "circuit size for WIDTH {} r1cs: {}", + WIDTH, + cs.num_constraints() + ) + } + + #[test] + #[should_panic] + // poseidon should output something if num_inputs > arity + fn sanity_test_failure() { + const ARITY: usize = 4; + const WIDTH: usize = ARITY + 1; + let mut rng = test_rng(); + + let param = PoseidonConstants::generate::(); + let mut poseidon = + PoseidonRef::<(), NativeSpecRef, WIDTH>::new(&mut (), param); + (0..(ARITY + 1)).for_each(|_| { + let _ = poseidon.input(Fr::rand(&mut rng)).unwrap(); + }); + let _ = poseidon.output_hash(&mut ()); + } + + use crate::tests::{ + conversion::cast_field, + neptune_hyper_parameter::collect_neptune_constants, + }; + use neptune::{ + poseidon::{HashMode, PoseidonConstants as NeptunePoseidonConstants}, + Strength, + }; + // let constants = NeptunePoseidonConstants::::new_with_strength(strength); let mut p = NeptunePoseidon::::new(&constants); let mut p2 = NeptunePoseidon::::new(&constants); let mut p3 = NeptunePoseidon::::new(&constants); let mut p4 = NeptunePoseidon::::new(&constants); + + // let test_arity = constants.arity(); + // for n in 0..test_arity { + // let scalar = Fr::from(n as u64); + // p.input(scalar).unwrap(); + // p2.input(scalar).unwrap(); + // p3.input(scalar).unwrap(); + // p4.input(scalar).unwrap(); + // } + + // let digest = p.hash(); + // let digest2 = p2.hash_in_mode(Correct); + // let digest3 = p3.hash_in_mode(OptimizedStatic); + // let digest4 = p4.hash_in_mode(OptimizedDynamic); + + #[test] + fn compare_with_neptune() { + const ARITY: usize = 2; + const WIDTH: usize = ARITY + 1; + type NepArity = generic_array::typenum::U2; + + let (nep_consts, ark_consts) = + collect_neptune_constants::(Strength::Standard); + + let mut rng = test_rng(); + let inputs_ff = (0..ARITY) + .map(|_| blstrs::Scalar::random(&mut rng)) + .collect::>(); + let inputs = + inputs_ff.iter().map(|&x| cast_field(x)).collect::>(); + + let mut neptune_poseidon = + neptune::Poseidon::::new(&nep_consts); + let mut ark_poseidon = PoseidonRef::<(), NativeSpecRef, WIDTH>::new( + &mut (), + ark_consts, + ); + + inputs_ff.iter().for_each(|x| { + neptune_poseidon.input(*x).unwrap(); + }); + inputs.iter().for_each(|x| { + ark_poseidon.input(*x).unwrap(); + }); + + let digest_expected = + cast_field(neptune_poseidon.hash_in_mode(HashMode::Correct)); + let digest_actual = ark_poseidon.output_hash(&mut ()); + + assert_eq!(digest_expected, digest_actual); + } +} diff --git a/plonk-hashing/src/poseidon/preprocessing.rs b/plonk-hashing/src/poseidon/preprocessing.rs new file mode 100644 index 00000000..642213f9 --- /dev/null +++ b/plonk-hashing/src/poseidon/preprocessing.rs @@ -0,0 +1,88 @@ +//! acknowledgement: adapted from FileCoin Project: https://github.com/filecoin-project/neptune/blob/master/src/preprocessing.rs + +use super::{matrix::vec_add, mds::MdsMatrices}; +use ark_ff::vec::Vec; +use ark_ff::PrimeField; + +// - Compress constants by pushing them back through linear layers and through +// the identity components of partial layers. +// - As a result, constants need only be added after each S-box. +pub(crate) fn compress_round_constants( + width: usize, + full_rounds: usize, + partial_rounds: usize, + round_constants: &Vec, + mds_matrices: &MdsMatrices, +) -> Vec { + let inverse_matrix = &mds_matrices.m_inv; + + let mut res: Vec = Vec::new(); + + let round_keys = |r: usize| &round_constants[r * width..(r + 1) * width]; + + // This is half full-rounds. + let half_full_rounds = full_rounds / 2; + + // First round constants are unchanged. + res.extend(round_keys(0)); + + // Post S-box adds for the first set of full rounds should be 'inverted' + // from next round. The final round is skipped when fully preprocessing + // because that value must be obtained from the result of preprocessing + // the partial rounds. + let end = half_full_rounds - 1; + for i in 0..end { + let next_round = round_keys(i + 1); + let inverted = inverse_matrix.right_apply(next_round); + res.extend(inverted); + } + + // The plan: + // - Work backwards from last row in this group + // - Invert the row. + // - Save first constant (corresponding to the one S-box performed). + // - Add inverted result to previous row. + // - Repeat until all partial round key rows have been consumed. + // - Extend the preprocessed result by the final resultant row. + // - Move the accumulated list of single round keys to the preprocesed + // result. + // - (Last produced should be first applied, so either pop until empty, or + // reverse and extend, etc.) + + // 'partial_keys' will accumulated the single post-S-box constant for each + // partial-round, in reverse order. + let mut partial_keys: Vec = Vec::new(); + + let final_round = half_full_rounds + partial_rounds; + let final_round_key = round_keys(final_round).to_vec(); + + // 'round_acc' holds the accumulated result of inverting and adding + // subsequent round constants (in reverse). + let round_acc = (0..partial_rounds) + .map(|i| round_keys(final_round - i - 1)) + .fold(final_round_key, |acc, previous_round_keys| { + let mut inverted = inverse_matrix.right_apply(&acc); + + partial_keys.push(inverted[0]); + inverted[0] = F::zero(); + + vec_add(&previous_round_keys, &inverted) + }); + + res.extend(inverse_matrix.right_apply(&round_acc)); + + while let Some(x) = partial_keys.pop() { + res.push(x) + } + + // Post S-box adds for the first set of full rounds should be 'inverted' + // from next round. + for i in 1..(half_full_rounds) { + let start = half_full_rounds + partial_rounds; + let next_round = round_keys(i + start); + let inverted = inverse_matrix.right_apply(next_round); + res.extend(inverted); + } + + res +} diff --git a/plonk-hashing/src/poseidon/round_constant.rs b/plonk-hashing/src/poseidon/round_constant.rs new file mode 100644 index 00000000..754b5f5a --- /dev/null +++ b/plonk-hashing/src/poseidon/round_constant.rs @@ -0,0 +1,171 @@ +use alloc::collections::vec_deque::VecDeque; +use ark_ff::{BigInteger, PrimeField}; +/// From the paper +/// THe parameter describes the initial state of constant generation (80-bits) +/// * `field`: description of field. b0, b1 +/// * `sbox`: description of s-box. b2..=b5 +/// * `field_size`: binary representation of field size. b6..=b17 +/// * `t`: binary representation of t. b18..=b29 +/// * `rf`: binary representation of rf. b30..=b39 +/// * `rp`: binary representation of rp. b40..=b49 +/// * `ones`: set to 1. b50..=b79 +pub fn generate_constants( + field: u8, + sbox: u8, + field_size: u16, + t: u16, + r_f: u16, + r_p: u16, +) -> Vec { + let n_bytes = (F::size_in_bits() + 8 - 1) / 8; + if n_bytes != 32 { + unimplemented!("neptune currently supports 32-byte fields exclusively"); + }; + assert_eq!((field_size as f32 / 8.0).ceil() as usize, n_bytes); + + // r_f here is 2* number of *half* full rounds. + let num_constants = (r_f + r_p) * t; + let mut init_sequence: VecDeque = VecDeque::new(); + append_bits(&mut init_sequence, 2, field); // Bits 0-1 + append_bits(&mut init_sequence, 4, sbox); // Bits 2-5 + append_bits(&mut init_sequence, 12, field_size); // Bits 6-17 + append_bits(&mut init_sequence, 12, t); // Bits 18-29 + append_bits(&mut init_sequence, 10, r_f); // Bits 30-39 + append_bits(&mut init_sequence, 10, r_p); // Bits 40-49 + append_bits(&mut init_sequence, 30, 0b111111111111111111111111111111u128); // Bits 50-79 + + let mut grain = GrainLFSR::new(init_sequence, field_size); + let mut round_constants: Vec = Vec::new(); + + match field { + 1 => { + for _ in 0..num_constants { + while { + // TODO: Please review this part. May be different from + // neptune. + + // Generate 32 bytes and interpret them as a big-endian + // integer. Bytes are big-endian to + // agree with the integers generated by grain_random_bits in + // the reference implementation: + // + // def grain_random_bits(num_bits): + // random_bits = [grain_gen.next() for i in range(0, + // num_bits)] random_int = + // int("".join(str(i) for i in random_bits), 2) + // return random_int + let mut repr = F::default().into_repr().to_bytes_be(); + grain.get_next_bytes(repr.as_mut()); + repr.reverse(); + if let Some(f) = F::from_random_bytes(&repr) { + round_constants.push(f); + false + } else { + true + } + } {} + } + } + _ => { + panic!("Only prime fields are supported."); + } + } + round_constants +} + +fn append_bits>(vec: &mut VecDeque, n: usize, from: T) { + let val = from.into() as u128; + for i in (0..n).rev() { + vec.push_back((val >> i) & 1 != 0); + } +} + +// adapted from: https://github.com/filecoin-project/neptune/blob/master/src/round_constants.rs +struct GrainLFSR { + state: VecDeque, + field_size: u16, +} + +impl GrainLFSR { + pub fn new(initial_sequence: VecDeque, field_size: u16) -> Self { + assert_eq!( + initial_sequence.len(), + 80, + "Initial Sequence must be 80 bits" + ); + let mut g = GrainLFSR { + state: initial_sequence, + field_size, + }; + (0..160).for_each(|_| { + g.generate_new_bit(); + }); + assert_eq!(80, g.state.len()); + g + } + + fn generate_new_bit(&mut self) -> bool { + let new_bit = self.bit(62) + ^ self.bit(51) + ^ self.bit(38) + ^ self.bit(23) + ^ self.bit(13) + ^ self.bit(0); + self.state.pop_front(); + self.state.push_back(new_bit); + new_bit + } + + fn bit(&self, index: usize) -> bool { + self.state[index] + } + + fn next_byte(&mut self, bit_count: usize) -> u8 { + // Accumulate bits from most to least significant, so the most + // significant bit is the one generated first by the bit stream + let mut acc: u8 = 0; + self.take(bit_count).for_each(|bit| { + acc <<= 1; + if bit { + acc += 1; + } + }); + + acc + } + + fn get_next_bytes(&mut self, result: &mut [u8]) { + let remainder_bits = self.field_size as usize % 8; + // Prime fields will always have remainder bits, + // but other field types could be supported in the future. + if remainder_bits > 0 { + // If there is an unfull byte, it should be the first. + // Subsequent bytes are packed into result in the order generated. + result[0] = self.next_byte(remainder_bits); + } else { + result[0] = self.next_byte(8); + } + + // First byte is already set + for item in result.iter_mut().skip(1) { + *item = self.next_byte(8) + } + } +} + +impl Iterator for GrainLFSR { + type Item = bool; + + // TO BE checked + fn next(&mut self) -> Option { + let mut new_bit = self.generate_new_bit(); + while !new_bit { + let _new_bit = self.generate_new_bit(); + new_bit = self.generate_new_bit(); + } + new_bit = self.generate_new_bit(); + Some(new_bit) + } +} + +// TODO: TO BE TESTED! diff --git a/plonk-hashing/src/poseidon/round_numbers.rs b/plonk-hashing/src/poseidon/round_numbers.rs new file mode 100644 index 00000000..c448e27f --- /dev/null +++ b/plonk-hashing/src/poseidon/round_numbers.rs @@ -0,0 +1,200 @@ +// Adapted from https://github.com/filecoin-project/neptune/blob/master/src/round_numbers.rs + +// The number of bits of the Poseidon prime field modulus. Denoted `n` in the +// Poseidon paper (where `n = ceil(log2(p))`). Note that BLS12-381's scalar +// field modulus is 255 bits, however we use 256 bits for simplicity when +// operating on bytes as the single bit difference does not affect +// the round number security properties. +const PRIME_BITLEN: usize = 256; + +// Security level (in bits), denoted `M` in the Poseidon paper. +const M: usize = 128; + +/// The number of S-boxes (also called the "cost") given by equation (14) in the +/// Poseidon paper: `cost = t * R_F + R_P`. +fn n_sboxes(t: usize, rf: usize, rp: usize) -> usize { + t * rf + rp +} + +/// Returns the round numbers for a given arity `(R_F, R_P)`. +pub(crate) fn round_numbers_base(arity: usize) -> (usize, usize) { + let t = arity + 1; + calc_round_numbers(t, true) +} + +/// In case of newly-discovered attacks, we may need stronger security. +/// This option exists so we can preemptively create circuits in order to switch +/// to them quickly if needed. +/// +/// "A realistic alternative is to increase the number of partial rounds by 25%. +/// Then it is unlikely that a new attack breaks through this number, +/// but even if this happens then the complexity is almost surely above 2^64, +/// and you will be safe." +/// - D Khovratovich +pub(crate) fn round_numbers_strengthened(arity: usize) -> (usize, usize) { + let (full_round, partial_rounds) = round_numbers_base(arity); + + // Increase by 25%, rounding up. + let strengthened_partial_rounds = + f64::ceil(partial_rounds as f64 * 1.25) as usize; + + (full_round, strengthened_partial_rounds) +} + +/// Returns the round numbers for a given width `t`. Here, the `security_margin` +/// parameter does not indicate that we are calculating `R_F` and `R_P` for the +/// "strengthened" round numbers, done in the function +/// `round_numbers_strengthened()`. +pub(crate) fn calc_round_numbers( + t: usize, + security_margin: bool, +) -> (usize, usize) { + let mut rf = 0; + let mut rp = 0; + let mut n_sboxes_min = usize::MAX; + + for mut rf_test in (2..=1000).step_by(2) { + for mut rp_test in 4..200 { + if round_numbers_are_secure(t, rf_test, rp_test) { + if security_margin { + rf_test += 2; + rp_test = (1.075 * rp_test as f32).ceil() as usize; + } + let n_sboxes = n_sboxes(t, rf_test, rp_test); + if n_sboxes < n_sboxes_min + || (n_sboxes == n_sboxes_min && rf_test < rf) + { + rf = rf_test; + rp = rp_test; + n_sboxes_min = n_sboxes; + } + } + } + } + + (rf, rp) +} + +/// Returns `true` if the provided round numbers satisfy the security +/// inequalities specified in the Poseidon paper. +fn round_numbers_are_secure(t: usize, rf: usize, rp: usize) -> bool { + let (rp, t, n, m) = (rp as f32, t as f32, PRIME_BITLEN as f32, M as f32); + let rf_stat = if m <= (n - 3.0) * (t + 1.0) { + 6.0 + } else { + 10.0 + }; + let rf_interp = 0.43 * m + t.log2() - rp; + let rf_grob_1 = 0.21 * n - rp; + let rf_grob_2 = (0.14 * n - 1.0 - rp) / (t - 1.0); + let rf_max = [rf_stat, rf_interp, rf_grob_1, rf_grob_2] + .iter() + .map(|rf| rf.ceil() as usize) + .max() + .unwrap(); + rf >= rf_max +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::fs; + + #[test] + fn test_round_numbers_against_known_values() { + // Each case contains a `t` (where `t = arity + 1`) and the `R_P` + // expected for that `t`. + let cases = [ + (2usize, 55usize), + (3, 55), + (4, 56), + (5, 56), + (6, 56), + (7, 56), + (8, 57), + (9, 57), + (10, 57), + (11, 57), + (12, 57), + (13, 57), + (14, 57), + (15, 57), + (16, 59), + (17, 59), + (25, 59), + (37, 60), + (65, 61), + ]; + for (t, rp_expected) in cases.iter() { + let (rf, rp) = calc_round_numbers(*t, true); + assert_eq!(rf, 8); + assert_eq!(rp, *rp_expected); + } + } + + #[ignore] + #[test] + fn test_round_numbers_against_python_script() { + // A parsed line from `parameters/round_numbers.txt`. + struct Line { + t: usize, + rf: usize, + rp: usize, + sbox_cost: usize, + size_cost: usize, + } + + let lines: Vec = fs::read_to_string( + "parameters/round_numbers.txt", + ) + .expect( + "failed to read round numbers file: `parameters/round_numbers.txt`", + ) + .lines() + .skip_while(|line| line.starts_with('#')) + .map(|line| { + let nums: Vec = line + .split(' ') + .map(|s| { + s.parse().unwrap_or_else(|_| { + panic!("failed to parse line as `usize`s: {}", line) + }) + }) + .collect(); + assert_eq!( + nums.len(), + 5, + "line in does not contain 5 values: {}", + line + ); + Line { + t: nums[0], + rf: nums[1], + rp: nums[2], + sbox_cost: nums[3], + size_cost: nums[4], + } + }) + .collect(); + + assert!( + !lines.is_empty(), + "no lines were parsed from `round_numbers.txt`", + ); + + for line in lines { + let (rf, rp) = calc_round_numbers(line.t, true); + let sbox_cost = n_sboxes(line.t, rf, rp); + let size_cost = sbox_cost * PRIME_BITLEN; + + assert_eq!(rf, line.rf, "full rounds differ from script"); + assert_eq!(rp, line.rp, "partial rounds differ from script"); + assert_eq!(sbox_cost, line.sbox_cost, "cost differs from script"); + assert_eq!( + size_cost, line.size_cost, + "size-cost differs from script" + ); + } + } +}