From dc4124e96871c5c9a08e35b811a6f854f8424acd Mon Sep 17 00:00:00 2001 From: Joaquin Carletti Date: Tue, 15 Oct 2024 18:51:21 -0300 Subject: [PATCH] wip --- math/src/circle/cfft.rs | 222 +++++++++++++++++++++++++++++----- math/src/circle/polynomial.rs | 45 ++++--- math/src/circle/twiddles.rs | 14 +-- 3 files changed, 227 insertions(+), 54 deletions(-) diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs index 791a8d7ab..1f680cbb8 100644 --- a/math/src/circle/cfft.rs +++ b/math/src/circle/cfft.rs @@ -2,39 +2,80 @@ extern crate alloc; use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; #[cfg(feature = "alloc")] -pub fn inplace_cfft( +pub fn cfft( input: &mut [FieldElement], twiddles: Vec>>, ) { - let mut group_count = 1; - let mut group_size = input.len(); - let mut round = 0; + let log_2_size = input.len().trailing_zeros(); + + (0..log_2_size).for_each(|i| { + let chunk_size = 1 << i + 1; + let half_chunk_size = 1 << i; + input.chunks_mut(chunk_size).for_each(|chunk| { + let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); + hi_part.into_iter().zip(low_part).enumerate().for_each( |(j, (hi, low))| { + let temp = *low * twiddles[i as usize][j]; + *low = *hi - temp; + *hi = *hi + temp; + }); + }); + }); +} - while group_count < input.len() { - let round_twiddles = &twiddles[round]; - #[allow(clippy::needless_range_loop)] // the suggestion would obfuscate a bit the algorithm - for group in 0..group_count { - let first_in_group = group * group_size; - let first_in_next_group = first_in_group + group_size / 2; - let w = &round_twiddles[group]; // a twiddle factor is used per group +#[cfg(feature = "alloc")] +pub fn icfft( + input: &mut [FieldElement], + twiddles: Vec>>, +) { + let log_2_size = input.len().trailing_zeros(); + + println!("{:?}", twiddles); + + (0..log_2_size).for_each(|i| { + let chunk_size = 1 << log_2_size - i; + let half_chunk_size = chunk_size >> 1; + input.chunks_mut(chunk_size).for_each(|chunk| { + let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); + hi_part.into_iter().zip(low_part).enumerate().for_each( |(j, (hi, low))| { + let temp = *hi + *low; + *low = (*hi - *low) * twiddles[i as usize][j]; + *hi = temp; + }); + }); + }); +} - for i in first_in_group..first_in_next_group { - let wi = w * input[i + group_size / 2]; +pub fn order_cfft_result_naive(input: &mut [FieldElement]) -> Vec> { + let mut result = Vec::new(); + let length = input.len(); + for i in (0..length/2) { + result.push(input[i]); + result.push(input[length - i - 1]); + } + result +} - let y0 = input[i] + wi; - let y1 = input[i] - wi; +pub fn order_icfft_input_naive(input: &mut [FieldElement]) -> Vec> { + let mut result = Vec::new(); + (0..input.len()).step_by(2).for_each( |i| { + result.push(input[i]); + }); + (1..input.len()).step_by(2).rev().for_each( |i| { + result.push(input[i]); + }); + result +} - input[i] = y0; - input[i + group_size / 2] = y1; - } - } - group_count *= 2; - group_size /= 2; - round += 1; +pub fn reverse_cfft_index(index: usize, length: usize) -> usize { + if index < (length >> 1) { // index < length / 2 + index << 1 // index * 2 + } else { + (((length - 1) - index) << 1) + 1 } } + pub fn cfft_4( input: &mut [FieldElement], twiddles: Vec>>, @@ -104,19 +145,134 @@ pub fn cfft_8( stage3.into_iter().map(|elem| elem * f).collect() } -pub fn inplace_order_cfft_values(input: &mut [FieldElement]) { - for i in 0..input.len() { - let cfft_index = reverse_cfft_index(i, input.len().trailing_zeros()); - if cfft_index > i { - input.swap(i, cfft_index); + +#[cfg(test)] +mod tests { + use super::*; + type FE = FieldElement; + + #[test] + fn ordering_4() { + let expected_slice = [ + FE::from(0), + FE::from(1), + FE::from(2), + FE::from(3), + ]; + + let mut slice = [ + FE::from(0), + FE::from(2), + FE::from(3), + FE::from(1), + ]; + + let res = order_cfft_result_naive(&mut slice); + + assert_eq!(res, expected_slice) + } + + #[test] + fn ordering() { + let expected_slice = [ + FE::from(0), + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + ]; + + let mut slice = [ + FE::from(0), + FE::from(2), + FE::from(4), + FE::from(6), + FE::from(8), + FE::from(10), + FE::from(12), + FE::from(14), + FE::from(15), + FE::from(13), + FE::from(11), + FE::from(9), + FE::from(7), + FE::from(5), + FE::from(3), + FE::from(1), + ]; + + let res = order_cfft_result_naive(&mut slice); + + assert_eq!(res, expected_slice) + } + + #[test] + fn reverse_cfft_index_works() { + let mut reversed: Vec = Vec::with_capacity(16); + for i in 0..reversed.capacity() { + reversed.push(reverse_cfft_index(i, reversed.capacity())); } + assert_eq!( + reversed[..], + [0, 2, 4, 6, 8, 10, 12, 14, 15, 13, 11, 9, 7, 5, 3, 1] + ); } -} -pub fn reverse_cfft_index(index: usize, log_2_size: u32) -> usize { - let (mut new_index, lsb) = (index >> 1, index & 1); - if (lsb == 1) & (log_2_size > 1) { - new_index = (1 << log_2_size) - new_index - 1; + #[test] + fn from_natural_to_icfft_input_order() { + let mut slice = [ + FE::from(0), + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + ]; + + let expected_slice = [ + FE::from(0), + FE::from(2), + FE::from(4), + FE::from(6), + FE::from(8), + FE::from(10), + FE::from(12), + FE::from(14), + FE::from(15), + FE::from(13), + FE::from(11), + FE::from(9), + FE::from(7), + FE::from(5), + FE::from(3), + FE::from(1), + ]; + + let res = order_icfft_input_naive(&mut slice); + + assert_eq!(res, expected_slice) } - new_index.reverse_bits() >> (usize::BITS - log_2_size) + + } diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs index e7c2ebaf4..b5f0e1877 100644 --- a/math/src/circle/polynomial.rs +++ b/math/src/circle/polynomial.rs @@ -1,7 +1,10 @@ -use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use crate::{ + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, + fft::cpu::bit_reversing::in_place_bit_reverse_permute +}; use super::{ - cfft::{cfft_4, cfft_8, inplace_cfft, inplace_order_cfft_values}, + cfft::{cfft, icfft, cfft_4, cfft_8, order_cfft_result_naive, order_icfft_input_naive}, cosets::Coset, twiddles::{ get_twiddles, get_twiddles_itnerpolation_4, get_twiddles_itnerpolation_8, TwiddlesConfig, @@ -14,14 +17,15 @@ use super::{ pub fn evaluate_cfft( mut coeff: Vec>, ) -> Vec> { + in_place_bit_reverse_permute::>(&mut coeff); let domain_log_2_size: u32 = coeff.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); let config = TwiddlesConfig::Evaluation; let twiddles = get_twiddles(coset, config); - inplace_cfft(&mut coeff, twiddles); - inplace_order_cfft_values(&mut coeff); - coeff + cfft(&mut coeff, twiddles); + let result = order_cfft_result_naive(&mut coeff); + result } /// Interpolates the 2^n evaluations of a two-variables polynomial on the points of the standard coset of size 2^n. @@ -30,14 +34,15 @@ pub fn evaluate_cfft( pub fn interpolate_cfft( mut eval: Vec>, ) -> Vec> { + let mut eval_ordered = order_icfft_input_naive(&mut eval); let domain_log_2_size: u32 = eval.len().trailing_zeros(); let coset = Coset::new_standard(domain_log_2_size); let config = TwiddlesConfig::Interpolation; let twiddles = get_twiddles(coset, config); - inplace_cfft(&mut eval, twiddles); - inplace_order_cfft_values(&mut eval); - eval + icfft(&mut eval_ordered, twiddles); + let result = order_cfft_result_naive(&mut eval); + result } pub fn interpolate_4( @@ -116,7 +121,7 @@ mod tests { // We create the coset points and evaluate them without the fft. let coset = Coset::new_standard(2); let points = Coset::get_coset_points(&coset); - let mut input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4)]; + let input = [FpE::from(1), FpE::from(2), FpE::from(3), FpE::from(4)]; let mut expected_result: Vec = Vec::new(); for point in points { let point_eval = evaluate_poly_4(&input, point.x, point.y); @@ -133,7 +138,7 @@ mod tests { // We create the coset points and evaluate them without the fft. let coset = Coset::new_standard(3); let points = Coset::get_coset_points(&coset); - let mut input = [ + let input = [ FpE::from(1), FpE::from(2), FpE::from(3), @@ -158,7 +163,7 @@ mod tests { fn cfft_evaluation_16_points() { let coset = Coset::new_standard(4); let points = Coset::get_coset_points(&coset); - let mut input = [ + let input = [ FpE::from(1), FpE::from(2), FpE::from(3), @@ -231,8 +236,20 @@ mod tests { } #[test] - fn cuentas() { - println!("{:?}", FpE::from(32768).inv().unwrap()); // { value: 65536 } - println!("{:?}", FpE::from(2147450879).inv().unwrap()); // { value: 2147418111 } + fn evaluate_and_interpolate() { + let coeff = vec![ + FpE::from(1), + FpE::from(2), + FpE::from(3), + FpE::from(4), + FpE::from(5), + FpE::from(6), + FpE::from(7), + FpE::from(8), + ]; + let evals = evaluate_cfft(coeff.clone()); + let new_coeff = interpolate_cfft(evals); + + assert_eq!(coeff, new_coeff); } } diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs index e3a892718..abe48581b 100644 --- a/math/src/circle/twiddles.rs +++ b/math/src/circle/twiddles.rs @@ -16,32 +16,32 @@ pub fn get_twiddles( domain: Coset, config: TwiddlesConfig, ) -> Vec>> { - let mut half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); - if config == TwiddlesConfig::Evaluation { - in_place_bit_reverse_permute::>(&mut half_domain_points[..]); - } + let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); let mut twiddles: Vec>> = vec![half_domain_points.iter().map(|p| p.y).collect()]; if domain.log_2_size >= 2 { - twiddles.push(half_domain_points.iter().step_by(2).map(|p| p.x).collect()); + twiddles.push(half_domain_points.iter().take(half_domain_points.len() / 2 ).map(|p| p.x).collect()); for _ in 0..(domain.log_2_size - 2) { let prev = twiddles.last().unwrap(); let cur = prev .iter() - .step_by(2) + .take(prev.len() / 2 ) .map(|x| x.square().double() - FieldElement::::one()) .collect(); twiddles.push(cur); } } - twiddles.reverse(); if config == TwiddlesConfig::Interpolation { + // For the interpolation, we need to take the inverse element of each twiddle in the default order. twiddles.iter_mut().for_each(|x| { FieldElement::::inplace_batch_inverse(x).unwrap(); }); + } else { + // For the evaluation, we need the vector of twiddles but in the inverse order. + twiddles.reverse(); } twiddles }