From f3d902fd611bb0e34ae6963dc93d44b27fe39741 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Wed, 11 Jan 2023 17:32:48 +0100 Subject: [PATCH 01/19] Changes for cargoo --- pyas2lib/as2.py | 53 +++++++++++++++++++++++++++++++++++++++++-------- setup.py | 2 +- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index cb058ad..70f1158 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -522,7 +522,14 @@ def _decompress_data(self, payload): return False, payload - def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None): + def parse( + self, + raw_content, + find_org_cb=None, + find_partner_cb=None, + find_message_cb=None, + find_org_partner_cb=None, + ): """Function parses the RAW AS2 message; decrypts, verifies and decompresses it and extracts the payload. @@ -530,18 +537,26 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) A byte string of the received HTTP headers followed by the body. :param find_org_cb: - A callback the returns an Organization object if exists. The - as2-to header value is passed as an argument to it. + A conditional callback the returns an Organization object if exists. The + as2-to header value is passed as an argument to it. Must be provided + when find_partner_cb is provided and find_org_partner_cb is None :param find_partner_cb: - A callback the returns an Partner object if exists. The - as2-from header value is passed as an argument to it. + An conditional callback the returns an Partner object if exists. The + as2-from header value is passed as an argument to it. Must be provided + when find_org_cb is provided and find_org_partner_cb is None. :param find_message_cb: An optional callback the returns an Message object if exists in order to check for duplicates. The message id and partner id is passed as arguments to it. + :param find_org_partner_cb: + A conditional callback that return Organization object and + Partner object if exist. The as2-to and as2-from header value + are passed as an argument to it. Must be provided + when find_org_cb and find_org_partner_cb is None. + :return: A three element tuple containing (status, (exception, traceback) , mdn). The status is a string indicating the status of the @@ -550,6 +565,18 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) the partner did not request it. """ + # Validate passed arguments + if not any( + [ + find_org_cb and find_partner_cb and not find_org_partner_cb, + find_org_partner_cb and not find_partner_cb and not find_org_cb, + ] + ): + raise TypeError( + "Incorrect arguments passed: either find_org_cb and find_partner_cb " + "or only find_org_partner_cb must be passed." + ) + # Parse the raw MIME message and extract its content and headers status, detailed_status, exception, mdn = "processed", None, (None, None), None self.payload = parse_mime(raw_content) @@ -563,12 +590,17 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) try: # Get the organization and partner for this transmission org_id = unquote_as2name(as2_headers["as2-to"]) - self.receiver = find_org_cb(org_id) + partner_id = unquote_as2name(as2_headers["as2-from"]) + + if find_org_partner_cb: + self.receiver, self.sender = find_org_partner_cb(org_id, partner_id) + else: + self.receiver = find_org_cb(org_id) + self.sender = find_partner_cb(partner_id) + if not self.receiver: raise PartnerNotFound(f"Unknown AS2 organization with id {org_id}") - partner_id = unquote_as2name(as2_headers["as2-from"]) - self.sender = find_partner_cb(partner_id) if not self.sender: raise PartnerNotFound(f"Unknown AS2 partner with id {partner_id}") @@ -898,6 +930,11 @@ def parse(self, raw_content, find_message_cb): # Call the find message callback which should return a Message instance orig_message = find_message_cb(self.orig_message_id, orig_recipient) + if not orig_message: + status = "failed/Failure" + details_status = "original-message-not-found" + return status, details_status + # Extract the headers and save it mdn_headers = {} for k, v in self.payload.items(): diff --git a/setup.py b/setup.py index 7bc4ff0..e5f350d 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ install_requires = [ "asn1crypto==1.5.1", "oscrypto==1.3.0", - "pyOpenSSL==21.0.0", + "pyOpenSSL==23.0.0", ] tests_require = [ From 47e62faf9fd0c77c91eb21421dc530fbc5d9343d Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Sun, 10 Mar 2024 22:48:47 +0100 Subject: [PATCH 02/19] Adding Async Support for Callbacks --- pyas2lib/as2.py | 46 +++++++++++++++++++++++++++------ pyas2lib/tests/test_advanced.py | 40 ++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 70f1158..11760b7 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -10,6 +10,8 @@ from email import utils as email_utils from email.mime.multipart import MIMEMultipart from oscrypto import asymmetric +import asyncio +import inspect from pyas2lib.cms import ( compress_message, @@ -522,7 +524,7 @@ def _decompress_data(self, payload): return False, payload - def parse( + async def aparse( self, raw_content, find_org_cb=None, @@ -593,10 +595,23 @@ def parse( partner_id = unquote_as2name(as2_headers["as2-from"]) if find_org_partner_cb: - self.receiver, self.sender = find_org_partner_cb(org_id, partner_id) + if inspect.iscoroutinefunction(find_org_partner_cb): + self.receiver, self.sender = await find_org_partner_cb(org_id, partner_id) + else: + self.receiver, self.sender = find_org_partner_cb(org_id, partner_id) + else: - self.receiver = find_org_cb(org_id) - self.sender = find_partner_cb(partner_id) + if find_org_cb: + if inspect.iscoroutinefunction(find_org_cb): + self.receiver = await find_org_cb(org_id) + else: + self.receiver = find_org_cb(org_id) + + if find_partner_cb: + if inspect.iscoroutinefunction(find_partner_cb): + self.sender = await find_partner_cb(partner_id) + else: + self.sender = find_partner_cb(partner_id) if not self.receiver: raise PartnerNotFound(f"Unknown AS2 organization with id {org_id}") @@ -604,10 +619,15 @@ def parse( if not self.sender: raise PartnerNotFound(f"Unknown AS2 partner with id {partner_id}") - if find_message_cb and find_message_cb(self.message_id, partner_id): - raise DuplicateDocument( - "Duplicate message received, message with this ID already processed." - ) + if find_message_cb: + if inspect.iscoroutinefunction(find_message_cb): + message_exists = await find_message_cb(self.message_id, partner_id) + else: + message_exists = find_message_cb(self.message_id, partner_id) + if message_exists: + raise DuplicateDocument( + "Duplicate message received, message with this ID already processed." + ) if ( self.sender.encrypt @@ -728,6 +748,16 @@ def parse( return status, exception, mdn + def parse(self, *args, **kwargs): + """ + A synchronous wrapper for the asynchronous parse method. + It runs the parse coroutine in an event loop and returns the result. + """ + loop = asyncio.get_event_loop() + if loop.is_running(): + raise RuntimeError("Cannot run synchronous parse within an already running event loop.") + return loop.run_until_complete(self.aparse(*args, **kwargs)) + class Mdn: """Class for handling AS2 MDNs. Includes functions for both diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index 0bc6db7..1b8f6fd 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -9,6 +9,7 @@ from pyas2lib.exceptions import ImproperlyConfigured from pyas2lib.tests import Pyas2TestCase, TEST_DIR +import asyncio class TestAdvanced(Pyas2TestCase): def setUp(self): @@ -515,6 +516,36 @@ def test_final_recipient_fallback(self): self.assertEqual(message_recipient, self.partner.as2_name) + @pytest.mark.asyncio + async def test_duplicate_message_async(self): + """Test case where a duplicate message is sent to the partner asynchronously""" + + # Build an As2 message to be transmitted to partner + self.partner.sign = True + self.partner.encrypt = True + self.partner.mdn_mode = as2.SYNCHRONOUS_MDN + self.out_message = as2.Message(self.org, self.partner) + self.out_message.build(self.test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = ( + self.out_message.headers_str + b"\r\n" + self.out_message.content + ) + in_message = as2.Message() + _, _, mdn = await in_message.parse( + raw_out_message, + find_org_cb=self.afind_org, + find_partner_cb=self.afind_partner, + find_message_cb=self.afind_duplicate_message, + ) + + out_mdn = as2.Mdn() + status, detailed_status = await out_mdn.parse( + mdn.headers_str + b"\r\n" + mdn.content, find_message_cb=self.afind_message + ) + self.assertEqual(status, "processed/Warning") + self.assertEqual(detailed_status, "duplicate-document") + def find_org(self, headers): return self.org @@ -524,6 +555,15 @@ def find_partner(self, headers): def find_message(self, message_id, message_recipient): return self.out_message + async def afind_org(self, headers): + return self.org + + async def afind_partner(self, headers): + return self.partner + + async def afind_duplicate_message(self, message_id, message_recipient): + return True + class SterlingIntegratorTest(Pyas2TestCase): def setUp(self): From 91fe0570c0c24ee7af7c2c18b8d0e4101b90cafc Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 08:23:19 +0100 Subject: [PATCH 03/19] Adding Async Support for Callbacks --- pyas2lib/tests/test_advanced.py | 27 +++++++++++++++++++++++++-- pyas2lib/tests/test_basic.py | 27 +++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index 1b8f6fd..2fd23f6 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -518,7 +518,7 @@ def test_final_recipient_fallback(self): @pytest.mark.asyncio async def test_duplicate_message_async(self): - """Test case where a duplicate message is sent to the partner asynchronously""" + """Test case where a duplicate message is sent to the partner using async callbacks""" # Build an As2 message to be transmitted to partner self.partner.sign = True @@ -546,6 +546,28 @@ async def test_duplicate_message_async(self): self.assertEqual(status, "processed/Warning") self.assertEqual(detailed_status, "duplicate-document") + @pytest.mark.asyncio + async def test_async_partnership(self): + """Test Async Partnership callback with Unencrypted Unsigned Uncompressed Message""" + + # Build an As2 message to be transmitted to partner + out_message = as2.Message(self.org, self.partner) + out_message.build(self.test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + # Parse the generated AS2 message as the partner + in_message = as2.Message() + status, _, _ = in_message.parse( + raw_out_message, + find_org_cb=self.find_org, + find_partner_cb=self.afind_org_partner, + ) + + # Compare contents of the input and output messages + self.assertEqual(status, "processed") + self.assertEqual(self.test_data, in_message.content) + + def find_org(self, headers): return self.org @@ -564,7 +586,8 @@ async def afind_partner(self, headers): async def afind_duplicate_message(self, message_id, message_recipient): return True - + async def afind_org_partner(self, as2_org, as2_partner): + return self.org, self.partner class SterlingIntegratorTest(Pyas2TestCase): def setUp(self): self.org = as2.Organization( diff --git a/pyas2lib/tests/test_basic.py b/pyas2lib/tests/test_basic.py index 3eadbb8..afd4aa6 100644 --- a/pyas2lib/tests/test_basic.py +++ b/pyas2lib/tests/test_basic.py @@ -183,6 +183,30 @@ def test_encrypted_signed_compressed_message(self): self.assertEqual(out_message.mic, in_message.mic) self.assertEqual(self.test_data.splitlines(), in_message.content.splitlines()) + def test_encrypted_signed_message_partnership(self): + """Test Encrypted Signed Uncompressed Message with Partnership""" + + # Build an As2 message to be transmitted to partner + self.partner.sign = True + self.partner.encrypt = True + out_message = as2.Message(self.org, self.partner) + out_message.build(self.test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + # Parse the generated AS2 message as the partner + in_message = as2.Message() + status, _, _ = in_message.parse( + raw_out_message, + find_org_partner_cb=self.find_org_partner, + ) + + # Compare the mic contents of the input and output messages + self.assertEqual(status, "processed") + self.assertTrue(in_message.signed) + self.assertTrue(in_message.encrypted) + self.assertEqual(out_message.mic, in_message.mic) + self.assertEqual(self.test_data.splitlines(), in_message.content.splitlines()) + def test_plain_message_with_domain(self): """Test Message building with an org domain""" @@ -205,3 +229,6 @@ def find_org(self, as2_id): def find_partner(self, as2_id): return self.partner + + def find_org_partner(self, as2_org, as2_partner): + return self.org, self.partner From 2e4c253276bfc56071684f03ae4228c0e16b3646 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 11:24:38 +0100 Subject: [PATCH 04/19] Add Changelog entries --- AUTHORS.md | 2 +- CHANGELOG.md | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/AUTHORS.md b/AUTHORS.md index f8b47d9..e1e95c1 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -1,5 +1,5 @@ - Abhishek Ram @abhishek-ram -- Chad Gates @chadgates +- Wassilios Lytras @chadgates - Bruno Ribeiro da Silva @loop0 - Robin C Samuel @robincsamuel - Brandon Joyce @brandonjoyce diff --git a/CHANGELOG.md b/CHANGELOG.md index 356f159..9131c81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,13 @@ # Release History +## 1.4.4 - 2024- + +* feat: added partnership lookup function +* feat: added support for async callback functions + ## 1.4.3 - 2023-01-25 -* fix: update pyopenssl version to resovle pyca/cryptography#7959 +* fix: update pyopenssl version to resolve pyca/cryptography#7959 ## 1.4.2 - 2022-12-11 From 0b39518fcb687e3279d12815c5da840eae34f140 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 13:57:31 +0100 Subject: [PATCH 05/19] Adding key encryption algo option to support rsaes_oaep. --- pyas2lib/cms.py | 81 +++++++++++++++++++++++--------------- pyas2lib/tests/test_cms.py | 20 +++++++++- 2 files changed, 68 insertions(+), 33 deletions(-) diff --git a/pyas2lib/cms.py b/pyas2lib/cms.py index 0172980..8f95a11 100644 --- a/pyas2lib/cms.py +++ b/pyas2lib/cms.py @@ -65,12 +65,15 @@ def decompress_message(compressed_data): raise DecompressionError("Decompression failed with cause: {}".format(e)) from e -def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): +def encrypt_message( + data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15" +): """Function encrypts data and returns the generated ASN.1 :param data_to_encrypt: A byte string of the data to be encrypted :param enc_alg: The algorithm to be used for encrypting the data :param encryption_cert: The certificate to be used for encrypting the data + :param key_enc_alg: The encryption scheme for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep :return: A CMS ASN.1 byte string of the encrypted data. """ @@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): raise AS2Exception("Unsupported Encryption Algorithm") # Encrypt the key and build the ASN.1 message - encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + if key_enc_alg == "rsaes_pkcs1v15": + encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + elif key_enc_alg == "rsaes_oaep": + encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key) + else: + raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}") return cms.ContentInfo( { @@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): } ), "key_encryption_algorithm": cms.KeyEncryptionAlgorithm( - {"algorithm": cms.KeyEncryptionAlgorithmId("rsa")} + { + "algorithm": cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) + } ), "encrypted_key": cms.OctetString(encrypted_key), } @@ -200,7 +212,7 @@ def decrypt_message(encrypted_data, decryption_key): encrypted_key = recipient_info["encrypted_key"].native if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsa" + "rsaes_pkcs1v15" ): try: key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key) @@ -208,38 +220,45 @@ def decrypt_message(encrypted_data, decryption_key): raise DecryptionError( "Failed to decrypt the payload: Could not extract decryption key." ) from e - - alg = cms_content["content"]["encrypted_content_info"][ - "content_encryption_algorithm" - ] - encapsulated_data = cms_content["content"]["encrypted_content_info"][ - "encrypted_content" - ].native - + elif cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( + "rsaes_oaep" + ): try: - if alg["algorithm"].native == "rc4": - decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) - elif alg.encryption_cipher == "tripledes": - cipher = "tripledes_192_cbc" - decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "aes": - decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "rc2": - decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( - key, encapsulated_data, alg["parameters"]["iv"].native - ) - else: - raise AS2Exception("Unsupported Encryption Algorithm") + key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key) except Exception as e: raise DecryptionError( - "Failed to decrypt the payload: {}".format(e) + "Failed to decrypt the payload: Could not extract decryption key." ) from e else: - raise AS2Exception("Unsupported Encryption Algorithm") + raise AS2Exception(f"Unsupported Key Encryption Algorithm {key_enc_alg}") + + alg = cms_content["content"]["encrypted_content_info"][ + "content_encryption_algorithm" + ] + encapsulated_data = cms_content["content"]["encrypted_content_info"][ + "encrypted_content" + ].native + + try: + if alg["algorithm"].native == "rc4": + decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) + elif alg.encryption_cipher == "tripledes": + cipher = "tripledes_192_cbc" + decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "aes": + decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "rc2": + decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( + key, encapsulated_data, alg["parameters"]["iv"].native + ) + else: + raise AS2Exception("Unsupported Encryption Algorithm") + except Exception as e: + raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e else: raise DecryptionError("Encrypted data not found in ASN.1 ") diff --git a/pyas2lib/tests/test_cms.py b/pyas2lib/tests/test_cms.py index 34655db..a324fa7 100644 --- a/pyas2lib/tests/test_cms.py +++ b/pyas2lib/tests/test_cms.py @@ -88,8 +88,20 @@ def test_encryption(): "aes_192_cbc", "aes_256_cbc", ] - for enc_algorithm in enc_algorithms: - encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert) + + key_enc_algos = [ + "rsaes_oaep", + "rsaes_pkcs1v15", + ] + + encryption_algos = [ + (alg, scheme) for alg, scheme in zip(enc_algorithms, key_enc_algos) + ] + + for enc_algorithm, encryption_scheme in encryption_algos: + encrypted_data = cms.encrypt_message( + b"data", enc_algorithm, encrypt_cert, encryption_scheme + ) _, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key) assert decrypted_data == b"data" @@ -101,3 +113,7 @@ def test_encryption(): encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert) with pytest.raises(AS2Exception): cms.decrypt_message(encrypted_data, decrypt_key) + + # Test faulty key encryption algorithm + with pytest.raises(AS2Exception): + cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc") From 43b84e3d52a3019d1c3c4b18796d5e651af903a9 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 14:00:00 +0100 Subject: [PATCH 06/19] Adding key encryption algo option to support rsaes_oaep. --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9131c81..eb7674e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ * feat: added partnership lookup function * feat: added support for async callback functions +* feat: added support for optional key encryption algorithm rsaes_oaep for encryption and decryption ## 1.4.3 - 2023-01-25 From e02be3862f71bf27b4ba28f185d1d29cbdc40e4a Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 14:34:04 +0100 Subject: [PATCH 07/19] Adding key encryption algo option to support rsaes_oaep. --- pyas2lib/tests/test_advanced.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index 2fd23f6..490b74a 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -11,6 +11,7 @@ import asyncio + class TestAdvanced(Pyas2TestCase): def setUp(self): self.org = as2.Organization( @@ -529,7 +530,7 @@ async def test_duplicate_message_async(self): # Parse the generated AS2 message as the partner raw_out_message = ( - self.out_message.headers_str + b"\r\n" + self.out_message.content + self.out_message.headers_str + b"\r\n" + self.out_message.content ) in_message = as2.Message() _, _, mdn = await in_message.parse( @@ -567,7 +568,6 @@ async def test_async_partnership(self): self.assertEqual(status, "processed") self.assertEqual(self.test_data, in_message.content) - def find_org(self, headers): return self.org @@ -588,6 +588,8 @@ async def afind_duplicate_message(self, message_id, message_recipient): async def afind_org_partner(self, as2_org, as2_partner): return self.org, self.partner + + class SterlingIntegratorTest(Pyas2TestCase): def setUp(self): self.org = as2.Organization( From 5dd639fe5e6bdbb2674222299e9dff8383c6b198 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 15:46:45 +0100 Subject: [PATCH 08/19] Formatting --- pyas2lib/as2.py | 17 +++++++++++------ pyas2lib/cms.py | 4 ++-- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 5f0c79c..4e580dd 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -1,7 +1,9 @@ """Define the core functions/classes of the pyas2 package.""" -import logging -import hashlib +import asyncio import binascii +import hashlib +import inspect +import logging import traceback from dataclasses import dataclass from email import encoders @@ -9,9 +11,8 @@ from email import message_from_bytes as parse_mime from email import utils as email_utils from email.mime.multipart import MIMEMultipart + from oscrypto import asymmetric -import asyncio -import inspect from pyas2lib.cms import ( compress_message, @@ -613,7 +614,9 @@ async def aparse( if find_org_partner_cb: if inspect.iscoroutinefunction(find_org_partner_cb): - self.receiver, self.sender = await find_org_partner_cb(org_id, partner_id) + self.receiver, self.sender = await find_org_partner_cb( + org_id, partner_id + ) else: self.receiver, self.sender = find_org_partner_cb(org_id, partner_id) @@ -772,7 +775,9 @@ def parse(self, *args, **kwargs): """ loop = asyncio.get_event_loop() if loop.is_running(): - raise RuntimeError("Cannot run synchronous parse within an already running event loop.") + raise RuntimeError( + "Cannot run synchronous parse within an already running event loop." + ) return loop.run_until_complete(self.aparse(*args, **kwargs)) diff --git a/pyas2lib/cms.py b/pyas2lib/cms.py index 8f95a11..db3d148 100644 --- a/pyas2lib/cms.py +++ b/pyas2lib/cms.py @@ -3,7 +3,7 @@ import zlib from datetime import datetime, timezone -from asn1crypto import cms, core, algos +from asn1crypto import algos, cms, core from asn1crypto.cms import SMIMECapabilityIdentifier from oscrypto import asymmetric, symmetric, util @@ -73,7 +73,7 @@ def encrypt_message( :param data_to_encrypt: A byte string of the data to be encrypted :param enc_alg: The algorithm to be used for encrypting the data :param encryption_cert: The certificate to be used for encrypting the data - :param key_enc_alg: The encryption scheme for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep + :param key_enc_alg: The algo for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep :return: A CMS ASN.1 byte string of the encrypted data. """ From c0508299e24f6d503cac0ddab7c0124a95b7bbca Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 15:49:40 +0100 Subject: [PATCH 09/19] Add pytest-asyncio testing capabilities --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e4d55b8..37128a7 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,8 @@ ] tests_require = [ - "pytest==6.2.5", + "pytest==7.4.4", + "pytest-asyncio==0.21.1", "toml==0.10.2", "pytest-cov==2.8.1", "coverage==5.0.4", From a93d4c17deacfc09f684a606b148fd67751e1309 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 20:25:06 +0100 Subject: [PATCH 10/19] Add Async callback to MDN --- pyas2lib/as2.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 4e580dd..ea288b0 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -776,7 +776,7 @@ def parse(self, *args, **kwargs): loop = asyncio.get_event_loop() if loop.is_running(): raise RuntimeError( - "Cannot run synchronous parse within an already running event loop." + "Cannot run synchronous parse within an already running event loop, use aparse." ) return loop.run_until_complete(self.aparse(*args, **kwargs)) @@ -955,7 +955,7 @@ def build( f"content:\n {mime_to_bytes(self.payload)}" ) - def parse(self, raw_content, find_message_cb): + async def aparse(self, raw_content, find_message_cb): """Function parses the RAW AS2 MDN, verifies it and extracts the processing status of the orginal AS2 message. @@ -980,7 +980,13 @@ def parse(self, raw_content, find_message_cb): self.orig_message_id, orig_recipient = self.detect_mdn() # Call the find message callback which should return a Message instance - orig_message = find_message_cb(self.orig_message_id, orig_recipient) + if find_message_cb: + if inspect.iscoroutinefunction(find_message_cb): + orig_message = await find_message_cb( + self.orig_message_id, orig_recipient + ) + else: + orig_message = find_message_cb(self.orig_message_id, orig_recipient) if not orig_message: status = "failed/Failure" @@ -1063,6 +1069,18 @@ def parse(self, raw_content, find_message_cb): logger.error(f"Failed to parse AS2 MDN\n: {traceback.format_exc()}") return status, detailed_status + def parse(self, *args, **kwargs): + """ + A synchronous wrapper for the asynchronous parse method. + It runs the parse coroutine in an event loop and returns the result. + """ + loop = asyncio.get_event_loop() + if loop.is_running(): + raise RuntimeError( + "Cannot run synchronous parse within an already running event loop, use aparse." + ) + return loop.run_until_complete(self.aparse(*args, **kwargs)) + def detect_mdn(self): """Function checks if the received raw message is an AS2 MDN or not. From 23a8c395f23ba30e5a2dd58fc2f94ed40d7a36a3 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 20:25:21 +0100 Subject: [PATCH 11/19] Async Tests --- pyas2lib/tests/test_advanced.py | 65 ----------------------- pyas2lib/tests/test_async.py | 92 +++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 65 deletions(-) create mode 100644 pyas2lib/tests/test_async.py diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index 490b74a..0bc6db7 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -9,8 +9,6 @@ from pyas2lib.exceptions import ImproperlyConfigured from pyas2lib.tests import Pyas2TestCase, TEST_DIR -import asyncio - class TestAdvanced(Pyas2TestCase): def setUp(self): @@ -517,57 +515,6 @@ def test_final_recipient_fallback(self): self.assertEqual(message_recipient, self.partner.as2_name) - @pytest.mark.asyncio - async def test_duplicate_message_async(self): - """Test case where a duplicate message is sent to the partner using async callbacks""" - - # Build an As2 message to be transmitted to partner - self.partner.sign = True - self.partner.encrypt = True - self.partner.mdn_mode = as2.SYNCHRONOUS_MDN - self.out_message = as2.Message(self.org, self.partner) - self.out_message.build(self.test_data) - - # Parse the generated AS2 message as the partner - raw_out_message = ( - self.out_message.headers_str + b"\r\n" + self.out_message.content - ) - in_message = as2.Message() - _, _, mdn = await in_message.parse( - raw_out_message, - find_org_cb=self.afind_org, - find_partner_cb=self.afind_partner, - find_message_cb=self.afind_duplicate_message, - ) - - out_mdn = as2.Mdn() - status, detailed_status = await out_mdn.parse( - mdn.headers_str + b"\r\n" + mdn.content, find_message_cb=self.afind_message - ) - self.assertEqual(status, "processed/Warning") - self.assertEqual(detailed_status, "duplicate-document") - - @pytest.mark.asyncio - async def test_async_partnership(self): - """Test Async Partnership callback with Unencrypted Unsigned Uncompressed Message""" - - # Build an As2 message to be transmitted to partner - out_message = as2.Message(self.org, self.partner) - out_message.build(self.test_data) - raw_out_message = out_message.headers_str + b"\r\n" + out_message.content - - # Parse the generated AS2 message as the partner - in_message = as2.Message() - status, _, _ = in_message.parse( - raw_out_message, - find_org_cb=self.find_org, - find_partner_cb=self.afind_org_partner, - ) - - # Compare contents of the input and output messages - self.assertEqual(status, "processed") - self.assertEqual(self.test_data, in_message.content) - def find_org(self, headers): return self.org @@ -577,18 +524,6 @@ def find_partner(self, headers): def find_message(self, message_id, message_recipient): return self.out_message - async def afind_org(self, headers): - return self.org - - async def afind_partner(self, headers): - return self.partner - - async def afind_duplicate_message(self, message_id, message_recipient): - return True - - async def afind_org_partner(self, as2_org, as2_partner): - return self.org, self.partner - class SterlingIntegratorTest(Pyas2TestCase): def setUp(self): diff --git a/pyas2lib/tests/test_async.py b/pyas2lib/tests/test_async.py new file mode 100644 index 0000000..57acb66 --- /dev/null +++ b/pyas2lib/tests/test_async.py @@ -0,0 +1,92 @@ +import pytest +from pyas2lib import as2 +import os + +from pyas2lib.tests import TEST_DIR + +with open(os.path.join(TEST_DIR, "payload.txt"), "rb") as fp: + test_data = fp.read() + +with open(os.path.join(TEST_DIR, "cert_test.p12"), "rb") as fp: + private_key = fp.read() + +with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp: + public_key = fp.read() + +org = as2.Organization( + as2_name="some_organization", + sign_key=private_key, + sign_key_pass="test", + decrypt_key=private_key, + decrypt_key_pass="test", +) +partner = as2.Partner( + as2_name="some_partner", + verify_cert=public_key, + encrypt_cert=public_key, +) + + +async def afind_org(headers): + return org + + +async def afind_partner(headers): + return partner + + +async def afind_duplicate_message(message_id, message_recipient): + return True + + +async def afind_org_partner(as2_org, as2_partner): + return org, partner + + +@pytest.mark.asyncio +async def test_duplicate_message_async(): + """Test case where a duplicate message is sent to the partner using async callbacks""" + + # Build an As2 message to be transmitted to partner + partner.sign = True + partner.encrypt = True + partner.mdn_mode = as2.SYNCHRONOUS_MDN + out_message = as2.Message(org, partner) + out_message.build(test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + in_message = as2.Message() + _, _, mdn = await in_message.aparse( + raw_out_message, + find_org_cb=afind_org, + find_partner_cb=afind_partner, + find_message_cb=afind_duplicate_message, + ) + + out_mdn = as2.Mdn() + status, detailed_status = await out_mdn.aparse( + mdn.headers_str + b"\r\n" + mdn.content, + find_message_cb=lambda x, y: out_message, + ) + assert status == "processed/Warning" + assert detailed_status == "duplicate-document" + + +@pytest.mark.asyncio +async def test_async_partnership(): + """Test Async Partnership callback""" + + # Build an As2 message to be transmitted to partner + out_message = as2.Message(org, partner) + out_message.build(test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + # Parse the generated AS2 message as the partner + in_message = as2.Message() + status, _, _ = await in_message.aparse( + raw_out_message, find_org_partner_cb=afind_org_partner + ) + + # Compare contents of the input and output messages + assert status == "processed" From c2923023aa9bb31ffdcd8fffb81a35a21d0fa967 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Tue, 19 Mar 2024 17:13:36 +0100 Subject: [PATCH 12/19] Adding Tests --- pyas2lib/cms.py | 34 +++++++-------- pyas2lib/tests/test_advanced.py | 27 ++++++++++++ pyas2lib/tests/test_async.py | 35 +++++++++++++++- pyas2lib/tests/test_cms.py | 74 ++++++++++++++++++++++++++++++++- 4 files changed, 149 insertions(+), 21 deletions(-) diff --git a/pyas2lib/cms.py b/pyas2lib/cms.py index db3d148..8366bb6 100644 --- a/pyas2lib/cms.py +++ b/pyas2lib/cms.py @@ -211,26 +211,24 @@ def decrypt_message(encrypted_data, decryption_key): key_enc_alg = recipient_info["key_encryption_algorithm"]["algorithm"].native encrypted_key = recipient_info["encrypted_key"].native - if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsaes_pkcs1v15" - ): - try: + try: + if cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_pkcs1v15"): key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key) - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: Could not extract decryption key." - ) from e - elif cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsaes_oaep" - ): - try: + + elif cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_oaep"): key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key) - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: Could not extract decryption key." - ) from e - else: - raise AS2Exception(f"Unsupported Key Encryption Algorithm {key_enc_alg}") + else: + raise AS2Exception( + f"Unsupported Key Encryption Algorithm {key_enc_alg}" + ) + except Exception as e: + raise DecryptionError( + "Failed to decrypt the payload: Could not extract decryption key." + ) from e alg = cms_content["content"]["encrypted_content_info"][ "content_encryption_algorithm" diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index 0bc6db7..836c90b 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -390,6 +390,33 @@ def test_mdn_not_found(self): self.assertEqual(status, "failed/Failure") self.assertEqual(detailed_status, "mdn-not-found") + def test_mdn_original_message_not_found(self): + """Test that the MDN parser raises MDN not found when a non MDN message is passed.""" + self.partner.mdn_mode = as2.SYNCHRONOUS_MDN + self.out_message = as2.Message(self.org, self.partner) + self.out_message.build(self.test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = ( + self.out_message.headers_str + b"\r\n" + self.out_message.content + ) + in_message = as2.Message() + _, _, mdn = in_message.parse( + raw_out_message, + find_org_cb=self.find_org, + find_partner_cb=self.find_partner, + find_message_cb=lambda x, y: False, + ) + + # Parse the MDN + out_mdn = as2.Mdn() + status, detailed_status = out_mdn.parse( + mdn.headers_str + b"\r\n" + mdn.content, find_message_cb=lambda x, y: False + ) + + self.assertEqual(status, "failed/Failure") + self.assertEqual(detailed_status, "original-message-not-found") + def test_unsigned_mdn_sent_error(self): """Test the case where a signed mdn was expected but unsigned mdn was returned.""" self.partner.mdn_mode = as2.SYNCHRONOUS_MDN diff --git a/pyas2lib/tests/test_async.py b/pyas2lib/tests/test_async.py index 57acb66..047e689 100644 --- a/pyas2lib/tests/test_async.py +++ b/pyas2lib/tests/test_async.py @@ -54,6 +54,9 @@ async def test_duplicate_message_async(): out_message = as2.Message(org, partner) out_message.build(test_data) + async def afind_message(message_id, message_recipient): + return out_message + # Parse the generated AS2 message as the partner raw_out_message = out_message.headers_str + b"\r\n" + out_message.content in_message = as2.Message() @@ -67,7 +70,7 @@ async def test_duplicate_message_async(): out_mdn = as2.Mdn() status, detailed_status = await out_mdn.aparse( mdn.headers_str + b"\r\n" + mdn.content, - find_message_cb=lambda x, y: out_message, + find_message_cb=afind_message, ) assert status == "processed/Warning" assert detailed_status == "duplicate-document" @@ -90,3 +93,33 @@ async def test_async_partnership(): # Compare contents of the input and output messages assert status == "processed" + + +@pytest.mark.asyncio +async def test_runtime_error(): + with pytest.raises(RuntimeError): + out_message = as2.Message(org, partner) + out_message.build(test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + in_message = as2.Message() + status, _, _ = in_message.parse( + raw_out_message, find_org_partner_cb=afind_org_partner + ) + + with pytest.raises(RuntimeError): + partner.sign = True + partner.encrypt = True + partner.mdn_mode = as2.SYNCHRONOUS_MDN + out_message = as2.Message(org, partner) + out_message.build(test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + in_message = as2.Message() + _, _, mdn = in_message.parse( + raw_out_message, + find_org_cb=afind_org, + find_partner_cb=afind_partner, + find_message_cb=afind_duplicate_message, + ) diff --git a/pyas2lib/tests/test_cms.py b/pyas2lib/tests/test_cms.py index a324fa7..fa3f596 100644 --- a/pyas2lib/tests/test_cms.py +++ b/pyas2lib/tests/test_cms.py @@ -2,7 +2,9 @@ import os import pytest -from oscrypto import asymmetric +from oscrypto import asymmetric, symmetric, util + +from asn1crypto import algos, cms as crypto_cms, core from pyas2lib.as2 import Organization from pyas2lib import cms @@ -22,6 +24,68 @@ ).dump() +def encrypted_data_with_faulty_key_algo(): + with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp: + encrypt_cert = asymmetric.load_certificate(fp.read()) + enc_alg_list = "rc4_128_cbc".split("_") + cipher, key_length, _ = enc_alg_list[0], enc_alg_list[1], enc_alg_list[2] + key = util.rand_bytes(int(key_length) // 8) + algorithm_id = "1.2.840.113549.3.4" + encrypted_content = symmetric.rc4_encrypt(key, b"data") + enc_alg_asn1 = algos.EncryptionAlgorithm( + { + "algorithm": algorithm_id, + } + ) + encrypted_key = asymmetric.rsa_oaep_encrypt(encrypt_cert, key) + return crypto_cms.ContentInfo( + { + "content_type": crypto_cms.ContentType("enveloped_data"), + "content": crypto_cms.EnvelopedData( + { + "version": crypto_cms.CMSVersion("v0"), + "recipient_infos": [ + crypto_cms.KeyTransRecipientInfo( + { + "version": crypto_cms.CMSVersion("v0"), + "rid": crypto_cms.RecipientIdentifier( + { + "issuer_and_serial_number": crypto_cms.IssuerAndSerialNumber( + { + "issuer": encrypt_cert.asn1[ + "tbs_certificate" + ]["issuer"], + "serial_number": encrypt_cert.asn1[ + "tbs_certificate" + ]["serial_number"], + } + ) + } + ), + "key_encryption_algorithm": crypto_cms.KeyEncryptionAlgorithm( + { + "algorithm": crypto_cms.KeyEncryptionAlgorithmId( + "aes128_wrap" + ) + } + ), + "encrypted_key": crypto_cms.OctetString(encrypted_key), + } + ) + ], + "encrypted_content_info": crypto_cms.EncryptedContentInfo( + { + "content_type": crypto_cms.ContentType("data"), + "content_encryption_algorithm": enc_alg_asn1, + "encrypted_content": encrypted_content, + } + ), + } + ), + } + ).dump() + + def test_compress(): """Test the compression and decompression functions.""" compressed_data = cms.compress_message(b"data") @@ -87,6 +151,7 @@ def test_encryption(): "aes_128_cbc", "aes_192_cbc", "aes_256_cbc", + "tripledes_192_cbc", ] key_enc_algos = [ @@ -95,7 +160,7 @@ def test_encryption(): ] encryption_algos = [ - (alg, scheme) for alg, scheme in zip(enc_algorithms, key_enc_algos) + (alg, key_algo) for alg in enc_algorithms for key_algo in key_enc_algos ] for enc_algorithm, encryption_scheme in encryption_algos: @@ -117,3 +182,8 @@ def test_encryption(): # Test faulty key encryption algorithm with pytest.raises(AS2Exception): cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc") + + # Test unsupported key encryption algorithm + encrypted_data = encrypted_data_with_faulty_key_algo() + with pytest.raises(AS2Exception): + cms.decrypt_message(encrypted_data, decrypt_key) From 625e479b4e0318ff4617e731857d29ee5365ca90 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Tue, 19 Mar 2024 17:21:44 +0100 Subject: [PATCH 13/19] Adding Tests --- pyas2lib/tests/test_async.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pyas2lib/tests/test_async.py b/pyas2lib/tests/test_async.py index 047e689..d266071 100644 --- a/pyas2lib/tests/test_async.py +++ b/pyas2lib/tests/test_async.py @@ -117,9 +117,15 @@ async def test_runtime_error(): # Parse the generated AS2 message as the partner raw_out_message = out_message.headers_str + b"\r\n" + out_message.content in_message = as2.Message() - _, _, mdn = in_message.parse( + _, _, mdn = await in_message.aparse( raw_out_message, find_org_cb=afind_org, find_partner_cb=afind_partner, find_message_cb=afind_duplicate_message, ) + + out_mdn = as2.Mdn() + _, _ = out_mdn.parse( + mdn.headers_str + b"\r\n" + mdn.content, + find_message_cb=afind_duplicate_message, + ) From e3622114054c45a4a83822b708be8a2ee3ceb140 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Tue, 19 Mar 2024 17:32:35 +0100 Subject: [PATCH 14/19] Remove obsolete checks --- pyas2lib/as2.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index ea288b0..190c439 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -620,18 +620,16 @@ async def aparse( else: self.receiver, self.sender = find_org_partner_cb(org_id, partner_id) - else: - if find_org_cb: - if inspect.iscoroutinefunction(find_org_cb): - self.receiver = await find_org_cb(org_id) - else: - self.receiver = find_org_cb(org_id) + elif find_org_cb and find_partner_cb: + if inspect.iscoroutinefunction(find_org_cb): + self.receiver = await find_org_cb(org_id) + else: + self.receiver = find_org_cb(org_id) - if find_partner_cb: - if inspect.iscoroutinefunction(find_partner_cb): - self.sender = await find_partner_cb(partner_id) - else: - self.sender = find_partner_cb(partner_id) + if inspect.iscoroutinefunction(find_partner_cb): + self.sender = await find_partner_cb(partner_id) + else: + self.sender = find_partner_cb(partner_id) if not self.receiver: raise PartnerNotFound(f"Unknown AS2 organization with id {org_id}") @@ -980,13 +978,12 @@ async def aparse(self, raw_content, find_message_cb): self.orig_message_id, orig_recipient = self.detect_mdn() # Call the find message callback which should return a Message instance - if find_message_cb: - if inspect.iscoroutinefunction(find_message_cb): - orig_message = await find_message_cb( - self.orig_message_id, orig_recipient - ) - else: - orig_message = find_message_cb(self.orig_message_id, orig_recipient) + if inspect.iscoroutinefunction(find_message_cb): + orig_message = await find_message_cb( + self.orig_message_id, orig_recipient + ) + else: + orig_message = find_message_cb(self.orig_message_id, orig_recipient) if not orig_message: status = "failed/Failure" From cbf0659f60150ffda76610618960dc5829f778ea Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Wed, 20 Mar 2024 13:54:48 +0100 Subject: [PATCH 15/19] https://github.com/abhishek-ram/pyas2-lib/issues/60 and also making https://github.com/abhishek-ram/pyas2-lib/issues/62 available on the partner --- pyas2lib/as2.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 190c439..74ae207 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -182,6 +182,12 @@ class Partner: :param canonicalize_as_binary: force binary canonicalization for this partner + :param sign_alg: The signing algorithm to be used for generating the + signature. (default `rsassa_pkcs1v15`) + + :param key_enc_alg: The key encryption algorithm to be used. + (default `rsaes_pkcs1v15`) + """ as2_name: str @@ -200,6 +206,8 @@ class Partner: mdn_confirm_text: str = MDN_CONFIRM_TEXT ignore_self_signed: bool = True canonicalize_as_binary: bool = False + sign_alg: str = "rsassa_pkcs1v15" + key_enc_alg: str = "rsaes_pkcs1v15" def __post_init__(self): """Run the post initialisation checks for this class.""" @@ -469,7 +477,10 @@ def build( ) del signature["MIME-Version"] signature_data = sign_message( - mic_content, self.digest_alg, self.sender.sign_key + mic_content, + self.digest_alg, + self.sender.sign_key, + self.receiver.sign_alg, ) signature.set_payload(signature_data) encoders.encode_base64(signature) @@ -930,7 +941,10 @@ def build( del signature["MIME-Version"] signed_data = sign_message( - canonicalize(self.payload), self.digest_alg, message.receiver.sign_key + canonicalize(self.payload), + self.digest_alg, + message.receiver.sign_key, + message.sender.sign_alg, ) signature.set_payload(signed_data) encoders.encode_base64(signature) From adbd933dc6e6cb73c77dd8d5f92a11b714a00e52 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Wed, 17 Apr 2024 23:01:25 +0200 Subject: [PATCH 16/19] https://github.com/abhishek-ram/pyas2-lib/issues/60 Extending Partner with Signature Algo and pass the setting to signing function. --- pyas2lib/as2.py | 21 +++++++++++++++++++-- pyas2lib/constants.py | 4 ++++ pyas2lib/tests/test_advanced.py | 3 +++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 12ebac7..338084d 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -28,6 +28,7 @@ MDN_CONFIRM_TEXT, MDN_FAILED_TEXT, MDN_MODES, + SIGNATUR_ALGORITHMS, SYNCHRONOUS_MDN, ) from pyas2lib.exceptions import ( @@ -179,6 +180,9 @@ class Partner: :param canonicalize_as_binary: force binary canonicalization for this partner + :param sign_alg: The signing algorithm to be used for generating the + signature. (default `rsassa_pkcs1v15`) + """ as2_name: str @@ -197,6 +201,7 @@ class Partner: mdn_confirm_text: str = MDN_CONFIRM_TEXT ignore_self_signed: bool = True canonicalize_as_binary: bool = False + sign_alg: str = "rsassa_pkcs1v15" def __post_init__(self): """Run the post initialisation checks for this class.""" @@ -225,6 +230,12 @@ def __post_init__(self): f"must be one of {DIGEST_ALGORITHMS}" ) + if self.sign_alg and self.sign_alg not in SIGNATUR_ALGORITHMS: + raise ImproperlyConfigured( + f"Unsupported Signature Algorithm {self.sign_alg}, " + f"must be one of {SIGNATUR_ALGORITHMS}" + ) + def load_verify_cert(self): """Load the verification certificate of the partner and returned the parsed cert.""" if self.validate_certs: @@ -466,7 +477,10 @@ def build( ) del signature["MIME-Version"] signature_data = sign_message( - mic_content, self.digest_alg, self.sender.sign_key + mic_content, + self.digest_alg, + self.sender.sign_key, + self.receiver.sign_alg, ) signature.set_payload(signature_data) encoders.encode_base64(signature) @@ -865,7 +879,10 @@ def build( del signature["MIME-Version"] signed_data = sign_message( - canonicalize(self.payload), self.digest_alg, message.receiver.sign_key + canonicalize(self.payload), + self.digest_alg, + message.receiver.sign_key, + message.sender.sign_alg, ) signature.set_payload(signed_data) encoders.encode_base64(signature) diff --git a/pyas2lib/constants.py b/pyas2lib/constants.py index 53e6c1f..b5d4de2 100644 --- a/pyas2lib/constants.py +++ b/pyas2lib/constants.py @@ -28,3 +28,7 @@ "aes_192_cbc", "aes_256_cbc", ) +SIGNATUR_ALGORITHMS = ( + "rsassa_pkcs1v15", + "rsassa_pss", +) diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index 0bc6db7..f010ecd 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -334,6 +334,9 @@ def test_partner_checks(self): with self.assertRaises(ImproperlyConfigured): as2.Partner("a partner", mdn_digest_alg="xyz") + with self.assertRaises(ImproperlyConfigured): + as2.Partner("a partner", sign_alg="xyz") + def test_message_checks(self): """Test the checks and other features of Message.""" msg = as2.Message() From 0d6765af73a3997fed7ad63a64e6734c392bcb94 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Wed, 1 May 2024 15:03:53 +0200 Subject: [PATCH 17/19] https://github.com/abhishek-ram/pyas2-lib/issues/62 --- pyas2lib/as2.py | 11 ++++ pyas2lib/cms.py | 101 +++++++++++++++++++------------- pyas2lib/constants.py | 4 ++ pyas2lib/tests/test_advanced.py | 3 + pyas2lib/tests/test_cms.py | 92 ++++++++++++++++++++++++++++- 5 files changed, 166 insertions(+), 45 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 338084d..cb320a3 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -25,6 +25,7 @@ DIGEST_ALGORITHMS, EDIINT_FEATURES, ENCRYPTION_ALGORITHMS, + KEY_ENCRYPTION_ALGORITHMS, MDN_CONFIRM_TEXT, MDN_FAILED_TEXT, MDN_MODES, @@ -183,6 +184,9 @@ class Partner: :param sign_alg: The signing algorithm to be used for generating the signature. (default `rsassa_pkcs1v15`) + :param key_enc_alg: The key encryption algorithm to be used. + (default `rsaes_pkcs1v15`) + """ as2_name: str @@ -202,6 +206,7 @@ class Partner: ignore_self_signed: bool = True canonicalize_as_binary: bool = False sign_alg: str = "rsassa_pkcs1v15" + key_enc_alg: str = "rsaes_pkcs1v15" def __post_init__(self): """Run the post initialisation checks for this class.""" @@ -236,6 +241,12 @@ def __post_init__(self): f"must be one of {SIGNATUR_ALGORITHMS}" ) + if self.key_enc_alg and self.key_enc_alg not in KEY_ENCRYPTION_ALGORITHMS: + raise ImproperlyConfigured( + f"Unsupported Key Encryption Algorithm {self.key_enc_alg}, " + f"must be one of {KEY_ENCRYPTION_ALGORITHMS}" + ) + def load_verify_cert(self): """Load the verification certificate of the partner and returned the parsed cert.""" if self.validate_certs: diff --git a/pyas2lib/cms.py b/pyas2lib/cms.py index 0172980..3266ef0 100644 --- a/pyas2lib/cms.py +++ b/pyas2lib/cms.py @@ -65,12 +65,15 @@ def decompress_message(compressed_data): raise DecompressionError("Decompression failed with cause: {}".format(e)) from e -def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): +def encrypt_message( + data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15" +): """Function encrypts data and returns the generated ASN.1 :param data_to_encrypt: A byte string of the data to be encrypted :param enc_alg: The algorithm to be used for encrypting the data :param encryption_cert: The certificate to be used for encrypting the data + :param key_enc_alg: The algo for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep :return: A CMS ASN.1 byte string of the encrypted data. """ @@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): raise AS2Exception("Unsupported Encryption Algorithm") # Encrypt the key and build the ASN.1 message - encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + if key_enc_alg == "rsaes_pkcs1v15": + encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + elif key_enc_alg == "rsaes_oaep": + encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key) + else: + raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}") return cms.ContentInfo( { @@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): } ), "key_encryption_algorithm": cms.KeyEncryptionAlgorithm( - {"algorithm": cms.KeyEncryptionAlgorithmId("rsa")} + { + "algorithm": cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) + } ), "encrypted_key": cms.OctetString(encrypted_key), } @@ -199,47 +211,52 @@ def decrypt_message(encrypted_data, decryption_key): key_enc_alg = recipient_info["key_encryption_algorithm"]["algorithm"].native encrypted_key = recipient_info["encrypted_key"].native - if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsa" - ): - try: + try: + if cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_pkcs1v15"): key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key) - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: Could not extract decryption key." - ) from e - - alg = cms_content["content"]["encrypted_content_info"][ - "content_encryption_algorithm" - ] - encapsulated_data = cms_content["content"]["encrypted_content_info"][ - "encrypted_content" - ].native - try: - if alg["algorithm"].native == "rc4": - decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) - elif alg.encryption_cipher == "tripledes": - cipher = "tripledes_192_cbc" - decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "aes": - decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "rc2": - decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( - key, encapsulated_data, alg["parameters"]["iv"].native - ) - else: - raise AS2Exception("Unsupported Encryption Algorithm") - except Exception as e: - raise DecryptionError( - "Failed to decrypt the payload: {}".format(e) - ) from e - else: - raise AS2Exception("Unsupported Encryption Algorithm") + elif cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) == cms.KeyEncryptionAlgorithmId("rsaes_oaep"): + key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key) + else: + raise AS2Exception( + f"Unsupported Key Encryption Algorithm {key_enc_alg}" + ) + except Exception as e: + raise DecryptionError( + "Failed to decrypt the payload: Could not extract decryption key." + ) from e + + alg = cms_content["content"]["encrypted_content_info"][ + "content_encryption_algorithm" + ] + encapsulated_data = cms_content["content"]["encrypted_content_info"][ + "encrypted_content" + ].native + + try: + if alg["algorithm"].native == "rc4": + decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) + elif alg.encryption_cipher == "tripledes": + cipher = "tripledes_192_cbc" + decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "aes": + decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "rc2": + decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( + key, encapsulated_data, alg["parameters"]["iv"].native + ) + else: + raise AS2Exception("Unsupported Encryption Algorithm") + except Exception as e: + raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e else: raise DecryptionError("Encrypted data not found in ASN.1 ") diff --git a/pyas2lib/constants.py b/pyas2lib/constants.py index b5d4de2..f5456f7 100644 --- a/pyas2lib/constants.py +++ b/pyas2lib/constants.py @@ -32,3 +32,7 @@ "rsassa_pkcs1v15", "rsassa_pss", ) +KEY_ENCRYPTION_ALGORITHMS = ( + "rsaes_pkcs1v15", + "rsaes_oaep", +) diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index f010ecd..fefda8d 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -337,6 +337,9 @@ def test_partner_checks(self): with self.assertRaises(ImproperlyConfigured): as2.Partner("a partner", sign_alg="xyz") + with self.assertRaises(ImproperlyConfigured): + as2.Partner("a partner", key_enc_alg="xyz") + def test_message_checks(self): """Test the checks and other features of Message.""" msg = as2.Message() diff --git a/pyas2lib/tests/test_cms.py b/pyas2lib/tests/test_cms.py index 34655db..fa3f596 100644 --- a/pyas2lib/tests/test_cms.py +++ b/pyas2lib/tests/test_cms.py @@ -2,7 +2,9 @@ import os import pytest -from oscrypto import asymmetric +from oscrypto import asymmetric, symmetric, util + +from asn1crypto import algos, cms as crypto_cms, core from pyas2lib.as2 import Organization from pyas2lib import cms @@ -22,6 +24,68 @@ ).dump() +def encrypted_data_with_faulty_key_algo(): + with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp: + encrypt_cert = asymmetric.load_certificate(fp.read()) + enc_alg_list = "rc4_128_cbc".split("_") + cipher, key_length, _ = enc_alg_list[0], enc_alg_list[1], enc_alg_list[2] + key = util.rand_bytes(int(key_length) // 8) + algorithm_id = "1.2.840.113549.3.4" + encrypted_content = symmetric.rc4_encrypt(key, b"data") + enc_alg_asn1 = algos.EncryptionAlgorithm( + { + "algorithm": algorithm_id, + } + ) + encrypted_key = asymmetric.rsa_oaep_encrypt(encrypt_cert, key) + return crypto_cms.ContentInfo( + { + "content_type": crypto_cms.ContentType("enveloped_data"), + "content": crypto_cms.EnvelopedData( + { + "version": crypto_cms.CMSVersion("v0"), + "recipient_infos": [ + crypto_cms.KeyTransRecipientInfo( + { + "version": crypto_cms.CMSVersion("v0"), + "rid": crypto_cms.RecipientIdentifier( + { + "issuer_and_serial_number": crypto_cms.IssuerAndSerialNumber( + { + "issuer": encrypt_cert.asn1[ + "tbs_certificate" + ]["issuer"], + "serial_number": encrypt_cert.asn1[ + "tbs_certificate" + ]["serial_number"], + } + ) + } + ), + "key_encryption_algorithm": crypto_cms.KeyEncryptionAlgorithm( + { + "algorithm": crypto_cms.KeyEncryptionAlgorithmId( + "aes128_wrap" + ) + } + ), + "encrypted_key": crypto_cms.OctetString(encrypted_key), + } + ) + ], + "encrypted_content_info": crypto_cms.EncryptedContentInfo( + { + "content_type": crypto_cms.ContentType("data"), + "content_encryption_algorithm": enc_alg_asn1, + "encrypted_content": encrypted_content, + } + ), + } + ), + } + ).dump() + + def test_compress(): """Test the compression and decompression functions.""" compressed_data = cms.compress_message(b"data") @@ -87,9 +151,22 @@ def test_encryption(): "aes_128_cbc", "aes_192_cbc", "aes_256_cbc", + "tripledes_192_cbc", + ] + + key_enc_algos = [ + "rsaes_oaep", + "rsaes_pkcs1v15", ] - for enc_algorithm in enc_algorithms: - encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert) + + encryption_algos = [ + (alg, key_algo) for alg in enc_algorithms for key_algo in key_enc_algos + ] + + for enc_algorithm, encryption_scheme in encryption_algos: + encrypted_data = cms.encrypt_message( + b"data", enc_algorithm, encrypt_cert, encryption_scheme + ) _, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key) assert decrypted_data == b"data" @@ -101,3 +178,12 @@ def test_encryption(): encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert) with pytest.raises(AS2Exception): cms.decrypt_message(encrypted_data, decrypt_key) + + # Test faulty key encryption algorithm + with pytest.raises(AS2Exception): + cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc") + + # Test unsupported key encryption algorithm + encrypted_data = encrypted_data_with_faulty_key_algo() + with pytest.raises(AS2Exception): + cms.decrypt_message(encrypted_data, decrypt_key) From 5a29efb47f08a84cd4e1467279c9424de13f9d7b Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Thu, 2 May 2024 08:26:12 +0200 Subject: [PATCH 18/19] Asserting error messages and _encrypted_data_with_faulty_key_algo --- pyas2lib/tests/test_cms.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pyas2lib/tests/test_cms.py b/pyas2lib/tests/test_cms.py index fa3f596..cf2f711 100644 --- a/pyas2lib/tests/test_cms.py +++ b/pyas2lib/tests/test_cms.py @@ -24,7 +24,7 @@ ).dump() -def encrypted_data_with_faulty_key_algo(): +def _encrypted_data_with_faulty_key_algo(): with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp: encrypt_cert = asymmetric.load_certificate(fp.read()) enc_alg_list = "rc4_128_cbc".split("_") @@ -180,10 +180,15 @@ def test_encryption(): cms.decrypt_message(encrypted_data, decrypt_key) # Test faulty key encryption algorithm - with pytest.raises(AS2Exception): + with pytest.raises( + AS2Exception, match="Unsupported Key Encryption Scheme: des_64_cbc" + ): cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc") # Test unsupported key encryption algorithm - encrypted_data = encrypted_data_with_faulty_key_algo() - with pytest.raises(AS2Exception): + encrypted_data = _encrypted_data_with_faulty_key_algo() + with pytest.raises( + AS2Exception, + match="Failed to decrypt the payload: Could not extract decryption key.", + ): cms.decrypt_message(encrypted_data, decrypt_key) From 69b13d5801c598673870925a3c419cd95bc3f51e Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Fri, 3 May 2024 10:00:30 +0200 Subject: [PATCH 19/19] Add specific error when MDN received, but Original Message was not found. Related to https://github.com/abhishek-ram/django-pyas2/issues/45 and will be implemented/used in django-pyas2. --- pyas2lib/as2.py | 5 +++++ pyas2lib/tests/test_advanced.py | 27 +++++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index cb320a3..9c74e54 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -943,6 +943,11 @@ def parse(self, raw_content, find_message_cb): # Call the find message callback which should return a Message instance orig_message = find_message_cb(self.orig_message_id, orig_recipient) + if not orig_message: + status = "failed/Failure" + details_status = "original-message-not-found" + return status, details_status + # Extract the headers and save it mdn_headers = {} for k, v in self.payload.items(): diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index fefda8d..a0aca23 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -396,6 +396,33 @@ def test_mdn_not_found(self): self.assertEqual(status, "failed/Failure") self.assertEqual(detailed_status, "mdn-not-found") + def test_mdn_original_message_not_found(self): + """Test that the MDN parser raises MDN not found when a non MDN message is passed.""" + self.partner.mdn_mode = as2.SYNCHRONOUS_MDN + self.out_message = as2.Message(self.org, self.partner) + self.out_message.build(self.test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = ( + self.out_message.headers_str + b"\r\n" + self.out_message.content + ) + in_message = as2.Message() + _, _, mdn = in_message.parse( + raw_out_message, + find_org_cb=self.find_org, + find_partner_cb=self.find_partner, + find_message_cb=lambda x, y: False, + ) + + # Parse the MDN + out_mdn = as2.Mdn() + status, detailed_status = out_mdn.parse( + mdn.headers_str + b"\r\n" + mdn.content, find_message_cb=lambda x, y: False + ) + + self.assertEqual(status, "failed/Failure") + self.assertEqual(detailed_status, "original-message-not-found") + def test_unsigned_mdn_sent_error(self): """Test the case where a signed mdn was expected but unsigned mdn was returned.""" self.partner.mdn_mode = as2.SYNCHRONOUS_MDN