From a93d4c17deacfc09f684a606b148fd67751e1309 Mon Sep 17 00:00:00 2001 From: Wassilios Lytras Date: Mon, 18 Mar 2024 20:25:06 +0100 Subject: [PATCH] Add Async callback to MDN --- pyas2lib/as2.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 4e580dd..ea288b0 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -776,7 +776,7 @@ def parse(self, *args, **kwargs): loop = asyncio.get_event_loop() if loop.is_running(): raise RuntimeError( - "Cannot run synchronous parse within an already running event loop." + "Cannot run synchronous parse within an already running event loop, use aparse." ) return loop.run_until_complete(self.aparse(*args, **kwargs)) @@ -955,7 +955,7 @@ def build( f"content:\n {mime_to_bytes(self.payload)}" ) - def parse(self, raw_content, find_message_cb): + async def aparse(self, raw_content, find_message_cb): """Function parses the RAW AS2 MDN, verifies it and extracts the processing status of the orginal AS2 message. @@ -980,7 +980,13 @@ def parse(self, raw_content, find_message_cb): self.orig_message_id, orig_recipient = self.detect_mdn() # Call the find message callback which should return a Message instance - orig_message = find_message_cb(self.orig_message_id, orig_recipient) + if find_message_cb: + if inspect.iscoroutinefunction(find_message_cb): + orig_message = await find_message_cb( + self.orig_message_id, orig_recipient + ) + else: + orig_message = find_message_cb(self.orig_message_id, orig_recipient) if not orig_message: status = "failed/Failure" @@ -1063,6 +1069,18 @@ def parse(self, raw_content, find_message_cb): logger.error(f"Failed to parse AS2 MDN\n: {traceback.format_exc()}") return status, detailed_status + def parse(self, *args, **kwargs): + """ + A synchronous wrapper for the asynchronous parse method. + It runs the parse coroutine in an event loop and returns the result. + """ + loop = asyncio.get_event_loop() + if loop.is_running(): + raise RuntimeError( + "Cannot run synchronous parse within an already running event loop, use aparse." + ) + return loop.run_until_complete(self.aparse(*args, **kwargs)) + def detect_mdn(self): """Function checks if the received raw message is an AS2 MDN or not.