diff --git a/pyas2lib/as2.py b/pyas2lib/as2.py index 9c74e54..e65c288 100644 --- a/pyas2lib/as2.py +++ b/pyas2lib/as2.py @@ -564,7 +564,14 @@ def _decompress_data(self, payload): return False, payload - def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None): + def parse( + self, + raw_content, + find_org_cb=None, + find_partner_cb=None, + find_message_cb=None, + find_org_partner_cb=None, + ): """Function parses the RAW AS2 message; decrypts, verifies and decompresses it and extracts the payload. @@ -572,18 +579,24 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) A byte string of the received HTTP headers followed by the body. :param find_org_cb: - A callback the returns an Organization object if exists. The + A conditional callback the returns an Organization object if exists. The as2-to header value is passed as an argument to it. :param find_partner_cb: - A callback the returns an Partner object if exists. The + A conditional callback the returns a Partner object if exists. The as2-from header value is passed as an argument to it. :param find_message_cb: - An optional callback the returns an Message object if exists in + An optional callback the returns a Message object if exists in order to check for duplicates. The message id and partner id is passed as arguments to it. + :param find_org_partner_cb: + A conditional callback that return Organization object and + Partner object if exist. The as2-to and as2-from header value + are passed as an argument to it. Must be provided + when find_org_cb and find_org_partner_cb is None. + :return: A three element tuple containing (status, (exception, traceback) , mdn). The status is a string indicating the status of the @@ -592,6 +605,18 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) the partner did not request it. """ + # Validate passed arguments + if not any( + [ + find_org_cb and find_partner_cb and not find_org_partner_cb, + find_org_partner_cb and not find_partner_cb and not find_org_cb, + ] + ): + raise TypeError( + "Incorrect arguments passed: either find_org_cb and find_partner_cb " + "or only find_org_partner_cb must be passed." + ) + # Parse the raw MIME message and extract its content and headers status, detailed_status, exception, mdn = "processed", None, (None, None), None self.payload = parse_mime(raw_content) @@ -605,14 +630,17 @@ def parse(self, raw_content, find_org_cb, find_partner_cb, find_message_cb=None) try: # Get the organization and partner for this transmission org_id = unquote_as2name(as2_headers["as2-to"]) - self.receiver = find_org_cb(org_id) - if not self.receiver: - raise PartnerNotFound(f"Unknown AS2 organization with id {org_id}") - partner_id = unquote_as2name(as2_headers["as2-from"]) - self.sender = find_partner_cb(partner_id) - if not self.sender: - raise PartnerNotFound(f"Unknown AS2 partner with id {partner_id}") + if find_org_partner_cb: + self.receiver, self.sender = find_org_partner_cb(org_id, partner_id) + elif find_org_cb and find_partner_cb: + self.receiver = find_org_cb(org_id) + if not self.receiver: + raise PartnerNotFound(f"Unknown AS2 organization with id {org_id}") + + self.sender = find_partner_cb(partner_id) + if not self.sender: + raise PartnerNotFound(f"Unknown AS2 partner with id {partner_id}") if find_message_cb and find_message_cb(self.message_id, partner_id): raise DuplicateDocument( diff --git a/pyas2lib/tests/test_basic.py b/pyas2lib/tests/test_basic.py index fed94b9..87f09cb 100644 --- a/pyas2lib/tests/test_basic.py +++ b/pyas2lib/tests/test_basic.py @@ -184,6 +184,30 @@ def test_encrypted_signed_compressed_message(self): self.assertEqual(out_message.mic, in_message.mic) self.assertEqual(self.test_data.splitlines(), in_message.content.splitlines()) + def test_encrypted_signed_message_partnership(self): + """Test Encrypted Signed Uncompressed Message with Partnership""" + + # Build an As2 message to be transmitted to partner + self.partner.sign = True + self.partner.encrypt = True + out_message = as2.Message(self.org, self.partner) + out_message.build(self.test_data) + raw_out_message = out_message.headers_str + b"\r\n" + out_message.content + + # Parse the generated AS2 message as the partner + in_message = as2.Message() + status, _, _ = in_message.parse( + raw_out_message, + find_org_partner_cb=self.find_org_partner, + ) + + # Compare the mic contents of the input and output messages + self.assertEqual(status, "processed") + self.assertTrue(in_message.signed) + self.assertTrue(in_message.encrypted) + self.assertEqual(out_message.mic, in_message.mic) + self.assertEqual(self.test_data.splitlines(), in_message.content.splitlines()) + def test_plain_message_with_domain(self): """Test Message building with an org domain""" @@ -229,3 +253,6 @@ def find_org(self, as2_id): def find_partner(self, as2_id): return self.partner + + def find_org_partner(self, as2_org, as2_partner): + return self.org, self.partner