Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offload TLS negotiation to I/O threads #1338

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ typedef enum {
CONN_STATE_ERROR
} ConnectionState;

#define CONN_FLAG_CLOSE_SCHEDULED (1 << 0) /* Closed scheduled by a handler */
#define CONN_FLAG_WRITE_BARRIER (1 << 1) /* Write barrier requested */
#define CONN_FLAG_CLOSE_SCHEDULED (1 << 0) /* Closed scheduled by a handler */
#define CONN_FLAG_WRITE_BARRIER (1 << 1) /* Write barrier requested */
#define CONN_FLAG_ALLOW_ACCEPT_OFFLOAD (1 << 2) /* Connection accept can be offloaded to IO threads. */

#define CONN_TYPE_SOCKET "tcp"
#define CONN_TYPE_UNIX "unix"
Expand Down
52 changes: 52 additions & 0 deletions src/io_threads.c
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,55 @@ void trySendPollJobToIOThreads(void) {
aeSetPollProtect(server.el, 1);
IOJobQueue_push(jq, IOThreadPoll, server.el);
}

static void ioThreadAccept(void *data) {
client *c = (client *)data;
connAccept(c->conn, NULL);
c->io_read_state = CLIENT_COMPLETED_IO;
}

/*
* Attempts to offload an Accept operation (currently used for TLS accept) for a client
* connection to I/O threads.
*
* Returns:
* C_OK - If the accept operation was successfully queued for processing
* C_ERR - If the connection is not eligible for offloading
*
* Parameters:
* conn - The connection object to perform the accept operation on
*/
int trySendAcceptToIOThreads(connection *conn) {
if (server.io_threads_num <= 1) {
return C_ERR;
}

if (!(conn->flags & CONN_FLAG_ALLOW_ACCEPT_OFFLOAD)) {
return C_ERR;
}

client *c = connGetPrivateData(conn);
if (c->io_read_state != CLIENT_IDLE) {
return C_OK;
}

if (server.active_io_threads_num <= 1) {
return C_ERR;
}

size_t thread_id = (c->id % (server.active_io_threads_num - 1)) + 1;
IOJobQueue *job_queue = &io_jobs[thread_id];

if (IOJobQueue_isFull(job_queue)) {
return C_ERR;
}

c->io_read_state = CLIENT_PENDING_IO;
c->flag.pending_read = 1;
listLinkNodeTail(server.clients_pending_io_read, &c->pending_read_list_node);
connSetPostponeUpdateState(c->conn, 1);
server.stat_io_accept_offloaded++;
IOJobQueue_push(job_queue, ioThreadAccept, c);

return C_OK;
}
1 change: 1 addition & 0 deletions src/io_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ int tryOffloadFreeArgvToIOThreads(client *c);
void adjustIOThreadsByEventLoad(int numevents, int increase_only);
void drainIOThreadsQueue(void);
void trySendPollJobToIOThreads(void);
int trySendAcceptToIOThreads(connection *conn);

#endif /* IO_THREADS_H */
6 changes: 6 additions & 0 deletions src/networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ client *createClient(connection *conn) {
if (server.tcpkeepalive) connKeepAlive(conn, server.tcpkeepalive);
connSetReadHandler(conn, readQueryFromClient);
connSetPrivateData(conn, c);
conn->flags |= CONN_FLAG_ALLOW_ACCEPT_OFFLOAD;
}
c->buf = zmalloc_usable(PROTO_REPLY_CHUNK_BYTES, &c->buf_usable_size);
selectDb(c, 0);
Expand Down Expand Up @@ -4722,9 +4723,14 @@ int processIOThreadsReadDone(void) {
processed++;
server.stat_io_reads_processed++;

/* Save the current conn state, as connUpdateState may modify it */
int in_accept_state = (connGetState(c->conn) == CONN_STATE_ACCEPTING);
connSetPostponeUpdateState(c->conn, 0);
connUpdateState(c->conn);

/* In accept state, no client's data was read - stop here. */
if (in_accept_state) continue;

/* On read error - stop here. */
if (handleReadResult(c) == C_ERR) {
continue;
Expand Down
2 changes: 2 additions & 0 deletions src/server.c
Original file line number Diff line number Diff line change
Expand Up @@ -2604,6 +2604,7 @@ void resetServerStats(void) {
server.stat_total_reads_processed = 0;
server.stat_io_writes_processed = 0;
server.stat_io_freed_objects = 0;
server.stat_io_accept_offloaded = 0;
server.stat_poll_processed_by_io_threads = 0;
server.stat_total_writes_processed = 0;
server.stat_client_qbuf_limit_disconnections = 0;
Expand Down Expand Up @@ -5862,6 +5863,7 @@ sds genValkeyInfoString(dict *section_dict, int all_sections, int everything) {
"io_threaded_reads_processed:%lld\r\n", server.stat_io_reads_processed,
"io_threaded_writes_processed:%lld\r\n", server.stat_io_writes_processed,
"io_threaded_freed_objects:%lld\r\n", server.stat_io_freed_objects,
"io_threaded_accept:%lld\r\n", server.stat_io_accept_offloaded,
ranshid marked this conversation as resolved.
Show resolved Hide resolved
"io_threaded_poll_processed:%lld\r\n", server.stat_poll_processed_by_io_threads,
"io_threaded_total_prefetch_batches:%lld\r\n", server.stat_total_prefetch_batches,
"io_threaded_total_prefetch_entries:%lld\r\n", server.stat_total_prefetch_entries,
Expand Down
1 change: 1 addition & 0 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -1841,6 +1841,7 @@ struct valkeyServer {
long long stat_io_reads_processed; /* Number of read events processed by IO threads */
long long stat_io_writes_processed; /* Number of write events processed by IO threads */
long long stat_io_freed_objects; /* Number of objects freed by IO threads */
long long stat_io_accept_offloaded; /* Number of offloaded accepts */
long long stat_poll_processed_by_io_threads; /* Total number of poll jobs processed by IO */
long long stat_total_reads_processed; /* Total number of read events processed */
long long stat_total_writes_processed; /* Total number of write events processed */
Expand Down
138 changes: 70 additions & 68 deletions src/tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "server.h"
#include "connhelpers.h"
#include "adlist.h"
#include "io_threads.h"

#if (USE_OPENSSL == 1 /* BUILD_YES */) || ((USE_OPENSSL == 2 /* BUILD_MODULE */) && (BUILD_TLS_MODULE == 2))

Expand Down Expand Up @@ -437,16 +438,13 @@ static ConnectionType CT_TLS;
*
*/

typedef enum {
WANT_READ = 1,
WANT_WRITE
} WantIOType;

#define TLS_CONN_FLAG_READ_WANT_WRITE (1 << 0)
#define TLS_CONN_FLAG_WRITE_WANT_READ (1 << 1)
#define TLS_CONN_FLAG_FD_SET (1 << 2)
#define TLS_CONN_FLAG_POSTPONE_UPDATE_STATE (1 << 3)
#define TLS_CONN_FLAG_HAS_PENDING (1 << 4)
#define TLS_CONN_FLAG_ACCEPT_ERROR (1 << 5)
#define TLS_CONN_FLAG_ACCEPT_SUCCESS (1 << 6)

typedef struct tls_connection {
connection c;
Expand Down Expand Up @@ -514,20 +512,26 @@ static connection *connCreateAcceptedTLS(int fd, void *priv) {
return (connection *)conn;
}

static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler);
static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask);
static void updateSSLEvent(tls_connection *conn);

static void clearTLSWantFlags(tls_connection *conn) {
conn->flags &= ~(TLS_CONN_FLAG_WRITE_WANT_READ | TLS_CONN_FLAG_READ_WANT_WRITE);
}

/* Process the return code received from OpenSSL>
* Update the want parameter with expected I/O.
* Update the conn flags with the WANT_READ/WANT_WRITE flags.
* Update the connection's error state if a real error has occurred.
* Returns an SSL error code, or 0 if no further handling is required.
*/
static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *want) {
static int handleSSLReturnCode(tls_connection *conn, int ret_value) {
clearTLSWantFlags(conn);
if (ret_value <= 0) {
int ssl_err = SSL_get_error(conn->ssl, ret_value);
switch (ssl_err) {
case SSL_ERROR_WANT_WRITE: *want = WANT_WRITE; return 0;
case SSL_ERROR_WANT_READ: *want = WANT_READ; return 0;
case SSL_ERROR_WANT_WRITE: conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE; return 0;
case SSL_ERROR_WANT_READ: conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; return 0;
case SSL_ERROR_SYSCALL:
conn->c.last_errno = errno;
if (conn->ssl_error) zfree(conn->ssl_error);
Expand Down Expand Up @@ -563,11 +567,8 @@ static int updateStateAfterSSLIO(tls_connection *conn, int ret_value, int update
}

if (ret_value <= 0) {
WantIOType want = 0;
int ssl_err;
if (!(ssl_err = handleSSLReturnCode(conn, ret_value, &want))) {
if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ;
if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE;
if (!(ssl_err = handleSSLReturnCode(conn, ret_value))) {
if (update_event) updateSSLEvent(conn);
errno = EAGAIN;
return -1;
Expand All @@ -585,19 +586,17 @@ static int updateStateAfterSSLIO(tls_connection *conn, int ret_value, int update
return ret_value;
}

static void registerSSLEvent(tls_connection *conn, WantIOType want) {
static void registerSSLEvent(tls_connection *conn) {
int mask = aeGetFileEvents(server.el, conn->c.fd);

switch (want) {
case WANT_READ:
if (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ) {
if (mask & AE_WRITABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE);
if (!(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn);
break;
case WANT_WRITE:
} else if (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE) {
if (mask & AE_READABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE);
if (!(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn);
break;
default: serverAssert(0); break;
} else {
serverAssert(0);
}
}

Expand Down Expand Up @@ -650,12 +649,46 @@ static void updateSSLEvent(tls_connection *conn) {
if (!need_write && (mask & AE_WRITABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE);
}

static int TLSHandleAcceptResult(tls_connection *conn, int call_handler_on_error) {
uriyage marked this conversation as resolved.
Show resolved Hide resolved
if (conn->flags & TLS_CONN_FLAG_ACCEPT_SUCCESS) {
conn->c.state = CONN_STATE_CONNECTED;
} else if (conn->flags & TLS_CONN_FLAG_ACCEPT_ERROR) {
conn->c.state = CONN_STATE_ERROR;
if (!call_handler_on_error) return C_ERR;
} else {
/* Still pending accept */
registerSSLEvent(conn);
return C_OK;
}

/* call accept handler */
if (!callHandler((connection *)conn, conn->c.conn_handler)) return C_ERR;
conn->c.conn_handler = NULL;
return C_OK;
}

static void updateSSLState(connection *conn_) {
tls_connection *conn = (tls_connection *)conn_;

if (conn->c.state == CONN_STATE_ACCEPTING) {
if (TLSHandleAcceptResult(conn, 1) == C_ERR || conn->c.state != CONN_STATE_CONNECTED) return;
}

updateSSLEvent(conn);
updatePendingData(conn);
}

static void TLSAccept(void *_conn) {
tls_connection *conn = (tls_connection *)_conn;
ERR_clear_error();
int ret = SSL_accept(conn->ssl);
if (ret > 0) {
conn->flags |= TLS_CONN_FLAG_ACCEPT_SUCCESS;
} else if (handleSSLReturnCode(conn, ret)) {
conn->flags |= TLS_CONN_FLAG_ACCEPT_ERROR;
}
}

static void tlsHandleEvent(tls_connection *conn, int mask) {
int ret, conn_error;

Expand All @@ -676,10 +709,8 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
}
ret = SSL_connect(conn->ssl);
if (ret <= 0) {
WantIOType want = 0;
if (!handleSSLReturnCode(conn, ret, &want)) {
registerSSLEvent(conn, want);

if (!handleSSLReturnCode(conn, ret)) {
registerSSLEvent(conn);
/* Avoid hitting UpdateSSLEvent, which knows nothing
* of what SSL_connect() wants and instead looks at our
* R/W handlers.
Expand All @@ -698,27 +729,7 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
conn->c.conn_handler = NULL;
break;
case CONN_STATE_ACCEPTING:
ERR_clear_error();
ret = SSL_accept(conn->ssl);
if (ret <= 0) {
WantIOType want = 0;
if (!handleSSLReturnCode(conn, ret, &want)) {
/* Avoid hitting UpdateSSLEvent, which knows nothing
* of what SSL_connect() wants and instead looks at our
* R/W handlers.
*/
registerSSLEvent(conn, want);
return;
}

/* If not handled, it's an error */
conn->c.state = CONN_STATE_ERROR;
} else {
conn->c.state = CONN_STATE_CONNECTED;
}

if (!callHandler((connection *)conn, conn->c.conn_handler)) return;
conn->c.conn_handler = NULL;
if (connTLSAccept((connection *)conn, NULL) == C_ERR || conn->c.state != CONN_STATE_CONNECTED) return;
break;
case CONN_STATE_CONNECTED: {
int call_read = ((mask & AE_READABLE) && conn->c.read_handler) ||
Expand All @@ -740,20 +751,17 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER;

if (!invert && call_read) {
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
if (!callHandler((connection *)conn, conn->c.read_handler)) return;
}

/* Fire the writable event. */
if (call_write) {
conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ;
if (!callHandler((connection *)conn, conn->c.write_handler)) return;
}

/* If we have to invert the call, fire the readable event now
* after the writable one. */
if (invert && call_read) {
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
if (!callHandler((connection *)conn, conn->c.read_handler)) return;
}
updatePendingData(conn);
Expand Down Expand Up @@ -841,31 +849,25 @@ static void connTLSClose(connection *conn_) {

static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) {
tls_connection *conn = (tls_connection *)_conn;
int ret;

if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR;
ERR_clear_error();

int call_handler_on_error = 1;
/* Try to accept */
conn->c.conn_handler = accept_handler;
ret = SSL_accept(conn->ssl);

if (ret <= 0) {
WantIOType want = 0;
if (!handleSSLReturnCode(conn, ret, &want)) {
registerSSLEvent(conn, want); /* We'll fire back */
return C_OK;
} else {
conn->c.state = CONN_STATE_ERROR;
return C_ERR;
}
if (accept_handler) {
conn->c.conn_handler = accept_handler;
call_handler_on_error = 0;
}

conn->c.state = CONN_STATE_CONNECTED;
if (!callHandler((connection *)conn, conn->c.conn_handler)) return C_OK;
conn->c.conn_handler = NULL;
/* We're in IO thread - just call accept and return, the main thread will handle the rest */
if (!inMainThread()) {
TLSAccept(conn);
return C_OK;
}

return C_OK;
/* Try to offload accept to IO threads */
if (trySendAcceptToIOThreads(_conn) == C_OK) return C_OK;

TLSAccept(conn);
return TLSHandleAcceptResult(conn, call_handler_on_error);
}

static int connTLSConnect(connection *conn_,
Expand Down
Loading