diff --git a/bindings/rust/src/lib.rs b/bindings/rust/src/lib.rs index 6f68ccdc..49dac393 100644 --- a/bindings/rust/src/lib.rs +++ b/bindings/rust/src/lib.rs @@ -1971,6 +1971,13 @@ pub mod min_sig { ); } +pub trait MultiPoint { + type Output; + + fn mult(&self, scalars: &[u8], nbits: usize) -> Self::Output; + fn add(&self) -> Self::Output; +} + #[cfg(feature = "std")] include!("pippenger.rs"); diff --git a/bindings/rust/src/pippenger-no_std.rs b/bindings/rust/src/pippenger-no_std.rs index c316e87a..24249708 100644 --- a/bindings/rust/src/pippenger-no_std.rs +++ b/bindings/rust/src/pippenger-no_std.rs @@ -59,16 +59,29 @@ macro_rules! pippenger_mult_impl { ret } + #[inline] pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point { - let npoints = self.points.len(); + self.as_slice().mult(scalars, nbits) + } + + #[inline] + pub fn add(&self) -> $point { + self.as_slice().add() + } + } + + impl MultiPoint for [$point_affine] { + type Output = $point; + + fn mult(&self, scalars: &[u8], nbits: usize) -> $point { + let npoints = self.len(); let nbytes = (nbits + 7) / 8; if scalars.len() < nbytes * npoints { panic!("scalars length mismatch"); } - let p: [*const $point_affine; 2] = - [&self.points[0], ptr::null()]; + let p: [*const $point_affine; 2] = [&self[0], ptr::null()]; let s: [*const u8; 2] = [&scalars[0], ptr::null()]; let mut ret = <$point>::default(); @@ -89,10 +102,10 @@ macro_rules! pippenger_mult_impl { ret } - pub fn add(&self) -> $point { - let npoints = self.points.len(); + fn add(&self) -> $point { + let npoints = self.len(); - let p: [*const _; 2] = [&self.points[0], ptr::null()]; + let p: [*const _; 2] = [&self[0], ptr::null()]; let mut ret = <$point>::default(); unsafe { $add(&mut ret, &p[0], npoints) }; @@ -125,7 +138,7 @@ pippenger_mult_impl!( blst_p1s_tile_pippenger, blst_p1_add_or_double, blst_p1_double, - p1_multi_scalar, + p1_multi_point, blst_p1_generator, blst_p1_mult, blst_p1s_add, @@ -141,7 +154,7 @@ pippenger_mult_impl!( blst_p2s_tile_pippenger, blst_p2_add_or_double, blst_p2_double, - p2_multi_scalar, + p2_multi_point, blst_p2_generator, blst_p2_mult, blst_p2s_add, diff --git a/bindings/rust/src/pippenger.rs b/bindings/rust/src/pippenger.rs index bdaec95a..2932e6a6 100644 --- a/bindings/rust/src/pippenger.rs +++ b/bindings/rust/src/pippenger.rs @@ -113,8 +113,22 @@ macro_rules! pippenger_mult_impl { ret } + #[inline] pub fn mult(&self, scalars: &[u8], nbits: usize) -> $point { - let npoints = self.points.len(); + self.as_slice().mult(scalars, nbits) + } + + #[inline] + pub fn add(&self) -> $point { + self.as_slice().add() + } + } + + impl MultiPoint for [$point_affine] { + type Output = $point; + + fn mult(&self, scalars: &[u8], nbits: usize) -> $point { + let npoints = self.len(); let nbytes = (nbits + 7) / 8; if scalars.len() < nbytes * npoints { @@ -124,8 +138,7 @@ macro_rules! pippenger_mult_impl { let pool = mt::da_pool(); let ncpus = pool.max_count(); if ncpus < 2 || npoints < 32 { - let p: [*const $point_affine; 2] = - [&self.points[0], ptr::null()]; + let p: [*const $point_affine; 2] = [&self[0], ptr::null()]; let s: [*const u8; 2] = [&scalars[0], ptr::null()]; unsafe { @@ -178,7 +191,7 @@ macro_rules! pippenger_mult_impl { } let grid = &grid[..]; - let points = &self.points[..]; + let points = &self[..]; let sz = unsafe { $scratch_sizeof(0) / 8 }; let mut row_sync: Vec = Vec::with_capacity(ny); @@ -262,13 +275,13 @@ macro_rules! pippenger_mult_impl { ret } - pub fn add(&self) -> $point { - let npoints = self.points.len(); + fn add(&self) -> $point { + let npoints = self.len(); let pool = mt::da_pool(); let ncpus = pool.max_count(); if ncpus < 2 || npoints < 384 { - let p: [*const _; 2] = [&self.points[0], ptr::null()]; + let p: [*const _; 2] = [&self[0], ptr::null()]; let mut ret = <$point>::default(); unsafe { $add(&mut ret, &p[0], npoints) }; return ret; @@ -295,7 +308,7 @@ macro_rules! pippenger_mult_impl { if work >= npoints { break; } - p[0] = &self.points[work]; + p[0] = &self[work]; if work + chunk > npoints { chunk = npoints - work; } @@ -345,7 +358,7 @@ pippenger_mult_impl!( blst_p1s_tile_pippenger, blst_p1_add_or_double, blst_p1_double, - p1_multi_scalar, + p1_multi_point, blst_p1_generator, blst_p1_mult, blst_p1s_add, @@ -361,7 +374,7 @@ pippenger_mult_impl!( blst_p2s_tile_pippenger, blst_p2_add_or_double, blst_p2_double, - p2_multi_scalar, + p2_multi_point, blst_p2_generator, blst_p2_mult, blst_p2s_add,