Skip to content

Commit

Permalink
change cipher api to support not-inplace io
Browse files Browse the repository at this point in the history
  • Loading branch information
pabuhler committed Jan 23, 2024
1 parent 1807a48 commit 08673a1
Show file tree
Hide file tree
Showing 13 changed files with 293 additions and 141 deletions.
32 changes: 23 additions & 9 deletions crypto/cipher/aes_gcm_mbedtls.c
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,10 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_set_aad(void *cv,
* enc_len length of encrypt buffer
*/
static srtp_err_status_t srtp_aes_gcm_mbedtls_encrypt(void *cv,
uint8_t *buf,
size_t *enc_len)
const uint8_t *src,
size_t src_len,
uint8_t *dst,
size_t *dst_len)
{
FUNC_ENTRY();
srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv;
Expand All @@ -292,16 +294,22 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_encrypt(void *cv,
return (srtp_err_status_bad_param);
}

errCode = mbedtls_gcm_crypt_and_tag(c->ctx, MBEDTLS_GCM_ENCRYPT, *enc_len,
if (*dst_len < src_len) {
return srtp_err_status_buffer_small;
}

errCode = mbedtls_gcm_crypt_and_tag(c->ctx, MBEDTLS_GCM_ENCRYPT, src_len,
c->iv, c->iv_len, c->aad, c->aad_size,
buf, buf, c->tag_len, c->tag);
src, dst, c->tag_len, c->tag);

c->aad_size = 0;
if (errCode != 0) {
debug_print(srtp_mod_aes_gcm, "mbedtls error code: %d", errCode);
return srtp_err_status_bad_param;
}

*dst_len = src_len;

return (srtp_err_status_ok);
}

Expand Down Expand Up @@ -337,8 +345,10 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_get_tag(void *cv,
* enc_len length of encrypt buffer
*/
static srtp_err_status_t srtp_aes_gcm_mbedtls_decrypt(void *cv,
uint8_t *buf,
size_t *enc_len)
const uint8_t *src,
size_t src_len,
uint8_t *dst,
size_t *dst_len)
{
FUNC_ENTRY();
srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv;
Expand All @@ -348,12 +358,16 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_decrypt(void *cv,
return (srtp_err_status_bad_param);
}

if (*dst_len < (src_len - c->tag_len)) {
return srtp_err_status_buffer_small;
}

debug_print(srtp_mod_aes_gcm, "AAD: %s",
srtp_octet_string_hex_string(c->aad, c->aad_size));

errCode = mbedtls_gcm_auth_decrypt(
c->ctx, (*enc_len - c->tag_len), c->iv, c->iv_len, c->aad, c->aad_size,
buf + (*enc_len - c->tag_len), c->tag_len, buf, buf);
c->ctx, (src_len - c->tag_len), c->iv, c->iv_len, c->aad, c->aad_size,
src + (src_len - c->tag_len), c->tag_len, src, dst);
c->aad_size = 0;
if (errCode != 0) {
return (srtp_err_status_auth_fail);
Expand All @@ -363,7 +377,7 @@ static srtp_err_status_t srtp_aes_gcm_mbedtls_decrypt(void *cv,
* Reduce the buffer size by the tag length since the tag
* is not part of the original payload
*/
*enc_len -= c->tag_len;
*dst_len = (src_len - c->tag_len);

return (srtp_err_status_ok);
}
Expand Down
62 changes: 43 additions & 19 deletions crypto/cipher/aes_gcm_nss.c
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,10 @@ static srtp_err_status_t srtp_aes_gcm_nss_set_aad(void *cv,

static srtp_err_status_t srtp_aes_gcm_nss_do_crypto(void *cv,
bool encrypt,
uint8_t *buf,
size_t *enc_len)
const uint8_t *src,
size_t src_len,
uint8_t *dst,
size_t *dst_len)
{
srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv;

Expand All @@ -299,13 +301,13 @@ static srtp_err_status_t srtp_aes_gcm_nss_do_crypto(void *cv,
SECItem param = { siBuffer, (unsigned char *)&c->params,
sizeof(CK_GCM_PARAMS) };
if (encrypt) {
rv = PK11_Encrypt(c->key, CKM_AES_GCM, &param, buf, &out_len,
*enc_len + 16, buf, *enc_len);
rv = PK11_Encrypt(c->key, CKM_AES_GCM, &param, dst, &out_len, *dst_len,
src, src_len);
} else {
rv = PK11_Decrypt(c->key, CKM_AES_GCM, &param, buf, &out_len, *enc_len,
buf, *enc_len);
rv = PK11_Decrypt(c->key, CKM_AES_GCM, &param, dst, &out_len, *dst_len,
src, src_len);
}
*enc_len = out_len;
*dst_len = out_len;
srtp_err_status_t status = (srtp_err_status_ok);
if (rv != SECSuccess) {
status = (srtp_err_status_cipher_fail);
Expand All @@ -328,32 +330,43 @@ static srtp_err_status_t srtp_aes_gcm_nss_do_crypto(void *cv,
* enc_len length of encrypt buffer
*/
static srtp_err_status_t srtp_aes_gcm_nss_encrypt(void *cv,
uint8_t *buf,
size_t *enc_len)
const uint8_t *src,
size_t src_len,
uint8_t *dst,
size_t *dst_len)
{
srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv;

//#todo, this might need som looking at
// nss requires space for tag, currently we assume that ther is space, this
// should change, the best would be to merge the cipher encrypt and get_tag
// api
*dst_len += 16;

// When we get a non-NULL buffer, we know that the caller is
// prepared to also take the tag. When we get a NULL buffer,
// even though there's no data, we need to give NSS a buffer
// where it can write the tag. We can't just use c->tag because
// memcpy has undefined behavior on overlapping ranges.
uint8_t tagbuf[16];
uint8_t *non_null_buf = buf;
if (!non_null_buf && (*enc_len == 0)) {
const uint8_t *non_null_buf = src;
uint8_t *non_null_dst_buf = dst;
if (!non_null_buf && (src_len == 0)) {
non_null_buf = tagbuf;
non_null_dst_buf = tagbuf;
*dst_len = sizeof(tagbuf);
} else if (!non_null_buf) {
return srtp_err_status_bad_param;
}

srtp_err_status_t status =
srtp_aes_gcm_nss_do_crypto(cv, true, non_null_buf, enc_len);
srtp_err_status_t status = srtp_aes_gcm_nss_do_crypto(
cv, true, non_null_buf, src_len, non_null_dst_buf, dst_len);
if (status != srtp_err_status_ok) {
return status;
}

memcpy(c->tag, non_null_buf + (*enc_len - c->tag_size), c->tag_size);
*enc_len -= c->tag_size;
memcpy(c->tag, non_null_dst_buf + (*dst_len - c->tag_size), c->tag_size);
*dst_len -= c->tag_size;
return srtp_err_status_ok;
}

Expand Down Expand Up @@ -387,11 +400,22 @@ static srtp_err_status_t srtp_aes_gcm_nss_get_tag(void *cv,
* enc_len length of encrypt buffer
*/
static srtp_err_status_t srtp_aes_gcm_nss_decrypt(void *cv,
uint8_t *buf,
size_t *enc_len)
const uint8_t *src,
size_t src_len,
uint8_t *dst,
size_t *dst_len)
{
srtp_err_status_t status =
srtp_aes_gcm_nss_do_crypto(cv, false, buf, enc_len);
uint8_t tagbuf[16];
uint8_t *non_null_dst_buf = dst;
if (!non_null_dst_buf && (*dst_len == 0)) {
non_null_dst_buf = tagbuf;
*dst_len = sizeof(tagbuf);
} else if (!non_null_dst_buf) {
return srtp_err_status_bad_param;
}

srtp_err_status_t status = srtp_aes_gcm_nss_do_crypto(
cv, false, src, src_len, non_null_dst_buf, dst_len);
if (status != srtp_err_status_ok) {
int err = PR_GetError();
if (err == SEC_ERROR_BAD_DATA) {
Expand Down
26 changes: 17 additions & 9 deletions crypto/cipher/aes_gcm_ossl.c
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,10 @@ static srtp_err_status_t srtp_aes_gcm_openssl_set_aad(void *cv,
* enc_len length of encrypt buffer
*/
static srtp_err_status_t srtp_aes_gcm_openssl_encrypt(void *cv,
uint8_t *buf,
size_t *enc_len)
const uint8_t *src,
size_t src_len,
uint8_t *dst,
size_t *dst_len)
{
srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv;
if (c->dir != srtp_direction_encrypt && c->dir != srtp_direction_decrypt) {
Expand All @@ -304,7 +306,8 @@ static srtp_err_status_t srtp_aes_gcm_openssl_encrypt(void *cv,
/*
* Encrypt the data
*/
EVP_Cipher(c->ctx, buf, buf, *enc_len);
EVP_Cipher(c->ctx, dst, src, src_len);
*dst_len = src_len;

return (srtp_err_status_ok);
}
Expand Down Expand Up @@ -354,8 +357,10 @@ static srtp_err_status_t srtp_aes_gcm_openssl_get_tag(void *cv,
* enc_len length of encrypt buffer
*/
static srtp_err_status_t srtp_aes_gcm_openssl_decrypt(void *cv,
uint8_t *buf,
size_t *enc_len)
const uint8_t *src,
size_t src_len,
uint8_t *dst,
size_t *dst_len)
{
srtp_aes_gcm_ctx_t *c = (srtp_aes_gcm_ctx_t *)cv;
if (c->dir != srtp_direction_encrypt && c->dir != srtp_direction_decrypt) {
Expand All @@ -364,12 +369,15 @@ static srtp_err_status_t srtp_aes_gcm_openssl_decrypt(void *cv,

/*
* Set the tag before decrypting
*
* explicitly cast away const of src
*/
if (!EVP_CIPHER_CTX_ctrl(c->ctx, EVP_CTRL_GCM_SET_TAG, c->tag_len,
buf + (*enc_len - c->tag_len))) {
if (!EVP_CIPHER_CTX_ctrl(
c->ctx, EVP_CTRL_GCM_SET_TAG, c->tag_len,
(void *)(uintptr_t)(src + (src_len - c->tag_len)))) {
return (srtp_err_status_auth_fail);
}
EVP_Cipher(c->ctx, buf, buf, *enc_len - c->tag_len);
EVP_Cipher(c->ctx, dst, src, src_len - c->tag_len);

/*
* Check the tag
Expand All @@ -382,7 +390,7 @@ static srtp_err_status_t srtp_aes_gcm_openssl_decrypt(void *cv,
* Reduce the buffer size by the tag length since the tag
* is not part of the original payload
*/
*enc_len -= c->tag_len;
*dst_len = src_len -= c->tag_len;

return (srtp_err_status_ok);
}
Expand Down
72 changes: 42 additions & 30 deletions crypto/cipher/aes_icm.c
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,20 @@ static void srtp_aes_icm_advance(srtp_aes_icm_ctx_t *c)
*/

static srtp_err_status_t srtp_aes_icm_encrypt(void *cv,
uint8_t *buf,
size_t *enc_len)
const uint8_t *src,
size_t src_len,
uint8_t *dst,
size_t *dst_len)
{
srtp_aes_icm_ctx_t *c = (srtp_aes_icm_ctx_t *)cv;
size_t bytes_to_encr = *enc_len;
size_t bytes_to_encr = src_len;
uint32_t *b;
const uint32_t *s;

// check out length if not equal or greater bail!
*dst_len = src_len;

unsigned char *buf = dst;

/* check that there's enough segment left*/
size_t bytes_of_new_keystream = bytes_to_encr - c->bytes_in_buffer;
Expand All @@ -314,7 +322,7 @@ static srtp_err_status_t srtp_aes_icm_encrypt(void *cv,
/* deal with odd case of small bytes_to_encr */
for (size_t i = (sizeof(v128_t) - c->bytes_in_buffer);
i < (sizeof(v128_t) - c->bytes_in_buffer + bytes_to_encr); i++) {
*buf++ ^= c->keystream_buffer.v8[i];
*buf++ = *src++ ^ c->keystream_buffer.v8[i];
}

c->bytes_in_buffer -= bytes_to_encr;
Expand All @@ -326,7 +334,7 @@ static srtp_err_status_t srtp_aes_icm_encrypt(void *cv,
/* encrypt bytes until the remaining data is 16-byte aligned */
for (size_t i = (sizeof(v128_t) - c->bytes_in_buffer);
i < sizeof(v128_t); i++) {
*buf++ ^= c->keystream_buffer.v8[i];
*buf++ = *src++ ^ c->keystream_buffer.v8[i];
}

bytes_to_encr -= c->bytes_in_buffer;
Expand All @@ -345,36 +353,40 @@ static srtp_err_status_t srtp_aes_icm_encrypt(void *cv,

#if ALIGN_32
b = (uint32_t *)buf;
*b++ ^= c->keystream_buffer.v32[0];
*b++ ^= c->keystream_buffer.v32[1];
*b++ ^= c->keystream_buffer.v32[2];
*b++ ^= c->keystream_buffer.v32[3];
s = (const uint32_t *)src;
*b++ = *s++ ^ c->keystream_buffer.v32[0];
*b++ = *s++ ^ c->keystream_buffer.v32[1];
*b++ = *s++ ^ c->keystream_buffer.v32[2];
*b++ = *s++ ^ c->keystream_buffer.v32[3];
buf = (uint8_t *)b;
src = (const uint8_t *)s;
#else
if ((((uintptr_t)buf) & 0x03) != 0) {
*buf++ ^= c->keystream_buffer.v8[0];
*buf++ ^= c->keystream_buffer.v8[1];
*buf++ ^= c->keystream_buffer.v8[2];
*buf++ ^= c->keystream_buffer.v8[3];
*buf++ ^= c->keystream_buffer.v8[4];
*buf++ ^= c->keystream_buffer.v8[5];
*buf++ ^= c->keystream_buffer.v8[6];
*buf++ ^= c->keystream_buffer.v8[7];
*buf++ ^= c->keystream_buffer.v8[8];
*buf++ ^= c->keystream_buffer.v8[9];
*buf++ ^= c->keystream_buffer.v8[10];
*buf++ ^= c->keystream_buffer.v8[11];
*buf++ ^= c->keystream_buffer.v8[12];
*buf++ ^= c->keystream_buffer.v8[13];
*buf++ ^= c->keystream_buffer.v8[14];
*buf++ ^= c->keystream_buffer.v8[15];
*buf++ = *src++ ^ c->keystream_buffer.v8[0];
*buf++ = *src++ ^ c->keystream_buffer.v8[1];
*buf++ = *src++ ^ c->keystream_buffer.v8[2];
*buf++ = *src++ ^ c->keystream_buffer.v8[3];
*buf++ = *src++ ^ c->keystream_buffer.v8[4];
*buf++ = *src++ ^ c->keystream_buffer.v8[5];
*buf++ = *src++ ^ c->keystream_buffer.v8[6];
*buf++ = *src++ ^ c->keystream_buffer.v8[7];
*buf++ = *src++ ^ c->keystream_buffer.v8[8];
*buf++ = *src++ ^ c->keystream_buffer.v8[9];
*buf++ = *src++ ^ c->keystream_buffer.v8[10];
*buf++ = *src++ ^ c->keystream_buffer.v8[11];
*buf++ = *src++ ^ c->keystream_buffer.v8[12];
*buf++ = *src++ ^ c->keystream_buffer.v8[13];
*buf++ = *src++ ^ c->keystream_buffer.v8[14];
*buf++ = *src++ ^ c->keystream_buffer.v8[15];
} else {
b = (uint32_t *)buf;
*b++ ^= c->keystream_buffer.v32[0];
*b++ ^= c->keystream_buffer.v32[1];
*b++ ^= c->keystream_buffer.v32[2];
*b++ ^= c->keystream_buffer.v32[3];
s = (const uint32_t *)src;
*b++ = *s++ ^ c->keystream_buffer.v32[0];
*b++ = *s++ ^ c->keystream_buffer.v32[1];
*b++ = *s++ ^ c->keystream_buffer.v32[2];
*b++ = *s++ ^ c->keystream_buffer.v32[3];
buf = (uint8_t *)b;
src = (const uint8_t *)s;
}
#endif /* #if ALIGN_32 */
}
Expand All @@ -385,7 +397,7 @@ static srtp_err_status_t srtp_aes_icm_encrypt(void *cv,
srtp_aes_icm_advance(c);

for (size_t i = 0; i < (bytes_to_encr & 0xf); i++) {
*buf++ ^= c->keystream_buffer.v8[i];
*buf++ = *src++ ^ c->keystream_buffer.v8[i];
}

/* reset the keystream buffer size to right value */
Expand Down
Loading

0 comments on commit 08673a1

Please sign in to comment.