Skip to content

Commit

Permalink
Async Support/Wrapper for callback functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
chadgates committed May 6, 2024
1 parent c97ec7a commit 04cf4c3
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 13 deletions.
80 changes: 68 additions & 12 deletions pyas2lib/as2.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Define the core functions/classes of the pyas2 package."""
import logging
import hashlib
import asyncio
import binascii
import hashlib
import inspect
import logging
import traceback
from dataclasses import dataclass
from email import encoders
from email import message as email_message
from email import message_from_bytes as parse_mime
from email import utils as email_utils
from email.mime.multipart import MIMEMultipart

from oscrypto import asymmetric

from pyas2lib.cms import (
Expand Down Expand Up @@ -564,7 +567,7 @@ def _decompress_data(self, payload):

return False, payload

def parse(
async def aparse(
self,
raw_content,
find_org_cb=None,
Expand Down Expand Up @@ -631,22 +634,41 @@ def parse(
# Get the organization and partner for this transmission
org_id = unquote_as2name(as2_headers["as2-to"])
partner_id = unquote_as2name(as2_headers["as2-from"])

if find_org_partner_cb:
self.receiver, self.sender = find_org_partner_cb(org_id, partner_id)
if inspect.iscoroutinefunction(find_org_partner_cb):
self.receiver, self.sender = await find_org_partner_cb(
org_id, partner_id
)
else:
self.receiver, self.sender = find_org_partner_cb(org_id, partner_id)

elif find_org_cb and find_partner_cb:
self.receiver = find_org_cb(org_id)
self.sender = find_partner_cb(partner_id)
if inspect.iscoroutinefunction(find_org_cb):
self.receiver = await find_org_cb(org_id)
else:
self.receiver = find_org_cb(org_id)

if inspect.iscoroutinefunction(find_partner_cb):
self.sender = await find_partner_cb(partner_id)
else:
self.sender = find_partner_cb(partner_id)

if not self.receiver:
raise PartnerNotFound(f"Unknown AS2 organization with id {org_id}")

if not self.sender:
raise PartnerNotFound(f"Unknown AS2 partner with id {partner_id}")

if find_message_cb and find_message_cb(self.message_id, partner_id):
raise DuplicateDocument(
"Duplicate message received, message with this ID already processed."
)
if find_message_cb:
if inspect.iscoroutinefunction(find_message_cb):
message_exists = await find_message_cb(self.message_id, partner_id)
else:
message_exists = find_message_cb(self.message_id, partner_id)
if message_exists:
raise DuplicateDocument(
"Duplicate message received, message with this ID already processed."
)

if (
self.sender.encrypt
Expand Down Expand Up @@ -767,6 +789,18 @@ def parse(

return status, exception, mdn

def parse(self, *args, **kwargs):
"""
A synchronous wrapper for the asynchronous parse method.
It runs the parse coroutine in an event loop and returns the result.
"""
loop = asyncio.get_event_loop()
if loop.is_running():
raise RuntimeError(
"Cannot run synchronous parse within an already running event loop, use aparse."
)
return loop.run_until_complete(self.aparse(*args, **kwargs))


class Mdn:
"""Class for handling AS2 MDNs. Includes functions for both
Expand Down Expand Up @@ -945,7 +979,7 @@ def build(
f"content:\n {mime_to_bytes(self.payload)}"
)

def parse(self, raw_content, find_message_cb):
async def aparse(self, raw_content, find_message_cb):
"""Function parses the RAW AS2 MDN, verifies it and extracts the
processing status of the orginal AS2 message.
Expand All @@ -970,7 +1004,17 @@ def parse(self, raw_content, find_message_cb):
self.orig_message_id, orig_recipient = self.detect_mdn()

# Call the find message callback which should return a Message instance
orig_message = find_message_cb(self.orig_message_id, orig_recipient)
if inspect.iscoroutinefunction(find_message_cb):
orig_message = await find_message_cb(
self.orig_message_id, orig_recipient
)
else:
orig_message = find_message_cb(self.orig_message_id, orig_recipient)

if not orig_message:
status = "failed/Failure"
details_status = "original-message-not-found"
return status, details_status

if not orig_message:
status = "failed/Failure"
Expand Down Expand Up @@ -1053,6 +1097,18 @@ def parse(self, raw_content, find_message_cb):
logger.error(f"Failed to parse AS2 MDN\n: {traceback.format_exc()}")
return status, detailed_status

def parse(self, *args, **kwargs):
"""
A synchronous wrapper for the asynchronous parse method.
It runs the parse coroutine in an event loop and returns the result.
"""
loop = asyncio.get_event_loop()
if loop.is_running():
raise RuntimeError(
"Cannot run synchronous parse within an already running event loop, use aparse."
)
return loop.run_until_complete(self.aparse(*args, **kwargs))

def detect_mdn(self):
"""Function checks if the received raw message is an AS2 MDN or not.
Expand Down
140 changes: 140 additions & 0 deletions pyas2lib/tests/test_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os

import pytest

from pyas2lib import as2
from pyas2lib.tests import TEST_DIR

with open(os.path.join(TEST_DIR, "payload.txt"), "rb") as fp:
test_data = fp.read()

with open(os.path.join(TEST_DIR, "cert_test.p12"), "rb") as fp:
private_key = fp.read()

with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp:
public_key = fp.read()

org = as2.Organization(
as2_name="some_organization",
sign_key=private_key,
sign_key_pass="test",
decrypt_key=private_key,
decrypt_key_pass="test",
)
partner = as2.Partner(
as2_name="some_partner",
verify_cert=public_key,
encrypt_cert=public_key,
)


async def afind_org(headers):
return org


async def afind_partner(headers):
return partner


async def afind_duplicate_message(message_id, message_recipient):
return True


async def afind_org_partner(as2_org, as2_partner):
return org, partner


@pytest.mark.asyncio
async def test_async_callbacks_with_duplicate_message():
"""Test case where async callbacks are used and a duplicate message is sent to the partner"""

# Build an As2 message to be transmitted to partner
partner.sign = True
partner.encrypt = True
partner.mdn_mode = as2.SYNCHRONOUS_MDN
out_message = as2.Message(org, partner)
out_message.build(test_data)

async def afind_message(message_id, message_recipient):
return out_message

# Parse the generated AS2 message as the partner
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content
in_message = as2.Message()
_, _, mdn = await in_message.aparse(
raw_out_message,
find_org_cb=afind_org,
find_partner_cb=afind_partner,
find_message_cb=afind_duplicate_message,
)

out_mdn = as2.Mdn()
status, detailed_status = await out_mdn.aparse(
mdn.headers_str + b"\r\n" + mdn.content,
find_message_cb=afind_message,
)
assert status == "processed/Warning"
assert detailed_status == "duplicate-document"


@pytest.mark.asyncio
async def test_async_partnership():
"""Test Async Partnership callback"""

# Build an As2 message to be transmitted to partner
out_message = as2.Message(org, partner)
out_message.build(test_data)
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content

# Parse the generated AS2 message as the partner
in_message = as2.Message()
status, _, _ = await in_message.aparse(
raw_out_message, find_org_partner_cb=afind_org_partner
)

# Compare contents of the input and output messages
assert status == "processed"


@pytest.mark.asyncio
async def test_runtime_error():
"""Test to get Runtime error when calling parse instead of aparse from Async Context"""

with pytest.raises(
RuntimeError,
match="Cannot run synchronous parse within an already running event loop, use aparse.",
):
out_message = as2.Message(org, partner)
out_message.build(test_data)
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content

in_message = as2.Message()
status, _, _ = in_message.parse(
raw_out_message, find_org_partner_cb=afind_org_partner
)

with pytest.raises(
RuntimeError,
match="Cannot run synchronous parse within an already running event loop, use aparse.",
):
partner.sign = True
partner.encrypt = True
partner.mdn_mode = as2.SYNCHRONOUS_MDN
out_message = as2.Message(org, partner)
out_message.build(test_data)

# Parse the generated AS2 message as the partner
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content
in_message = as2.Message()
_, _, mdn = await in_message.aparse(
raw_out_message,
find_org_cb=afind_org,
find_partner_cb=afind_partner,
find_message_cb=afind_duplicate_message,
)

out_mdn = as2.Mdn()
_, _ = out_mdn.parse(
mdn.headers_str + b"\r\n" + mdn.content,
find_message_cb=afind_duplicate_message,
)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
]

tests_require = [
"pytest==6.2.5",
"pytest==7.4.4",
"pytest-asyncio==0.21.1",
"toml==0.10.2",
"pytest-cov==2.8.1",
"coverage==5.0.4",
Expand Down

0 comments on commit 04cf4c3

Please sign in to comment.