From 9b400b1e7661a414bfd4dda78cb4ec6f6337813a Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Thu, 2 May 2024 16:49:15 +0200 Subject: [PATCH 1/2] Extending partner to accept key encryption algo and pass that down. * Fix github action build fail due to: https://stackoverflow.com/questions/71673404/importerror-cannot-import-name-unicodefun-from-click * Added partner setting to force canonicalize binary. * Formatted with black * https://github.com/abhishek-ram/pyas2-lib/issues/62 * Asserting error messages and _encrypted_data_with_faulty_key_algo --- pyas2lib/as2.py | 11 ++++ pyas2lib/cms.py | 101 +++++++++++++++++++------------- pyas2lib/constants.py | 4 ++ pyas2lib/tests/test_advanced.py | 3 + pyas2lib/tests/test_cms.py | 97 +++++++++++++++++++++++++++++- 5 files changed, 171 insertions(+), 45 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 338084d..cb320a3 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -25,6 +25,7 @@ DIGEST_ALGORITHMS, EDIINT_FEATURES, ENCRYPTION_ALGORITHMS, + KEY_ENCRYPTION_ALGORITHMS, MDN_CONFIRM_TEXT, MDN_FAILED_TEXT, MDN_MODES, @@ -183,6 +184,9 @@ class Partner: :param sign_alg: The signing algorithm to be used for generating the signature. (default `rsassa_pkcs1v15`) + :param key_enc_alg: The key encryption algorithm to be used. + (default `rsaes_pkcs1v15`) + """ as2_name: str @@ -202,6 +206,7 @@ class Partner: ignore_self_signed: bool = True canonicalize_as_binary: bool = False sign_alg: str = "rsassa_pkcs1v15" + key_enc_alg: str = "rsaes_pkcs1v15" def __post_init__(self): """Run the post initialisation checks for this class.""" @@ -236,6 +241,12 @@ def __post_init__(self): f"must be one of {SIGNATUR_ALGORITHMS}" ) + if self.key_enc_alg and self.key_enc_alg not in KEY_ENCRYPTION_ALGORITHMS: + raise ImproperlyConfigured( + f"Unsupported Key Encryption Algorithm {self.key_enc_alg}, " + f"must be one of {KEY_ENCRYPTION_ALGORITHMS}" + ) + def load_verify_cert(self): """Load the verification certificate of the partner and returned the parsed cert.""" if self.validate_certs: diff --git a/pyas2lib/cms.py b/pyas2lib/cms.py index 0172980..3266ef0 100644 --- a/pyas2lib/cms.py +++ b/pyas2lib/cms.py @@ -65,12 +65,15 @@ def decompress_message(compressed_data): raise DecompressionError("Decompression failed with cause: {}".format(e)) from e -def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): +def encrypt_message( + data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15" +): """Function encrypts data and returns the generated ASN.1 :param data_to_encrypt: A byte string of the data to be encrypted :param enc_alg: The algorithm to be used for encrypting the data :param encryption_cert: The certificate to be used for encrypting the data + :param key_enc_alg: The algo for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep :return: A CMS ASN.1 byte string of the encrypted data. """ @@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): raise AS2Exception("Unsupported Encryption Algorithm") # Encrypt the key and build the ASN.1 message - encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + if key_enc_alg == "rsaes_pkcs1v15": + encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + elif key_enc_alg == "rsaes_oaep": + encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key) + else: + raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}") return cms.ContentInfo( { @@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): } ), "key_encryption_algorithm": cms.KeyEncryptionAlgorithm( - {"algorithm": cms.KeyEncryptionAlgorithmId("rsa")} + { + "algorithm": cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) + } ), "encrypted_key": cms.OctetString(encrypted_key), } @@ -199,47 +211,52 @@ def decrypt_message(encrypted_data, decryption_key): key_enc_alg = recipient_info["key_encryption_algorithm"]["algorithm"].native encrypted_key = recipient_info["encrypted_key"].native - if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsa" - ): - try: + try: + if cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_pkcs1v15"): key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key) - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: Could not extract decryption key." - ) from e - - alg = cms_content["content"]["encrypted_content_info"][ - "content_encryption_algorithm" - ] - encapsulated_data = cms_content["content"]["encrypted_content_info"][ - "encrypted_content" - ].native - try: - if alg["algorithm"].native == "rc4": - decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) - elif alg.encryption_cipher == "tripledes": - cipher = "tripledes_192_cbc" - decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "aes": - decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "rc2": - decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( - key, encapsulated_data, alg["parameters"]["iv"].native - ) - else: - raise AS2Exception("Unsupported Encryption Algorithm") - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: {}".format(e) - ) from e - else: - raise AS2Exception("Unsupported Encryption Algorithm") + elif cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_oaep"): + key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key) + else: + raise AS2Exception( + f"Unsupported Key Encryption Algorithm {key_enc_alg}" + ) + except Exception as e: + raise DecryptionError( + "Failed to decrypt the payload: Could not extract decryption key." + ) from e + + alg = cms_content["content"]["encrypted_content_info"][ + "content_encryption_algorithm" + ] + encapsulated_data = cms_content["content"]["encrypted_content_info"][ + "encrypted_content" + ].native + + try: + if alg["algorithm"].native == "rc4": + decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) + elif alg.encryption_cipher == "tripledes": + cipher = "tripledes_192_cbc" + decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "aes": + decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "rc2": + decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( + key, encapsulated_data, alg["parameters"]["iv"].native + ) + else: + raise AS2Exception("Unsupported Encryption Algorithm") + except Exception as e: + raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e else: raise DecryptionError("Encrypted data not found in ASN.1 ") diff --git a/pyas2lib/constants.py b/pyas2lib/constants.py index b5d4de2..f5456f7 100644 --- a/pyas2lib/constants.py +++ b/pyas2lib/constants.py @@ -32,3 +32,7 @@ "rsassa_pkcs1v15", "rsassa_pss", ) +KEY_ENCRYPTION_ALGORITHMS = ( + "rsaes_pkcs1v15", + "rsaes_oaep", +) diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index f010ecd..fefda8d 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -337,6 +337,9 @@ def test_partner_checks(self): with self.assertRaises(ImproperlyConfigured): as2.Partner("a partner", sign_alg="xyz") + with self.assertRaises(ImproperlyConfigured): + as2.Partner("a partner", key_enc_alg="xyz") + def test_message_checks(self): """Test the checks and other features of Message.""" msg = as2.Message() diff --git a/pyas2lib/tests/test_cms.py b/pyas2lib/tests/test_cms.py index 34655db..cf2f711 100644 --- a/pyas2lib/tests/test_cms.py +++ b/pyas2lib/tests/test_cms.py @@ -2,7 +2,9 @@ import os import pytest -from oscrypto import asymmetric +from oscrypto import asymmetric, symmetric, util + +from asn1crypto import algos, cms as crypto_cms, core from pyas2lib.as2 import Organization from pyas2lib import cms @@ -22,6 +24,68 @@ ).dump() +def _encrypted_data_with_faulty_key_algo(): + with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp: + encrypt_cert = asymmetric.load_certificate(fp.read()) + enc_alg_list = "rc4_128_cbc".split("_") + cipher, key_length, _ = enc_alg_list[0], enc_alg_list[1], enc_alg_list[2] + key = util.rand_bytes(int(key_length) // 8) + algorithm_id = "1.2.840.113549.3.4" + encrypted_content = symmetric.rc4_encrypt(key, b"data") + enc_alg_asn1 = algos.EncryptionAlgorithm( + { + "algorithm": algorithm_id, + } + ) + encrypted_key = asymmetric.rsa_oaep_encrypt(encrypt_cert, key) + return crypto_cms.ContentInfo( + { + "content_type": crypto_cms.ContentType("enveloped_data"), + "content": crypto_cms.EnvelopedData( + { + "version": crypto_cms.CMSVersion("v0"), + "recipient_infos": [ + crypto_cms.KeyTransRecipientInfo( + { + "version": crypto_cms.CMSVersion("v0"), + "rid": crypto_cms.RecipientIdentifier( + { + "issuer_and_serial_number": crypto_cms.IssuerAndSerialNumber( + { + "issuer": encrypt_cert.asn1[ + "tbs_certificate" + ]["issuer"], + "serial_number": encrypt_cert.asn1[ + "tbs_certificate" + ]["serial_number"], + } + ) + } + ), + "key_encryption_algorithm": crypto_cms.KeyEncryptionAlgorithm( + { + "algorithm": crypto_cms.KeyEncryptionAlgorithmId( + "aes128_wrap" + ) + } + ), + "encrypted_key": crypto_cms.OctetString(encrypted_key), + } + ) + ], + "encrypted_content_info": crypto_cms.EncryptedContentInfo( + { + "content_type": crypto_cms.ContentType("data"), + "content_encryption_algorithm": enc_alg_asn1, + "encrypted_content": encrypted_content, + } + ), + } + ), + } + ).dump() + + def test_compress(): """Test the compression and decompression functions.""" compressed_data = cms.compress_message(b"data") @@ -87,9 +151,22 @@ def test_encryption(): "aes_128_cbc", "aes_192_cbc", "aes_256_cbc", + "tripledes_192_cbc", + ] + + key_enc_algos = [ + "rsaes_oaep", + "rsaes_pkcs1v15", + ] + + encryption_algos = [ + (alg, key_algo) for alg in enc_algorithms for key_algo in key_enc_algos ] - for enc_algorithm in enc_algorithms: - encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert) + + for enc_algorithm, encryption_scheme in encryption_algos: + encrypted_data = cms.encrypt_message( + b"data", enc_algorithm, encrypt_cert, encryption_scheme + ) _, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key) assert decrypted_data == b"data" @@ -101,3 +178,17 @@ def test_encryption(): encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert) with pytest.raises(AS2Exception): cms.decrypt_message(encrypted_data, decrypt_key) + + # Test faulty key encryption algorithm + with pytest.raises( + AS2Exception, match="Unsupported Key Encryption Scheme: des_64_cbc" + ): + cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc") + + # Test unsupported key encryption algorithm + encrypted_data = _encrypted_data_with_faulty_key_algo() + with pytest.raises( + AS2Exception, + match="Failed to decrypt the payload: Could not extract decryption key.", + ): + cms.decrypt_message(encrypted_data, decrypt_key) From 69b13d5801c598673870925a3c419cd95bc3f51e Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Fri, 3 May 2024 10:00:30 +0200 Subject: [PATCH 2/2] Add specific error when MDN received, but Original Message was not found. Related to https://github.com/abhishek-ram/django-pyas2/issues/45 and will be implemented/used in django-pyas2. --- pyas2lib/as2.py | 5 +++++ pyas2lib/tests/test_advanced.py | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index cb320a3..9c74e54 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -943,6 +943,11 @@ def parse(self, raw_content, find_message_cb): # Call the find message callback which should return a Message instance orig_message = find_message_cb(self.orig_message_id, orig_recipient) + if not orig_message: + status = "failed/Failure" + details_status = "original-message-not-found" + return status, details_status + # Extract the headers and save it mdn_headers = {} for k, v in self.payload.items(): diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index fefda8d..a0aca23 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -396,6 +396,33 @@ def test_mdn_not_found(self): self.assertEqual(status, "failed/Failure") self.assertEqual(detailed_status, "mdn-not-found") + def test_mdn_original_message_not_found(self): + """Test that the MDN parser raises MDN not found when a non MDN message is passed.""" + self.partner.mdn_mode = as2.SYNCHRONOUS_MDN + self.out_message = as2.Message(self.org, self.partner) + self.out_message.build(self.test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = ( + self.out_message.headers_str + b"\r\n" + self.out_message.content + ) + in_message = as2.Message() + _, _, mdn = in_message.parse( + raw_out_message, + find_org_cb=self.find_org, + find_partner_cb=self.find_partner, + find_message_cb=lambda x, y: False, + ) + + # Parse the MDN + out_mdn = as2.Mdn() + status, detailed_status = out_mdn.parse( + mdn.headers_str + b"\r\n" + mdn.content, find_message_cb=lambda x, y: False + ) + + self.assertEqual(status, "failed/Failure") + self.assertEqual(detailed_status, "original-message-not-found") + def test_unsigned_mdn_sent_error(self): """Test the case where a signed mdn was expected but unsigned mdn was returned.""" self.partner.mdn_mode = as2.SYNCHRONOUS_MDN