diff --git a/co-circom/circom-mpc-compiler/src/lib.rs b/co-circom/circom-mpc-compiler/src/lib.rs index 0b727760b..7d14d209d 100644 --- a/co-circom/circom-mpc-compiler/src/lib.rs +++ b/co-circom/circom-mpc-compiler/src/lib.rs @@ -332,8 +332,8 @@ where OperatorType::Lesser => self.emit_opcode(MpcOpCode::Lt), OperatorType::Greater => self.emit_opcode(MpcOpCode::Gt), OperatorType::Eq(size) => { - assert_ne!(size, 0); - self.emit_opcode(MpcOpCode::Eq); + assert_ne!(size, 0, "size must be > 0"); + self.emit_opcode(MpcOpCode::Eq(size)); } OperatorType::NotEq => self.emit_opcode(MpcOpCode::Neq), OperatorType::BoolOr => self.emit_opcode(MpcOpCode::BoolOr), diff --git a/co-circom/circom-mpc-vm/src/mpc_vm.rs b/co-circom/circom-mpc-vm/src/mpc_vm.rs index 53d059eb7..fd38b64bb 100644 --- a/co-circom/circom-mpc-vm/src/mpc_vm.rs +++ b/co-circom/circom-mpc-vm/src/mpc_vm.rs @@ -327,7 +327,7 @@ impl> Component { match inst { op_codes::MpcOpCode::PushConstant(index) => { let constant = ctx.constant_table[*index].clone(); - tracing::debug!("pushing constant {}", constant); + tracing::trace!("pushing constant {}", constant); self.push_field(constant); } op_codes::MpcOpCode::PushIndex(index) => self.push_index(*index), @@ -338,7 +338,7 @@ impl> Component { .iter() .cloned() .for_each(|signal| { - tracing::debug!("pushing signal {signal}"); + tracing::trace!("pushing signal {signal}"); self.push_field(signal); }); } @@ -477,7 +477,7 @@ impl> Component { let mut input_signals = vec![C::VmType::default(); *amount]; for i in 0..*amount { input_signals[*amount - i - 1] = self.pop_field(); - tracing::debug!("popping {}", input_signals.last().unwrap()); + tracing::trace!("popping {}", input_signals.last().unwrap()); } let component = &mut self.sub_components[sub_comp_index]; @@ -602,10 +602,22 @@ impl> Component { let lhs = self.pop_field(); self.push_field(protocol.ge(lhs, rhs)?); } - op_codes::MpcOpCode::Eq => { - let rhs = self.pop_field(); - let lhs = self.pop_field(); - self.push_field(protocol.eq(lhs, rhs)?); + op_codes::MpcOpCode::Eq(size) => { + let size = *size; + let mut lhs = Vec::with_capacity(size); + let mut rhs = Vec::with_capacity(size); + for _ in 0..size { + rhs.push(self.pop_field()); + } + for _ in 0..size { + lhs.push(self.pop_field()); + } + let mut result = protocol.public_one(); + for (lhs, rhs) in izip!(lhs, rhs) { + let cmp = protocol.eq(lhs, rhs)?; + result = protocol.bool_and(cmp, result)?; + } + self.push_field(result); } op_codes::MpcOpCode::Neq => { let rhs = self.pop_field(); diff --git a/co-circom/circom-mpc-vm/src/op_codes.rs b/co-circom/circom-mpc-vm/src/op_codes.rs index 7ad8e4559..7c2b939c1 100644 --- a/co-circom/circom-mpc-vm/src/op_codes.rs +++ b/co-circom/circom-mpc-vm/src/op_codes.rs @@ -95,8 +95,8 @@ pub enum MpcOpCode { Gt, /// Pops two elements from the field stack, compares if the first popped value is greater than or equal to the second, and pushes a boolean result onto the stack. Ge, - /// Pops two elements from the field stack, compares if the first popped value is equal to the second, and pushes a boolean result onto the stack. - Eq, + /// Pops the provided amount of elements time 2 from the field stack and compares if the first n popped values are equal to the second set, and pushes a boolean result onto the stack. + Eq(usize), /// Pops two elements from the field stack, compares if the first popped value is not equal to the second, and pushes a boolean result onto the stack. Neq, /// Pops two boolean values from the field stack, computes their boolean OR, and pushes the result onto the stack. @@ -171,7 +171,7 @@ impl std::fmt::Display for MpcOpCode { MpcOpCode::Le => "LESS_EQ_OP".to_owned(), MpcOpCode::Gt => "GREATER_THAN_OP".to_owned(), MpcOpCode::Ge => "GREATER_EQ_OP".to_owned(), - MpcOpCode::Eq => "IS_EQUAL_OP".to_owned(), + MpcOpCode::Eq(size) => format!("IS_EQUAL_OP {size}"), MpcOpCode::Neq => "NOT_EQUAL_OP".to_owned(), MpcOpCode::BoolOr => "BOOL_OR_OP".to_owned(), MpcOpCode::BoolAnd => "BOOL_AND_OP".to_owned(), diff --git a/test_vectors/WitnessExtension/kats/array_equals/input0.json b/test_vectors/WitnessExtension/kats/array_equals/input0.json new file mode 100644 index 000000000..e013d4ea6 --- /dev/null +++ b/test_vectors/WitnessExtension/kats/array_equals/input0.json @@ -0,0 +1,3 @@ +{ + "in": ["22", "-11", "22", "-11"] +} diff --git a/test_vectors/WitnessExtension/kats/array_equals/input1.json b/test_vectors/WitnessExtension/kats/array_equals/input1.json new file mode 100644 index 000000000..079af0f70 --- /dev/null +++ b/test_vectors/WitnessExtension/kats/array_equals/input1.json @@ -0,0 +1,3 @@ +{ + "in": ["22", "11", "22", "11"] +} diff --git a/test_vectors/WitnessExtension/kats/array_equals/input2.json b/test_vectors/WitnessExtension/kats/array_equals/input2.json new file mode 100644 index 000000000..9eb7dcd09 --- /dev/null +++ b/test_vectors/WitnessExtension/kats/array_equals/input2.json @@ -0,0 +1,3 @@ +{ + "in": ["0", "32094032", "0", "32094032"] +} diff --git a/test_vectors/WitnessExtension/kats/array_equals/witness0.wtns b/test_vectors/WitnessExtension/kats/array_equals/witness0.wtns new file mode 100644 index 000000000..467294e08 Binary files /dev/null and b/test_vectors/WitnessExtension/kats/array_equals/witness0.wtns differ diff --git a/test_vectors/WitnessExtension/kats/array_equals/witness1.wtns b/test_vectors/WitnessExtension/kats/array_equals/witness1.wtns new file mode 100644 index 000000000..467294e08 Binary files /dev/null and b/test_vectors/WitnessExtension/kats/array_equals/witness1.wtns differ diff --git a/test_vectors/WitnessExtension/kats/array_equals/witness2.wtns b/test_vectors/WitnessExtension/kats/array_equals/witness2.wtns new file mode 100644 index 000000000..467294e08 Binary files /dev/null and b/test_vectors/WitnessExtension/kats/array_equals/witness2.wtns differ diff --git a/test_vectors/WitnessExtension/tests/array_equals.circom b/test_vectors/WitnessExtension/tests/array_equals.circom new file mode 100644 index 000000000..44b990180 --- /dev/null +++ b/test_vectors/WitnessExtension/tests/array_equals.circom @@ -0,0 +1,10 @@ +pragma circom 2.0.0; + +template Main() { + signal input a[2]; + signal input b[2]; + + a === b; +} + +component main = Main(); diff --git a/tests/tests/circom/witness_extension_tests/plain_vm.rs b/tests/tests/circom/witness_extension_tests/plain_vm.rs index 8f2859837..fa54e3e50 100644 --- a/tests/tests/circom/witness_extension_tests/plain_vm.rs +++ b/tests/tests/circom/witness_extension_tests/plain_vm.rs @@ -108,6 +108,7 @@ pub fn from_test_name(fn_name: &str) -> TestInputs { } witness_extension_test_plain!(aliascheck_test); +witness_extension_test_plain!(array_equals); witness_extension_test_plain!(babyadd_tester); witness_extension_test_plain!(babycheck_test); witness_extension_test_plain!(babypbk_test); diff --git a/tests/tests/circom/witness_extension_tests/rep3.rs b/tests/tests/circom/witness_extension_tests/rep3.rs index a8288c238..783d08f0b 100644 --- a/tests/tests/circom/witness_extension_tests/rep3.rs +++ b/tests/tests/circom/witness_extension_tests/rep3.rs @@ -199,6 +199,7 @@ macro_rules! witness_extension_test_rep3_ignored { } witness_extension_test_rep3!(aliascheck_test); +witness_extension_test_rep3!(array_equals); witness_extension_test_rep3!(babyadd_tester); witness_extension_test_rep3!(babycheck_test); witness_extension_test_rep3!(babypbk_test);