From 002f76ab754e39a551e362f3a043f94e72b050d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20B=C3=BChler?= Date: Tue, 23 Jan 2024 13:33:42 +0100 Subject: [PATCH] change cipher api to support not-inplace io --- crypto/cipher/aes_gcm_mbedtls.c | 32 ++++++++++----- crypto/cipher/aes_gcm_nss.c | 61 +++++++++++++++++++--------- crypto/cipher/aes_gcm_ossl.c | 26 +++++++----- crypto/cipher/aes_icm.c | 72 +++++++++++++++++++-------------- crypto/cipher/aes_icm_mbedtls.c | 16 ++++++-- crypto/cipher/aes_icm_nss.c | 10 +++-- crypto/cipher/aes_icm_ossl.c | 14 ++++--- crypto/cipher/cipher.c | 66 ++++++++++++++++++------------ crypto/cipher/null_cipher.c | 16 +++++--- crypto/include/cipher.h | 30 ++++++++------ crypto/test/cipher_driver.c | 6 +-- include/srtp.h | 3 +- srtp/srtp.c | 25 ++++++------ 13 files changed, 237 insertions(+), 140 deletions(-) diff --git a/crypto/cipher/aes_gcm_mbedtls.c b/crypto/cipher/aes_gcm_mbedtls.c index 2965d202e..5b190963b 100644 --- a/crypto/cipher/aes_gcm_mbedtls.c +++ b/crypto/cipher/aes_gcm_mbedtls.c @@ -281,8 +281,10 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_set_aad(void *cv, * enc_len length of encrypt buffer */ static srtp_err_status_t srtp_aes_gcm_mbedtls_encrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { FUNC_ENTRY(); srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv; @@ -292,9 +294,13 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_encrypt(void *cv, return (srtp_err_status_bad_param); } - errCode = mbedtls_gcm_crypt_and_tag(c->ctx, MBEDTLS_GCM_ENCRYPT, *enc_len, + if (*dst_len < src_len) { + return srtp_err_status_buffer_small; + } + + errCode = mbedtls_gcm_crypt_and_tag(c->ctx, MBEDTLS_GCM_ENCRYPT, src_len, c->iv, c->iv_len, c->aad, c->aad_size, - buf, buf, c->tag_len, c->tag); + src, dst, c->tag_len, c->tag); c->aad_size = 0; if (errCode != 0) { @@ -302,6 +308,8 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_encrypt(void *cv, return srtp_err_status_bad_param; } + *dst_len = src_len; + return (srtp_err_status_ok); } @@ -337,8 +345,10 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_get_tag(void *cv, * enc_len length of encrypt buffer */ static srtp_err_status_t srtp_aes_gcm_mbedtls_decrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { FUNC_ENTRY(); srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv; @@ -348,12 +358,16 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_decrypt(void *cv, return (srtp_err_status_bad_param); } + if (*dst_len < (src_len - c->tag_len)) { + return srtp_err_status_buffer_small; + } + debug_print(srtp_mod_aes_gcm, "AAD: %s", srtp_octet_string_hex_string(c->aad, c->aad_size)); errCode = mbedtls_gcm_auth_decrypt( - c->ctx, (*enc_len - c->tag_len), c->iv, c->iv_len, c->aad, c->aad_size, - buf + (*enc_len - c->tag_len), c->tag_len, buf, buf); + c->ctx, (src_len - c->tag_len), c->iv, c->iv_len, c->aad, c->aad_size, + src + (src_len - c->tag_len), c->tag_len, src, dst); c->aad_size = 0; if (errCode != 0) { return (srtp_err_status_auth_fail); @@ -363,7 +377,7 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_decrypt(void *cv, * Reduce the buffer size by the tag length since the tag * is not part of the original payload */ - *enc_len -= c->tag_len; + *dst_len = (src_len - c->tag_len); return (srtp_err_status_ok); } diff --git a/crypto/cipher/aes_gcm_nss.c b/crypto/cipher/aes_gcm_nss.c index 7f514f54d..4023853da 100644 --- a/crypto/cipher/aes_gcm_nss.c +++ b/crypto/cipher/aes_gcm_nss.c @@ -281,8 +281,10 @@ static srtp_err_status_t srtp_aes_gcm_nss_set_aad(void *cv, static srtp_err_status_t srtp_aes_gcm_nss_do_crypto(void *cv, bool encrypt, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv; @@ -299,13 +301,13 @@ static srtp_err_status_t srtp_aes_gcm_nss_do_crypto(void *cv, SECItem param = { siBuffer, (unsigned char *)&c->params, sizeof(CK_GCM_PARAMS) }; if (encrypt) { - rv = PK11_Encrypt(c->key, CKM_AES_GCM, ¶m, buf, &out_len, - *enc_len + 16, buf, *enc_len); + rv = PK11_Encrypt(c->key, CKM_AES_GCM, ¶m, dst, &out_len, *dst_len, + src, src_len); } else { - rv = PK11_Decrypt(c->key, CKM_AES_GCM, ¶m, buf, &out_len, *enc_len, - buf, *enc_len); + rv = PK11_Decrypt(c->key, CKM_AES_GCM, ¶m, dst, &out_len, *dst_len, + src, src_len); } - *enc_len = out_len; + *dst_len = out_len; srtp_err_status_t status = (srtp_err_status_ok); if (rv != SECSuccess) { status = (srtp_err_status_cipher_fail); @@ -328,32 +330,42 @@ static srtp_err_status_t srtp_aes_gcm_nss_do_crypto(void *cv, * enc_len length of encrypt buffer */ static srtp_err_status_t srtp_aes_gcm_nss_encrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv; + //#todo, this might need som looking at + // nss requires space for tag, currently we assume that ther is space, this + // should change, the best would be to merge the cipher encrypt and get_tag api + *dst_len += 16; + // When we get a non-NULL buffer, we know that the caller is // prepared to also take the tag. When we get a NULL buffer, // even though there's no data, we need to give NSS a buffer // where it can write the tag. We can't just use c->tag because // memcpy has undefined behavior on overlapping ranges. uint8_t tagbuf[16]; - uint8_t *non_null_buf = buf; - if (!non_null_buf && (*enc_len == 0)) { + const uint8_t *non_null_buf = src; + uint8_t *non_null_dst_buf = dst; + if (!non_null_buf && (src_len == 0)) { non_null_buf = tagbuf; + non_null_dst_buf = tagbuf; + *dst_len = sizeof(tagbuf); } else if (!non_null_buf) { return srtp_err_status_bad_param; } - srtp_err_status_t status = - srtp_aes_gcm_nss_do_crypto(cv, true, non_null_buf, enc_len); + srtp_err_status_t status = srtp_aes_gcm_nss_do_crypto( + cv, true, non_null_buf, src_len, non_null_dst_buf, dst_len); if (status != srtp_err_status_ok) { return status; } - memcpy(c->tag, non_null_buf + (*enc_len - c->tag_size), c->tag_size); - *enc_len -= c->tag_size; + memcpy(c->tag, non_null_dst_buf + (*dst_len - c->tag_size), c->tag_size); + *dst_len -= c->tag_size; return srtp_err_status_ok; } @@ -387,11 +399,22 @@ static srtp_err_status_t srtp_aes_gcm_nss_get_tag(void *cv, * enc_len length of encrypt buffer */ static srtp_err_status_t srtp_aes_gcm_nss_decrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { - srtp_err_status_t status = - srtp_aes_gcm_nss_do_crypto(cv, false, buf, enc_len); + uint8_t tagbuf[16]; + uint8_t *non_null_dst_buf = dst; + if (!non_null_dst_buf && (*dst_len == 0)) { + non_null_dst_buf = tagbuf; + *dst_len = sizeof(tagbuf); + } else if (!non_null_dst_buf) { + return srtp_err_status_bad_param; + } + + srtp_err_status_t status = srtp_aes_gcm_nss_do_crypto( + cv, false, src, src_len, non_null_dst_buf, dst_len); if (status != srtp_err_status_ok) { int err = PR_GetError(); if (err == SEC_ERROR_BAD_DATA) { diff --git a/crypto/cipher/aes_gcm_ossl.c b/crypto/cipher/aes_gcm_ossl.c index 6a56450db..692ab8957 100644 --- a/crypto/cipher/aes_gcm_ossl.c +++ b/crypto/cipher/aes_gcm_ossl.c @@ -293,8 +293,10 @@ static srtp_err_status_t srtp_aes_gcm_openssl_set_aad(void *cv, * enc_len length of encrypt buffer */ static srtp_err_status_t srtp_aes_gcm_openssl_encrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv; if (c->dir != srtp_direction_encrypt && c->dir != srtp_direction_decrypt) { @@ -304,7 +306,8 @@ static srtp_err_status_t srtp_aes_gcm_openssl_encrypt(void *cv, /* * Encrypt the data */ - EVP_Cipher(c->ctx, buf, buf, *enc_len); + EVP_Cipher(c->ctx, dst, src, src_len); + *dst_len = src_len; return (srtp_err_status_ok); } @@ -354,8 +357,10 @@ static srtp_err_status_t srtp_aes_gcm_openssl_get_tag(void *cv, * enc_len length of encrypt buffer */ static srtp_err_status_t srtp_aes_gcm_openssl_decrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv; if (c->dir != srtp_direction_encrypt && c->dir != srtp_direction_decrypt) { @@ -364,12 +369,15 @@ static srtp_err_status_t srtp_aes_gcm_openssl_decrypt(void *cv, /* * Set the tag before decrypting + * + * explicitly cast away const of src */ - if (!EVP_CIPHER_CTX_ctrl(c->ctx, EVP_CTRL_GCM_SET_TAG, c->tag_len, - buf + (*enc_len - c->tag_len))) { + if (!EVP_CIPHER_CTX_ctrl( + c->ctx, EVP_CTRL_GCM_SET_TAG, c->tag_len, + (void *)(uintptr_t)(src + (src_len - c->tag_len)))) { return (srtp_err_status_auth_fail); } - EVP_Cipher(c->ctx, buf, buf, *enc_len - c->tag_len); + EVP_Cipher(c->ctx, dst, src, src_len - c->tag_len); /* * Check the tag @@ -382,7 +390,7 @@ static srtp_err_status_t srtp_aes_gcm_openssl_decrypt(void *cv, * Reduce the buffer size by the tag length since the tag * is not part of the original payload */ - *enc_len -= c->tag_len; + *dst_len = src_len -= c->tag_len; return (srtp_err_status_ok); } diff --git a/crypto/cipher/aes_icm.c b/crypto/cipher/aes_icm.c index 744df6aae..a612edc2d 100644 --- a/crypto/cipher/aes_icm.c +++ b/crypto/cipher/aes_icm.c @@ -295,12 +295,20 @@ static void srtp_aes_icm_advance(srtp_aes_icm_ctx_t *c) */ static srtp_err_status_t srtp_aes_icm_encrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { srtp_aes_icm_ctx_t *c = (srtp_aes_icm_ctx_t *)cv; - size_t bytes_to_encr = *enc_len; + size_t bytes_to_encr = src_len; uint32_t *b; + const uint32_t *s; + + // check out length if not equal or greater bail! + *dst_len = src_len; + + unsigned char *buf = dst; /* check that there's enough segment left*/ size_t bytes_of_new_keystream = bytes_to_encr - c->bytes_in_buffer; @@ -314,7 +322,7 @@ static srtp_err_status_t srtp_aes_icm_encrypt(void *cv, /* deal with odd case of small bytes_to_encr */ for (size_t i = (sizeof(v128_t) - c->bytes_in_buffer); i < (sizeof(v128_t) - c->bytes_in_buffer + bytes_to_encr); i++) { - *buf++ ^= c->keystream_buffer.v8[i]; + *buf++ = *src++ ^ c->keystream_buffer.v8[i]; } c->bytes_in_buffer -= bytes_to_encr; @@ -326,7 +334,7 @@ static srtp_err_status_t srtp_aes_icm_encrypt(void *cv, /* encrypt bytes until the remaining data is 16-byte aligned */ for (size_t i = (sizeof(v128_t) - c->bytes_in_buffer); i < sizeof(v128_t); i++) { - *buf++ ^= c->keystream_buffer.v8[i]; + *buf++ = *src++ ^ c->keystream_buffer.v8[i]; } bytes_to_encr -= c->bytes_in_buffer; @@ -345,36 +353,40 @@ static srtp_err_status_t srtp_aes_icm_encrypt(void *cv, #if ALIGN_32 b = (uint32_t *)buf; - *b++ ^= c->keystream_buffer.v32[0]; - *b++ ^= c->keystream_buffer.v32[1]; - *b++ ^= c->keystream_buffer.v32[2]; - *b++ ^= c->keystream_buffer.v32[3]; + s = (const uint32_t *)src; + *b++ = *s++ ^ c->keystream_buffer.v32[0]; + *b++ = *s++ ^ c->keystream_buffer.v32[1]; + *b++ = *s++ ^ c->keystream_buffer.v32[2]; + *b++ = *s++ ^ c->keystream_buffer.v32[3]; buf = (uint8_t *)b; + src = (const uint8_t *)s; #else if ((((uintptr_t)buf) & 0x03) != 0) { - *buf++ ^= c->keystream_buffer.v8[0]; - *buf++ ^= c->keystream_buffer.v8[1]; - *buf++ ^= c->keystream_buffer.v8[2]; - *buf++ ^= c->keystream_buffer.v8[3]; - *buf++ ^= c->keystream_buffer.v8[4]; - *buf++ ^= c->keystream_buffer.v8[5]; - *buf++ ^= c->keystream_buffer.v8[6]; - *buf++ ^= c->keystream_buffer.v8[7]; - *buf++ ^= c->keystream_buffer.v8[8]; - *buf++ ^= c->keystream_buffer.v8[9]; - *buf++ ^= c->keystream_buffer.v8[10]; - *buf++ ^= c->keystream_buffer.v8[11]; - *buf++ ^= c->keystream_buffer.v8[12]; - *buf++ ^= c->keystream_buffer.v8[13]; - *buf++ ^= c->keystream_buffer.v8[14]; - *buf++ ^= c->keystream_buffer.v8[15]; + *buf++ = *src++ ^ c->keystream_buffer.v8[0]; + *buf++ = *src++ ^ c->keystream_buffer.v8[1]; + *buf++ = *src++ ^ c->keystream_buffer.v8[2]; + *buf++ = *src++ ^ c->keystream_buffer.v8[3]; + *buf++ = *src++ ^ c->keystream_buffer.v8[4]; + *buf++ = *src++ ^ c->keystream_buffer.v8[5]; + *buf++ = *src++ ^ c->keystream_buffer.v8[6]; + *buf++ = *src++ ^ c->keystream_buffer.v8[7]; + *buf++ = *src++ ^ c->keystream_buffer.v8[8]; + *buf++ = *src++ ^ c->keystream_buffer.v8[9]; + *buf++ = *src++ ^ c->keystream_buffer.v8[10]; + *buf++ = *src++ ^ c->keystream_buffer.v8[11]; + *buf++ = *src++ ^ c->keystream_buffer.v8[12]; + *buf++ = *src++ ^ c->keystream_buffer.v8[13]; + *buf++ = *src++ ^ c->keystream_buffer.v8[14]; + *buf++ = *src++ ^ c->keystream_buffer.v8[15]; } else { b = (uint32_t *)buf; - *b++ ^= c->keystream_buffer.v32[0]; - *b++ ^= c->keystream_buffer.v32[1]; - *b++ ^= c->keystream_buffer.v32[2]; - *b++ ^= c->keystream_buffer.v32[3]; + s = (const uint32_t *)src; + *b++ = *s++ ^ c->keystream_buffer.v32[0]; + *b++ = *s++ ^ c->keystream_buffer.v32[1]; + *b++ = *s++ ^ c->keystream_buffer.v32[2]; + *b++ = *s++ ^ c->keystream_buffer.v32[3]; buf = (uint8_t *)b; + src = (const uint8_t *)s; } #endif /* #if ALIGN_32 */ } @@ -385,7 +397,7 @@ static srtp_err_status_t srtp_aes_icm_encrypt(void *cv, srtp_aes_icm_advance(c); for (size_t i = 0; i < (bytes_to_encr & 0xf); i++) { - *buf++ ^= c->keystream_buffer.v8[i]; + *buf++ = *src++ ^ c->keystream_buffer.v8[i]; } /* reset the keystream buffer size to right value */ diff --git a/crypto/cipher/aes_icm_mbedtls.c b/crypto/cipher/aes_icm_mbedtls.c index 7cbde4d03..e2aa98968 100644 --- a/crypto/cipher/aes_icm_mbedtls.c +++ b/crypto/cipher/aes_icm_mbedtls.c @@ -290,22 +290,30 @@ static srtp_err_status_t srtp_aes_icm_mbedtls_set_iv( * enc_len length of encrypt buffer */ static srtp_err_status_t srtp_aes_icm_mbedtls_encrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { srtp_aes_icm_ctx_t *c = (srtp_aes_icm_ctx_t *)cv; int errCode = 0; debug_print(srtp_mod_aes_icm, "rs0: %s", v128_hex_string(&c->counter)); + if (*dst_len < src_len) { + return srtp_err_status_buffer_small; + } + errCode = - mbedtls_aes_crypt_ctr(c->ctx, *enc_len, &(c->nc_off), c->counter.v8, - c->stream_block.v8, buf, buf); + mbedtls_aes_crypt_ctr(c->ctx, src_len, &(c->nc_off), c->counter.v8, + c->stream_block.v8, src, dst); if (errCode != 0) { debug_print(srtp_mod_aes_icm, "encrypt error: %d", errCode); return srtp_err_status_cipher_fail; } + *dst_len = src_len; + return srtp_err_status_ok; } diff --git a/crypto/cipher/aes_icm_nss.c b/crypto/cipher/aes_icm_nss.c index 0f9c883fd..25611bea9 100644 --- a/crypto/cipher/aes_icm_nss.c +++ b/crypto/cipher/aes_icm_nss.c @@ -323,8 +323,10 @@ static srtp_err_status_t srtp_aes_icm_nss_set_iv(void *cv, * enc_len length of encrypt buffer */ static srtp_err_status_t srtp_aes_icm_nss_encrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { srtp_aes_icm_ctx_t *c = (srtp_aes_icm_ctx_t *)cv; @@ -333,8 +335,8 @@ static srtp_err_status_t srtp_aes_icm_nss_encrypt(void *cv, } int out_len = 0; - int rv = PK11_CipherOp(c->ctx, buf, &out_len, *enc_len, buf, *enc_len); - *enc_len = out_len; + int rv = PK11_CipherOp(c->ctx, dst, &out_len, *dst_len, src, src_len); + *dst_len = out_len; srtp_err_status_t status = (srtp_err_status_ok); if (rv != SECSuccess) { status = (srtp_err_status_cipher_fail); diff --git a/crypto/cipher/aes_icm_ossl.c b/crypto/cipher/aes_icm_ossl.c index 4319de5c6..287642d7a 100644 --- a/crypto/cipher/aes_icm_ossl.c +++ b/crypto/cipher/aes_icm_ossl.c @@ -298,23 +298,25 @@ static srtp_err_status_t srtp_aes_icm_openssl_set_iv( * enc_len length of encrypt buffer */ static srtp_err_status_t srtp_aes_icm_openssl_encrypt(void *cv, - uint8_t *buf, - size_t *enc_len) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { srtp_aes_icm_ctx_t *c = (srtp_aes_icm_ctx_t *)cv; int len = 0; debug_print(srtp_mod_aes_icm, "rs0: %s", v128_hex_string(&c->counter)); - if (!EVP_EncryptUpdate(c->ctx, buf, &len, buf, *enc_len)) { + if (!EVP_EncryptUpdate(c->ctx, dst, &len, src, src_len)) { return srtp_err_status_cipher_fail; } - *enc_len = len; + *dst_len = len; - if (!EVP_EncryptFinal_ex(c->ctx, buf + len, &len)) { + if (!EVP_EncryptFinal_ex(c->ctx, dst + len, &len)) { return srtp_err_status_cipher_fail; } - *enc_len += len; + *dst_len += len; return srtp_err_status_ok; } diff --git a/crypto/cipher/cipher.c b/crypto/cipher/cipher.c index 35a2321b6..0abe3f64e 100644 --- a/crypto/cipher/cipher.c +++ b/crypto/cipher/cipher.c @@ -105,29 +105,34 @@ srtp_err_status_t srtp_cipher_output(srtp_cipher_t *c, octet_string_set_to_zero(buffer, *num_octets_to_output); /* exor keystream into buffer */ - return (((c)->type)->encrypt(((c)->state), buffer, num_octets_to_output)); + return (((c)->type)->encrypt(((c)->state), buffer, *num_octets_to_output, + buffer, num_octets_to_output)); } srtp_err_status_t srtp_cipher_encrypt(srtp_cipher_t *c, - uint8_t *buffer, - size_t *num_octets_to_output) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { if (!c || !c->type || !c->state) { return (srtp_err_status_bad_param); } - return (((c)->type)->encrypt(((c)->state), buffer, num_octets_to_output)); + return (((c)->type)->encrypt(((c)->state), src, src_len, dst, dst_len)); } srtp_err_status_t srtp_cipher_decrypt(srtp_cipher_t *c, - uint8_t *buffer, - size_t *num_octets_to_output) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { if (!c || !c->type || !c->state) { return (srtp_err_status_bad_param); } - return (((c)->type)->decrypt(((c)->state), buffer, num_octets_to_output)); + return (((c)->type)->decrypt(((c)->state), src, src_len, dst, dst_len)); } srtp_err_status_t srtp_cipher_get_tag(srtp_cipher_t *c, @@ -290,8 +295,9 @@ srtp_err_status_t srtp_cipher_type_test( } /* encrypt */ - len = test_case->plaintext_length_octets; - status = srtp_cipher_encrypt(c, buffer, &len); + len = sizeof(buffer); + status = srtp_cipher_encrypt( + c, buffer, test_case->plaintext_length_octets, buffer, &len); if (status) { srtp_cipher_dealloc(c); return status; @@ -390,8 +396,9 @@ srtp_err_status_t srtp_cipher_type_test( } /* decrypt */ - len = test_case->ciphertext_length_octets; - status = srtp_cipher_decrypt(c, buffer, &len); + len = sizeof(buffer); + status = srtp_cipher_decrypt( + c, buffer, test_case->ciphertext_length_octets, buffer, &len); if (status) { srtp_cipher_dealloc(c); return status; @@ -452,21 +459,24 @@ srtp_err_status_t srtp_cipher_type_test( } for (size_t j = 0; j < NUM_RAND_TESTS; j++) { - size_t length; size_t plaintext_len; + size_t encrypted_len; + size_t decrypted_len; uint8_t key[MAX_KEY_LEN]; uint8_t iv[MAX_KEY_LEN]; /* choose a length at random (leaving room for IV and padding) */ - length = srtp_cipher_rand_u32_for_tests() % (SELF_TEST_BUF_OCTETS - 64); - debug_print(srtp_mod_cipher, "random plaintext length %zu\n", length); - srtp_cipher_rand_for_tests(buffer, length); + plaintext_len = + srtp_cipher_rand_u32_for_tests() % (SELF_TEST_BUF_OCTETS - 64); + debug_print(srtp_mod_cipher, "random plaintext length %zu\n", + plaintext_len); + srtp_cipher_rand_for_tests(buffer, plaintext_len); debug_print(srtp_mod_cipher, "plaintext: %s", - srtp_octet_string_hex_string(buffer, length)); + srtp_octet_string_hex_string(buffer, plaintext_len)); /* copy plaintext into second buffer */ - for (size_t i = 0; i < length; i++) { + for (size_t i = 0; i < plaintext_len; i++) { buffer2[i] = buffer[i]; } @@ -511,8 +521,9 @@ srtp_err_status_t srtp_cipher_type_test( } /* encrypt buffer with cipher */ - plaintext_len = length; - status = srtp_cipher_encrypt(c, buffer, &length); + encrypted_len = sizeof(buffer); + status = srtp_cipher_encrypt(c, buffer, plaintext_len, buffer, + &encrypted_len); if (status) { srtp_cipher_dealloc(c); return status; @@ -522,15 +533,15 @@ srtp_err_status_t srtp_cipher_type_test( /* * Get the GCM tag */ - status = srtp_cipher_get_tag(c, buffer + length, &tag_len); + status = srtp_cipher_get_tag(c, buffer + encrypted_len, &tag_len); if (status) { srtp_cipher_dealloc(c); return status; } - length += tag_len; + encrypted_len += tag_len; } debug_print(srtp_mod_cipher, "ciphertext: %s", - srtp_octet_string_hex_string(buffer, length)); + srtp_octet_string_hex_string(buffer, encrypted_len)); /* * re-initialize cipher for decryption, re-set the iv, then @@ -562,17 +573,19 @@ srtp_err_status_t srtp_cipher_type_test( srtp_octet_string_hex_string( test_case->aad, test_case->aad_length_octets)); } - status = srtp_cipher_decrypt(c, buffer, &length); + decrypted_len = sizeof(buffer); + status = srtp_cipher_decrypt(c, buffer, encrypted_len, buffer, + &decrypted_len); if (status) { srtp_cipher_dealloc(c); return status; } debug_print(srtp_mod_cipher, "plaintext[2]: %s", - srtp_octet_string_hex_string(buffer, length)); + srtp_octet_string_hex_string(buffer, decrypted_len)); /* compare the resulting plaintext with the original one */ - if (length != plaintext_len) { + if (decrypted_len != plaintext_len) { srtp_cipher_dealloc(c); return srtp_err_status_algo_fail; } @@ -654,7 +667,8 @@ uint64_t srtp_cipher_bits_per_second(srtp_cipher_t *c, } // Encrypt the buffer - if (srtp_cipher_encrypt(c, enc_buf, &len) != srtp_err_status_ok) { + if (srtp_cipher_encrypt(c, enc_buf, len, enc_buf, &len) != + srtp_err_status_ok) { srtp_crypto_free(enc_buf); return 0; } diff --git a/crypto/cipher/null_cipher.c b/crypto/cipher/null_cipher.c index a507f4758..88dca5614 100644 --- a/crypto/cipher/null_cipher.c +++ b/crypto/cipher/null_cipher.c @@ -116,13 +116,19 @@ static srtp_err_status_t srtp_null_cipher_set_iv(void *cv, } static srtp_err_status_t srtp_null_cipher_encrypt(void *cv, - uint8_t *buf, - size_t *bytes_to_encr) + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len) { - /* srtp_null_cipher_ctx_t *c = (srtp_null_cipher_ctx_t *)cv; */ (void)cv; - (void)buf; - (void)bytes_to_encr; + if (src != dst) { + if (*dst_len < src_len) { + return srtp_err_status_buffer_small; + } + memcpy(dst, src, src_len); + } + *dst_len = src_len; return srtp_err_status_ok; } diff --git a/crypto/include/cipher.h b/crypto/include/cipher.h index 4c5b7b716..995cfcb0a 100644 --- a/crypto/include/cipher.h +++ b/crypto/include/cipher.h @@ -97,16 +97,18 @@ typedef srtp_err_status_t (*srtp_cipher_set_aad_func_t)(void *state, size_t aad_len); /* a srtp_cipher_encrypt_func_t encrypts data in-place */ -typedef srtp_err_status_t (*srtp_cipher_encrypt_func_t)( - void *state, - uint8_t *buffer, - size_t *octets_to_encrypt); +typedef srtp_err_status_t (*srtp_cipher_encrypt_func_t)(void *state, + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len); /* a srtp_cipher_decrypt_func_t decrypts data in-place */ -typedef srtp_err_status_t (*srtp_cipher_decrypt_func_t)( - void *state, - uint8_t *buffer, - size_t *octets_to_decrypt); +typedef srtp_err_status_t (*srtp_cipher_decrypt_func_t)(void *state, + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len); /* * a srtp_cipher_set_iv_func_t function sets the current initialization vector @@ -219,11 +221,15 @@ srtp_err_status_t srtp_cipher_output(srtp_cipher_t *c, uint8_t *buffer, size_t *num_octets_to_output); srtp_err_status_t srtp_cipher_encrypt(srtp_cipher_t *c, - uint8_t *buffer, - size_t *num_octets_to_output); + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len); srtp_err_status_t srtp_cipher_decrypt(srtp_cipher_t *c, - uint8_t *buffer, - size_t *num_octets_to_output); + const uint8_t *src, + size_t src_len, + uint8_t *dst, + size_t *dst_len); srtp_err_status_t srtp_cipher_get_tag(srtp_cipher_t *c, uint8_t *buffer, size_t *tag_len); diff --git a/crypto/test/cipher_driver.c b/crypto/test/cipher_driver.c index 89806b024..ceedd08ec 100644 --- a/crypto/test/cipher_driver.c +++ b/crypto/test/cipher_driver.c @@ -381,7 +381,7 @@ srtp_err_status_t cipher_driver_test_buffering(srtp_cipher_t *c) } /* generate 'reference' value by encrypting all at once */ - status = srtp_cipher_encrypt(c, buffer0, &buflen); + status = srtp_cipher_encrypt(c, buffer0, buflen, buffer0, &buflen); if (status) { return status; } @@ -404,7 +404,7 @@ srtp_err_status_t cipher_driver_test_buffering(srtp_cipher_t *c) len = end - current; } - status = srtp_cipher_encrypt(c, current, &len); + status = srtp_cipher_encrypt(c, current, len, current, &len); if (status) { return status; } @@ -566,7 +566,7 @@ uint64_t cipher_array_bits_per_second(srtp_cipher_t *cipher_array[], srtp_cipher_set_iv(cipher_array[cipher_index], (uint8_t *)&nonce, srtp_direction_encrypt); srtp_cipher_encrypt(cipher_array[cipher_index], enc_buf, - &octets_to_encrypt); + octets_to_encrypt, enc_buf, &octets_to_encrypt); /* choose a cipher at random from the array*/ cipher_index = (*((size_t *)enc_buf)) % num_cipher; diff --git a/include/srtp.h b/include/srtp.h index 8d512358e..ffc95c7ff 100644 --- a/include/srtp.h +++ b/include/srtp.h @@ -213,8 +213,9 @@ typedef enum { /**< invalid */ srtp_err_status_pkt_idx_old = 26, /**< packet index is too old to */ /**< consider */ - srtp_err_status_pkt_idx_adv = 27 /**< packet index advanced, reset */ + srtp_err_status_pkt_idx_adv = 27, /**< packet index advanced, reset */ /**< needed */ + srtp_err_status_buffer_small = 28, /**< out buffer is too small */ } srtp_err_status_t; typedef struct srtp_ctx_t_ srtp_ctx_t; diff --git a/srtp/srtp.c b/srtp/srtp.c index e4d004b88..3c9ac2ace 100644 --- a/srtp/srtp.c +++ b/srtp/srtp.c @@ -826,7 +826,7 @@ static srtp_err_status_t srtp_kdf_generate(srtp_kdf_t *kdf, /* generate keystream output */ octet_string_set_to_zero(key, length); - status = srtp_cipher_encrypt(kdf->cipher, key, &length); + status = srtp_cipher_encrypt(kdf->cipher, key, length, key, &length); if (status) { return status; } @@ -1853,7 +1853,7 @@ static srtp_err_status_t srtp_protect_aead(srtp_ctx_t *ctx, /* Encrypt the payload */ status = srtp_cipher_encrypt(session_keys->rtp_cipher, enc_start, - &enc_octet_len); + enc_octet_len, enc_start, &enc_octet_len); if (status) { return srtp_err_status_cipher_fail; } @@ -1988,7 +1988,7 @@ static srtp_err_status_t srtp_unprotect_aead(srtp_ctx_t *ctx, /* Decrypt the ciphertext. This also checks the auth tag based * on the AAD we just specified above */ status = srtp_cipher_decrypt(session_keys->rtp_cipher, enc_start, - &enc_octet_len); + enc_octet_len, enc_start, &enc_octet_len); if (status) { return status; } @@ -2342,7 +2342,7 @@ srtp_err_status_t srtp_protect_mki(srtp_ctx_t *ctx, /* if we're encrypting, exor keystream into the message */ if (enc_start) { status = srtp_cipher_encrypt(session_keys->rtp_cipher, enc_start, - &enc_octet_len); + enc_octet_len, enc_start, &enc_octet_len); if (status) { return srtp_err_status_cipher_fail; } @@ -2665,7 +2665,7 @@ srtp_err_status_t srtp_unprotect_mki(srtp_ctx_t *ctx, /* if we're decrypting, add keystream into ciphertext */ if (enc_start) { status = srtp_cipher_decrypt(session_keys->rtp_cipher, enc_start, - &enc_octet_len); + enc_octet_len, enc_start, &enc_octet_len); if (status) { return srtp_err_status_cipher_fail; } @@ -3604,7 +3604,7 @@ static srtp_err_status_t srtp_protect_rtcp_aead( /* if we're encrypting, exor keystream into the message */ if (enc_start) { status = srtp_cipher_encrypt(session_keys->rtcp_cipher, enc_start, - &enc_octet_len); + enc_octet_len, enc_start, &enc_octet_len); if (status) { return srtp_err_status_cipher_fail; } @@ -3623,7 +3623,8 @@ static srtp_err_status_t srtp_protect_rtcp_aead( * to run the cipher to get the auth tag. */ size_t nolen = 0; - status = srtp_cipher_encrypt(session_keys->rtcp_cipher, NULL, &nolen); + status = srtp_cipher_encrypt(session_keys->rtcp_cipher, NULL, nolen, + NULL, &nolen); if (status) { return srtp_err_status_cipher_fail; } @@ -3772,7 +3773,7 @@ static srtp_err_status_t srtp_unprotect_rtcp_aead( /* if we're decrypting, exor keystream into the message */ if (enc_start) { status = srtp_cipher_decrypt(session_keys->rtcp_cipher, enc_start, - &enc_octet_len); + enc_octet_len, enc_start, &enc_octet_len); if (status) { return status; } @@ -3781,8 +3782,8 @@ static srtp_err_status_t srtp_unprotect_rtcp_aead( * Still need to run the cipher to check the tag */ tmp_len = tag_len; - status = - srtp_cipher_decrypt(session_keys->rtcp_cipher, auth_tag, &tmp_len); + status = srtp_cipher_decrypt(session_keys->rtcp_cipher, auth_tag, + tmp_len, auth_tag, &tmp_len); if (status) { return status; } @@ -4050,7 +4051,7 @@ srtp_err_status_t srtp_protect_rtcp_mki(srtp_t ctx, /* if we're encrypting, exor keystream into the message */ if (enc_start) { status = srtp_cipher_encrypt(session_keys->rtcp_cipher, enc_start, - &enc_octet_len); + enc_octet_len, enc_start, &enc_octet_len); if (status) { return srtp_err_status_cipher_fail; } @@ -4304,7 +4305,7 @@ srtp_err_status_t srtp_unprotect_rtcp_mki(srtp_t ctx, /* if we're decrypting, exor keystream into the message */ if (enc_start) { status = srtp_cipher_decrypt(session_keys->rtcp_cipher, enc_start, - &enc_octet_len); + enc_octet_len, enc_start, &enc_octet_len); if (status) { return srtp_err_status_cipher_fail; }