diff --git a/src/MQTTAsync.c b/src/MQTTAsync.c index 11d8723f..8b3ff506 100644 --- a/src/MQTTAsync.c +++ b/src/MQTTAsync.c @@ -600,7 +600,7 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) } if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */ { - if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 5) + if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 6) { rc = MQTTASYNC_BAD_STRUCTURE; goto exit; @@ -750,7 +750,7 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) if (m->c->sslopts->privateKey) free((void*)m->c->sslopts->privateKey); if (m->c->sslopts->privateKeyPassword) - free((void*)m->c->sslopts->privateKeyPassword); + free((void*)m->c->sslopts->privateKeyPassword); if (m->c->sslopts->enabledCipherSuites) free((void*)m->c->sslopts->enabledCipherSuites); if (m->c->sslopts->struct_version >= 2) @@ -758,6 +758,15 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) if (m->c->sslopts->CApath) free((void*)m->c->sslopts->CApath); } + if (m->c->sslopts->struct_version >= 6) + { + if (m->c->sslopts->pemRootCerts) + free((void*)m->c->sslopts->pemRootCerts); + if (m->c->sslopts->pemCertChain) + free((void*)m->c->sslopts->pemCertChain); + if (m->c->sslopts->pemPrivateKey) + free((void*)m->c->sslopts->pemPrivateKey); + } free((void*)m->c->sslopts); m->c->sslopts = NULL; } @@ -807,6 +816,15 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) m->c->sslopts->protos = (const unsigned char*)MQTTStrdup((const char*)options->ssl->protos); m->c->sslopts->protos_len = options->ssl->protos_len; } + if (m->c->sslopts->struct_version >= 6) + { + if (options->ssl->pemRootCerts) + m->c->sslopts->pemRootCerts = MQTTStrdup(options->ssl->pemRootCerts); + if (options->ssl->pemCertChain) + m->c->sslopts->pemCertChain = MQTTStrdup(options->ssl->pemCertChain); + if (options->ssl->pemPrivateKey) + m->c->sslopts->pemPrivateKey = MQTTStrdup(options->ssl->pemPrivateKey); + } } #else if (options->struct_version != 0 && options->ssl) diff --git a/src/MQTTAsync.h b/src/MQTTAsync.h index 90f9f7b7..8ee7ac3b 100644 --- a/src/MQTTAsync.h +++ b/src/MQTTAsync.h @@ -1125,7 +1125,7 @@ typedef struct * From the OpenSSL documentation: * If CApath is not NULL, it points to a directory containing CA certificates in PEM format. * Exists only if struct_version >= 2 - */ + */ const char* CApath; /** @@ -1174,9 +1174,24 @@ typedef struct * Exists only if struct_version >= 5 */ unsigned int protos_len; + + /** + * document + */ + const char* pemRootCerts; + + /** + * document + */ + const char* pemCertChain; + + /** + * document + */ + const char* pemPrivateKey; } MQTTAsync_SSLOptions; -#define MQTTAsync_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 5, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0 } +#define MQTTAsync_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 6, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, 0, NULL, 0, NULL, NULL, NULL } /** Utility structure where name/value pairs are needed */ typedef struct diff --git a/src/MQTTClient.c b/src/MQTTClient.c index 3548e957..bde370f7 100644 --- a/src/MQTTClient.c +++ b/src/MQTTClient.c @@ -1601,6 +1601,15 @@ static MQTTResponse MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectO if (m->c->sslopts->CApath) free((void*)m->c->sslopts->CApath); } + if (m->c->sslopts->struct_version >= 6) + { + if (m->c->sslopts->pemRootCerts) + free((void*)m->c->sslopts->pemRootCerts); + if (m->c->sslopts->pemCertChain) + free((void*)m->c->sslopts->pemCertChain); + if (m->c->sslopts->pemPrivateKey) + free((void*)m->c->sslopts->pemPrivateKey); + } free(m->c->sslopts); m->c->sslopts = NULL; } @@ -1649,6 +1658,15 @@ static MQTTResponse MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectO m->c->sslopts->protos = options->ssl->protos; m->c->sslopts->protos_len = options->ssl->protos_len; } + if (m->c->sslopts->struct_version >= 6) + { + if (options->ssl->pemRootCerts) + m->c->sslopts->pemRootCerts = MQTTStrdup(options->ssl->pemRootCerts); + if (options->ssl->pemCertChain) + m->c->sslopts->pemCertChain = MQTTStrdup(options->ssl->pemCertChain); + if (options->ssl->pemPrivateKey) + m->c->sslopts->pemPrivateKey = MQTTStrdup(options->ssl->pemPrivateKey); + } } #endif @@ -1801,7 +1819,7 @@ MQTTResponse MQTTClient_connectAll(MQTTClient handle, MQTTClient_connectOptions* #if defined(OPENSSL) if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */ { - if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 5) + if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 6) { rc.reasonCode = MQTTCLIENT_BAD_STRUCTURE; goto exit; diff --git a/src/MQTTClient.h b/src/MQTTClient.h index 8d6a32af..bc59d110 100644 --- a/src/MQTTClient.h +++ b/src/MQTTClient.h @@ -777,9 +777,24 @@ typedef struct * Exists only if struct_version >= 5 */ unsigned int protos_len; + + /** + * document + */ + const char* pemRootCerts; + + /** + * document + */ + const char* pemCertChain; + + /** + * document + */ + const char* pemPrivateKey; } MQTTClient_SSLOptions; -#define MQTTClient_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 5, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0 } +#define MQTTClient_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 6, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0, NULL, NULL, NULL } /** * MQTTClient_libraryInfo is used to store details relating to the currently used diff --git a/src/MQTTPacket.c b/src/MQTTPacket.c index 9d8f08dd..2ac8024f 100644 --- a/src/MQTTPacket.c +++ b/src/MQTTPacket.c @@ -361,10 +361,10 @@ int MQTTPacket_decode(networkHandles* net, size_t* value) */ int readInt(char** pptr) { - char* ptr = *pptr; - int len = 256*((unsigned char)(*ptr)) + (unsigned char)(*(ptr+1)); + char *ptr = *pptr; + int val = ((((uint32_t)ptr[0]) << 8) | ((uint32_t)ptr[1])); *pptr += 2; - return len; + return val; } @@ -452,10 +452,10 @@ void writeChar(char** pptr, char c) */ void writeInt(char** pptr, int anInt) { - **pptr = (char)(anInt / 256); - (*pptr)++; - **pptr = (char)(anInt % 256); - (*pptr)++; + char* ptr = *pptr; + ptr[0] = (uint8_t) ((anInt >> 8) & 0xFF); + ptr[1] = (uint8_t) (anInt & 0xFF); + *pptr += 2; } @@ -944,33 +944,28 @@ void MQTTPacket_free_packet(MQTTPacket* pack) */ void writeInt4(char** pptr, int anInt) { - **pptr = (char)(anInt / 16777216); - (*pptr)++; - anInt %= 16777216; - **pptr = (char)(anInt / 65536); - (*pptr)++; - anInt %= 65536; - **pptr = (char)(anInt / 256); - (*pptr)++; - **pptr = (char)(anInt % 256); - (*pptr)++; + unsigned char* ptr = (unsigned char*)*pptr; + ptr[0] = (uint8_t) ((anInt >> 24) & 0xFF); + ptr[1] = (uint8_t) ((anInt >> 16) & 0xFF); + ptr[2] = (uint8_t) ((anInt >> 8) & 0xFF); + ptr[3] = (uint8_t) (anInt & 0xFF); + *pptr += 4; } - /** * Calculates an integer from two bytes read from the input buffer - * @param pptr pointer to the input buffer - incremented by the number of bytes used & returned + * @param pptr pointer to the input buffer - incremented by the number of bytes + * used & returned * @return the integer value calculated */ int readInt4(char** pptr) { - unsigned char* ptr = (unsigned char*)*pptr; - int value = 16777216*(*ptr) + 65536*(*(ptr+1)) + 256*(*(ptr+2)) + (*(ptr+3)); + unsigned char *ptr = (unsigned char *)*pptr; + int val = ((((uint32_t)ptr[0]) << 24) | (((uint32_t)ptr[1]) << 16) | (((uint32_t)ptr[2]) << 8) | ((uint32_t)ptr[3])); *pptr += 4; - return value; + return val; } - void writeMQTTLenString(char** pptr, MQTTLenString lenstring) { writeInt(pptr, lenstring.len); diff --git a/src/MQTTProtocolClient.c b/src/MQTTProtocolClient.c index e4b8f16c..6bf07bd1 100644 --- a/src/MQTTProtocolClient.c +++ b/src/MQTTProtocolClient.c @@ -981,6 +981,15 @@ void MQTTProtocol_freeClient(Clients* client) if (client->sslopts->protos) free((void*)client->sslopts->protos); } + if (client->sslopts->struct_version >= 6) + { + if (client->sslopts->pemRootCerts) + free((void*)client->sslopts->pemRootCerts); + if (client->sslopts->pemCertChain) + free((void*)client->sslopts->pemCertChain); + if (client->sslopts->pemPrivateKey) + free((void*)client->sslopts->pemPrivateKey); + } free(client->sslopts); client->sslopts = NULL; } diff --git a/src/SSLSocket.c b/src/SSLSocket.c index fd80c727..d2dfae7b 100644 --- a/src/SSLSocket.c +++ b/src/SSLSocket.c @@ -122,6 +122,170 @@ static int SSLSocket_error(char* aString, SSL* ssl, SOCKET sock, int rc, int (*c return error; } +/* Loads an in-memory PEM certificate chain into the SSL context. */ +static int SSLSocket_use_certificate_chain(SSL_CTX* context, const char* pem_cert_chain, size_t pem_cert_chain_size) +{ + int rc = 0; + + X509* certificate; + BIO* pem; + + pem = BIO_new_mem_buf(pem_cert_chain, (int)pem_cert_chain_size); + if (!pem) + goto exit; + + certificate = PEM_read_bio_X509_AUX(pem, NULL, NULL, ""); + if (!certificate) + goto exit; + + + if (!SSL_CTX_use_certificate(context, certificate)) + goto exit; + + for (;;) + { + X509* certificate_authority = PEM_read_bio_X509(pem, NULL, NULL, ""); + if (!certificate_authority) + { + ERR_clear_error(); + rc = 1; + goto exit; + } + + if (!SSL_CTX_add_extra_chain_cert(context, certificate_authority)) + { + X509_free(certificate_authority); + goto exit; + } + } + +exit: + if (certificate) + X509_free(certificate); + BIO_free(pem); + + return rc; +} + +static int SSLSocket_use_pem_private_key(SSL_CTX* context, const char* pem_key, size_t pem_key_size) +{ + int rc = 0; + EVP_PKEY* private_key = NULL; + BIO* pem; + pem = BIO_new_mem_buf(pem_key, (int)pem_key_size); + + if (pem == NULL) + goto exit; + + private_key = PEM_read_bio_PrivateKey(pem, NULL, NULL, ""); + if (private_key == NULL) + goto exit; + if (!SSL_CTX_use_PrivateKey(context, private_key)) + goto exit; + + rc = 1; + +exit: + if (private_key != NULL) + EVP_PKEY_free(private_key); + + BIO_free(pem); + return rc; +} + + +/* Loads in-memory PEM verification certs into the SSL context and optionally + returns the verification cert names (root_names can be NULL). */ +static int SSLSocket_x509_store_load_certs(X509_STORE* cert_store, const char* pem_roots, size_t pem_roots_size, STACK_OF(X509_NAME) **root_names) { + int result = 1; + size_t num_roots = 0; + X509* root = NULL; + X509_NAME* root_name = NULL; + BIO* pem; + + pem = BIO_new_mem_buf(pem_roots, (int)pem_roots_size); + if (pem == NULL) + return 0; + + if (root_names != NULL) + { + *root_names = sk_X509_NAME_new_null(); + if (*root_names == NULL) + return 0; + } + + for (;;) + { + root = PEM_read_bio_X509_AUX(pem, NULL, NULL, ""); + if (root == NULL) + { + ERR_clear_error(); + break; /* We're at the end of stream. */ + } + if (root_names != NULL) + { + root_name = X509_get_subject_name(root); + if (root_name == NULL) + { + result = 0; + break; + } + + root_name = X509_NAME_dup(root_name); + if (root_name == NULL) + { + result = 0; + break; + } + sk_X509_NAME_push(*root_names, root_name); + root_name = NULL; + } + + ERR_clear_error(); + + if (!X509_STORE_add_cert(cert_store, root)) + { + unsigned long error = ERR_get_error(); + if (ERR_GET_LIB(error) != ERR_LIB_X509 || ERR_GET_REASON(error) != X509_R_CERT_ALREADY_IN_HASH_TABLE) + { + result = 0; + break; + } + } + X509_free(root); + num_roots++; + } + + if (num_roots == 0) + { + result = 0; + } + + if (result != 0) + { + if (root != NULL) + X509_free(root); + if (root_names != NULL) + { + sk_X509_NAME_pop_free(*root_names, X509_NAME_free); + *root_names = NULL; + if (root_name != NULL) + X509_NAME_free(root_name); + } + } + + BIO_free(pem); + return result; +} + +static int SSLSocket_load_verification_certs(SSL_CTX* context, const char* pem_roots, size_t pem_roots_size, STACK_OF(X509_NAME) * *root_name) +{ + X509_STORE* cert_store = SSL_CTX_get_cert_store(context); + X509_STORE_set_flags(cert_store, X509_V_FLAG_PARTIAL_CHAIN | X509_V_FLAG_TRUSTED_FIRST); + + return SSLSocket_x509_store_load_certs(cert_store, pem_roots, pem_roots_size, root_name); +} + static struct { int code; @@ -597,6 +761,45 @@ int SSLSocket_createContext(networkHandles* net, MQTTClient_SSLOptions* opts) SSL_CTX_set_security_level(net->ctx, 1); #endif + if (opts->pemRootCerts) + { + + STACK_OF(X509_NAME)* root_names = NULL; + if ((rc = SSLSocket_load_verification_certs(net->ctx, opts->pemRootCerts, strlen(opts->pemRootCerts), &root_names)) != 1) + { + if (opts->struct_version >= 3) + SSLSocket_error("SSLSocket_load_verification_certs", NULL, net->socket, rc, opts->ssl_error_cb, opts->ssl_error_context); + else + SSLSocket_error("SSLSocket_load_verification_certs", NULL, net->socket, rc, NULL, NULL); + goto exit; + } + SSL_CTX_set_client_CA_list(net->ctx, root_names); + } + + if (opts->pemCertChain) + { + if ((rc = SSLSocket_use_certificate_chain(net->ctx, opts->pemCertChain, strlen(opts->pemCertChain))) != 1) + { + if (opts->struct_version >= 3) + SSLSocket_error("SSLSocket_use_certificate_chain", NULL, net->socket, rc, opts->ssl_error_cb, opts->ssl_error_context); + else + SSLSocket_error("SSLSocket_use_certificate_chain", NULL, net->socket, rc, NULL, NULL); + goto exit; + } + } + + if (opts->pemPrivateKey) + { + if ((rc = SSLSocket_use_pem_private_key(net->ctx, opts->pemPrivateKey, strlen(opts->pemPrivateKey))) != 1 || !SSL_CTX_check_private_key(net->ctx)) + { + if (opts->struct_version >= 3) + SSLSocket_error("SSLSocket_use_pem_private_key", NULL, net->socket, rc, opts->ssl_error_cb, opts->ssl_error_context); + else + SSLSocket_error("SSLSocket_use_pem_private_key", NULL, net->socket, rc, NULL, NULL); + goto exit; + } + } + if (opts->keyStore) { if ((rc = SSL_CTX_use_certificate_chain_file(net->ctx, opts->keyStore)) != 1)