From 04cf4c36d2216d910ca3dd072045e98e4f135643 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 6 May 2024 08:43:05 +0200 Subject: [PATCH] Async Support/Wrapper for callback functions. --- pyas2lib/as2.py | 80 +++++++++++++++++--- pyas2lib/tests/test_async.py | 140 +++++++++++++++++++++++++++++++++++ setup.py | 3 +- 3 files changed, 210 insertions(+), 13 deletions(-) create mode 100644 pyas2lib/tests/test_async.py diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 2e03331..1c98df4 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -1,7 +1,9 @@ """Define the core functions/classes of the pyas2 package.""" -import logging -import hashlib +import asyncio import binascii +import hashlib +import inspect +import logging import traceback from dataclasses import dataclass from email import encoders @@ -9,6 +11,7 @@ from email import message_from_bytes as parse_mime from email import utils as email_utils from email.mime.multipart import MIMEMultipart + from oscrypto import asymmetric from pyas2lib.cms import ( @@ -564,7 +567,7 @@ def _decompress_data(self, payload): return False, payload - def parse( + async def aparse( self, raw_content, find_org_cb=None, @@ -631,11 +634,25 @@ def parse( # Get the organization and partner for this transmission org_id = unquote_as2name(as2_headers["as2-to"]) partner_id = unquote_as2name(as2_headers["as2-from"]) + if find_org_partner_cb: - self.receiver, self.sender = find_org_partner_cb(org_id, partner_id) + if inspect.iscoroutinefunction(find_org_partner_cb): + self.receiver, self.sender = await find_org_partner_cb( + org_id, partner_id + ) + else: + self.receiver, self.sender = find_org_partner_cb(org_id, partner_id) + elif find_org_cb and find_partner_cb: - self.receiver = find_org_cb(org_id) - self.sender = find_partner_cb(partner_id) + if inspect.iscoroutinefunction(find_org_cb): + self.receiver = await find_org_cb(org_id) + else: + self.receiver = find_org_cb(org_id) + + if inspect.iscoroutinefunction(find_partner_cb): + self.sender = await find_partner_cb(partner_id) + else: + self.sender = find_partner_cb(partner_id) if not self.receiver: raise PartnerNotFound(f"Unknown AS2 organization with id {org_id}") @@ -643,10 +660,15 @@ def parse( if not self.sender: raise PartnerNotFound(f"Unknown AS2 partner with id {partner_id}") - if find_message_cb and find_message_cb(self.message_id, partner_id): - raise DuplicateDocument( - "Duplicate message received, message with this ID already processed." - ) + if find_message_cb: + if inspect.iscoroutinefunction(find_message_cb): + message_exists = await find_message_cb(self.message_id, partner_id) + else: + message_exists = find_message_cb(self.message_id, partner_id) + if message_exists: + raise DuplicateDocument( + "Duplicate message received, message with this ID already processed." + ) if ( self.sender.encrypt @@ -767,6 +789,18 @@ def parse( return status, exception, mdn + def parse(self, *args, **kwargs): + """ + A synchronous wrapper for the asynchronous parse method. + It runs the parse coroutine in an event loop and returns the result. + """ + loop = asyncio.get_event_loop() + if loop.is_running(): + raise RuntimeError( + "Cannot run synchronous parse within an already running event loop, use aparse." + ) + return loop.run_until_complete(self.aparse(*args, **kwargs)) + class Mdn: """Class for handling AS2 MDNs. Includes functions for both @@ -945,7 +979,7 @@ def build( f"content:\n {mime_to_bytes(self.payload)}" ) - def parse(self, raw_content, find_message_cb): + async def aparse(self, raw_content, find_message_cb): """Function parses the RAW AS2 MDN, verifies it and extracts the processing status of the orginal AS2 message. @@ -970,7 +1004,17 @@ def parse(self, raw_content, find_message_cb): self.orig_message_id, orig_recipient = self.detect_mdn() # Call the find message callback which should return a Message instance - orig_message = find_message_cb(self.orig_message_id, orig_recipient) + if inspect.iscoroutinefunction(find_message_cb): + orig_message = await find_message_cb( + self.orig_message_id, orig_recipient + ) + else: + orig_message = find_message_cb(self.orig_message_id, orig_recipient) + + if not orig_message: + status = "failed/Failure" + details_status = "original-message-not-found" + return status, details_status if not orig_message: status = "failed/Failure" @@ -1053,6 +1097,18 @@ def parse(self, raw_content, find_message_cb): logger.error(f"Failed to parse AS2 MDN\n: {traceback.format_exc()}") return status, detailed_status + def parse(self, *args, **kwargs): + """ + A synchronous wrapper for the asynchronous parse method. + It runs the parse coroutine in an event loop and returns the result. + """ + loop = asyncio.get_event_loop() + if loop.is_running(): + raise RuntimeError( + "Cannot run synchronous parse within an already running event loop, use aparse." + ) + return loop.run_until_complete(self.aparse(*args, **kwargs)) + def detect_mdn(self): """Function checks if the received raw message is an AS2 MDN or not. diff --git a/pyas2lib/tests/test_async.py b/pyas2lib/tests/test_async.py new file mode 100644 index 0000000..5e75e8a --- /dev/null +++ b/pyas2lib/tests/test_async.py @@ -0,0 +1,140 @@ +import os + +import pytest + +from pyas2lib import as2 +from pyas2lib.tests import TEST_DIR + +with open(os.path.join(TEST_DIR, "payload.txt"), "rb") as fp: + test_data = fp.read() + +with open(os.path.join(TEST_DIR, "cert_test.p12"), "rb") as fp: + private_key = fp.read() + +with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp: + public_key = fp.read() + +org = as2.Organization( + as2_name="some_organization", + sign_key=private_key, + sign_key_pass="test", + decrypt_key=private_key, + decrypt_key_pass="test", +) +partner = as2.Partner( + as2_name="some_partner", + verify_cert=public_key, + encrypt_cert=public_key, +) + + +async def afind_org(headers): + return org + + +async def afind_partner(headers): + return partner + + +async def afind_duplicate_message(message_id, message_recipient): + return True + + +async def afind_org_partner(as2_org, as2_partner): + return org, partner + + +@pytest.mark.asyncio +async def test_async_callbacks_with_duplicate_message(): + """Test case where async callbacks are used and a duplicate message is sent to the partner""" + + # Build an As2 message to be transmitted to partner + partner.sign = True + partner.encrypt = True + partner.mdn_mode = as2.SYNCHRONOUS_MDN + out_message = as2.Message(org, partner) + out_message.build(test_data) + + async def afind_message(message_id, message_recipient): + return out_message + + # Parse the generated AS2 message as the partner + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + in_message = as2.Message() + _, _, mdn = await in_message.aparse( + raw_out_message, + find_org_cb=afind_org, + find_partner_cb=afind_partner, + find_message_cb=afind_duplicate_message, + ) + + out_mdn = as2.Mdn() + status, detailed_status = await out_mdn.aparse( + mdn.headers_str + b"\r\n" + mdn.content, + find_message_cb=afind_message, + ) + assert status == "processed/Warning" + assert detailed_status == "duplicate-document" + + +@pytest.mark.asyncio +async def test_async_partnership(): + """Test Async Partnership callback""" + + # Build an As2 message to be transmitted to partner + out_message = as2.Message(org, partner) + out_message.build(test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + # Parse the generated AS2 message as the partner + in_message = as2.Message() + status, _, _ = await in_message.aparse( + raw_out_message, find_org_partner_cb=afind_org_partner + ) + + # Compare contents of the input and output messages + assert status == "processed" + + +@pytest.mark.asyncio +async def test_runtime_error(): + """Test to get Runtime error when calling parse instead of aparse from Async Context""" + + with pytest.raises( + RuntimeError, + match="Cannot run synchronous parse within an already running event loop, use aparse.", + ): + out_message = as2.Message(org, partner) + out_message.build(test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + in_message = as2.Message() + status, _, _ = in_message.parse( + raw_out_message, find_org_partner_cb=afind_org_partner + ) + + with pytest.raises( + RuntimeError, + match="Cannot run synchronous parse within an already running event loop, use aparse.", + ): + partner.sign = True + partner.encrypt = True + partner.mdn_mode = as2.SYNCHRONOUS_MDN + out_message = as2.Message(org, partner) + out_message.build(test_data) + + # Parse the generated AS2 message as the partner + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + in_message = as2.Message() + _, _, mdn = await in_message.aparse( + raw_out_message, + find_org_cb=afind_org, + find_partner_cb=afind_partner, + find_message_cb=afind_duplicate_message, + ) + + out_mdn = as2.Mdn() + _, _ = out_mdn.parse( + mdn.headers_str + b"\r\n" + mdn.content, + find_message_cb=afind_duplicate_message, + ) diff --git a/setup.py b/setup.py index e4d55b8..37128a7 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,8 @@ ] tests_require = [ - "pytest==6.2.5", + "pytest==7.4.4", + "pytest-asyncio==0.21.1", "toml==0.10.2", "pytest-cov==2.8.1", "coverage==5.0.4",