Skip to content

Commit

Permalink
Add verify tag and zk to aead (#390)
Browse files Browse the repository at this point in the history
* add verify tag and zk to aead

* move local decryption into stream cipher

* fix arg name

* return plaintext

* truncate tag in zk methods
  • Loading branch information
sinui0 authored Dec 29, 2023
1 parent 3cb59eb commit 9169088
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 30 deletions.
93 changes: 93 additions & 0 deletions components/aead/src/aes_gcm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,20 @@ impl Aead for MpcAesGcm {
Ok(())
}

async fn decode_key_private(&mut self) -> Result<(), AeadError> {
self.aes_ctr
.decode_key_private()
.await
.map_err(AeadError::from)
}

async fn decode_key_blind(&mut self) -> Result<(), AeadError> {
self.aes_ctr
.decode_key_blind()
.await
.map_err(AeadError::from)
}

fn set_transcript_id(&mut self, id: &str) {
self.aes_ctr.set_transcript_id(id)
}
Expand Down Expand Up @@ -321,6 +335,52 @@ impl Aead for MpcAesGcm {
.map_err(AeadError::from)
.await
}

async fn verify_tag(
&mut self,
explicit_nonce: Vec<u8>,
mut ciphertext: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), AeadError> {
let purported_tag = ciphertext.split_off(ciphertext.len() - AES_GCM_TAG_LEN);

let tag = self
.compute_tag(explicit_nonce.clone(), ciphertext, aad)
.await?;

// Reject if tag is incorrect
if tag == purported_tag {
Ok(())
} else {
Err(AeadError::CorruptedTag)
}
}

async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
mut ciphertext: Vec<u8>,
) -> Result<Vec<u8>, AeadError> {
ciphertext.truncate(ciphertext.len() - AES_GCM_TAG_LEN);

self.aes_ctr
.prove_plaintext(explicit_nonce, ciphertext)
.map_err(AeadError::from)
.await
}

async fn verify_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
mut ciphertext: Vec<u8>,
) -> Result<(), AeadError> {
ciphertext.truncate(ciphertext.len() - AES_GCM_TAG_LEN);

self.aes_ctr
.verify_plaintext(explicit_nonce, ciphertext)
.map_err(AeadError::from)
.await
}
}

#[cfg(test)]
Expand Down Expand Up @@ -580,4 +640,37 @@ mod tests {
.unwrap_err();
assert!(matches!(err, AeadError::CorruptedTag));
}

#[tokio::test]
async fn test_aes_gcm_verify_tag() {
let key = vec![0u8; 16];
let iv = vec![0u8; 4];
let explicit_nonce = vec![0u8; 8];
let plaintext = vec![1u8; 32];
let aad = vec![2u8; 12];
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);

let len = ciphertext.len();

let ((mut leader, mut follower), (_leader_vm, _follower_vm)) =
setup_pair(key.clone(), iv.clone()).await;

tokio::try_join!(
leader.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone()),
follower.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone())
)
.unwrap();

// corrupt tag
let mut corrupted = ciphertext.clone();
corrupted[len - 1] -= 1;

let (leader_res, follower_res) = tokio::join!(
leader.verify_tag(explicit_nonce.clone(), corrupted.clone(), aad.clone()),
follower.verify_tag(explicit_nonce.clone(), corrupted, aad.clone())
);

assert!(matches!(leader_res.unwrap_err(), AeadError::CorruptedTag));
assert!(matches!(follower_res.unwrap_err(), AeadError::CorruptedTag));
}
}
50 changes: 50 additions & 0 deletions components/aead/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ pub trait Aead: Send {
/// Sets the key for the AEAD.
async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), AeadError>;

/// Decodes the key for the AEAD, revealing it to this party.
async fn decode_key_private(&mut self) -> Result<(), AeadError>;

/// Decodes the key for the AEAD, revealing it to the other party(s).
async fn decode_key_blind(&mut self) -> Result<(), AeadError>;

/// Sets the transcript id
///
/// The AEAD assigns unique identifiers to each byte of plaintext
Expand Down Expand Up @@ -144,4 +150,48 @@ pub trait Aead: Send {
ciphertext: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), AeadError>;

/// Verifies the tag of a ciphertext message.
///
/// This method checks the authenticity of the ciphertext, tag and additional data.
///
/// * `explicit_nonce` - The explicit nonce to use for decryption.
/// * `ciphertext` - The ciphertext and tag to authenticate and decrypt.
/// * `aad` - Additional authenticated data.
async fn verify_tag(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), AeadError>;

/// Locally decrypts the provided ciphertext and then proves in ZK to the other party(s) that the
/// plaintext is correct.
///
/// Returns the plaintext.
///
/// This method requires this party to know the encryption key, which can be achieved by calling
/// the `decode_key_private` method.
///
/// # Arguments
///
/// * `explicit_nonce`: The explicit nonce to use for the keystream.
/// * `ciphertext`: The ciphertext to decrypt and prove.
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, AeadError>;

/// Verifies the other party(s) can prove they know a plaintext which encrypts to the given ciphertext.
///
/// # Arguments
///
/// * `explicit_nonce`: The explicit nonce to use for the keystream.
/// * `ciphertext`: The ciphertext to verify.
async fn verify_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<(), AeadError>;
}
2 changes: 1 addition & 1 deletion components/cipher/stream-cipher/benches/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async fn bench_stream_cipher_zk(thread_count: usize, len: usize) {

let plaintext = vec![0u8; len];
let explicit_nonce = [0u8; 8];
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext);
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext).unwrap();

_ = tokio::try_join!(
leader.prove_plaintext(explicit_nonce.to_vec(), plaintext),
Expand Down
41 changes: 30 additions & 11 deletions components/cipher/stream-cipher/src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use mpz_circuits::{
Circuit,
};

use crate::circuit::AES_CTR;
use crate::{circuit::AES_CTR, StreamCipherError};

/// A counter-mode block cipher circuit.
pub trait CtrCircuit: Default + Clone + Send + Sync + 'static {
/// The key type
type KEY: StaticValueType + Send + Sync + 'static;
type KEY: StaticValueType + TryFrom<Vec<u8>> + Send + Sync + 'static;
/// The block type
type BLOCK: StaticValueType
+ TryFrom<Vec<u8>>
Expand Down Expand Up @@ -54,12 +54,12 @@ pub trait CtrCircuit: Default + Clone + Send + Sync + 'static {

/// Applies the keystream to the message
fn apply_keystream(
key: &Self::KEY,
iv: &Self::IV,
key: &[u8],
iv: &[u8],
start_ctr: usize,
explicit_nonce: &Self::NONCE,
explicit_nonce: &[u8],
msg: &[u8],
) -> Vec<u8>;
) -> Result<Vec<u8>, StreamCipherError>;
}

/// A circuit for AES-128 in counter mode.
Expand All @@ -82,16 +82,35 @@ impl CtrCircuit for Aes128Ctr {
}

fn apply_keystream(
key: &Self::KEY,
iv: &Self::IV,
key: &[u8],
iv: &[u8],
start_ctr: usize,
explicit_nonce: &Self::NONCE,
explicit_nonce: &[u8],
msg: &[u8],
) -> Vec<u8> {
) -> Result<Vec<u8>, StreamCipherError> {
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use ctr::Ctr32BE;

let key: &[u8; 16] = key
.try_into()
.map_err(|_| StreamCipherError::InvalidKeyLength {
expected: 16,
actual: key.len(),
})?;
let iv: &[u8; 4] = iv
.try_into()
.map_err(|_| StreamCipherError::InvalidIvLength {
expected: 4,
actual: iv.len(),
})?;
let explicit_nonce: &[u8; 8] = explicit_nonce.try_into().map_err(|_| {
StreamCipherError::InvalidExplicitNonceLength {
expected: 8,
actual: explicit_nonce.len(),
}
})?;

let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(iv);
full_iv[4..12].copy_from_slice(explicit_nonce);
Expand All @@ -103,6 +122,6 @@ impl CtrCircuit for Aes128Ctr {
.expect("start counter is less than keystream length");
cipher.apply_keystream(&mut buf);

buf
Ok(buf)
}
}
37 changes: 28 additions & 9 deletions components/cipher/stream-cipher/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ pub enum StreamCipherError {
VerifyError(#[from] mpz_garble::VerifyError),
#[error("key and iv is not set")]
KeyIvNotSet,
#[error("invalid key length: expected {expected}, got {actual}")]
InvalidKeyLength { expected: usize, actual: usize },
#[error("invalid iv length: expected {expected}, got {actual}")]
InvalidIvLength { expected: usize, actual: usize },
#[error("invalid explicit nonce length: expected {expected}, got {actual}")]
InvalidExplicitNonceLength { expected: usize, actual: usize },
#[error("missing value for {0}")]
Expand All @@ -57,6 +61,12 @@ where
/// Sets the key and iv for the stream cipher.
fn set_key(&mut self, key: ValueRef, iv: ValueRef);

/// Decodes the key for the stream cipher, revealing it to this party.
async fn decode_key_private(&mut self) -> Result<(), StreamCipherError>;

/// Decodes the key for the stream cipher, revealing it to the other party(s).
async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError>;

/// Sets the transcript id
///
/// The stream cipher assigns unique identifiers to each byte of plaintext
Expand Down Expand Up @@ -149,17 +159,23 @@ where
ciphertext: Vec<u8>,
) -> Result<(), StreamCipherError>;

/// Privately proves to the other party(s) the plaintext encrypts to a certain ciphertext.
/// Locally decrypts the provided ciphertext and then proves in ZK to the other party(s) that the
/// plaintext is correct.
///
/// Returns the plaintext.
///
/// This method requires this party to know the encryption key, which can be achieved by calling
/// the `decode_key_private` method.
///
/// # Arguments
///
/// * `explicit_nonce`: The explicit nonce to use for the keystream.
/// * `plaintext`: The plaintext to prove.
/// * `ciphertext`: The ciphertext to decrypt and prove.
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
) -> Result<(), StreamCipherError>;
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError>;

/// Verifies the other party(s) can prove they know a plaintext which encrypts to the given ciphertext.
///
Expand Down Expand Up @@ -306,7 +322,7 @@ mod tests {
(follower_encrypted_msg, follower_decrypted_msg),
) = futures::join!(leader_fut, follower_fut);

let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg);
let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();

assert_eq!(leader_encrypted_msg, reference);
assert_eq!(leader_decrypted_msg, msg);
Expand All @@ -324,7 +340,7 @@ mod tests {

let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();

let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg);
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();

let ((mut leader, mut follower), (mut leader_vm, mut follower_vm)) =
create_test_pair::<Aes128Ctr>(1, key, iv, 8).await;
Expand Down Expand Up @@ -398,7 +414,8 @@ mod tests {
.map(|(a, b)| a ^ b)
.collect::<Vec<u8>>();

let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &[0u8; 16]);
let reference =
Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &[0u8; 16]).unwrap();

assert_eq!(reference, key_block);
}
Expand All @@ -413,13 +430,15 @@ mod tests {

let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();

let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &msg);
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &msg).unwrap();

let ((mut leader, mut follower), (mut leader_vm, mut follower_vm)) =
create_test_pair::<Aes128Ctr>(2, key, iv, 8).await;

futures::try_join!(leader.decode_key_private(), follower.decode_key_blind()).unwrap();

futures::try_join!(
leader.prove_plaintext(explicit_nonce.to_vec(), msg),
leader.prove_plaintext(explicit_nonce.to_vec(), ciphertext.clone()),
follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext)
)
.unwrap();
Expand Down
Loading

0 comments on commit 9169088

Please sign in to comment.