Skip to content

Commit

Permalink
[secure-transport] simplify Process() method (openthread#10933)
Browse files Browse the repository at this point in the history
This commit simplifies `SecureTransport::Process()`:
- The `mReceiveCallback` is invoked only after a successful `read()`
  Afterward, we `continue` back through the loop to check if it can
  read again.
- The `rval` checks are now combined into a single `switch()`
  statement to determine if the connection should be disconnected,
  reset, or if the process should wait, and then takes the proper
  action.
  • Loading branch information
abtink authored Nov 19, 2024
1 parent 013bb3e commit 8d39758
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 55 deletions.
114 changes: 59 additions & 55 deletions src/core/meshcop/secure_transport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1004,10 +1004,10 @@ void SecureTransport::HandleTimer(void)

void SecureTransport::Process(void)
{
uint8_t buf[OPENTHREAD_CONFIG_DTLS_MAX_CONTENT_LEN];
bool shouldDisconnect = false;
uint8_t buf[kMaxContentLen];
int rval;
ConnectEvent event;
ConnectEvent disconnectEvent;
bool shouldReset;

while (IsStateConnectingOrConnected())
{
Expand All @@ -1025,72 +1025,76 @@ void SecureTransport::Process(void)
else
{
rval = mbedtls_ssl_read(&mSsl, buf, sizeof(buf));
}

if (rval > 0)
{
mReceiveCallback.InvokeIfSet(buf, static_cast<uint16_t>(rval));
if (rval > 0)
{
mReceiveCallback.InvokeIfSet(buf, static_cast<uint16_t>(rval));
continue;
}
}
else if (rval == 0 || rval == MBEDTLS_ERR_SSL_WANT_READ || rval == MBEDTLS_ERR_SSL_WANT_WRITE)

// Check `rval` to determine if the connection should be
// disconnected, reset, or if we should wait.

disconnectEvent = kConnected;
shouldReset = true;

switch (rval)
{
case 0:
case MBEDTLS_ERR_SSL_WANT_READ:
case MBEDTLS_ERR_SSL_WANT_WRITE:
shouldReset = false;
break;
}
else
{
switch (rval)

case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
mbedtls_ssl_close_notify(&mSsl);
disconnectEvent = kDisconnectedPeerClosed;
break;

case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
break;

case MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE:
mbedtls_ssl_close_notify(&mSsl);
disconnectEvent = kDisconnectedError;
break;

case MBEDTLS_ERR_SSL_INVALID_MAC:
if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
{
case MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY:
mbedtls_ssl_close_notify(&mSsl);
event = kDisconnectedPeerClosed;
ExitNow(shouldDisconnect = true);
OT_UNREACHABLE_CODE(break);

case MBEDTLS_ERR_SSL_HELLO_VERIFY_REQUIRED:
break;

case MBEDTLS_ERR_SSL_FATAL_ALERT_MESSAGE:
mbedtls_ssl_close_notify(&mSsl);
event = kDisconnectedError;
ExitNow(shouldDisconnect = true);
OT_UNREACHABLE_CODE(break);

case MBEDTLS_ERR_SSL_INVALID_MAC:
if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
{
mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
MBEDTLS_SSL_ALERT_MSG_BAD_RECORD_MAC);
event = kDisconnectedError;
ExitNow(shouldDisconnect = true);
}

break;

default:
if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
{
mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE);
event = kDisconnectedError;
ExitNow(shouldDisconnect = true);
}

break;
mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
MBEDTLS_SSL_ALERT_MSG_BAD_RECORD_MAC);
disconnectEvent = kDisconnectedError;
}
break;

default:
if (mSsl.MBEDTLS_PRIVATE(state) != MBEDTLS_SSL_HANDSHAKE_OVER)
{
mbedtls_ssl_send_alert_message(&mSsl, MBEDTLS_SSL_ALERT_LEVEL_FATAL,
MBEDTLS_SSL_ALERT_MSG_HANDSHAKE_FAILURE);
disconnectEvent = kDisconnectedError;
}

break;
}

if (disconnectEvent != kConnected)
{
Disconnect(disconnectEvent);
}
else if (shouldReset)
{
mbedtls_ssl_session_reset(&mSsl);

if (mCipherSuite == kEcjpakeWithAes128Ccm8)
{
mbedtls_ssl_set_hs_ecjpake_password(&mSsl, mPsk, mPskLength);
}
break;
}
}

exit:

if (shouldDisconnect)
{
Disconnect(event);
break; // from `while()` loop
}
}

Expand Down
1 change: 1 addition & 0 deletions src/core/meshcop/secure_transport.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ class SecureTransport : public InstanceLocator
void HandleReceive(Message &aMessage, const Ip6::MessageInfo &aMessageInfo);

private:
static constexpr uint16_t kMaxContentLen = OPENTHREAD_CONFIG_DTLS_MAX_CONTENT_LEN;
static constexpr uint32_t kGuardTimeNewConnectionMilli = 2000;
static constexpr size_t kSecureTransportKeyBlockSize = 40;
static constexpr size_t kSecureTransportRandomBufferSize = 32;
Expand Down

0 comments on commit 8d39758

Please sign in to comment.