diff --git a/AUTHORS.md b/AUTHORS.md index f8b47d9..e1e95c1 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -1,5 +1,5 @@ - Abhishek Ram @abhishek-ram -- Chad Gates @chadgates +- Wassilios Lytras @chadgates - Bruno Ribeiro da Silva @loop0 - Robin C Samuel @robincsamuel - Brandon Joyce @brandonjoyce diff --git a/CHANGELOG.md b/CHANGELOG.md index 356f159..eb7674e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,14 @@ # Release History +## 1.4.4 - 2024- + +* feat: added partnership lookup function +* feat: added support for async callback functions +* feat: added support for optional key encryption algorithm rsaes_oaep for encryption and decryption + ## 1.4.3 - 2023-01-25 -* fix: update pyopenssl version to resovle pyca/cryptography#7959 +* fix: update pyopenssl version to resolve pyca/cryptography#7959 ## 1.4.2 - 2022-12-11 diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 12ebac7..4e580dd 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -1,7 +1,9 @@ """Define the core functions/classes of the pyas2 package.""" -import logging -import hashlib +import asyncio import binascii +import hashlib +import inspect +import logging import traceback from dataclasses import dataclass from email import encoders @@ -9,6 +11,7 @@ from email import message_from_bytes as parse_mime from email import utils as email_utils from email.mime.multipart import MIMEMultipart + from oscrypto import asymmetric from pyas2lib.cms import ( @@ -539,7 +542,14 @@ def _decompress_data(self, payload): return False, payload - def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None): + async def aparse( + self, + raw_content, + find_org_cb=None, + find_partner_cb=None, + find_message_cb=None, + find_org_partner_cb=None, + ): """Function parses the RAW AS2 message; decrypts, verifies and decompresses it and extracts the payload. @@ -547,18 +557,26 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) A byte string of the received HTTP headers followed by the body. :param find_org_cb: - A callback the returns an Organization object if exists. The - as2-to header value is passed as an argument to it. + A conditional callback the returns an Organization object if exists. The + as2-to header value is passed as an argument to it. Must be provided + when find_partner_cb is provided and find_org_partner_cb is None :param find_partner_cb: - A callback the returns an Partner object if exists. The - as2-from header value is passed as an argument to it. + An conditional callback the returns an Partner object if exists. The + as2-from header value is passed as an argument to it. Must be provided + when find_org_cb is provided and find_org_partner_cb is None. :param find_message_cb: An optional callback the returns an Message object if exists in order to check for duplicates. The message id and partner id is passed as arguments to it. + :param find_org_partner_cb: + A conditional callback that return Organization object and + Partner object if exist. The as2-to and as2-from header value + are passed as an argument to it. Must be provided + when find_org_cb and find_org_partner_cb is None. + :return: A three element tuple containing (status, (exception, traceback) , mdn). The status is a string indicating the status of the @@ -567,6 +585,18 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) the partner did not request it. """ + # Validate passed arguments + if not any( + [ + find_org_cb and find_partner_cb and not find_org_partner_cb, + find_org_partner_cb and not find_partner_cb and not find_org_cb, + ] + ): + raise TypeError( + "Incorrect arguments passed: either find_org_cb and find_partner_cb " + "or only find_org_partner_cb must be passed." + ) + # Parse the raw MIME message and extract its content and headers status, detailed_status, exception, mdn = "processed", None, (None, None), None self.payload = parse_mime(raw_content) @@ -580,19 +610,44 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) try: # Get the organization and partner for this transmission org_id = unquote_as2name(as2_headers["as2-to"]) - self.receiver = find_org_cb(org_id) + partner_id = unquote_as2name(as2_headers["as2-from"]) + + if find_org_partner_cb: + if inspect.iscoroutinefunction(find_org_partner_cb): + self.receiver, self.sender = await find_org_partner_cb( + org_id, partner_id + ) + else: + self.receiver, self.sender = find_org_partner_cb(org_id, partner_id) + + else: + if find_org_cb: + if inspect.iscoroutinefunction(find_org_cb): + self.receiver = await find_org_cb(org_id) + else: + self.receiver = find_org_cb(org_id) + + if find_partner_cb: + if inspect.iscoroutinefunction(find_partner_cb): + self.sender = await find_partner_cb(partner_id) + else: + self.sender = find_partner_cb(partner_id) + if not self.receiver: raise PartnerNotFound(f"Unknown AS2 organization with id {org_id}") - partner_id = unquote_as2name(as2_headers["as2-from"]) - self.sender = find_partner_cb(partner_id) if not self.sender: raise PartnerNotFound(f"Unknown AS2 partner with id {partner_id}") - if find_message_cb and find_message_cb(self.message_id, partner_id): - raise DuplicateDocument( - "Duplicate message received, message with this ID already processed." - ) + if find_message_cb: + if inspect.iscoroutinefunction(find_message_cb): + message_exists = await find_message_cb(self.message_id, partner_id) + else: + message_exists = find_message_cb(self.message_id, partner_id) + if message_exists: + raise DuplicateDocument( + "Duplicate message received, message with this ID already processed." + ) if ( self.sender.encrypt @@ -713,6 +768,18 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) return status, exception, mdn + def parse(self, *args, **kwargs): + """ + A synchronous wrapper for the asynchronous parse method. + It runs the parse coroutine in an event loop and returns the result. + """ + loop = asyncio.get_event_loop() + if loop.is_running(): + raise RuntimeError( + "Cannot run synchronous parse within an already running event loop." + ) + return loop.run_until_complete(self.aparse(*args, **kwargs)) + class Mdn: """Class for handling AS2 MDNs. Includes functions for both @@ -915,6 +982,11 @@ def parse(self, raw_content, find_message_cb): # Call the find message callback which should return a Message instance orig_message = find_message_cb(self.orig_message_id, orig_recipient) + if not orig_message: + status = "failed/Failure" + details_status = "original-message-not-found" + return status, details_status + # Extract the headers and save it mdn_headers = {} for k, v in self.payload.items(): diff --git a/pyas2lib/cms.py b/pyas2lib/cms.py index 0172980..db3d148 100644 --- a/pyas2lib/cms.py +++ b/pyas2lib/cms.py @@ -3,7 +3,7 @@ import zlib from datetime import datetime, timezone -from asn1crypto import cms, core, algos +from asn1crypto import algos, cms, core from asn1crypto.cms import SMIMECapabilityIdentifier from oscrypto import asymmetric, symmetric, util @@ -65,12 +65,15 @@ def decompress_message(compressed_data): raise DecompressionError("Decompression failed with cause: {}".format(e)) from e -def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): +def encrypt_message( + data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15" +): """Function encrypts data and returns the generated ASN.1 :param data_to_encrypt: A byte string of the data to be encrypted :param enc_alg: The algorithm to be used for encrypting the data :param encryption_cert: The certificate to be used for encrypting the data + :param key_enc_alg: The algo for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep :return: A CMS ASN.1 byte string of the encrypted data. """ @@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): raise AS2Exception("Unsupported Encryption Algorithm") # Encrypt the key and build the ASN.1 message - encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + if key_enc_alg == "rsaes_pkcs1v15": + encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key) + elif key_enc_alg == "rsaes_oaep": + encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key) + else: + raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}") return cms.ContentInfo( { @@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert): } ), "key_encryption_algorithm": cms.KeyEncryptionAlgorithm( - {"algorithm": cms.KeyEncryptionAlgorithmId("rsa")} + { + "algorithm": cms.KeyEncryptionAlgorithmId( + key_enc_alg + ) + } ), "encrypted_key": cms.OctetString(encrypted_key), } @@ -200,7 +212,7 @@ def decrypt_message(encrypted_data, decryption_key): encrypted_key = recipient_info["encrypted_key"].native if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( - "rsa" + "rsaes_pkcs1v15" ): try: key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key) @@ -208,38 +220,45 @@ def decrypt_message(encrypted_data, decryption_key): raise DecryptionError( "Failed to decrypt the payload: Could not extract decryption key." ) from e - - alg = cms_content["content"]["encrypted_content_info"][ - "content_encryption_algorithm" - ] - encapsulated_data = cms_content["content"]["encrypted_content_info"][ - "encrypted_content" - ].native - + elif cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId( + "rsaes_oaep" + ): try: - if alg["algorithm"].native == "rc4": - decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) - elif alg.encryption_cipher == "tripledes": - cipher = "tripledes_192_cbc" - decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "aes": - decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( - key, encapsulated_data, alg.encryption_iv - ) - elif alg.encryption_cipher == "rc2": - decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( - key, encapsulated_data, alg["parameters"]["iv"].native - ) - else: - raise AS2Exception("Unsupported Encryption Algorithm") + key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key) except Exception as e: raise DecryptionError( - "Failed to decrypt the payload: {}".format(e) + "Failed to decrypt the payload: Could not extract decryption key." ) from e else: - raise AS2Exception("Unsupported Encryption Algorithm") + raise AS2Exception(f"Unsupported Key Encryption Algorithm {key_enc_alg}") + + alg = cms_content["content"]["encrypted_content_info"][ + "content_encryption_algorithm" + ] + encapsulated_data = cms_content["content"]["encrypted_content_info"][ + "encrypted_content" + ].native + + try: + if alg["algorithm"].native == "rc4": + decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data) + elif alg.encryption_cipher == "tripledes": + cipher = "tripledes_192_cbc" + decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "aes": + decrypted_content = symmetric.aes_cbc_pkcs7_decrypt( + key, encapsulated_data, alg.encryption_iv + ) + elif alg.encryption_cipher == "rc2": + decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt( + key, encapsulated_data, alg["parameters"]["iv"].native + ) + else: + raise AS2Exception("Unsupported Encryption Algorithm") + except Exception as e: + raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e else: raise DecryptionError("Encrypted data not found in ASN.1 ") diff --git a/pyas2lib/tests/test_advanced.py b/pyas2lib/tests/test_advanced.py index 0bc6db7..490b74a 100644 --- a/pyas2lib/tests/test_advanced.py +++ b/pyas2lib/tests/test_advanced.py @@ -9,6 +9,8 @@ from pyas2lib.exceptions import ImproperlyConfigured from pyas2lib.tests import Pyas2TestCase, TEST_DIR +import asyncio + class TestAdvanced(Pyas2TestCase): def setUp(self): @@ -515,6 +517,57 @@ def test_final_recipient_fallback(self): self.assertEqual(message_recipient, self.partner.as2_name) + @pytest.mark.asyncio + async def test_duplicate_message_async(self): + """Test case where a duplicate message is sent to the partner using async callbacks""" + + # Build an As2 message to be transmitted to partner + self.partner.sign = True + self.partner.encrypt = True + self.partner.mdn_mode = as2.SYNCHRONOUS_MDN + self.out_message = as2.Message(self.org, self.partner) + self.out_message.build(self.test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = ( + self.out_message.headers_str + b"\r\n" + self.out_message.content + ) + in_message = as2.Message() + _, _, mdn = await in_message.parse( + raw_out_message, + find_org_cb=self.afind_org, + find_partner_cb=self.afind_partner, + find_message_cb=self.afind_duplicate_message, + ) + + out_mdn = as2.Mdn() + status, detailed_status = await out_mdn.parse( + mdn.headers_str + b"\r\n" + mdn.content, find_message_cb=self.afind_message + ) + self.assertEqual(status, "processed/Warning") + self.assertEqual(detailed_status, "duplicate-document") + + @pytest.mark.asyncio + async def test_async_partnership(self): + """Test Async Partnership callback with Unencrypted Unsigned Uncompressed Message""" + + # Build an As2 message to be transmitted to partner + out_message = as2.Message(self.org, self.partner) + out_message.build(self.test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + # Parse the generated AS2 message as the partner + in_message = as2.Message() + status, _, _ = in_message.parse( + raw_out_message, + find_org_cb=self.find_org, + find_partner_cb=self.afind_org_partner, + ) + + # Compare contents of the input and output messages + self.assertEqual(status, "processed") + self.assertEqual(self.test_data, in_message.content) + def find_org(self, headers): return self.org @@ -524,6 +577,18 @@ def find_partner(self, headers): def find_message(self, message_id, message_recipient): return self.out_message + async def afind_org(self, headers): + return self.org + + async def afind_partner(self, headers): + return self.partner + + async def afind_duplicate_message(self, message_id, message_recipient): + return True + + async def afind_org_partner(self, as2_org, as2_partner): + return self.org, self.partner + class SterlingIntegratorTest(Pyas2TestCase): def setUp(self): diff --git a/pyas2lib/tests/test_basic.py b/pyas2lib/tests/test_basic.py index fed94b9..87f09cb 100644 --- a/pyas2lib/tests/test_basic.py +++ b/pyas2lib/tests/test_basic.py @@ -184,6 +184,30 @@ def test_encrypted_signed_compressed_message(self): self.assertEqual(out_message.mic, in_message.mic) self.assertEqual(self.test_data.splitlines(), in_message.content.splitlines()) + def test_encrypted_signed_message_partnership(self): + """Test Encrypted Signed Uncompressed Message with Partnership""" + + # Build an As2 message to be transmitted to partner + self.partner.sign = True + self.partner.encrypt = True + out_message = as2.Message(self.org, self.partner) + out_message.build(self.test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + # Parse the generated AS2 message as the partner + in_message = as2.Message() + status, _, _ = in_message.parse( + raw_out_message, + find_org_partner_cb=self.find_org_partner, + ) + + # Compare the mic contents of the input and output messages + self.assertEqual(status, "processed") + self.assertTrue(in_message.signed) + self.assertTrue(in_message.encrypted) + self.assertEqual(out_message.mic, in_message.mic) + self.assertEqual(self.test_data.splitlines(), in_message.content.splitlines()) + def test_plain_message_with_domain(self): """Test Message building with an org domain""" @@ -229,3 +253,6 @@ def find_org(self, as2_id): def find_partner(self, as2_id): return self.partner + + def find_org_partner(self, as2_org, as2_partner): + return self.org, self.partner diff --git a/pyas2lib/tests/test_cms.py b/pyas2lib/tests/test_cms.py index 34655db..a324fa7 100644 --- a/pyas2lib/tests/test_cms.py +++ b/pyas2lib/tests/test_cms.py @@ -88,8 +88,20 @@ def test_encryption(): "aes_192_cbc", "aes_256_cbc", ] - for enc_algorithm in enc_algorithms: - encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert) + + key_enc_algos = [ + "rsaes_oaep", + "rsaes_pkcs1v15", + ] + + encryption_algos = [ + (alg, scheme) for alg, scheme in zip(enc_algorithms, key_enc_algos) + ] + + for enc_algorithm, encryption_scheme in encryption_algos: + encrypted_data = cms.encrypt_message( + b"data", enc_algorithm, encrypt_cert, encryption_scheme + ) _, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key) assert decrypted_data == b"data" @@ -101,3 +113,7 @@ def test_encryption(): encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert) with pytest.raises(AS2Exception): cms.decrypt_message(encrypted_data, decrypt_key) + + # Test faulty key encryption algorithm + with pytest.raises(AS2Exception): + cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc") diff --git a/setup.py b/setup.py index e4d55b8..37128a7 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,8 @@ ] tests_require = [ - "pytest==6.2.5", + "pytest==7.4.4", + "pytest-asyncio==0.21.1", "toml==0.10.2", "pytest-cov==2.8.1", "coverage==5.0.4",