-
Notifications
You must be signed in to change notification settings - Fork 145
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'refactor-air' of github.com:lambdaclass/lambdaworks int…
…o refactor-air
- Loading branch information
Showing
9 changed files
with
1,021 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
extern crate alloc; | ||
use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; | ||
use alloc::vec::Vec; | ||
|
||
#[cfg(feature = "alloc")] | ||
/// fft in place algorithm used to evaluate a polynomial of degree 2^n - 1 in 2^n points. | ||
/// Input must be of size 2^n for some n. | ||
pub fn cfft( | ||
input: &mut [FieldElement<Mersenne31Field>], | ||
twiddles: Vec<Vec<FieldElement<Mersenne31Field>>>, | ||
) { | ||
// If the input size is 2^n, then log_2_size is n. | ||
let log_2_size = input.len().trailing_zeros(); | ||
|
||
// The cfft has n layers. | ||
(0..log_2_size).for_each(|i| { | ||
// In each layer i we split the current input in chunks of size 2^{i+1}. | ||
let chunk_size = 1 << (i + 1); | ||
let half_chunk_size = 1 << i; | ||
input.chunks_mut(chunk_size).for_each(|chunk| { | ||
// We split each chunk in half, calling the first half hi_part and the second hal low_part. | ||
let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); | ||
|
||
// We apply the corresponding butterfly for every element j of the high and low part. | ||
hi_part | ||
.iter_mut() | ||
.zip(low_part) | ||
.enumerate() | ||
.for_each(|(j, (hi, low))| { | ||
let temp = *low * twiddles[i as usize][j]; | ||
*low = *hi - temp; | ||
*hi += temp | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
#[cfg(feature = "alloc")] | ||
/// The inverse fft algorithm used to interpolate 2^n points. | ||
/// Input must be of size 2^n for some n. | ||
pub fn icfft( | ||
input: &mut [FieldElement<Mersenne31Field>], | ||
twiddles: Vec<Vec<FieldElement<Mersenne31Field>>>, | ||
) { | ||
// If the input size is 2^n, then log_2_size is n. | ||
let log_2_size = input.len().trailing_zeros(); | ||
|
||
// The icfft has n layers. | ||
(0..log_2_size).for_each(|i| { | ||
// In each layer i we split the current input in chunks of size 2^{n - i}. | ||
let chunk_size = 1 << (log_2_size - i); | ||
let half_chunk_size = chunk_size >> 1; | ||
input.chunks_mut(chunk_size).for_each(|chunk| { | ||
// We split each chunk in half, calling the first half hi_part and the second hal low_part. | ||
let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); | ||
|
||
// We apply the corresponding butterfly for every element j of the high and low part. | ||
hi_part | ||
.iter_mut() | ||
.zip(low_part) | ||
.enumerate() | ||
.for_each(|(j, (hi, low))| { | ||
let temp = *hi + *low; | ||
*low = (*hi - *low) * twiddles[i as usize][j]; | ||
*hi = temp; | ||
}); | ||
}); | ||
}); | ||
} | ||
|
||
/// This function permutes a slice of field elements to order the result of the cfft in the natural way. | ||
/// We call the natural order to [P(x0, y0), P(x1, y1), P(x2, y2), ...], | ||
/// where (x0, y0) is the first point of the corresponding coset. | ||
/// The cfft doesn't return the evaluations in the natural order. | ||
/// For example, if we apply the cfft to 8 coefficients of a polynomial of degree 7 we'll get the evaluations in this order: | ||
/// [P(x0, y0), P(x2, y2), P(x4, y4), P(x6, y6), P(x7, y7), P(x5, y5), P(x3, y3), P(x1, y1)], | ||
/// where the even indices are found first in ascending order and then the odd indices in descending order. | ||
/// This function permutes the slice [0, 2, 4, 6, 7, 5, 3, 1] into [0, 1, 2, 3, 4, 5, 6, 7]. | ||
/// TODO: This can be optimized by performing in-place value swapping (WIP). | ||
pub fn order_cfft_result_naive( | ||
input: &[FieldElement<Mersenne31Field>], | ||
) -> Vec<FieldElement<Mersenne31Field>> { | ||
let mut result = Vec::new(); | ||
let length = input.len(); | ||
for i in 0..length / 2 { | ||
result.push(input[i]); // We push the left index. | ||
result.push(input[length - i - 1]); // We push the right index. | ||
} | ||
result | ||
} | ||
|
||
/// This function permutes a slice of field elements to order the input of the icfft in a specific way. | ||
/// For example, if we want to interpolate 8 points we should input them in the icfft in this order: | ||
/// [(x0, y0), (x2, y2), (x4, y4), (x6, y6), (x7, y7), (x5, y5), (x3, y3), (x1, y1)], | ||
/// where the even indices are found first in ascending order and then the odd indices in descending order. | ||
/// This function permutes the slice [0, 1, 2, 3, 4, 5, 6, 7] into [0, 2, 4, 6, 7, 5, 3, 1]. | ||
/// TODO: This can be optimized by performing in-place value swapping (WIP). | ||
pub fn order_icfft_input_naive( | ||
input: &mut [FieldElement<Mersenne31Field>], | ||
) -> Vec<FieldElement<Mersenne31Field>> { | ||
let mut result = Vec::new(); | ||
|
||
// We push the even indices. | ||
(0..input.len()).step_by(2).for_each(|i| { | ||
result.push(input[i]); | ||
}); | ||
|
||
// We push the odd indices. | ||
(1..input.len()).step_by(2).rev().for_each(|i| { | ||
result.push(input[i]); | ||
}); | ||
result | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
type FE = FieldElement<Mersenne31Field>; | ||
|
||
#[test] | ||
fn ordering_cfft_result_works_for_4_points() { | ||
let expected_slice = [FE::from(0), FE::from(1), FE::from(2), FE::from(3)]; | ||
|
||
let slice = [FE::from(0), FE::from(2), FE::from(3), FE::from(1)]; | ||
|
||
let res = order_cfft_result_naive(&slice); | ||
|
||
assert_eq!(res, expected_slice) | ||
} | ||
|
||
#[test] | ||
fn ordering_cfft_result_works_for_16_points() { | ||
let expected_slice = [ | ||
FE::from(0), | ||
FE::from(1), | ||
FE::from(2), | ||
FE::from(3), | ||
FE::from(4), | ||
FE::from(5), | ||
FE::from(6), | ||
FE::from(7), | ||
FE::from(8), | ||
FE::from(9), | ||
FE::from(10), | ||
FE::from(11), | ||
FE::from(12), | ||
FE::from(13), | ||
FE::from(14), | ||
FE::from(15), | ||
]; | ||
|
||
let slice = [ | ||
FE::from(0), | ||
FE::from(2), | ||
FE::from(4), | ||
FE::from(6), | ||
FE::from(8), | ||
FE::from(10), | ||
FE::from(12), | ||
FE::from(14), | ||
FE::from(15), | ||
FE::from(13), | ||
FE::from(11), | ||
FE::from(9), | ||
FE::from(7), | ||
FE::from(5), | ||
FE::from(3), | ||
FE::from(1), | ||
]; | ||
|
||
let res = order_cfft_result_naive(&slice); | ||
|
||
assert_eq!(res, expected_slice) | ||
} | ||
|
||
#[test] | ||
fn from_natural_to_icfft_input_order_works() { | ||
let mut slice = [ | ||
FE::from(0), | ||
FE::from(1), | ||
FE::from(2), | ||
FE::from(3), | ||
FE::from(4), | ||
FE::from(5), | ||
FE::from(6), | ||
FE::from(7), | ||
FE::from(8), | ||
FE::from(9), | ||
FE::from(10), | ||
FE::from(11), | ||
FE::from(12), | ||
FE::from(13), | ||
FE::from(14), | ||
FE::from(15), | ||
]; | ||
|
||
let expected_slice = [ | ||
FE::from(0), | ||
FE::from(2), | ||
FE::from(4), | ||
FE::from(6), | ||
FE::from(8), | ||
FE::from(10), | ||
FE::from(12), | ||
FE::from(14), | ||
FE::from(15), | ||
FE::from(13), | ||
FE::from(11), | ||
FE::from(9), | ||
FE::from(7), | ||
FE::from(5), | ||
FE::from(3), | ||
FE::from(1), | ||
]; | ||
|
||
let res = order_icfft_input_naive(&mut slice); | ||
|
||
assert_eq!(res, expected_slice) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
extern crate alloc; | ||
use crate::circle::point::CirclePoint; | ||
use crate::field::fields::mersenne31::field::Mersenne31Field; | ||
use alloc::vec::Vec; | ||
|
||
/// Given g_n, a generator of the subgroup of size n of the circle, i.e. <g_n>, | ||
/// and given a shift, that is a another point of the circle, | ||
/// we define the coset shift + <g_n> which is the set of all the points in | ||
/// <g_n> plus the shift. | ||
/// For example, if <g_4> = {p1, p2, p3, p4}, then g_8 + <g_4> = {g_8 + p1, g_8 + p2, g_8 + p3, g_8 + p4}. | ||
#[derive(Debug, Clone)] | ||
pub struct Coset { | ||
// Coset: shift + <g_n> where n = 2^{log_2_size}. | ||
// Example: g_16 + <g_8>, n = 8, log_2_size = 3, shift = g_16. | ||
pub log_2_size: u32, //TODO: Change log_2_size to u8 because log_2_size < 31. | ||
pub shift: CirclePoint<Mersenne31Field>, | ||
} | ||
|
||
impl Coset { | ||
pub fn new(log_2_size: u32, shift: CirclePoint<Mersenne31Field>) -> Self { | ||
Coset { log_2_size, shift } | ||
} | ||
|
||
/// Returns the coset g_2n + <g_n> | ||
pub fn new_standard(log_2_size: u32) -> Self { | ||
// shift is a generator of the subgroup of order 2n = 2^{log_2_size + 1}. | ||
let shift = CirclePoint::get_generator_of_subgroup(log_2_size + 1); | ||
Coset { log_2_size, shift } | ||
} | ||
|
||
/// Returns g_n, the generator of the subgroup of order n = 2^log_2_size. | ||
pub fn get_generator(&self) -> CirclePoint<Mersenne31Field> { | ||
CirclePoint::GENERATOR.repeated_double(31 - self.log_2_size) | ||
} | ||
|
||
/// Given a standard coset g_2n + <g_n>, returns the subcoset with half size g_2n + <g_{n/2}> | ||
pub fn half_coset(coset: Self) -> Self { | ||
Coset { | ||
log_2_size: coset.log_2_size - 1, | ||
shift: coset.shift, | ||
} | ||
} | ||
|
||
/// Given a coset shift + G returns the coset -shift + G. | ||
/// Note that (g_2n + <g_{n/2}>) U (-g_2n + <g_{n/2}>) = g_2n + <g_n>. | ||
pub fn conjugate(coset: Self) -> Self { | ||
Coset { | ||
log_2_size: coset.log_2_size, | ||
shift: coset.shift.conjugate(), | ||
} | ||
} | ||
|
||
/// Returns the vector of shift + g for every g in <g_n>. | ||
/// where g = i * g_n for i = 0, ..., n-1. | ||
#[cfg(feature = "alloc")] | ||
pub fn get_coset_points(coset: &Self) -> Vec<CirclePoint<Mersenne31Field>> { | ||
// g_n the generator of the subgroup of order n. | ||
let generator_n = CirclePoint::get_generator_of_subgroup(coset.log_2_size); | ||
let size: u8 = 1 << coset.log_2_size; | ||
core::iter::successors(Some(coset.shift.clone()), move |prev| { | ||
Some(prev + &generator_n) | ||
}) | ||
.take(size.into()) | ||
.collect() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use super::*; | ||
|
||
#[test] | ||
fn coset_points_vector_has_right_size() { | ||
let coset = Coset::new_standard(3); | ||
let points = Coset::get_coset_points(&coset); | ||
assert_eq!(1 << coset.log_2_size, points.len()) | ||
} | ||
|
||
#[test] | ||
fn antipode_of_coset_point_is_in_coset() { | ||
let coset = Coset::new_standard(3); | ||
let points = Coset::get_coset_points(&coset); | ||
let point = points[2].clone(); | ||
let anitpode_point = points[6].clone(); | ||
assert_eq!(anitpode_point, point.antipode()) | ||
} | ||
|
||
#[test] | ||
fn coset_generator_has_right_order() { | ||
let coset = Coset::new(2, CirclePoint::GENERATOR * 3); | ||
let generator_n = coset.get_generator(); | ||
assert_eq!(generator_n.repeated_double(2), CirclePoint::zero()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#[derive(Debug)] | ||
pub enum CircleError { | ||
PointDoesntSatisfyCircleEquation, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pub mod cfft; | ||
pub mod cosets; | ||
pub mod errors; | ||
pub mod point; | ||
pub mod polynomial; | ||
pub mod twiddles; |
Oops, something went wrong.