From aeae31557c6b372e899373d51c2d6096447d757e Mon Sep 17 00:00:00 2001 From: Roman Walch <9820846+rw0x0@users.noreply.github.com> Date: Thu, 28 Nov 2024 09:05:09 +0100 Subject: [PATCH] chore: Cleanup some code in the co-brillig MPC trait --- co-noir/co-brillig/src/mpc.rs | 6 +- co-noir/co-brillig/src/mpc/plain.rs | 23 -- co-noir/co-brillig/src/mpc/rep3.rs | 328 +-------------------------- co-noir/co-brillig/src/mpc/shamir.rs | 30 --- 4 files changed, 4 insertions(+), 383 deletions(-) diff --git a/co-noir/co-brillig/src/mpc.rs b/co-noir/co-brillig/src/mpc.rs index f5fa86855..061a3a8c5 100644 --- a/co-noir/co-brillig/src/mpc.rs +++ b/co-noir/co-brillig/src/mpc.rs @@ -180,8 +180,7 @@ pub trait BrilligDriver { lhs: Self::BrilligType, rhs: Self::BrilligType, ) -> eyre::Result { - let gt = self.lt(lhs, rhs)?; - self.not(gt) + self.lt(rhs, lhs) } /// Checks whether `lhs >= rhs`. The result @@ -197,8 +196,7 @@ pub trait BrilligDriver { lhs: Self::BrilligType, rhs: Self::BrilligType, ) -> eyre::Result { - let gt = self.lt(lhs, rhs)?; - self.not(gt) + self.le(rhs, lhs) } /// Converts the provided value to a binary representation, depending diff --git a/co-noir/co-brillig/src/mpc/plain.rs b/co-noir/co-brillig/src/mpc/plain.rs index 4a806259d..730d3d7a2 100644 --- a/co-noir/co-brillig/src/mpc/plain.rs +++ b/co-noir/co-brillig/src/mpc/plain.rs @@ -329,29 +329,6 @@ impl BrilligDriver for PlainBrilligDriver { } } - fn gt( - &mut self, - lhs: Self::BrilligType, - rhs: Self::BrilligType, - ) -> eyre::Result { - match (lhs, rhs) { - (PlainBrilligType::Field(lhs), PlainBrilligType::Field(rhs)) => { - let result = u128::from(lhs > rhs); - Ok(PlainBrilligType::Int(result, IntegerBitSize::U1)) - } - ( - PlainBrilligType::Int(lhs, lhs_bit_size), - PlainBrilligType::Int(rhs, rhs_bit_size), - ) if lhs_bit_size == rhs_bit_size => { - let result = u128::from(lhs > rhs); - Ok(PlainBrilligType::Int(result, IntegerBitSize::U1)) - } - x => eyre::bail!( - "type mismatch! Can only do bin ops on same types, but tried with {x:?}" - ), - } - } - fn to_radix( &mut self, val: Self::BrilligType, diff --git a/co-noir/co-brillig/src/mpc/rep3.rs b/co-noir/co-brillig/src/mpc/rep3.rs index 3b75e9881..7ac15c4a6 100644 --- a/co-noir/co-brillig/src/mpc/rep3.rs +++ b/co-noir/co-brillig/src/mpc/rep3.rs @@ -293,14 +293,8 @@ impl BrilligDriver for Rep3BrilligDriver cast_ring(rep3_ring_share, integer_bit_size, &mut self.io_context) } }, - (Rep3BrilligType::Public(public), BitSize::Field) => { - let casted = self.plain_driver.cast(public, bit_size)?; - Ok(Rep3BrilligType::Public(casted)) - } - (Rep3BrilligType::Public(public), BitSize::Integer(integer_bit_size)) => { - let casted = self - .plain_driver - .cast(public, BitSize::Integer(integer_bit_size))?; + (Rep3BrilligType::Public(public), bits) => { + let casted = self.plain_driver.cast(public, bits)?; Ok(Rep3BrilligType::Public(casted)) } } @@ -1249,324 +1243,6 @@ impl BrilligDriver for Rep3BrilligDriver Ok(result) } - fn gt( - &mut self, - lhs: Self::BrilligType, - rhs: Self::BrilligType, - ) -> eyre::Result { - let result = match (lhs, rhs) { - (Rep3BrilligType::Public(lhs), Rep3BrilligType::Public(rhs)) => { - let result = self.plain_driver.gt(lhs, rhs)?; - Rep3BrilligType::Public(result) - } - (Rep3BrilligType::Public(public), Rep3BrilligType::Shared(shared)) => { - match (shared, public) { - (Shared::Field(rhs), Public::Field(lhs)) => { - let le = rep3::arithmetic::ge_public_bit(rhs, lhs, &mut self.io_context)?; - let result = !Rep3RingShare::new( - Bit::cast_from_biguint(&le.a), - Bit::cast_from_biguint(&le.b), - ); - Rep3BrilligType::shared_u1(result) - } - (Shared::Ring128(rhs), Public::Int(lhs, IntegerBitSize::U128)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::lt_public( - rhs, - lhs.into(), - &mut self.io_context, - )?) - } - (Shared::Ring64(rhs), Public::Int(lhs, IntegerBitSize::U64)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::lt_public( - rhs, - u64::try_from(lhs).expect("must be u64").into(), - &mut self.io_context, - )?) - } - (Shared::Ring32(rhs), Public::Int(lhs, IntegerBitSize::U32)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::lt_public( - rhs, - u32::try_from(lhs).expect("must be u32").into(), - &mut self.io_context, - )?) - } - (Shared::Ring16(rhs), Public::Int(lhs, IntegerBitSize::U16)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::lt_public( - rhs, - u16::try_from(lhs).expect("must be u16").into(), - &mut self.io_context, - )?) - } - (Shared::Ring8(rhs), Public::Int(lhs, IntegerBitSize::U8)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::lt_public( - rhs, - u8::try_from(lhs).expect("must be u8").into(), - &mut self.io_context, - )?) - } - (Shared::Ring1(rhs), Public::Int(lhs, IntegerBitSize::U1)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::lt_public( - rhs, - bit_from_u128!(lhs), - &mut self.io_context, - )?) - } - x => eyre::bail!( - "type mismatch! Can only do bin ops on same types, but tried with {x:?}" - ), - } - } - (Rep3BrilligType::Shared(shared), Rep3BrilligType::Public(public)) => { - match (shared, public) { - (Shared::Field(lhs), Public::Field(rhs)) => { - let le = rep3::arithmetic::le_public_bit(lhs, rhs, &mut self.io_context)?; - let result = !Rep3RingShare::new( - Bit::cast_from_biguint(&le.a), - Bit::cast_from_biguint(&le.b), - ); - Rep3BrilligType::shared_u1(result) - } - (Shared::Ring128(lhs), Public::Int(rhs, IntegerBitSize::U128)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::gt_public( - lhs, - rhs.into(), - &mut self.io_context, - )?) - } - (Shared::Ring64(lhs), Public::Int(rhs, IntegerBitSize::U64)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::gt_public( - lhs, - u64::try_from(rhs).expect("must be u64").into(), - &mut self.io_context, - )?) - } - (Shared::Ring32(lhs), Public::Int(rhs, IntegerBitSize::U32)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::gt_public( - lhs, - u32::try_from(rhs).expect("must be u32").into(), - &mut self.io_context, - )?) - } - (Shared::Ring16(lhs), Public::Int(rhs, IntegerBitSize::U16)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::gt_public( - lhs, - u16::try_from(rhs).expect("must be u16").into(), - &mut self.io_context, - )?) - } - (Shared::Ring8(lhs), Public::Int(rhs, IntegerBitSize::U8)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::gt_public( - lhs, - u8::try_from(rhs).expect("must be u8").into(), - &mut self.io_context, - )?) - } - (Shared::Ring1(lhs), Public::Int(rhs, IntegerBitSize::U1)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::gt_public( - lhs, - bit_from_u128!(rhs), - &mut self.io_context, - )?) - } - x => eyre::bail!( - "type mismatch! Can only do bin ops on same types, but tried with {x:?}" - ), - } - } - (Rep3BrilligType::Shared(s1), Rep3BrilligType::Shared(s2)) => match (s1, s2) { - (Shared::Field(s1), Shared::Field(s2)) => { - let le = rep3::arithmetic::ge_bit(s2, s1, &mut self.io_context)?; - let result = !Rep3RingShare::new( - Bit::cast_from_biguint(&le.a), - Bit::cast_from_biguint(&le.b), - ); - Rep3BrilligType::shared_u1(result) - } - (Shared::Ring128(s1), Shared::Ring128(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::gt(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring64(s1), Shared::Ring64(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::gt(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring32(s1), Shared::Ring32(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::gt(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring16(s1), Shared::Ring16(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::gt(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring8(s1), Shared::Ring8(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::gt(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring1(s1), Shared::Ring1(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::gt(s1, s2, &mut self.io_context)?, - ), - x => eyre::bail!( - "type mismatch! Can only do bin ops on same types, but tried with {x:?}" - ), - }, - }; - Ok(result) - } - - fn ge( - &mut self, - lhs: Self::BrilligType, - rhs: Self::BrilligType, - ) -> eyre::Result { - let result = match (lhs, rhs) { - (Rep3BrilligType::Public(lhs), Rep3BrilligType::Public(rhs)) => { - let result = self.plain_driver.ge(lhs, rhs)?; - Rep3BrilligType::Public(result) - } - (Rep3BrilligType::Public(public), Rep3BrilligType::Shared(shared)) => { - match (shared, public) { - (Shared::Field(rhs), Public::Field(lhs)) => { - let ge = rep3::arithmetic::le_public_bit(rhs, lhs, &mut self.io_context)?; - let result = Rep3RingShare::new( - Bit::cast_from_biguint(&ge.a), - Bit::cast_from_biguint(&ge.b), - ); - Rep3BrilligType::shared_u1(result) - } - (Shared::Ring128(rhs), Public::Int(lhs, IntegerBitSize::U128)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::le_public( - rhs, - lhs.into(), - &mut self.io_context, - )?) - } - (Shared::Ring64(rhs), Public::Int(lhs, IntegerBitSize::U64)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::le_public( - rhs, - u64::try_from(lhs).expect("must be u64").into(), - &mut self.io_context, - )?) - } - (Shared::Ring32(rhs), Public::Int(lhs, IntegerBitSize::U32)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::le_public( - rhs, - u32::try_from(lhs).expect("must be u32").into(), - &mut self.io_context, - )?) - } - (Shared::Ring16(rhs), Public::Int(lhs, IntegerBitSize::U16)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::le_public( - rhs, - u16::try_from(lhs).expect("must be u16").into(), - &mut self.io_context, - )?) - } - (Shared::Ring8(rhs), Public::Int(lhs, IntegerBitSize::U8)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::le_public( - rhs, - u8::try_from(lhs).expect("must be u8").into(), - &mut self.io_context, - )?) - } - (Shared::Ring1(rhs), Public::Int(lhs, IntegerBitSize::U1)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::le_public( - rhs, - bit_from_u128!(lhs), - &mut self.io_context, - )?) - } - x => eyre::bail!( - "type mismatch! Can only do bin ops on same types, but tried with {x:?}" - ), - } - } - (Rep3BrilligType::Shared(shared), Rep3BrilligType::Public(public)) => { - match (shared, public) { - (Shared::Field(lhs), Public::Field(rhs)) => { - let ge = rep3::arithmetic::ge_public_bit(lhs, rhs, &mut self.io_context)?; - let result = Rep3RingShare::new( - Bit::cast_from_biguint(&ge.a), - Bit::cast_from_biguint(&ge.b), - ); - Rep3BrilligType::shared_u1(result) - } - (Shared::Ring128(lhs), Public::Int(rhs, IntegerBitSize::U128)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::ge_public( - lhs, - rhs.into(), - &mut self.io_context, - )?) - } - (Shared::Ring64(lhs), Public::Int(rhs, IntegerBitSize::U64)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::ge_public( - lhs, - u64::try_from(rhs).expect("must be u64").into(), - &mut self.io_context, - )?) - } - (Shared::Ring32(lhs), Public::Int(rhs, IntegerBitSize::U32)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::ge_public( - lhs, - u32::try_from(rhs).expect("must be u32").into(), - &mut self.io_context, - )?) - } - (Shared::Ring16(lhs), Public::Int(rhs, IntegerBitSize::U16)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::ge_public( - lhs, - u16::try_from(rhs).expect("must be u16").into(), - &mut self.io_context, - )?) - } - (Shared::Ring8(lhs), Public::Int(rhs, IntegerBitSize::U8)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::ge_public( - lhs, - u8::try_from(rhs).expect("must be u8").into(), - &mut self.io_context, - )?) - } - (Shared::Ring1(lhs), Public::Int(rhs, IntegerBitSize::U1)) => { - Rep3BrilligType::shared_u1(rep3_ring::arithmetic::ge_public( - lhs, - bit_from_u128!(rhs), - &mut self.io_context, - )?) - } - x => eyre::bail!( - "type mismatch! Can only do bin ops on same types, but tried with {x:?}" - ), - } - } - (Rep3BrilligType::Shared(s1), Rep3BrilligType::Shared(s2)) => match (s1, s2) { - (Shared::Field(s1), Shared::Field(s2)) => { - let ge = rep3::arithmetic::ge_bit(s1, s2, &mut self.io_context)?; - let result = Rep3RingShare::new( - Bit::cast_from_biguint(&ge.a), - Bit::cast_from_biguint(&ge.b), - ); - Rep3BrilligType::shared_u1(result) - } - (Shared::Ring128(s1), Shared::Ring128(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::ge(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring64(s1), Shared::Ring64(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::ge(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring32(s1), Shared::Ring32(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::ge(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring16(s1), Shared::Ring16(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::ge(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring8(s1), Shared::Ring8(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::ge(s1, s2, &mut self.io_context)?, - ), - (Shared::Ring1(s1), Shared::Ring1(s2)) => Rep3BrilligType::shared_u1( - rep3_ring::arithmetic::ge(s1, s2, &mut self.io_context)?, - ), - x => eyre::bail!( - "type mismatch! Can only do bin ops on same types, but tried with {x:?}" - ), - }, - }; - Ok(result) - } - fn to_radix( &mut self, val: Self::BrilligType, diff --git a/co-noir/co-brillig/src/mpc/shamir.rs b/co-noir/co-brillig/src/mpc/shamir.rs index d3d776a5c..0a104e128 100644 --- a/co-noir/co-brillig/src/mpc/shamir.rs +++ b/co-noir/co-brillig/src/mpc/shamir.rs @@ -274,36 +274,6 @@ impl BrilligDriver for ShamirBrilligDriver eyre::Result { - let result = match (lhs, rhs) { - (ShamirBrilligType::Public(lhs), ShamirBrilligType::Public(rhs)) => { - let result = self.plain_driver.gt(lhs, rhs)?; - ShamirBrilligType::Public(result) - } - _ => eyre::bail!("Cannot compare shared values with Shamir"), - }; - Ok(result) - } - - fn ge( - &mut self, - lhs: Self::BrilligType, - rhs: Self::BrilligType, - ) -> eyre::Result { - let result = match (lhs, rhs) { - (ShamirBrilligType::Public(lhs), ShamirBrilligType::Public(rhs)) => { - let result = self.plain_driver.ge(lhs, rhs)?; - ShamirBrilligType::Public(result) - } - _ => eyre::bail!("Cannot compare shared values with Shamir"), - }; - Ok(result) - } - fn to_radix( &mut self, val: Self::BrilligType,