Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/rsaes oaep #68

Merged
merged 10 commits into from
May 2, 2024
11 changes: 11 additions & 0 deletions pyas2lib/as2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
DIGEST_ALGORITHMS,
EDIINT_FEATURES,
ENCRYPTION_ALGORITHMS,
KEY_ENCRYPTION_ALGORITHMS,
MDN_CONFIRM_TEXT,
MDN_FAILED_TEXT,
MDN_MODES,
Expand Down Expand Up @@ -183,6 +184,9 @@ class Partner:
:param sign_alg: The signing algorithm to be used for generating the
signature. (default `rsassa_pkcs1v15`)

:param key_enc_alg: The key encryption algorithm to be used.
(default `rsaes_pkcs1v15`)

"""

as2_name: str
Expand All @@ -202,6 +206,7 @@ class Partner:
ignore_self_signed: bool = True
canonicalize_as_binary: bool = False
sign_alg: str = "rsassa_pkcs1v15"
key_enc_alg: str = "rsaes_pkcs1v15"

def __post_init__(self):
"""Run the post initialisation checks for this class."""
Expand Down Expand Up @@ -236,6 +241,12 @@ def __post_init__(self):
f"must be one of {SIGNATUR_ALGORITHMS}"
)

if self.key_enc_alg and self.key_enc_alg not in KEY_ENCRYPTION_ALGORITHMS:
raise ImproperlyConfigured(
f"Unsupported Key Encryption Algorithm {self.key_enc_alg}, "
f"must be one of {KEY_ENCRYPTION_ALGORITHMS}"
)

def load_verify_cert(self):
"""Load the verification certificate of the partner and returned the parsed cert."""
if self.validate_certs:
Expand Down
101 changes: 59 additions & 42 deletions pyas2lib/cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ def decompress_message(compressed_data):
raise DecompressionError("Decompression failed with cause: {}".format(e)) from e


def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
def encrypt_message(
data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15"
):
"""Function encrypts data and returns the generated ASN.1

:param data_to_encrypt: A byte string of the data to be encrypted
:param enc_alg: The algorithm to be used for encrypting the data
:param encryption_cert: The certificate to be used for encrypting the data
:param key_enc_alg: The algo for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep

:return: A CMS ASN.1 byte string of the encrypted data.
"""
Expand Down Expand Up @@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
raise AS2Exception("Unsupported Encryption Algorithm")

# Encrypt the key and build the ASN.1 message
encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key)
if key_enc_alg == "rsaes_pkcs1v15":
encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key)
elif key_enc_alg == "rsaes_oaep":
encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key)
else:
raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}")

return cms.ContentInfo(
{
Expand All @@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
}
),
"key_encryption_algorithm": cms.KeyEncryptionAlgorithm(
{"algorithm": cms.KeyEncryptionAlgorithmId("rsa")}
{
"algorithm": cms.KeyEncryptionAlgorithmId(
abhishek-ram marked this conversation as resolved.
Show resolved Hide resolved
key_enc_alg
)
}
),
"encrypted_key": cms.OctetString(encrypted_key),
}
Expand Down Expand Up @@ -199,47 +211,52 @@ def decrypt_message(encrypted_data, decryption_key):
key_enc_alg = recipient_info["key_encryption_algorithm"]["algorithm"].native
encrypted_key = recipient_info["encrypted_key"].native

if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId(
"rsa"
):
try:
try:
if cms.KeyEncryptionAlgorithmId(
key_enc_alg
) == cms.KeyEncryptionAlgorithmId("rsaes_pkcs1v15"):
key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key)
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: Could not extract decryption key."
) from e

alg = cms_content["content"]["encrypted_content_info"][
"content_encryption_algorithm"
]
encapsulated_data = cms_content["content"]["encrypted_content_info"][
"encrypted_content"
].native

try:
if alg["algorithm"].native == "rc4":
decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data)
elif alg.encryption_cipher == "tripledes":
cipher = "tripledes_192_cbc"
decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "aes":
decrypted_content = symmetric.aes_cbc_pkcs7_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "rc2":
decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt(
key, encapsulated_data, alg["parameters"]["iv"].native
)
else:
raise AS2Exception("Unsupported Encryption Algorithm")
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: {}".format(e)
) from e
else:
raise AS2Exception("Unsupported Encryption Algorithm")
elif cms.KeyEncryptionAlgorithmId(
key_enc_alg
) == cms.KeyEncryptionAlgorithmId("rsaes_oaep"):
key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key)
else:
raise AS2Exception(
f"Unsupported Key Encryption Algorithm {key_enc_alg}"
)
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: Could not extract decryption key."
) from e

alg = cms_content["content"]["encrypted_content_info"][
"content_encryption_algorithm"
]
encapsulated_data = cms_content["content"]["encrypted_content_info"][
"encrypted_content"
].native

try:
if alg["algorithm"].native == "rc4":
decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data)
elif alg.encryption_cipher == "tripledes":
cipher = "tripledes_192_cbc"
decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "aes":
decrypted_content = symmetric.aes_cbc_pkcs7_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "rc2":
decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt(
key, encapsulated_data, alg["parameters"]["iv"].native
)
else:
raise AS2Exception("Unsupported Encryption Algorithm")
except Exception as e:
raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e
else:
raise DecryptionError("Encrypted data not found in ASN.1 ")

Expand Down
4 changes: 4 additions & 0 deletions pyas2lib/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,7 @@
"rsassa_pkcs1v15",
"rsassa_pss",
)
KEY_ENCRYPTION_ALGORITHMS = (
"rsaes_pkcs1v15",
"rsaes_oaep",
)
3 changes: 3 additions & 0 deletions pyas2lib/tests/test_advanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ def test_partner_checks(self):
with self.assertRaises(ImproperlyConfigured):
as2.Partner("a partner", sign_alg="xyz")

with self.assertRaises(ImproperlyConfigured):
as2.Partner("a partner", key_enc_alg="xyz")

def test_message_checks(self):
"""Test the checks and other features of Message."""
msg = as2.Message()
Expand Down
92 changes: 89 additions & 3 deletions pyas2lib/tests/test_cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import os

import pytest
from oscrypto import asymmetric
from oscrypto import asymmetric, symmetric, util

from asn1crypto import algos, cms as crypto_cms, core

from pyas2lib.as2 import Organization
from pyas2lib import cms
Expand All @@ -22,6 +24,68 @@
).dump()


def encrypted_data_with_faulty_key_algo():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_encrypted_data_with_faulty_key_algo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

with open(os.path.join(TEST_DIR, "cert_test_public.pem"), "rb") as fp:
encrypt_cert = asymmetric.load_certificate(fp.read())
enc_alg_list = "rc4_128_cbc".split("_")
cipher, key_length, _ = enc_alg_list[0], enc_alg_list[1], enc_alg_list[2]
key = util.rand_bytes(int(key_length) // 8)
algorithm_id = "1.2.840.113549.3.4"
encrypted_content = symmetric.rc4_encrypt(key, b"data")
enc_alg_asn1 = algos.EncryptionAlgorithm(
{
"algorithm": algorithm_id,
}
)
encrypted_key = asymmetric.rsa_oaep_encrypt(encrypt_cert, key)
return crypto_cms.ContentInfo(
{
"content_type": crypto_cms.ContentType("enveloped_data"),
"content": crypto_cms.EnvelopedData(
{
"version": crypto_cms.CMSVersion("v0"),
"recipient_infos": [
crypto_cms.KeyTransRecipientInfo(
{
"version": crypto_cms.CMSVersion("v0"),
"rid": crypto_cms.RecipientIdentifier(
{
"issuer_and_serial_number": crypto_cms.IssuerAndSerialNumber(
{
"issuer": encrypt_cert.asn1[
"tbs_certificate"
]["issuer"],
"serial_number": encrypt_cert.asn1[
"tbs_certificate"
]["serial_number"],
}
)
}
),
"key_encryption_algorithm": crypto_cms.KeyEncryptionAlgorithm(
{
"algorithm": crypto_cms.KeyEncryptionAlgorithmId(
"aes128_wrap"
)
}
),
"encrypted_key": crypto_cms.OctetString(encrypted_key),
}
)
],
"encrypted_content_info": crypto_cms.EncryptedContentInfo(
{
"content_type": crypto_cms.ContentType("data"),
"content_encryption_algorithm": enc_alg_asn1,
"encrypted_content": encrypted_content,
}
),
}
),
}
).dump()


def test_compress():
"""Test the compression and decompression functions."""
compressed_data = cms.compress_message(b"data")
Expand Down Expand Up @@ -87,9 +151,22 @@ def test_encryption():
"aes_128_cbc",
"aes_192_cbc",
"aes_256_cbc",
"tripledes_192_cbc",
]

key_enc_algos = [
"rsaes_oaep",
"rsaes_pkcs1v15",
]
for enc_algorithm in enc_algorithms:
encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert)

encryption_algos = [
(alg, key_algo) for alg in enc_algorithms for key_algo in key_enc_algos
]

for enc_algorithm, encryption_scheme in encryption_algos:
encrypted_data = cms.encrypt_message(
b"data", enc_algorithm, encrypt_cert, encryption_scheme
)
_, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key)
assert decrypted_data == b"data"

Expand All @@ -101,3 +178,12 @@ def test_encryption():
encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert)
with pytest.raises(AS2Exception):
cms.decrypt_message(encrypted_data, decrypt_key)

# Test faulty key encryption algorithm
with pytest.raises(AS2Exception):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert the message as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc")

# Test unsupported key encryption algorithm
encrypted_data = encrypted_data_with_faulty_key_algo()
with pytest.raises(AS2Exception):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert the message as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

cms.decrypt_message(encrypted_data, decrypt_key)
Loading