diff --git a/components/aead/src/aes_gcm/mod.rs b/components/aead/src/aes_gcm/mod.rs index 77280be53e..25f7824733 100644 --- a/components/aead/src/aes_gcm/mod.rs +++ b/components/aead/src/aes_gcm/mod.rs @@ -180,6 +180,20 @@ impl Aead for MpcAesGcm { Ok(()) } + async fn decode_key_private(&mut self) -> Result<(), AeadError> { + self.aes_ctr + .decode_key_private() + .await + .map_err(AeadError::from) + } + + async fn decode_key_blind(&mut self) -> Result<(), AeadError> { + self.aes_ctr + .decode_key_blind() + .await + .map_err(AeadError::from) + } + fn set_transcript_id(&mut self, id: &str) { self.aes_ctr.set_transcript_id(id) } @@ -321,6 +335,52 @@ impl Aead for MpcAesGcm { .map_err(AeadError::from) .await } + + async fn verify_tag( + &mut self, + explicit_nonce: Vec, + mut ciphertext: Vec, + aad: Vec, + ) -> Result<(), AeadError> { + let purported_tag = ciphertext.split_off(ciphertext.len() - AES_GCM_TAG_LEN); + + let tag = self + .compute_tag(explicit_nonce.clone(), ciphertext, aad) + .await?; + + // Reject if tag is incorrect + if tag == purported_tag { + Ok(()) + } else { + Err(AeadError::CorruptedTag) + } + } + + async fn prove_plaintext( + &mut self, + explicit_nonce: Vec, + mut ciphertext: Vec, + ) -> Result, AeadError> { + ciphertext.truncate(ciphertext.len() - AES_GCM_TAG_LEN); + + self.aes_ctr + .prove_plaintext(explicit_nonce, ciphertext) + .map_err(AeadError::from) + .await + } + + async fn verify_plaintext( + &mut self, + explicit_nonce: Vec, + mut ciphertext: Vec, + ) -> Result<(), AeadError> { + ciphertext.truncate(ciphertext.len() - AES_GCM_TAG_LEN); + + self.aes_ctr + .verify_plaintext(explicit_nonce, ciphertext) + .map_err(AeadError::from) + .await + } } #[cfg(test)] @@ -580,4 +640,37 @@ mod tests { .unwrap_err(); assert!(matches!(err, AeadError::CorruptedTag)); } + + #[tokio::test] + async fn test_aes_gcm_verify_tag() { + let key = vec![0u8; 16]; + let iv = vec![0u8; 4]; + let explicit_nonce = vec![0u8; 8]; + let plaintext = vec![1u8; 32]; + let aad = vec![2u8; 12]; + let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad); + + let len = ciphertext.len(); + + let ((mut leader, mut follower), (_leader_vm, _follower_vm)) = + setup_pair(key.clone(), iv.clone()).await; + + tokio::try_join!( + leader.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone()), + follower.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone()) + ) + .unwrap(); + + // corrupt tag + let mut corrupted = ciphertext.clone(); + corrupted[len - 1] -= 1; + + let (leader_res, follower_res) = tokio::join!( + leader.verify_tag(explicit_nonce.clone(), corrupted.clone(), aad.clone()), + follower.verify_tag(explicit_nonce.clone(), corrupted, aad.clone()) + ); + + assert!(matches!(leader_res.unwrap_err(), AeadError::CorruptedTag)); + assert!(matches!(follower_res.unwrap_err(), AeadError::CorruptedTag)); + } } diff --git a/components/aead/src/lib.rs b/components/aead/src/lib.rs index fa52f65dea..cb4a25287a 100644 --- a/components/aead/src/lib.rs +++ b/components/aead/src/lib.rs @@ -49,6 +49,12 @@ pub trait Aead: Send { /// Sets the key for the AEAD. async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), AeadError>; + /// Decodes the key for the AEAD, revealing it to this party. + async fn decode_key_private(&mut self) -> Result<(), AeadError>; + + /// Decodes the key for the AEAD, revealing it to the other party(s). + async fn decode_key_blind(&mut self) -> Result<(), AeadError>; + /// Sets the transcript id /// /// The AEAD assigns unique identifiers to each byte of plaintext @@ -144,4 +150,48 @@ pub trait Aead: Send { ciphertext: Vec, aad: Vec, ) -> Result<(), AeadError>; + + /// Verifies the tag of a ciphertext message. + /// + /// This method checks the authenticity of the ciphertext, tag and additional data. + /// + /// * `explicit_nonce` - The explicit nonce to use for decryption. + /// * `ciphertext` - The ciphertext and tag to authenticate and decrypt. + /// * `aad` - Additional authenticated data. + async fn verify_tag( + &mut self, + explicit_nonce: Vec, + ciphertext: Vec, + aad: Vec, + ) -> Result<(), AeadError>; + + /// Locally decrypts the provided ciphertext and then proves in ZK to the other party(s) that the + /// plaintext is correct. + /// + /// Returns the plaintext. + /// + /// This method requires this party to know the encryption key, which can be achieved by calling + /// the `decode_key_private` method. + /// + /// # Arguments + /// + /// * `explicit_nonce`: The explicit nonce to use for the keystream. + /// * `ciphertext`: The ciphertext to decrypt and prove. + async fn prove_plaintext( + &mut self, + explicit_nonce: Vec, + ciphertext: Vec, + ) -> Result, AeadError>; + + /// Verifies the other party(s) can prove they know a plaintext which encrypts to the given ciphertext. + /// + /// # Arguments + /// + /// * `explicit_nonce`: The explicit nonce to use for the keystream. + /// * `ciphertext`: The ciphertext to verify. + async fn verify_plaintext( + &mut self, + explicit_nonce: Vec, + ciphertext: Vec, + ) -> Result<(), AeadError>; } diff --git a/components/cipher/stream-cipher/benches/mock.rs b/components/cipher/stream-cipher/benches/mock.rs index bb08b8f6c0..bae9b694e4 100644 --- a/components/cipher/stream-cipher/benches/mock.rs +++ b/components/cipher/stream-cipher/benches/mock.rs @@ -106,7 +106,7 @@ async fn bench_stream_cipher_zk(thread_count: usize, len: usize) { let plaintext = vec![0u8; len]; let explicit_nonce = [0u8; 8]; - let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext); + let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext).unwrap(); _ = tokio::try_join!( leader.prove_plaintext(explicit_nonce.to_vec(), plaintext), diff --git a/components/cipher/stream-cipher/src/cipher.rs b/components/cipher/stream-cipher/src/cipher.rs index ed50391b83..f9e75e3601 100644 --- a/components/cipher/stream-cipher/src/cipher.rs +++ b/components/cipher/stream-cipher/src/cipher.rs @@ -5,12 +5,12 @@ use mpz_circuits::{ Circuit, }; -use crate::circuit::AES_CTR; +use crate::{circuit::AES_CTR, StreamCipherError}; /// A counter-mode block cipher circuit. pub trait CtrCircuit: Default + Clone + Send + Sync + 'static { /// The key type - type KEY: StaticValueType + Send + Sync + 'static; + type KEY: StaticValueType + TryFrom> + Send + Sync + 'static; /// The block type type BLOCK: StaticValueType + TryFrom> @@ -54,12 +54,12 @@ pub trait CtrCircuit: Default + Clone + Send + Sync + 'static { /// Applies the keystream to the message fn apply_keystream( - key: &Self::KEY, - iv: &Self::IV, + key: &[u8], + iv: &[u8], start_ctr: usize, - explicit_nonce: &Self::NONCE, + explicit_nonce: &[u8], msg: &[u8], - ) -> Vec; + ) -> Result, StreamCipherError>; } /// A circuit for AES-128 in counter mode. @@ -82,16 +82,35 @@ impl CtrCircuit for Aes128Ctr { } fn apply_keystream( - key: &Self::KEY, - iv: &Self::IV, + key: &[u8], + iv: &[u8], start_ctr: usize, - explicit_nonce: &Self::NONCE, + explicit_nonce: &[u8], msg: &[u8], - ) -> Vec { + ) -> Result, StreamCipherError> { use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek}; use aes::Aes128; use ctr::Ctr32BE; + let key: &[u8; 16] = key + .try_into() + .map_err(|_| StreamCipherError::InvalidKeyLength { + expected: 16, + actual: key.len(), + })?; + let iv: &[u8; 4] = iv + .try_into() + .map_err(|_| StreamCipherError::InvalidIvLength { + expected: 4, + actual: iv.len(), + })?; + let explicit_nonce: &[u8; 8] = explicit_nonce.try_into().map_err(|_| { + StreamCipherError::InvalidExplicitNonceLength { + expected: 8, + actual: explicit_nonce.len(), + } + })?; + let mut full_iv = [0u8; 16]; full_iv[0..4].copy_from_slice(iv); full_iv[4..12].copy_from_slice(explicit_nonce); @@ -103,6 +122,6 @@ impl CtrCircuit for Aes128Ctr { .expect("start counter is less than keystream length"); cipher.apply_keystream(&mut buf); - buf + Ok(buf) } } diff --git a/components/cipher/stream-cipher/src/lib.rs b/components/cipher/stream-cipher/src/lib.rs index 2d58b39d5a..1de38ff869 100644 --- a/components/cipher/stream-cipher/src/lib.rs +++ b/components/cipher/stream-cipher/src/lib.rs @@ -42,6 +42,10 @@ pub enum StreamCipherError { VerifyError(#[from] mpz_garble::VerifyError), #[error("key and iv is not set")] KeyIvNotSet, + #[error("invalid key length: expected {expected}, got {actual}")] + InvalidKeyLength { expected: usize, actual: usize }, + #[error("invalid iv length: expected {expected}, got {actual}")] + InvalidIvLength { expected: usize, actual: usize }, #[error("invalid explicit nonce length: expected {expected}, got {actual}")] InvalidExplicitNonceLength { expected: usize, actual: usize }, #[error("missing value for {0}")] @@ -57,6 +61,12 @@ where /// Sets the key and iv for the stream cipher. fn set_key(&mut self, key: ValueRef, iv: ValueRef); + /// Decodes the key for the stream cipher, revealing it to this party. + async fn decode_key_private(&mut self) -> Result<(), StreamCipherError>; + + /// Decodes the key for the stream cipher, revealing it to the other party(s). + async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError>; + /// Sets the transcript id /// /// The stream cipher assigns unique identifiers to each byte of plaintext @@ -149,17 +159,23 @@ where ciphertext: Vec, ) -> Result<(), StreamCipherError>; - /// Privately proves to the other party(s) the plaintext encrypts to a certain ciphertext. + /// Locally decrypts the provided ciphertext and then proves in ZK to the other party(s) that the + /// plaintext is correct. + /// + /// Returns the plaintext. + /// + /// This method requires this party to know the encryption key, which can be achieved by calling + /// the `decode_key_private` method. /// /// # Arguments /// /// * `explicit_nonce`: The explicit nonce to use for the keystream. - /// * `plaintext`: The plaintext to prove. + /// * `ciphertext`: The ciphertext to decrypt and prove. async fn prove_plaintext( &mut self, explicit_nonce: Vec, - plaintext: Vec, - ) -> Result<(), StreamCipherError>; + ciphertext: Vec, + ) -> Result, StreamCipherError>; /// Verifies the other party(s) can prove they know a plaintext which encrypts to the given ciphertext. /// @@ -306,7 +322,7 @@ mod tests { (follower_encrypted_msg, follower_decrypted_msg), ) = futures::join!(leader_fut, follower_fut); - let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg); + let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap(); assert_eq!(leader_encrypted_msg, reference); assert_eq!(leader_decrypted_msg, msg); @@ -324,7 +340,7 @@ mod tests { let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec(); - let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg); + let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap(); let ((mut leader, mut follower), (mut leader_vm, mut follower_vm)) = create_test_pair::(1, key, iv, 8).await; @@ -398,7 +414,8 @@ mod tests { .map(|(a, b)| a ^ b) .collect::>(); - let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &[0u8; 16]); + let reference = + Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &[0u8; 16]).unwrap(); assert_eq!(reference, key_block); } @@ -413,13 +430,15 @@ mod tests { let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec(); - let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &msg); + let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &msg).unwrap(); let ((mut leader, mut follower), (mut leader_vm, mut follower_vm)) = create_test_pair::(2, key, iv, 8).await; + futures::try_join!(leader.decode_key_private(), follower.decode_key_blind()).unwrap(); + futures::try_join!( - leader.prove_plaintext(explicit_nonce.to_vec(), msg), + leader.prove_plaintext(explicit_nonce.to_vec(), ciphertext.clone()), follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext) ) .unwrap(); diff --git a/components/cipher/stream-cipher/src/stream_cipher.rs b/components/cipher/stream-cipher/src/stream_cipher.rs index ec68fdd58b..601bf6921a 100644 --- a/components/cipher/stream-cipher/src/stream_cipher.rs +++ b/components/cipher/stream-cipher/src/stream_cipher.rs @@ -28,6 +28,8 @@ where } struct State { + /// Encoded key and IV for the cipher. + encoded_key_iv: Option, /// Key and IV for the cipher. key_iv: Option, /// Unique identifier for each execution of the cipher. @@ -41,11 +43,17 @@ struct State { } #[derive(Clone)] -struct KeyAndIv { +struct EncodedKeyAndIv { key: ValueRef, iv: ValueRef, } +#[derive(Clone)] +struct KeyAndIv { + key: Vec, + iv: Vec, +} + impl MpcStreamCipher where C: CtrCircuit, @@ -60,6 +68,7 @@ where Self { config, state: State { + encoded_key_iv: None, key_iv: None, execution_id, transcript_counter, @@ -102,9 +111,9 @@ where len: usize, mode: ExecutionMode, ) -> Result { - let KeyAndIv { key, iv } = self + let EncodedKeyAndIv { key, iv } = self .state - .key_iv + .encoded_key_iv .clone() .ok_or(StreamCipherError::KeyIvNotSet)?; @@ -218,7 +227,41 @@ where E: Thread + Execute + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static, { fn set_key(&mut self, key: ValueRef, iv: ValueRef) { + self.state.encoded_key_iv = Some(EncodedKeyAndIv { key, iv }); + } + + async fn decode_key_private(&mut self) -> Result<(), StreamCipherError> { + let EncodedKeyAndIv { key, iv } = self + .state + .encoded_key_iv + .clone() + .ok_or(StreamCipherError::KeyIvNotSet)?; + + let mut scope = self.thread_pool.new_scope(); + scope.push(move |thread| Box::pin(async move { thread.decode_private(&[key, iv]).await })); + let output = scope.wait().await.into_iter().next().unwrap()?; + + let [key, iv]: [_; 2] = output.try_into().expect("decoded 2 values"); + let key: Vec = key.try_into().expect("key is an array"); + let iv: Vec = iv.try_into().expect("iv is an array"); + self.state.key_iv = Some(KeyAndIv { key, iv }); + + Ok(()) + } + + async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError> { + let EncodedKeyAndIv { key, iv } = self + .state + .encoded_key_iv + .clone() + .ok_or(StreamCipherError::KeyIvNotSet)?; + + let mut scope = self.thread_pool.new_scope(); + scope.push(move |thread| Box::pin(async move { thread.decode_blind(&[key, iv]).await })); + scope.wait().await.into_iter().next().unwrap()?; + + Ok(()) } fn set_transcript_id(&mut self, id: &str) { @@ -480,8 +523,22 @@ where async fn prove_plaintext( &mut self, explicit_nonce: Vec, - plaintext: Vec, - ) -> Result<(), StreamCipherError> { + ciphertext: Vec, + ) -> Result, StreamCipherError> { + let KeyAndIv { key, iv } = self + .state + .key_iv + .clone() + .ok_or(StreamCipherError::KeyIvNotSet)?; + + let plaintext = C::apply_keystream( + &key, + &iv, + self.config.start_ctr, + &explicit_nonce, + &ciphertext, + )?; + // Prove plaintext encrypts back to ciphertext let keystream = self .compute_keystream( @@ -497,7 +554,7 @@ where .apply_keystream( InputText::Private { ids: plaintext_ids, - text: plaintext, + text: plaintext.clone(), }, keystream, ExecutionMode::Prove, @@ -506,7 +563,7 @@ where self.prove(ciphertext).await?; - Ok(()) + Ok(plaintext) } async fn verify_plaintext( @@ -546,9 +603,9 @@ where explicit_nonce: Vec, ctr: usize, ) -> Result, StreamCipherError> { - let KeyAndIv { key, iv } = self + let EncodedKeyAndIv { key, iv } = self .state - .key_iv + .encoded_key_iv .clone() .ok_or(StreamCipherError::KeyIvNotSet)?;