From 0b39518fcb687e3279d12815c5da840eae34f140 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 13:57:31 +0100 Subject: [PATCH] Adding key encryption algo option to support rsaes_oaep. --- pyas2lib/cms.py | 81 +++++++++++++++++++++++--------------- pyas2lib/tests/test_cms.py | 20 +++++++++- 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/pyas2lib/cms.py b/pyas2lib/cms.py index 0172980..8f95a11 100644 --- a/pyas2lib/cms.py +++ b/pyas2lib/cms.py @@ -65,12 +65,15 @@ def decompress_message(compressed_data): raise DecompressionError("Decompression failed with cause: {}".format(e)) from e -def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): +def encrypt_message( + data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15" +): """Function encrypts data and returns the generated ASN.1 :param data_to_encrypt: A byte string of the data to be encrypted :param enc_alg: The algorithm to be used for encrypting the data :param encryption_cert: The certificate to be used for encrypting the data + :param key_enc_alg: The encryption scheme for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep :return: A CMS ASN.1 byte string of the encrypted data. """ @@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): raise AS2Exception("Unsupported Encryption Algorithm") # Encrypt the key and build the ASN.1 message - encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + if key_enc_alg == "rsaes_pkcs1v15": + encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + elif key_enc_alg == "rsaes_oaep": + encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key) + else: + raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}") return cms.ContentInfo( { @@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): } ), "key_encryption_algorithm": cms.KeyEncryptionAlgorithm( - {"algorithm": cms.KeyEncryptionAlgorithmId("rsa")} + { + "algorithm": cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) + } ), "encrypted_key": cms.OctetString(encrypted_key), } @@ -200,7 +212,7 @@ def decrypt_message(encrypted_data, decryption_key): encrypted_key = recipient_info["encrypted_key"].native if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsa" + "rsaes_pkcs1v15" ): try: key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key) @@ -208,38 +220,45 @@ def decrypt_message(encrypted_data, decryption_key): raise DecryptionError( "Failed to decrypt the payload: Could not extract decryption key." ) from e - - alg = cms_content["content"]["encrypted_content_info"][ - "content_encryption_algorithm" - ] - encapsulated_data = cms_content["content"]["encrypted_content_info"][ - "encrypted_content" - ].native - + elif cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( + "rsaes_oaep" + ): try: - if alg["algorithm"].native == "rc4": - decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) - elif alg.encryption_cipher == "tripledes": - cipher = "tripledes_192_cbc" - decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "aes": - decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "rc2": - decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( - key, encapsulated_data, alg["parameters"]["iv"].native - ) - else: - raise AS2Exception("Unsupported Encryption Algorithm") + key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key) except Exception as e: raise DecryptionError( - "Failed to decrypt the payload: {}".format(e) + "Failed to decrypt the payload: Could not extract decryption key." ) from e else: - raise AS2Exception("Unsupported Encryption Algorithm") + raise AS2Exception(f"Unsupported Key Encryption Algorithm {key_enc_alg}") + + alg = cms_content["content"]["encrypted_content_info"][ + "content_encryption_algorithm" + ] + encapsulated_data = cms_content["content"]["encrypted_content_info"][ + "encrypted_content" + ].native + + try: + if alg["algorithm"].native == "rc4": + decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) + elif alg.encryption_cipher == "tripledes": + cipher = "tripledes_192_cbc" + decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "aes": + decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "rc2": + decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( + key, encapsulated_data, alg["parameters"]["iv"].native + ) + else: + raise AS2Exception("Unsupported Encryption Algorithm") + except Exception as e: + raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e else: raise DecryptionError("Encrypted data not found in ASN.1 ") diff --git a/pyas2lib/tests/test_cms.py b/pyas2lib/tests/test_cms.py index 34655db..a324fa7 100644 --- a/pyas2lib/tests/test_cms.py +++ b/pyas2lib/tests/test_cms.py @@ -88,8 +88,20 @@ def test_encryption(): "aes_192_cbc", "aes_256_cbc", ] - for enc_algorithm in enc_algorithms: - encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert) + + key_enc_algos = [ + "rsaes_oaep", + "rsaes_pkcs1v15", + ] + + encryption_algos = [ + (alg, scheme) for alg, scheme in zip(enc_algorithms, key_enc_algos) + ] + + for enc_algorithm, encryption_scheme in encryption_algos: + encrypted_data = cms.encrypt_message( + b"data", enc_algorithm, encrypt_cert, encryption_scheme + ) _, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key) assert decrypted_data == b"data" @@ -101,3 +113,7 @@ def test_encryption(): encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert) with pytest.raises(AS2Exception): cms.decrypt_message(encrypted_data, decrypt_key) + + # Test faulty key encryption algorithm + with pytest.raises(AS2Exception): + cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc")