Skip to content

Commit

Permalink
Merge branch 'add_rescue_prime' of https://github.com/lambdaclass/lam…
Browse files Browse the repository at this point in the history
…bdaworks into add_rescue_prime
  • Loading branch information
jotabulacios committed Nov 4, 2024
2 parents c385cbb + 1c2db51 commit ea8267d
Show file tree
Hide file tree
Showing 9 changed files with 1,021 additions and 2 deletions.
220 changes: 220 additions & 0 deletions math/src/circle/cfft.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
extern crate alloc;
use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field};
use alloc::vec::Vec;

#[cfg(feature = "alloc")]
/// fft in place algorithm used to evaluate a polynomial of degree 2^n - 1 in 2^n points.
/// Input must be of size 2^n for some n.
pub fn cfft(
input: &mut [FieldElement<Mersenne31Field>],
twiddles: Vec<Vec<FieldElement<Mersenne31Field>>>,
) {
// If the input size is 2^n, then log_2_size is n.
let log_2_size = input.len().trailing_zeros();

// The cfft has n layers.
(0..log_2_size).for_each(|i| {
// In each layer i we split the current input in chunks of size 2^{i+1}.
let chunk_size = 1 << (i + 1);
let half_chunk_size = 1 << i;
input.chunks_mut(chunk_size).for_each(|chunk| {
// We split each chunk in half, calling the first half hi_part and the second hal low_part.
let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size);

// We apply the corresponding butterfly for every element j of the high and low part.
hi_part
.iter_mut()
.zip(low_part)
.enumerate()
.for_each(|(j, (hi, low))| {
let temp = *low * twiddles[i as usize][j];
*low = *hi - temp;
*hi += temp
});
});
});
}

#[cfg(feature = "alloc")]
/// The inverse fft algorithm used to interpolate 2^n points.
/// Input must be of size 2^n for some n.
pub fn icfft(
input: &mut [FieldElement<Mersenne31Field>],
twiddles: Vec<Vec<FieldElement<Mersenne31Field>>>,
) {
// If the input size is 2^n, then log_2_size is n.
let log_2_size = input.len().trailing_zeros();

// The icfft has n layers.
(0..log_2_size).for_each(|i| {
// In each layer i we split the current input in chunks of size 2^{n - i}.
let chunk_size = 1 << (log_2_size - i);
let half_chunk_size = chunk_size >> 1;
input.chunks_mut(chunk_size).for_each(|chunk| {
// We split each chunk in half, calling the first half hi_part and the second hal low_part.
let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size);

// We apply the corresponding butterfly for every element j of the high and low part.
hi_part
.iter_mut()
.zip(low_part)
.enumerate()
.for_each(|(j, (hi, low))| {
let temp = *hi + *low;
*low = (*hi - *low) * twiddles[i as usize][j];
*hi = temp;
});
});
});
}

/// This function permutes a slice of field elements to order the result of the cfft in the natural way.
/// We call the natural order to [P(x0, y0), P(x1, y1), P(x2, y2), ...],
/// where (x0, y0) is the first point of the corresponding coset.
/// The cfft doesn't return the evaluations in the natural order.
/// For example, if we apply the cfft to 8 coefficients of a polynomial of degree 7 we'll get the evaluations in this order:
/// [P(x0, y0), P(x2, y2), P(x4, y4), P(x6, y6), P(x7, y7), P(x5, y5), P(x3, y3), P(x1, y1)],
/// where the even indices are found first in ascending order and then the odd indices in descending order.
/// This function permutes the slice [0, 2, 4, 6, 7, 5, 3, 1] into [0, 1, 2, 3, 4, 5, 6, 7].
/// TODO: This can be optimized by performing in-place value swapping (WIP).
pub fn order_cfft_result_naive(
input: &[FieldElement<Mersenne31Field>],
) -> Vec<FieldElement<Mersenne31Field>> {
let mut result = Vec::new();
let length = input.len();
for i in 0..length / 2 {
result.push(input[i]); // We push the left index.
result.push(input[length - i - 1]); // We push the right index.
}
result
}

/// This function permutes a slice of field elements to order the input of the icfft in a specific way.
/// For example, if we want to interpolate 8 points we should input them in the icfft in this order:
/// [(x0, y0), (x2, y2), (x4, y4), (x6, y6), (x7, y7), (x5, y5), (x3, y3), (x1, y1)],
/// where the even indices are found first in ascending order and then the odd indices in descending order.
/// This function permutes the slice [0, 1, 2, 3, 4, 5, 6, 7] into [0, 2, 4, 6, 7, 5, 3, 1].
/// TODO: This can be optimized by performing in-place value swapping (WIP).
pub fn order_icfft_input_naive(
input: &mut [FieldElement<Mersenne31Field>],
) -> Vec<FieldElement<Mersenne31Field>> {
let mut result = Vec::new();

// We push the even indices.
(0..input.len()).step_by(2).for_each(|i| {
result.push(input[i]);
});

// We push the odd indices.
(1..input.len()).step_by(2).rev().for_each(|i| {
result.push(input[i]);
});
result
}

#[cfg(test)]
mod tests {
use super::*;
type FE = FieldElement<Mersenne31Field>;

#[test]
fn ordering_cfft_result_works_for_4_points() {
let expected_slice = [FE::from(0), FE::from(1), FE::from(2), FE::from(3)];

let slice = [FE::from(0), FE::from(2), FE::from(3), FE::from(1)];

let res = order_cfft_result_naive(&slice);

assert_eq!(res, expected_slice)
}

#[test]
fn ordering_cfft_result_works_for_16_points() {
let expected_slice = [
FE::from(0),
FE::from(1),
FE::from(2),
FE::from(3),
FE::from(4),
FE::from(5),
FE::from(6),
FE::from(7),
FE::from(8),
FE::from(9),
FE::from(10),
FE::from(11),
FE::from(12),
FE::from(13),
FE::from(14),
FE::from(15),
];

let slice = [
FE::from(0),
FE::from(2),
FE::from(4),
FE::from(6),
FE::from(8),
FE::from(10),
FE::from(12),
FE::from(14),
FE::from(15),
FE::from(13),
FE::from(11),
FE::from(9),
FE::from(7),
FE::from(5),
FE::from(3),
FE::from(1),
];

let res = order_cfft_result_naive(&slice);

assert_eq!(res, expected_slice)
}

#[test]
fn from_natural_to_icfft_input_order_works() {
let mut slice = [
FE::from(0),
FE::from(1),
FE::from(2),
FE::from(3),
FE::from(4),
FE::from(5),
FE::from(6),
FE::from(7),
FE::from(8),
FE::from(9),
FE::from(10),
FE::from(11),
FE::from(12),
FE::from(13),
FE::from(14),
FE::from(15),
];

let expected_slice = [
FE::from(0),
FE::from(2),
FE::from(4),
FE::from(6),
FE::from(8),
FE::from(10),
FE::from(12),
FE::from(14),
FE::from(15),
FE::from(13),
FE::from(11),
FE::from(9),
FE::from(7),
FE::from(5),
FE::from(3),
FE::from(1),
];

let res = order_icfft_input_naive(&mut slice);

assert_eq!(res, expected_slice)
}
}
95 changes: 95 additions & 0 deletions math/src/circle/cosets.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
extern crate alloc;
use crate::circle::point::CirclePoint;
use crate::field::fields::mersenne31::field::Mersenne31Field;
use alloc::vec::Vec;

/// Given g_n, a generator of the subgroup of size n of the circle, i.e. <g_n>,
/// and given a shift, that is a another point of the circle,
/// we define the coset shift + <g_n> which is the set of all the points in
/// <g_n> plus the shift.
/// For example, if <g_4> = {p1, p2, p3, p4}, then g_8 + <g_4> = {g_8 + p1, g_8 + p2, g_8 + p3, g_8 + p4}.
#[derive(Debug, Clone)]
pub struct Coset {
// Coset: shift + <g_n> where n = 2^{log_2_size}.
// Example: g_16 + <g_8>, n = 8, log_2_size = 3, shift = g_16.
pub log_2_size: u32, //TODO: Change log_2_size to u8 because log_2_size < 31.
pub shift: CirclePoint<Mersenne31Field>,
}

impl Coset {
pub fn new(log_2_size: u32, shift: CirclePoint<Mersenne31Field>) -> Self {
Coset { log_2_size, shift }
}

/// Returns the coset g_2n + <g_n>
pub fn new_standard(log_2_size: u32) -> Self {
// shift is a generator of the subgroup of order 2n = 2^{log_2_size + 1}.
let shift = CirclePoint::get_generator_of_subgroup(log_2_size + 1);
Coset { log_2_size, shift }
}

/// Returns g_n, the generator of the subgroup of order n = 2^log_2_size.
pub fn get_generator(&self) -> CirclePoint<Mersenne31Field> {
CirclePoint::GENERATOR.repeated_double(31 - self.log_2_size)
}

/// Given a standard coset g_2n + <g_n>, returns the subcoset with half size g_2n + <g_{n/2}>
pub fn half_coset(coset: Self) -> Self {
Coset {
log_2_size: coset.log_2_size - 1,
shift: coset.shift,
}
}

/// Given a coset shift + G returns the coset -shift + G.
/// Note that (g_2n + <g_{n/2}>) U (-g_2n + <g_{n/2}>) = g_2n + <g_n>.
pub fn conjugate(coset: Self) -> Self {
Coset {
log_2_size: coset.log_2_size,
shift: coset.shift.conjugate(),
}
}

/// Returns the vector of shift + g for every g in <g_n>.
/// where g = i * g_n for i = 0, ..., n-1.
#[cfg(feature = "alloc")]
pub fn get_coset_points(coset: &Self) -> Vec<CirclePoint<Mersenne31Field>> {
// g_n the generator of the subgroup of order n.
let generator_n = CirclePoint::get_generator_of_subgroup(coset.log_2_size);
let size: u8 = 1 << coset.log_2_size;
core::iter::successors(Some(coset.shift.clone()), move |prev| {
Some(prev + &generator_n)
})
.take(size.into())
.collect()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn coset_points_vector_has_right_size() {
let coset = Coset::new_standard(3);
let points = Coset::get_coset_points(&coset);
assert_eq!(1 << coset.log_2_size, points.len())
}

#[test]
fn antipode_of_coset_point_is_in_coset() {
let coset = Coset::new_standard(3);
let points = Coset::get_coset_points(&coset);
let point = points[2].clone();
let anitpode_point = points[6].clone();
assert_eq!(anitpode_point, point.antipode())
}

#[test]
fn coset_generator_has_right_order() {
let coset = Coset::new(2, CirclePoint::GENERATOR * 3);
let generator_n = coset.get_generator();
assert_eq!(generator_n.repeated_double(2), CirclePoint::zero());
}
}
4 changes: 4 additions & 0 deletions math/src/circle/errors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#[derive(Debug)]
pub enum CircleError {
PointDoesntSatisfyCircleEquation,
}
6 changes: 6 additions & 0 deletions math/src/circle/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pub mod cfft;
pub mod cosets;
pub mod errors;
pub mod point;
pub mod polynomial;
pub mod twiddles;
Loading

0 comments on commit ea8267d

Please sign in to comment.