Skip to content

Commit

Permalink
Adding Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chadgates committed Mar 19, 2024
1 parent 23a8c39 commit c292302
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 21 deletions.
34 changes: 16 additions & 18 deletions pyas2lib/cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,26 +211,24 @@ def decrypt_message(encrypted_data, decryption_key):
key_enc_alg = recipient_info["key_encryption_algorithm"]["algorithm"].native
encrypted_key = recipient_info["encrypted_key"].native

if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId(
"rsaes_pkcs1v15"
):
try:
try:
if cms.KeyEncryptionAlgorithmId(
key_enc_alg
) == cms.KeyEncryptionAlgorithmId("rsaes_pkcs1v15"):
key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key)
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: Could not extract decryption key."
) from e
elif cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId(
"rsaes_oaep"
):
try:

elif cms.KeyEncryptionAlgorithmId(
key_enc_alg
) == cms.KeyEncryptionAlgorithmId("rsaes_oaep"):
key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key)
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: Could not extract decryption key."
) from e
else:
raise AS2Exception(f"Unsupported Key Encryption Algorithm {key_enc_alg}")
else:
raise AS2Exception(
f"Unsupported Key Encryption Algorithm {key_enc_alg}"
)
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: Could not extract decryption key."
) from e

alg = cms_content["content"]["encrypted_content_info"][
"content_encryption_algorithm"
Expand Down
27 changes: 27 additions & 0 deletions pyas2lib/tests/test_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,33 @@ def test_mdn_not_found(self):
self.assertEqual(status, "failed/Failure")
self.assertEqual(detailed_status, "mdn-not-found")

def test_mdn_original_message_not_found(self):
"""Test that the MDN parser raises MDN not found when a non MDN message is passed."""
self.partner.mdn_mode = as2.SYNCHRONOUS_MDN
self.out_message = as2.Message(self.org, self.partner)
self.out_message.build(self.test_data)

# Parse the generated AS2 message as the partner
raw_out_message = (
self.out_message.headers_str + b"\r\n" + self.out_message.content
)
in_message = as2.Message()
_, _, mdn = in_message.parse(
raw_out_message,
find_org_cb=self.find_org,
find_partner_cb=self.find_partner,
find_message_cb=lambda x, y: False,
)

# Parse the MDN
out_mdn = as2.Mdn()
status, detailed_status = out_mdn.parse(
mdn.headers_str + b"\r\n" + mdn.content, find_message_cb=lambda x, y: False
)

self.assertEqual(status, "failed/Failure")
self.assertEqual(detailed_status, "original-message-not-found")

def test_unsigned_mdn_sent_error(self):
"""Test the case where a signed mdn was expected but unsigned mdn was returned."""
self.partner.mdn_mode = as2.SYNCHRONOUS_MDN
Expand Down
35 changes: 34 additions & 1 deletion pyas2lib/tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ async def test_duplicate_message_async():
out_message = as2.Message(org, partner)
out_message.build(test_data)

async def afind_message(message_id, message_recipient):
return out_message

# Parse the generated AS2 message as the partner
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content
in_message = as2.Message()
Expand All @@ -67,7 +70,7 @@ async def test_duplicate_message_async():
out_mdn = as2.Mdn()
status, detailed_status = await out_mdn.aparse(
mdn.headers_str + b"\r\n" + mdn.content,
find_message_cb=lambda x, y: out_message,
find_message_cb=afind_message,
)
assert status == "processed/Warning"
assert detailed_status == "duplicate-document"
Expand All @@ -90,3 +93,33 @@ async def test_async_partnership():

# Compare contents of the input and output messages
assert status == "processed"


@pytest.mark.asyncio
async def test_runtime_error():
with pytest.raises(RuntimeError):
out_message = as2.Message(org, partner)
out_message.build(test_data)
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content

in_message = as2.Message()
status, _, _ = in_message.parse(
raw_out_message, find_org_partner_cb=afind_org_partner
)

with pytest.raises(RuntimeError):
partner.sign = True
partner.encrypt = True
partner.mdn_mode = as2.SYNCHRONOUS_MDN
out_message = as2.Message(org, partner)
out_message.build(test_data)

# Parse the generated AS2 message as the partner
raw_out_message = out_message.headers_str + b"\r\n" + out_message.content
in_message = as2.Message()
_, _, mdn = in_message.parse(
raw_out_message,
find_org_cb=afind_org,
find_partner_cb=afind_partner,
find_message_cb=afind_duplicate_message,
)
74 changes: 72 additions & 2 deletions pyas2lib/tests/test_cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os

import pytest
from oscrypto import asymmetric
from oscrypto import asymmetric, symmetric, util

from asn1crypto import algos, cms as crypto_cms, core

from pyas2lib.as2 import Organization
from pyas2lib import cms
Expand All @@ -22,6 +24,68 @@
).dump()


def encrypted_data_with_faulty_key_algo():
with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp:
encrypt_cert = asymmetric.load_certificate(fp.read())
enc_alg_list = "rc4_128_cbc".split("_")
cipher, key_length, _ = enc_alg_list[0], enc_alg_list[1], enc_alg_list[2]
key = util.rand_bytes(int(key_length) // 8)
algorithm_id = "1.2.840.113549.3.4"
encrypted_content = symmetric.rc4_encrypt(key, b"data")
enc_alg_asn1 = algos.EncryptionAlgorithm(
{
"algorithm": algorithm_id,
}
)
encrypted_key = asymmetric.rsa_oaep_encrypt(encrypt_cert, key)
return crypto_cms.ContentInfo(
{
"content_type": crypto_cms.ContentType("enveloped_data"),
"content": crypto_cms.EnvelopedData(
{
"version": crypto_cms.CMSVersion("v0"),
"recipient_infos": [
crypto_cms.KeyTransRecipientInfo(
{
"version": crypto_cms.CMSVersion("v0"),
"rid": crypto_cms.RecipientIdentifier(
{
"issuer_and_serial_number": crypto_cms.IssuerAndSerialNumber(
{
"issuer": encrypt_cert.asn1[
"tbs_certificate"
]["issuer"],
"serial_number": encrypt_cert.asn1[
"tbs_certificate"
]["serial_number"],
}
)
}
),
"key_encryption_algorithm": crypto_cms.KeyEncryptionAlgorithm(
{
"algorithm": crypto_cms.KeyEncryptionAlgorithmId(
"aes128_wrap"
)
}
),
"encrypted_key": crypto_cms.OctetString(encrypted_key),
}
)
],
"encrypted_content_info": crypto_cms.EncryptedContentInfo(
{
"content_type": crypto_cms.ContentType("data"),
"content_encryption_algorithm": enc_alg_asn1,
"encrypted_content": encrypted_content,
}
),
}
),
}
).dump()


def test_compress():
"""Test the compression and decompression functions."""
compressed_data = cms.compress_message(b"data")
Expand Down Expand Up @@ -87,6 +151,7 @@ def test_encryption():
"aes_128_cbc",
"aes_192_cbc",
"aes_256_cbc",
"tripledes_192_cbc",
]

key_enc_algos = [
Expand All @@ -95,7 +160,7 @@ def test_encryption():
]

encryption_algos = [
(alg, scheme) for alg, scheme in zip(enc_algorithms, key_enc_algos)
(alg, key_algo) for alg in enc_algorithms for key_algo in key_enc_algos
]

for enc_algorithm, encryption_scheme in encryption_algos:
Expand All @@ -117,3 +182,8 @@ def test_encryption():
# Test faulty key encryption algorithm
with pytest.raises(AS2Exception):
cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc")

# Test unsupported key encryption algorithm
encrypted_data = encrypted_data_with_faulty_key_algo()
with pytest.raises(AS2Exception):
cms.decrypt_message(encrypted_data, decrypt_key)

0 comments on commit c292302

Please sign in to comment.