Skip to content

Commit

Permalink
Math: Implement subfield logic (#709)
Browse files Browse the repository at this point in the history
* add subfield trait

* implement IsSubfieldOf for Degree2ExtensionField struct in BLS12381

* fix mul

* avoid using field elements in issubfield impl. Add tests

* clippy and fmt

* simplify trait bounds

* add explicit type

* change iter to into_iter

* fix metal code

* remove explicit type
  • Loading branch information
schouhy authored Dec 6, 2023
1 parent b9b3118 commit 7a84ab5
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::field::{
quadratic::{HasQuadraticNonResidue, QuadraticExtensionField},
},
fields::montgomery_backed_prime_fields::{IsModulus, MontgomeryBackendPrimeField},
traits::IsField,
traits::{IsField, IsSubFieldOf},
};
use crate::traits::ByteConversion;
use crate::unsigned_integer::element::U384;
Expand Down Expand Up @@ -71,7 +71,7 @@ impl IsField for Degree2ExtensionField {

/// Returns the division of `a` and `b`
fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType {
Self::mul(a, &Self::inv(b).unwrap())
<Self as IsField>::mul(a, &Self::inv(b).unwrap())
}

/// Returns a boolean indicating whether `a` and `b` are equal component wise.
Expand Down Expand Up @@ -103,6 +103,47 @@ impl IsField for Degree2ExtensionField {
}
}

impl IsSubFieldOf<Degree2ExtensionField> for BLS12381PrimeField {
fn mul(
a: &Self::BaseType,
b: &<Degree2ExtensionField as IsField>::BaseType,
) -> <Degree2ExtensionField as IsField>::BaseType {
let c0 = FieldElement::from_raw(<Self as IsField>::mul(a, b[0].value()));
let c1 = FieldElement::from_raw(<Self as IsField>::mul(a, b[1].value()));
[c0, c1]
}

fn add(
a: &Self::BaseType,
b: &<Degree2ExtensionField as IsField>::BaseType,
) -> <Degree2ExtensionField as IsField>::BaseType {
let c0 = FieldElement::from_raw(<Self as IsField>::add(a, b[0].value()));
let c1 = FieldElement::from_raw(*b[1].value());
[c0, c1]
}

fn div(
a: &Self::BaseType,
b: &<Degree2ExtensionField as IsField>::BaseType,
) -> <Degree2ExtensionField as IsField>::BaseType {
let b_inv = Degree2ExtensionField::inv(b).unwrap();
<Self as IsSubFieldOf<Degree2ExtensionField>>::mul(a, &b_inv)
}

fn sub(
a: &Self::BaseType,
b: &<Degree2ExtensionField as IsField>::BaseType,
) -> <Degree2ExtensionField as IsField>::BaseType {
let c0 = FieldElement::from_raw(<Self as IsField>::sub(a, b[0].value()));
let c1 = FieldElement::from_raw(<Self as IsField>::neg(b[1].value()));
[c0, c1]
}

fn embed(a: Self::BaseType) -> <Degree2ExtensionField as IsField>::BaseType {
[FieldElement::from_raw(a), FieldElement::zero()]
}
}

impl ByteConversion for FieldElement<Degree2ExtensionField> {
#[cfg(feature = "std")]
fn to_bytes_be(&self) -> Vec<u8> {
Expand Down Expand Up @@ -328,4 +369,43 @@ mod tests {
assert_eq!(g_to_fp12_x, expectedx);
assert_eq!(g_to_fp12_y, expectedy);
}

#[test]
fn add_base_field_with_degree_2_extension() {
let a = FieldElement::<BLS12381PrimeField>::from(3);
let a_extension = FieldElement::<Degree2ExtensionField>::from(3);
let b = FieldElement::<Degree2ExtensionField>::from(2);
assert_eq!(a + &b, a_extension + b);
}

#[test]
fn mul_base_field_with_degree_2_extension() {
let a = FieldElement::<BLS12381PrimeField>::from(3);
let a_extension = FieldElement::<Degree2ExtensionField>::from(3);
let b = FieldElement::<Degree2ExtensionField>::from(2);
assert_eq!(a * &b, a_extension * b);
}

#[test]
fn sub_base_field_with_degree_2_extension() {
let a = FieldElement::<BLS12381PrimeField>::from(3);
let a_extension = FieldElement::<Degree2ExtensionField>::from(3);
let b = FieldElement::<Degree2ExtensionField>::from(2);
assert_eq!(a - &b, a_extension - b);
}

#[test]
fn div_base_field_with_degree_2_extension() {
let a = FieldElement::<BLS12381PrimeField>::from(3);
let a_extension = FieldElement::<Degree2ExtensionField>::from(3);
let b = FieldElement::<Degree2ExtensionField>::from(2);
assert_eq!(a / &b, a_extension / b);
}

#[test]
fn embed_base_field_with_degree_2_extension() {
let a = FieldElement::<BLS12381PrimeField>::from(3);
let a_extension = FieldElement::<Degree2ExtensionField>::from(3);
assert_eq!(a.to_extension::<Degree2ExtensionField>(), a_extension);
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
curve::{BLS12381Curve, MILLER_LOOP_CONSTANT},
field_extension::{Degree12ExtensionField, Degree2ExtensionField},
field_extension::{BLS12381PrimeField, Degree12ExtensionField, Degree2ExtensionField},
twist::BLS12381TwistCurve,
};
use crate::{
Expand Down Expand Up @@ -56,22 +56,23 @@ fn double_accumulate_line(
let [px, py, _] = p.coordinates();
let residue = LevelTwoResidue::residue();
let two_inv = FieldElement::<Degree2ExtensionField>::new_base("d0088f51cbff34d258dd3db21a5d66bb23ba5c279c2895fb39869507b587b120f55ffff58a9ffffdcff7fffffffd556");
let three = FieldElement::<BLS12381PrimeField>::from(3);

let a = &two_inv * x1 * y1;
let b = y1.square();
let c = z1.square();
let d = FieldElement::from(3) * &c;
let d = &three * &c;
let e = BLS12381TwistCurve::b() * d;
let f = FieldElement::from(3) * &e;
let f = &three * &e;
let g = two_inv * (&b + &f);
let h = (y1 + z1).square() - (&b + &c);

let x3 = &a * (&b - &f);
let y3 = g.square() - (FieldElement::from(3) * e.square());
let y3 = g.square() - (&three * e.square());
let z3 = &b * &h;

let [h0, h1] = h.value();
let x1_sq_3 = FieldElement::from(3) * x1.square();
let x1_sq_3 = three * x1.square();
let [x1_sq_30, x1_sq_31] = x1_sq_3.value();

t.0.value = [x3, y3, z3];
Expand Down Expand Up @@ -120,7 +121,7 @@ fn add_accumulate_line(
let e = &lambda * &d;
let f = z1 * c;
let g = x1 * d;
let h = &e + f - FieldElement::from(2) * &g;
let h = &e + f - FieldElement::<BLS12381PrimeField>::from(2) * &g;
let i = y1 * &e;

let x3 = &lambda * &h;
Expand Down Expand Up @@ -195,7 +196,7 @@ fn frobenius_square(
let f0 = FieldElement::new([a0.clone(), a1 * &omega_3, a2 * &omega_3_squared]);
let f1 = FieldElement::new([b0.clone(), b1 * omega_3, b2 * omega_3_squared]);

FieldElement::new([f0, f1 * w_raised_to_p_squared_minus_one])
FieldElement::new([f0, w_raised_to_p_squared_minus_one * f1])
}

// To understand more about how to reduce the final exponentiation
Expand Down
6 changes: 3 additions & 3 deletions math/src/fft/gpu/metal/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ pub fn fft<F: IsFFTField>(

let result = MetalState::retrieve_contents(&input_buffer);
let result = bitrev_permutation::<F, _>(&result, state)?;
Ok(result.iter().map(FieldElement::from_raw).collect())
Ok(result.into_iter().map(FieldElement::from_raw).collect())
}

/// Generates 2^{`order-1`} twiddle factors in parallel, with a certain `config`, in Metal.
Expand Down Expand Up @@ -89,7 +89,7 @@ pub fn gen_twiddles<F: IsFFTField>(
let (command_buffer, command_encoder) =
state.setup_command(&pipeline, Some(&[(0, &result_buffer)]));

let root = F::get_primitive_root_of_unity::<F>(order).unwrap();
let root = F::get_primitive_root_of_unity(order).unwrap();
command_encoder.set_bytes(1, mem::size_of::<F::BaseType>() as u64, void_ptr(&root));

let grid_size = MTLSize::new(len as u64, 1, 1);
Expand All @@ -103,7 +103,7 @@ pub fn gen_twiddles<F: IsFFTField>(
});

let result = MetalState::retrieve_contents(&result_buffer);
Ok(result.iter().map(FieldElement::from_raw).collect())
Ok(result.into_iter().map(FieldElement::from_raw).collect())
}

/// Executes a parallel bit-reverse permutation with the elements of `input`, in Metal.
Expand Down
Loading

0 comments on commit 7a84ab5

Please sign in to comment.