diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs index 171350dca..9b3cbf346 100644 --- a/math/src/circle/cosets.rs +++ b/math/src/circle/cosets.rs @@ -53,7 +53,7 @@ impl Coset { let generator_n = CirclePoint::get_generator_of_subgroup(coset.log_2_size); let size: u8 = 1 << coset.log_2_size; core::iter::successors(Some(coset.shift.clone()), move |prev| { - Some(prev.clone() + generator_n.clone()) + Some(prev + &generator_n) }) .take(size.into()) .collect() @@ -82,7 +82,7 @@ mod tests { #[test] fn coset_generator_has_right_order() { - let coset = Coset::new(2, CirclePoint::GENERATOR.scalar_mul(3)); + let coset = Coset::new(2, CirclePoint::GENERATOR * 3); let generator_n = coset.get_generator(); assert_eq!(generator_n.repeated_double(2), CirclePoint::zero()); } diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs index 0f1db689e..ac8824d8d 100644 --- a/math/src/circle/point.rs +++ b/math/src/circle/point.rs @@ -4,7 +4,7 @@ use crate::field::{ element::FieldElement, fields::mersenne31::{extensions::Degree4ExtensionField, field::Mersenne31Field}, }; -use core::ops::Add; +use core::ops::{Add, Mul}; #[derive(Debug, Clone)] pub struct CirclePoint { @@ -48,30 +48,31 @@ impl HasCircleParams for Degree4ExtensionField { const ORDER: u128 = 21267647892944572736998860269687930880; } -impl> CirclePoint { - pub fn new(x: FieldElement, y: FieldElement) -> Result { - if x.square() + y.square() == FieldElement::one() { - Ok(Self { x, y }) - } else { - Err(CircleError::InvalidValue) - } +/// Equality between two cricle points. +impl> PartialEq for CirclePoint { + fn eq(&self, other: &Self) -> bool { + self.x == other.x && self.y == other.y } +} - /// Neutral element of the Circle group (with additive notation). - pub fn zero() -> Self { - Self::new(FieldElement::one(), FieldElement::zero()).unwrap() - } +/// Addition (i.e. group operation with additive notation) between two points: +/// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) +impl> Add for &CirclePoint { + type Output = CirclePoint; - /// Computes (a0, a1) + (b0, b1) = (a0 * b0 - a1 * b1, a0 * b1 + a1 * b0) - #[allow(clippy::should_implement_trait)] - pub fn add(a: Self, b: Self) -> Self { - let x = &a.x * &b.x - &a.y * &b.y; - let y = a.x * b.y + a.y * b.x; + fn add(self, other: Self) -> Self::Output { + let x = &self.x * &other.x - &self.y * &other.y; + let y = &self.x * &other.y + &self.y * &other.x; CirclePoint { x, y } } +} - /// Computes n * (x, y) = (x ,y) + ... + (x, y) n-times. - pub fn scalar_mul(self, mut scalar: u128) -> Self { +/// Multiplication between a point and a scalar (i.e. group operation repeatedly): +/// (x, y) * n = (x ,y) + ... + (x, y) n-times. +impl> Mul for CirclePoint { + type Output = CirclePoint; + + fn mul(self, mut scalar: u128) -> Self { let mut res = Self::zero(); let mut cur = self; loop { @@ -79,12 +80,27 @@ impl> CirclePoint { return res; } if scalar & 1 == 1 { - res = res + cur.clone(); + res = &res + &cur; } cur = cur.double(); scalar >>= 1; } } +} + +impl> CirclePoint { + pub fn new(x: FieldElement, y: FieldElement) -> Result { + if x.square() + y.square() == FieldElement::one() { + Ok(Self { x, y }) + } else { + Err(CircleError::InvalidValue) + } + } + + /// Neutral element of the Circle group (with additive notation). + pub fn zero() -> Self { + Self::new(FieldElement::one(), FieldElement::zero()).unwrap() + } /// Computes 2(x, y) = (2x^2 - 1, 2xy). pub fn double(self) -> Self { @@ -135,19 +151,6 @@ impl> CirclePoint { pub const ORDER: u128 = F::ORDER; } -impl> PartialEq for CirclePoint { - fn eq(&self, other: &Self) -> bool { - self.x == other.x && self.y == other.y - } -} - -impl> Add for CirclePoint { - type Output = CirclePoint; - fn add(self, other: Self) -> Self { - CirclePoint::add(self, other) - } -} - #[cfg(test)] mod tests { use super::*; @@ -195,26 +198,26 @@ mod tests { fn zero_plus_zero_is_zero() { let a = G::zero(); let b = G::zero(); - assert_eq!(a + b, G::zero()) + assert_eq!(&a + &b, G::zero()) } #[test] fn generator_plus_zero_is_generator() { let g = G::GENERATOR; let zero = G::zero(); - assert_eq!(g.clone() + zero, g) + assert_eq!(&g + &zero, g) } #[test] fn double_equals_mul_two() { let g = G::GENERATOR; - assert_eq!(g.clone().double(), G::scalar_mul(g, 2)) + assert_eq!(g.clone().double(), g * 2) } #[test] fn mul_eight_equals_double_three_times() { let g = G::GENERATOR; - assert_eq!(g.clone().repeated_double(3), G::scalar_mul(g, 8)) + assert_eq!(g.clone().repeated_double(3), g * 8) } #[test] @@ -227,13 +230,13 @@ mod tests { #[test] fn generator_g4_has_the_order_of_the_group() { let g = G4::GENERATOR; - assert_eq!(g.scalar_mul(G4::ORDER), G4::zero()) + assert_eq!(g * G4::ORDER, G4::zero()) } #[test] fn conjugation_is_inverse_operation() { let g = G::GENERATOR; - assert_eq!(g.clone() + g.conjugate(), G::zero()) + assert_eq!(&g.clone() + &g.conjugate(), G::zero()) } #[test]