Skip to content

Commit

Permalink
Adding key encryption algo option to support rsaes_oaep.
Browse files Browse the repository at this point in the history
  • Loading branch information
chadgates committed Mar 18, 2024
1 parent 2e4c253 commit 0b39518
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 33 deletions.
81 changes: 50 additions & 31 deletions pyas2lib/cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,15 @@ def decompress_message(compressed_data):
raise DecompressionError("Decompression failed with cause: {}".format(e)) from e


def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
def encrypt_message(
data_to_encrypt, enc_alg, encryption_cert, key_enc_alg="rsaes_pkcs1v15"
):
"""Function encrypts data and returns the generated ASN.1
:param data_to_encrypt: A byte string of the data to be encrypted
:param enc_alg: The algorithm to be used for encrypting the data
:param encryption_cert: The certificate to be used for encrypting the data
:param key_enc_alg: The encryption scheme for the key encryption: rsaes_pkcs1v15 (default) or rsaes_oaep
:return: A CMS ASN.1 byte string of the encrypted data.
"""
Expand Down Expand Up @@ -136,7 +139,12 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
raise AS2Exception("Unsupported Encryption Algorithm")

# Encrypt the key and build the ASN.1 message
encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key)
if key_enc_alg == "rsaes_pkcs1v15":
encrypted_key = asymmetric.rsa_pkcs1v15_encrypt(encryption_cert, key)
elif key_enc_alg == "rsaes_oaep":
encrypted_key = asymmetric.rsa_oaep_encrypt(encryption_cert, key)
else:
raise AS2Exception(f"Unsupported Key Encryption Scheme: {key_enc_alg}")

return cms.ContentInfo(
{
Expand All @@ -163,7 +171,11 @@ def encrypt_message(data_to_encrypt, enc_alg, encryption_cert):
}
),
"key_encryption_algorithm": cms.KeyEncryptionAlgorithm(
{"algorithm": cms.KeyEncryptionAlgorithmId("rsa")}
{
"algorithm": cms.KeyEncryptionAlgorithmId(
key_enc_alg
)
}
),
"encrypted_key": cms.OctetString(encrypted_key),
}
Expand Down Expand Up @@ -200,46 +212,53 @@ def decrypt_message(encrypted_data, decryption_key):
encrypted_key = recipient_info["encrypted_key"].native

if cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId(
"rsa"
"rsaes_pkcs1v15"
):
try:
key = asymmetric.rsa_pkcs1v15_decrypt(decryption_key[0], encrypted_key)
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: Could not extract decryption key."
) from e

alg = cms_content["content"]["encrypted_content_info"][
"content_encryption_algorithm"
]
encapsulated_data = cms_content["content"]["encrypted_content_info"][
"encrypted_content"
].native

elif cms.KeyEncryptionAlgorithmId(key_enc_alg) == cms.KeyEncryptionAlgorithmId(
"rsaes_oaep"
):
try:
if alg["algorithm"].native == "rc4":
decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data)
elif alg.encryption_cipher == "tripledes":
cipher = "tripledes_192_cbc"
decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "aes":
decrypted_content = symmetric.aes_cbc_pkcs7_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "rc2":
decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt(
key, encapsulated_data, alg["parameters"]["iv"].native
)
else:
raise AS2Exception("Unsupported Encryption Algorithm")
key = asymmetric.rsa_oaep_decrypt(decryption_key[0], encrypted_key)
except Exception as e:
raise DecryptionError(
"Failed to decrypt the payload: {}".format(e)
"Failed to decrypt the payload: Could not extract decryption key."
) from e
else:
raise AS2Exception("Unsupported Encryption Algorithm")
raise AS2Exception(f"Unsupported Key Encryption Algorithm {key_enc_alg}")

alg = cms_content["content"]["encrypted_content_info"][
"content_encryption_algorithm"
]
encapsulated_data = cms_content["content"]["encrypted_content_info"][
"encrypted_content"
].native

try:
if alg["algorithm"].native == "rc4":
decrypted_content = symmetric.rc4_decrypt(key, encapsulated_data)
elif alg.encryption_cipher == "tripledes":
cipher = "tripledes_192_cbc"
decrypted_content = symmetric.tripledes_cbc_pkcs5_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "aes":
decrypted_content = symmetric.aes_cbc_pkcs7_decrypt(
key, encapsulated_data, alg.encryption_iv
)
elif alg.encryption_cipher == "rc2":
decrypted_content = symmetric.rc2_cbc_pkcs5_decrypt(
key, encapsulated_data, alg["parameters"]["iv"].native
)
else:
raise AS2Exception("Unsupported Encryption Algorithm")
except Exception as e:
raise DecryptionError("Failed to decrypt the payload: {}".format(e)) from e
else:
raise DecryptionError("Encrypted data not found in ASN.1 ")

Expand Down
20 changes: 18 additions & 2 deletions pyas2lib/tests/test_cms.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,20 @@ def test_encryption():
"aes_192_cbc",
"aes_256_cbc",
]
for enc_algorithm in enc_algorithms:
encrypted_data = cms.encrypt_message(b"data", enc_algorithm, encrypt_cert)

key_enc_algos = [
"rsaes_oaep",
"rsaes_pkcs1v15",
]

encryption_algos = [
(alg, scheme) for alg, scheme in zip(enc_algorithms, key_enc_algos)
]

for enc_algorithm, encryption_scheme in encryption_algos:
encrypted_data = cms.encrypt_message(
b"data", enc_algorithm, encrypt_cert, encryption_scheme
)
_, decrypted_data = cms.decrypt_message(encrypted_data, decrypt_key)
assert decrypted_data == b"data"

Expand All @@ -101,3 +113,7 @@ def test_encryption():
encrypted_data = cms.encrypt_message(b"data", "des_64_cbc", encrypt_cert)
with pytest.raises(AS2Exception):
cms.decrypt_message(encrypted_data, decrypt_key)

# Test faulty key encryption algorithm
with pytest.raises(AS2Exception):
cms.encrypt_message(b"data", "rc2_128_cbc", encrypt_cert, "des_64_cbc")

0 comments on commit 0b39518

Please sign in to comment.