From c2923023aa9bb31ffdcd8fffb81a35a21d0fa967 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Tue, 19 Mar 2024 17:13:36 +0100 Subject: [PATCH] Adding Tests --- pyas2lib/cms.py | 34 +++++++-------- pyas2lib/tests/test_advanced.py | 27 ++++++++++++ pyas2lib/tests/test_async.py | 35 +++++++++++++++- pyas2lib/tests/test_cms.py | 74 ++++++++++++++++++++++++++++++++- 4 files changed, 149 insertions(+), 21 deletions(-) diff --git a/pyas2lib/cms.py b/pyas2lib/cms.py index db3d148..8366bb6 100644 --- a/pyas2lib/cms.py +++ b/pyas2lib/cms.py @@ -211,26 +211,24 @@ def decrypt_message(encrypted_data, decryption_key): key_enc_alg = recipient_info["key_encryption_algorithm"]["algorithm"].native encrypted_key = recipient_info["encrypted_key"].native - if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsaes_pkcs1v15" - ): - try: + try: + if cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_pkcs1v15"): key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key) - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: Could not extract decryption key." - ) from e - elif cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsaes_oaep" - ): - try: + + elif cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_oaep"): key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key) - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: Could not extract decryption key." - ) from e - else: - raise AS2Exception(f"Unsupported Key Encryption Algorithm {key_enc_alg}") + else: + raise AS2Exception( + f"Unsupported Key Encryption Algorithm {key_enc_alg}" + ) + except Exception as e: + raise DecryptionError( + "Failed to decrypt the payload: Could not extract decryption key." + ) from e alg = cms_content["content"]["encrypted_content_info"][ "content_encryption_algorithm" diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index 0bc6db7..836c90b 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -390,6 +390,33 @@ def test_mdn_not_found(self): self.assertEqual(status, "failed/Failure") self.assertEqual(detailed_status, "mdn-not-found") + def test_mdn_original_message_not_found(self): + """Test that the MDN parser raises MDN not found when a non MDN message is passed.""" + self.partner.mdn_mode = as2.SYNCHRONOUS_MDN + self.out_message = as2.Message(self.org, self.partner) + self.out_message.build(self.test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = ( + self.out_message.headers_str + b"\r\n" + self.out_message.content + ) + in_message = as2.Message() + _, _, mdn = in_message.parse( + raw_out_message, + find_org_cb=self.find_org, + find_partner_cb=self.find_partner, + find_message_cb=lambda x, y: False, + ) + + # Parse the MDN + out_mdn = as2.Mdn() + status, detailed_status = out_mdn.parse( + mdn.headers_str + b"\r\n" + mdn.content, find_message_cb=lambda x, y: False + ) + + self.assertEqual(status, "failed/Failure") + self.assertEqual(detailed_status, "original-message-not-found") + def test_unsigned_mdn_sent_error(self): """Test the case where a signed mdn was expected but unsigned mdn was returned.""" self.partner.mdn_mode = as2.SYNCHRONOUS_MDN diff --git a/pyas2lib/tests/test_async.py b/pyas2lib/tests/test_async.py index 57acb66..047e689 100644 --- a/pyas2lib/tests/test_async.py +++ b/pyas2lib/tests/test_async.py @@ -54,6 +54,9 @@ async def test_duplicate_message_async(): out_message = as2.Message(org, partner) out_message.build(test_data) + async def afind_message(message_id, message_recipient): + return out_message + # Parse the generated AS2 message as the partner raw_out_message = out_message.headers_str + b"\r\n" + out_message.content in_message = as2.Message() @@ -67,7 +70,7 @@ async def test_duplicate_message_async(): out_mdn = as2.Mdn() status, detailed_status = await out_mdn.aparse( mdn.headers_str + b"\r\n" + mdn.content, - find_message_cb=lambda x, y: out_message, + find_message_cb=afind_message, ) assert status == "processed/Warning" assert detailed_status == "duplicate-document" @@ -90,3 +93,33 @@ async def test_async_partnership(): # Compare contents of the input and output messages assert status == "processed" + + +@pytest.mark.asyncio +async def test_runtime_error(): + with pytest.raises(RuntimeError): + out_message = as2.Message(org, partner) + out_message.build(test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + in_message = as2.Message() + status, _, _ = in_message.parse( + raw_out_message, find_org_partner_cb=afind_org_partner + ) + + with pytest.raises(RuntimeError): + partner.sign = True + partner.encrypt = True + partner.mdn_mode = as2.SYNCHRONOUS_MDN + out_message = as2.Message(org, partner) + out_message.build(test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + in_message = as2.Message() + _, _, mdn = in_message.parse( + raw_out_message, + find_org_cb=afind_org, + find_partner_cb=afind_partner, + find_message_cb=afind_duplicate_message, + ) diff --git a/pyas2lib/tests/test_cms.py b/pyas2lib/tests/test_cms.py index a324fa7..fa3f596 100644 --- a/pyas2lib/tests/test_cms.py +++ b/pyas2lib/tests/test_cms.py @@ -2,7 +2,9 @@ import os import pytest -from oscrypto import asymmetric +from oscrypto import asymmetric, symmetric, util + +from asn1crypto import algos, cms as crypto_cms, core from pyas2lib.as2 import Organization from pyas2lib import cms @@ -22,6 +24,68 @@ ).dump() +def encrypted_data_with_faulty_key_algo(): + with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp: + encrypt_cert = asymmetric.load_certificate(fp.read()) + enc_alg_list = "rc4_128_cbc".split("_") + cipher, key_length, _ = enc_alg_list[0], enc_alg_list[1], enc_alg_list[2] + key = util.rand_bytes(int(key_length) // 8) + algorithm_id = "1.2.840.113549.3.4" + encrypted_content = symmetric.rc4_encrypt(key, b"data") + enc_alg_asn1 = algos.EncryptionAlgorithm( + { + "algorithm": algorithm_id, + } + ) + encrypted_key = asymmetric.rsa_oaep_encrypt(encrypt_cert, key) + return crypto_cms.ContentInfo( + { + "content_type": crypto_cms.ContentType("enveloped_data"), + "content": crypto_cms.EnvelopedData( + { + "version": crypto_cms.CMSVersion("v0"), + "recipient_infos": [ + crypto_cms.KeyTransRecipientInfo( + { + "version": crypto_cms.CMSVersion("v0"), + "rid": crypto_cms.RecipientIdentifier( + { + "issuer_and_serial_number": crypto_cms.IssuerAndSerialNumber( + { + "issuer": encrypt_cert.asn1[ + "tbs_certificate" + ]["issuer"], + "serial_number": encrypt_cert.asn1[ + "tbs_certificate" + ]["serial_number"], + } + ) + } + ), + "key_encryption_algorithm": crypto_cms.KeyEncryptionAlgorithm( + { + "algorithm": crypto_cms.KeyEncryptionAlgorithmId( + "aes128_wrap" + ) + } + ), + "encrypted_key": crypto_cms.OctetString(encrypted_key), + } + ) + ], + "encrypted_content_info": crypto_cms.EncryptedContentInfo( + { + "content_type": crypto_cms.ContentType("data"), + "content_encryption_algorithm": enc_alg_asn1, + "encrypted_content": encrypted_content, + } + ), + } + ), + } + ).dump() + + def test_compress(): """Test the compression and decompression functions.""" compressed_data = cms.compress_message(b"data") @@ -87,6 +151,7 @@ def test_encryption(): "aes_128_cbc", "aes_192_cbc", "aes_256_cbc", + "tripledes_192_cbc", ] key_enc_algos = [ @@ -95,7 +160,7 @@ def test_encryption(): ] encryption_algos = [ - (alg, scheme) for alg, scheme in zip(enc_algorithms, key_enc_algos) + (alg, key_algo) for alg in enc_algorithms for key_algo in key_enc_algos ] for enc_algorithm, encryption_scheme in encryption_algos: @@ -117,3 +182,8 @@ def test_encryption(): # Test faulty key encryption algorithm with pytest.raises(AS2Exception): cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc") + + # Test unsupported key encryption algorithm + encrypted_data = encrypted_data_with_faulty_key_algo() + with pytest.raises(AS2Exception): + cms.decrypt_message(encrypted_data, decrypt_key)