diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 338084d..cb320a3 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -25,6 +25,7 @@ DIGEST_ALGORITHMS, EDIINT_FEATURES, ENCRYPTION_ALGORITHMS, + KEY_ENCRYPTION_ALGORITHMS, MDN_CONFIRM_TEXT, MDN_FAILED_TEXT, MDN_MODES, @@ -183,6 +184,9 @@ class Partner: :param sign_alg: The signing algorithm to be used for generating the signature. (default `rsassa_pkcs1v15`) + :param key_enc_alg: The key encryption algorithm to be used. + (default `rsaes_pkcs1v15`) + """ as2_name: str @@ -202,6 +206,7 @@ class Partner: ignore_self_signed: bool = True canonicalize_as_binary: bool = False sign_alg: str = "rsassa_pkcs1v15" + key_enc_alg: str = "rsaes_pkcs1v15" def __post_init__(self): """Run the post initialisation checks for this class.""" @@ -236,6 +241,12 @@ def __post_init__(self): f"must be one of {SIGNATUR_ALGORITHMS}" ) + if self.key_enc_alg and self.key_enc_alg not in KEY_ENCRYPTION_ALGORITHMS: + raise ImproperlyConfigured( + f"Unsupported Key Encryption Algorithm {self.key_enc_alg}, " + f"must be one of {KEY_ENCRYPTION_ALGORITHMS}" + ) + def load_verify_cert(self): """Load the verification certificate of the partner and returned the parsed cert.""" if self.validate_certs: diff --git a/pyas2lib/cms.py b/pyas2lib/cms.py index 0172980..3266ef0 100644 --- a/pyas2lib/cms.py +++ b/pyas2lib/cms.py @@ -65,12 +65,15 @@ def decompress_message(compressed_data): raise DecompressionError("Decompression failed with cause: {}".format(e)) from e -def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): +def encrypt_message( + data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15" +): """Function encrypts data and returns the generated ASN.1 :param data_to_encrypt: A byte string of the data to be encrypted :param enc_alg: The algorithm to be used for encrypting the data :param encryption_cert: The certificate to be used for encrypting the data + :param key_enc_alg: The algo for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep :return: A CMS ASN.1 byte string of the encrypted data. """ @@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): raise AS2Exception("Unsupported Encryption Algorithm") # Encrypt the key and build the ASN.1 message - encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + if key_enc_alg == "rsaes_pkcs1v15": + encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + elif key_enc_alg == "rsaes_oaep": + encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key) + else: + raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}") return cms.ContentInfo( { @@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): } ), "key_encryption_algorithm": cms.KeyEncryptionAlgorithm( - {"algorithm": cms.KeyEncryptionAlgorithmId("rsa")} + { + "algorithm": cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) + } ), "encrypted_key": cms.OctetString(encrypted_key), } @@ -199,47 +211,52 @@ def decrypt_message(encrypted_data, decryption_key): key_enc_alg = recipient_info["key_encryption_algorithm"]["algorithm"].native encrypted_key = recipient_info["encrypted_key"].native - if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsa" - ): - try: + try: + if cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_pkcs1v15"): key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key) - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: Could not extract decryption key." - ) from e - - alg = cms_content["content"]["encrypted_content_info"][ - "content_encryption_algorithm" - ] - encapsulated_data = cms_content["content"]["encrypted_content_info"][ - "encrypted_content" - ].native - try: - if alg["algorithm"].native == "rc4": - decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) - elif alg.encryption_cipher == "tripledes": - cipher = "tripledes_192_cbc" - decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "aes": - decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "rc2": - decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( - key, encapsulated_data, alg["parameters"]["iv"].native - ) - else: - raise AS2Exception("Unsupported Encryption Algorithm") - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: {}".format(e) - ) from e - else: - raise AS2Exception("Unsupported Encryption Algorithm") + elif cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_oaep"): + key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key) + else: + raise AS2Exception( + f"Unsupported Key Encryption Algorithm {key_enc_alg}" + ) + except Exception as e: + raise DecryptionError( + "Failed to decrypt the payload: Could not extract decryption key." + ) from e + + alg = cms_content["content"]["encrypted_content_info"][ + "content_encryption_algorithm" + ] + encapsulated_data = cms_content["content"]["encrypted_content_info"][ + "encrypted_content" + ].native + + try: + if alg["algorithm"].native == "rc4": + decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) + elif alg.encryption_cipher == "tripledes": + cipher = "tripledes_192_cbc" + decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "aes": + decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "rc2": + decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( + key, encapsulated_data, alg["parameters"]["iv"].native + ) + else: + raise AS2Exception("Unsupported Encryption Algorithm") + except Exception as e: + raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e else: raise DecryptionError("Encrypted data not found in ASN.1 ") diff --git a/pyas2lib/constants.py b/pyas2lib/constants.py index b5d4de2..f5456f7 100644 --- a/pyas2lib/constants.py +++ b/pyas2lib/constants.py @@ -32,3 +32,7 @@ "rsassa_pkcs1v15", "rsassa_pss", ) +KEY_ENCRYPTION_ALGORITHMS = ( + "rsaes_pkcs1v15", + "rsaes_oaep", +) diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index f010ecd..fefda8d 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -337,6 +337,9 @@ def test_partner_checks(self): with self.assertRaises(ImproperlyConfigured): as2.Partner("a partner", sign_alg="xyz") + with self.assertRaises(ImproperlyConfigured): + as2.Partner("a partner", key_enc_alg="xyz") + def test_message_checks(self): """Test the checks and other features of Message.""" msg = as2.Message() diff --git a/pyas2lib/tests/test_cms.py b/pyas2lib/tests/test_cms.py index 34655db..fa3f596 100644 --- a/pyas2lib/tests/test_cms.py +++ b/pyas2lib/tests/test_cms.py @@ -2,7 +2,9 @@ import os import pytest -from oscrypto import asymmetric +from oscrypto import asymmetric, symmetric, util + +from asn1crypto import algos, cms as crypto_cms, core from pyas2lib.as2 import Organization from pyas2lib import cms @@ -22,6 +24,68 @@ ).dump() +def encrypted_data_with_faulty_key_algo(): + with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp: + encrypt_cert = asymmetric.load_certificate(fp.read()) + enc_alg_list = "rc4_128_cbc".split("_") + cipher, key_length, _ = enc_alg_list[0], enc_alg_list[1], enc_alg_list[2] + key = util.rand_bytes(int(key_length) // 8) + algorithm_id = "1.2.840.113549.3.4" + encrypted_content = symmetric.rc4_encrypt(key, b"data") + enc_alg_asn1 = algos.EncryptionAlgorithm( + { + "algorithm": algorithm_id, + } + ) + encrypted_key = asymmetric.rsa_oaep_encrypt(encrypt_cert, key) + return crypto_cms.ContentInfo( + { + "content_type": crypto_cms.ContentType("enveloped_data"), + "content": crypto_cms.EnvelopedData( + { + "version": crypto_cms.CMSVersion("v0"), + "recipient_infos": [ + crypto_cms.KeyTransRecipientInfo( + { + "version": crypto_cms.CMSVersion("v0"), + "rid": crypto_cms.RecipientIdentifier( + { + "issuer_and_serial_number": crypto_cms.IssuerAndSerialNumber( + { + "issuer": encrypt_cert.asn1[ + "tbs_certificate" + ]["issuer"], + "serial_number": encrypt_cert.asn1[ + "tbs_certificate" + ]["serial_number"], + } + ) + } + ), + "key_encryption_algorithm": crypto_cms.KeyEncryptionAlgorithm( + { + "algorithm": crypto_cms.KeyEncryptionAlgorithmId( + "aes128_wrap" + ) + } + ), + "encrypted_key": crypto_cms.OctetString(encrypted_key), + } + ) + ], + "encrypted_content_info": crypto_cms.EncryptedContentInfo( + { + "content_type": crypto_cms.ContentType("data"), + "content_encryption_algorithm": enc_alg_asn1, + "encrypted_content": encrypted_content, + } + ), + } + ), + } + ).dump() + + def test_compress(): """Test the compression and decompression functions.""" compressed_data = cms.compress_message(b"data") @@ -87,9 +151,22 @@ def test_encryption(): "aes_128_cbc", "aes_192_cbc", "aes_256_cbc", + "tripledes_192_cbc", + ] + + key_enc_algos = [ + "rsaes_oaep", + "rsaes_pkcs1v15", ] - for enc_algorithm in enc_algorithms: - encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert) + + encryption_algos = [ + (alg, key_algo) for alg in enc_algorithms for key_algo in key_enc_algos + ] + + for enc_algorithm, encryption_scheme in encryption_algos: + encrypted_data = cms.encrypt_message( + b"data", enc_algorithm, encrypt_cert, encryption_scheme + ) _, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key) assert decrypted_data == b"data" @@ -101,3 +178,12 @@ def test_encryption(): encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert) with pytest.raises(AS2Exception): cms.decrypt_message(encrypted_data, decrypt_key) + + # Test faulty key encryption algorithm + with pytest.raises(AS2Exception): + cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc") + + # Test unsupported key encryption algorithm + encrypted_data = encrypted_data_with_faulty_key_algo() + with pytest.raises(AS2Exception): + cms.decrypt_message(encrypted_data, decrypt_key)