diff --git a/co-noir/co-brillig/Cargo.toml b/co-noir/co-brillig/Cargo.toml index 8010b0065..a4c4f1b7f 100644 --- a/co-noir/co-brillig/Cargo.toml +++ b/co-noir/co-brillig/Cargo.toml @@ -23,6 +23,7 @@ noirc-abi.workspace = true noirc-artifacts.workspace = true num-bigint.workspace = true num-traits.workspace = true +rand.workspace = true rayon.workspace = true serde.workspace = true thiserror.workspace = true diff --git a/co-noir/co-brillig/src/mpc/rep3.rs b/co-noir/co-brillig/src/mpc/rep3.rs index 5b88a3a94..3b75e9881 100644 --- a/co-noir/co-brillig/src/mpc/rep3.rs +++ b/co-noir/co-brillig/src/mpc/rep3.rs @@ -10,6 +10,8 @@ use mpc_core::protocols::rep3_ring::ring::int_ring::IntRing2k; use mpc_core::protocols::rep3_ring::ring::ring_impl::RingElement; use mpc_core::protocols::rep3_ring::{self, Rep3BitShare, Rep3RingShare}; use num_bigint::BigUint; +use num_traits::AsPrimitive; +use rand::distributions::{Distribution, Standard}; use std::marker::PhantomData; use super::PlainBrilligType as Public; @@ -145,6 +147,43 @@ macro_rules! bit_from_u128 { }}; } +fn cast_ring( + share: Rep3RingShare, + integer_bit_size: IntegerBitSize, + io_context: &mut IoContext, +) -> eyre::Result> +where + Standard: Distribution, + T: IntRing2k + + AsPrimitive + + AsPrimitive + + AsPrimitive + + AsPrimitive + + AsPrimitive + + AsPrimitive, +{ + match integer_bit_size { + IntegerBitSize::U1 => Ok(Rep3BrilligType::Shared(Shared::Ring1( + rep3_ring::casts::ring_cast_selector::<_, Bit, _>(share, io_context)?, + ))), + IntegerBitSize::U8 => Ok(Rep3BrilligType::Shared(Shared::Ring8( + rep3_ring::casts::ring_cast_selector::<_, u8, _>(share, io_context)?, + ))), + IntegerBitSize::U16 => Ok(Rep3BrilligType::Shared(Shared::Ring16( + rep3_ring::casts::ring_cast_selector::<_, u16, _>(share, io_context)?, + ))), + IntegerBitSize::U32 => Ok(Rep3BrilligType::Shared(Shared::Ring32( + rep3_ring::casts::ring_cast_selector::<_, u32, _>(share, io_context)?, + ))), + IntegerBitSize::U64 => Ok(Rep3BrilligType::Shared(Shared::Ring64( + rep3_ring::casts::ring_cast_selector::<_, u64, _>(share, io_context)?, + ))), + IntegerBitSize::U128 => Ok(Rep3BrilligType::Shared(Shared::Ring128( + rep3_ring::casts::ring_cast_selector::<_, u128, _>(share, io_context)?, + ))), + } +} + impl BrilligDriver for Rep3BrilligDriver { type BrilligType = Rep3BrilligType; @@ -235,234 +274,24 @@ impl BrilligDriver for Rep3BrilligDriver )?, ))), }, - Shared::Ring128(rep3_ring_share) => match integer_bit_size { - IntegerBitSize::U1 => Ok(Rep3BrilligType::Shared(Shared::Ring1( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U8 => Ok(Rep3BrilligType::Shared(Shared::Ring8( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U16 => Ok(Rep3BrilligType::Shared(Shared::Ring16( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U32 => Ok(Rep3BrilligType::Shared(Shared::Ring32( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U64 => Ok(Rep3BrilligType::Shared(Shared::Ring64( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U128 => Ok(Rep3BrilligType::Shared(Shared::Ring128( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - }, - Shared::Ring64(rep3_ring_share) => match integer_bit_size { - IntegerBitSize::U1 => Ok(Rep3BrilligType::Shared(Shared::Ring1( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U8 => Ok(Rep3BrilligType::Shared(Shared::Ring8( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U16 => Ok(Rep3BrilligType::Shared(Shared::Ring16( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U32 => Ok(Rep3BrilligType::Shared(Shared::Ring32( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U64 => Ok(Rep3BrilligType::Shared(Shared::Ring64( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U128 => Ok(Rep3BrilligType::Shared(Shared::Ring128( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - }, - Shared::Ring32(rep3_ring_share) => match integer_bit_size { - IntegerBitSize::U1 => Ok(Rep3BrilligType::Shared(Shared::Ring1( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U8 => Ok(Rep3BrilligType::Shared(Shared::Ring8( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U16 => Ok(Rep3BrilligType::Shared(Shared::Ring16( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U32 => Ok(Rep3BrilligType::Shared(Shared::Ring32( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U64 => Ok(Rep3BrilligType::Shared(Shared::Ring64( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U128 => Ok(Rep3BrilligType::Shared(Shared::Ring128( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - }, - Shared::Ring16(rep3_ring_share) => match integer_bit_size { - IntegerBitSize::U1 => Ok(Rep3BrilligType::Shared(Shared::Ring1( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U8 => Ok(Rep3BrilligType::Shared(Shared::Ring8( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U16 => Ok(Rep3BrilligType::Shared(Shared::Ring16( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U32 => Ok(Rep3BrilligType::Shared(Shared::Ring32( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U64 => Ok(Rep3BrilligType::Shared(Shared::Ring64( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U128 => Ok(Rep3BrilligType::Shared(Shared::Ring128( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - }, - Shared::Ring8(rep3_ring_share) => match integer_bit_size { - IntegerBitSize::U1 => Ok(Rep3BrilligType::Shared(Shared::Ring1( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U8 => Ok(Rep3BrilligType::Shared(Shared::Ring8( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U16 => Ok(Rep3BrilligType::Shared(Shared::Ring16( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U32 => Ok(Rep3BrilligType::Shared(Shared::Ring32( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U64 => Ok(Rep3BrilligType::Shared(Shared::Ring64( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U128 => Ok(Rep3BrilligType::Shared(Shared::Ring128( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - }, - Shared::Ring1(rep3_ring_share) => match integer_bit_size { - IntegerBitSize::U1 => Ok(Rep3BrilligType::Shared(Shared::Ring1( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U8 => Ok(Rep3BrilligType::Shared(Shared::Ring8( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U16 => Ok(Rep3BrilligType::Shared(Shared::Ring16( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U32 => Ok(Rep3BrilligType::Shared(Shared::Ring32( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U64 => Ok(Rep3BrilligType::Shared(Shared::Ring64( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - IntegerBitSize::U128 => Ok(Rep3BrilligType::Shared(Shared::Ring128( - rep3_ring::casts::ring_cast_selector( - rep3_ring_share, - &mut self.io_context, - )?, - ))), - }, + Shared::Ring128(rep3_ring_share) => { + cast_ring(rep3_ring_share, integer_bit_size, &mut self.io_context) + } + Shared::Ring64(rep3_ring_share) => { + cast_ring(rep3_ring_share, integer_bit_size, &mut self.io_context) + } + Shared::Ring32(rep3_ring_share) => { + cast_ring(rep3_ring_share, integer_bit_size, &mut self.io_context) + } + Shared::Ring16(rep3_ring_share) => { + cast_ring(rep3_ring_share, integer_bit_size, &mut self.io_context) + } + Shared::Ring8(rep3_ring_share) => { + cast_ring(rep3_ring_share, integer_bit_size, &mut self.io_context) + } + Shared::Ring1(rep3_ring_share) => { + cast_ring(rep3_ring_share, integer_bit_size, &mut self.io_context) + } }, (Rep3BrilligType::Public(public), BitSize::Field) => { let casted = self.plain_driver.cast(public, bit_size)?;