Skip to content

Commit

Permalink
Extending partner to accept key encryption algo and pass that down.
Browse files Browse the repository at this point in the history
* Fix github action build fail due to:
https://stackoverflow.com/questions/71673404/importerror-cannot-import-name-unicodefun-from-click

* Added partner setting to force canonicalize binary.

* Formatted with black

* #62

* Asserting error messages and _encrypted_data_with_faulty_key_algo
  • Loading branch information
chadgates authored May 2, 2024
1 parent 499fd03 commit 9b400b1
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 45 deletions.
11 changes: 11 additions & 0 deletions pyas2lib/as2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DIGEST_ALGORITHMS,
EDIINT_FEATURES,
ENCRYPTION_ALGORITHMS,
KEY_ENCRYPTION_ALGORITHMS,
MDN_CONFIRM_TEXT,
MDN_FAILED_TEXT,
MDN_MODES,
Expand Down Expand Up @@ -183,6 +184,9 @@ class Partner:
:param sign_alg: The signing algorithm to be used for generating the
signature. (default `rsassa_pkcs1v15`)
:param key_enc_alg: The key encryption algorithm to be used.
(default `rsaes_pkcs1v15`)
"""

as2_name: str
Expand All @@ -202,6 +206,7 @@ class Partner:
ignore_self_signed: bool = True
canonicalize_as_binary: bool = False
sign_alg: str = "rsassa_pkcs1v15"
key_enc_alg: str = "rsaes_pkcs1v15"

def __post_init__(self):
"""Run the post initialisation checks for this class."""
Expand Down Expand Up @@ -236,6 +241,12 @@ def __post_init__(self):
f"must be one of {SIGNATUR_ALGORITHMS}"
)

if self.key_enc_alg and self.key_enc_alg not in KEY_ENCRYPTION_ALGORITHMS:
raise ImproperlyConfigured(
f"Unsupported Key Encryption Algorithm {self.key_enc_alg}, "
f"must be one of {KEY_ENCRYPTION_ALGORITHMS}"
)

def load_verify_cert(self):
"""Load the verification certificate of the partner and returned the parsed cert."""
if self.validate_certs:
Expand Down
101 changes: 59 additions & 42 deletions pyas2lib/cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ def decompress_message(compressed_data):
raise DecompressionError("Decompression failed with cause: {}".format(e)) from e


def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
def encrypt_message(
data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15"
):
"""Function encrypts data and returns the generated ASN.1
:param data_to_encrypt: A byte string of the data to be encrypted
:param enc_alg: The algorithm to be used for encrypting the data
:param encryption_cert: The certificate to be used for encrypting the data
:param key_enc_alg: The algo for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep
:return: A CMS ASN.1 byte string of the encrypted data.
"""
Expand Down Expand Up @@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
raise AS2Exception("Unsupported Encryption Algorithm")

# Encrypt the key and build the ASN.1 message
encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key)
if key_enc_alg == "rsaes_pkcs1v15":
encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key)
elif key_enc_alg == "rsaes_oaep":
encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key)
else:
raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}")

return cms.ContentInfo(
{
Expand All @@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
}
),
"key_encryption_algorithm": cms.KeyEncryptionAlgorithm(
{"algorithm": cms.KeyEncryptionAlgorithmId("rsa")}
{
"algorithm": cms.KeyEncryptionAlgorithmId(
key_enc_alg
)
}
),
"encrypted_key": cms.OctetString(encrypted_key),
}
Expand Down Expand Up @@ -199,47 +211,52 @@ def decrypt_message(encrypted_data, decryption_key):
key_enc_alg = recipient_info["key_encryption_algorithm"]["algorithm"].native
encrypted_key = recipient_info["encrypted_key"].native

if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId(
"rsa"
):
try:
try:
if cms.KeyEncryptionAlgorithmId(
key_enc_alg
) == cms.KeyEncryptionAlgorithmId("rsaes_pkcs1v15"):
key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key)
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: Could not extract decryption key."
) from e

alg = cms_content["content"]["encrypted_content_info"][
"content_encryption_algorithm"
]
encapsulated_data = cms_content["content"]["encrypted_content_info"][
"encrypted_content"
].native

try:
if alg["algorithm"].native == "rc4":
decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data)
elif alg.encryption_cipher == "tripledes":
cipher = "tripledes_192_cbc"
decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "aes":
decrypted_content = symmetric.aes_cbc_pkcs7_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "rc2":
decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt(
key, encapsulated_data, alg["parameters"]["iv"].native
)
else:
raise AS2Exception("Unsupported Encryption Algorithm")
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: {}".format(e)
) from e
else:
raise AS2Exception("Unsupported Encryption Algorithm")
elif cms.KeyEncryptionAlgorithmId(
key_enc_alg
) == cms.KeyEncryptionAlgorithmId("rsaes_oaep"):
key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key)
else:
raise AS2Exception(
f"Unsupported Key Encryption Algorithm {key_enc_alg}"
)
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: Could not extract decryption key."
) from e

alg = cms_content["content"]["encrypted_content_info"][
"content_encryption_algorithm"
]
encapsulated_data = cms_content["content"]["encrypted_content_info"][
"encrypted_content"
].native

try:
if alg["algorithm"].native == "rc4":
decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data)
elif alg.encryption_cipher == "tripledes":
cipher = "tripledes_192_cbc"
decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "aes":
decrypted_content = symmetric.aes_cbc_pkcs7_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "rc2":
decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt(
key, encapsulated_data, alg["parameters"]["iv"].native
)
else:
raise AS2Exception("Unsupported Encryption Algorithm")
except Exception as e:
raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e
else:
raise DecryptionError("Encrypted data not found in ASN.1 ")

Expand Down
4 changes: 4 additions & 0 deletions pyas2lib/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@
"rsassa_pkcs1v15",
"rsassa_pss",
)
KEY_ENCRYPTION_ALGORITHMS = (
"rsaes_pkcs1v15",
"rsaes_oaep",
)
3 changes: 3 additions & 0 deletions pyas2lib/tests/test_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def test_partner_checks(self):
with self.assertRaises(ImproperlyConfigured):
as2.Partner("a partner", sign_alg="xyz")

with self.assertRaises(ImproperlyConfigured):
as2.Partner("a partner", key_enc_alg="xyz")

def test_message_checks(self):
"""Test the checks and other features of Message."""
msg = as2.Message()
Expand Down
97 changes: 94 additions & 3 deletions pyas2lib/tests/test_cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os

import pytest
from oscrypto import asymmetric
from oscrypto import asymmetric, symmetric, util

from asn1crypto import algos, cms as crypto_cms, core

from pyas2lib.as2 import Organization
from pyas2lib import cms
Expand All @@ -22,6 +24,68 @@
).dump()


def _encrypted_data_with_faulty_key_algo():
with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp:
encrypt_cert = asymmetric.load_certificate(fp.read())
enc_alg_list = "rc4_128_cbc".split("_")
cipher, key_length, _ = enc_alg_list[0], enc_alg_list[1], enc_alg_list[2]
key = util.rand_bytes(int(key_length) // 8)
algorithm_id = "1.2.840.113549.3.4"
encrypted_content = symmetric.rc4_encrypt(key, b"data")
enc_alg_asn1 = algos.EncryptionAlgorithm(
{
"algorithm": algorithm_id,
}
)
encrypted_key = asymmetric.rsa_oaep_encrypt(encrypt_cert, key)
return crypto_cms.ContentInfo(
{
"content_type": crypto_cms.ContentType("enveloped_data"),
"content": crypto_cms.EnvelopedData(
{
"version": crypto_cms.CMSVersion("v0"),
"recipient_infos": [
crypto_cms.KeyTransRecipientInfo(
{
"version": crypto_cms.CMSVersion("v0"),
"rid": crypto_cms.RecipientIdentifier(
{
"issuer_and_serial_number": crypto_cms.IssuerAndSerialNumber(
{
"issuer": encrypt_cert.asn1[
"tbs_certificate"
]["issuer"],
"serial_number": encrypt_cert.asn1[
"tbs_certificate"
]["serial_number"],
}
)
}
),
"key_encryption_algorithm": crypto_cms.KeyEncryptionAlgorithm(
{
"algorithm": crypto_cms.KeyEncryptionAlgorithmId(
"aes128_wrap"
)
}
),
"encrypted_key": crypto_cms.OctetString(encrypted_key),
}
)
],
"encrypted_content_info": crypto_cms.EncryptedContentInfo(
{
"content_type": crypto_cms.ContentType("data"),
"content_encryption_algorithm": enc_alg_asn1,
"encrypted_content": encrypted_content,
}
),
}
),
}
).dump()


def test_compress():
"""Test the compression and decompression functions."""
compressed_data = cms.compress_message(b"data")
Expand Down Expand Up @@ -87,9 +151,22 @@ def test_encryption():
"aes_128_cbc",
"aes_192_cbc",
"aes_256_cbc",
"tripledes_192_cbc",
]

key_enc_algos = [
"rsaes_oaep",
"rsaes_pkcs1v15",
]

encryption_algos = [
(alg, key_algo) for alg in enc_algorithms for key_algo in key_enc_algos
]
for enc_algorithm in enc_algorithms:
encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert)

for enc_algorithm, encryption_scheme in encryption_algos:
encrypted_data = cms.encrypt_message(
b"data", enc_algorithm, encrypt_cert, encryption_scheme
)
_, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key)
assert decrypted_data == b"data"

Expand All @@ -101,3 +178,17 @@ def test_encryption():
encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert)
with pytest.raises(AS2Exception):
cms.decrypt_message(encrypted_data, decrypt_key)

# Test faulty key encryption algorithm
with pytest.raises(
AS2Exception, match="Unsupported Key Encryption Scheme: des_64_cbc"
):
cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc")

# Test unsupported key encryption algorithm
encrypted_data = _encrypted_data_with_faulty_key_algo()
with pytest.raises(
AS2Exception,
match="Failed to decrypt the payload: Could not extract decryption key.",
):
cms.decrypt_message(encrypted_data, decrypt_key)

0 comments on commit 9b400b1

Please sign in to comment.