Skip to content

Commit

Permalink
plumb S2AStub close to handshake end + add integration test.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmehta19 committed Oct 10, 2024
1 parent d9d4317 commit 1698b54
Show file tree
Hide file tree
Showing 10 changed files with 130 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.netty.channel.ChannelHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
import java.util.Optional;
import java.util.concurrent.Executor;

/**
Expand All @@ -40,9 +41,10 @@ private InternalProtocolNegotiators() {}
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
ObjectPool<? extends Executor> executorPool,
Optional<Runnable> handshakeCompleteRunnable) {
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
executorPool);
executorPool, handshakeCompleteRunnable);
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {

@Override
Expand Down Expand Up @@ -70,7 +72,7 @@ public void close() {
* may happen immediately, even before the TLS Handshake is complete.
*/
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
return tls(sslContext, null);
return tls(sslContext, null, Optional.empty());
}

/**
Expand Down Expand Up @@ -167,7 +169,8 @@ public static ChannelHandler grpcNegotiationHandler(GrpcHttp2ConnectionHandler n
public static ChannelHandler clientTlsHandler(
ChannelHandler next, SslContext sslContext, String authority,
ChannelLogger negotiationLogger) {
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger);
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger,
Optional.empty());

Check warning on line 173 in netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java#L172-L173

Added lines #L172 - L173 were not covered by tests
}

public static class ProtocolNegotiationHandler
Expand Down
3 changes: 2 additions & 1 deletion netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -604,7 +605,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType(
case PLAINTEXT_UPGRADE:
return ProtocolNegotiators.plaintextUpgrade();
case TLS:
return ProtocolNegotiators.tls(sslContext, executorPool);
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.empty());
default:
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
}
Expand Down
24 changes: 18 additions & 6 deletions netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import java.nio.channels.ClosedChannelException;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.logging.Level;
Expand Down Expand Up @@ -543,16 +544,18 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {

public ClientTlsProtocolNegotiator(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
this.sslContext = checkNotNull(sslContext, "sslContext");
this.executorPool = executorPool;
if (this.executorPool != null) {
this.executor = this.executorPool.getObject();
}
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
}

private final SslContext sslContext;
private final ObjectPool<? extends Executor> executorPool;
private final Optional<Runnable> handshakeCompleteRunnable;
private Executor executor;

@Override
Expand All @@ -565,7 +568,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
this.executor, negotiationLogger);
this.executor, negotiationLogger, handshakeCompleteRunnable);
return new WaitUntilActiveHandler(cth, negotiationLogger);
}

Expand All @@ -583,15 +586,18 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler {
private final String host;
private final int port;
private Executor executor;
private final Optional<Runnable> handshakeCompleteRunnable;

ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
Executor executor, ChannelLogger negotiationLogger) {
Executor executor, ChannelLogger negotiationLogger,
Optional<Runnable> handshakeCompleteRunnable) {
super(next, negotiationLogger);
this.sslContext = checkNotNull(sslContext, "sslContext");
HostPort hostPort = parseAuthority(authority);
this.host = hostPort.host;
this.port = hostPort.port;
this.executor = executor;
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
}

@Override
Expand Down Expand Up @@ -634,6 +640,9 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws
.withCause(t)
.asRuntimeException();
}
if (handshakeCompleteRunnable.isPresent()) {
handshakeCompleteRunnable.get().run();

Check warning on line 644 in netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java#L644

Added line #L644 was not covered by tests
}
ctx.fireExceptionCaught(t);
}
} else {
Expand All @@ -649,6 +658,9 @@ private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session)
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
.build();
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
if (handshakeCompleteRunnable.isPresent()) {
handshakeCompleteRunnable.get().run();

Check warning on line 662 in netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java

View check run for this annotation

Codecov / codecov/patch

netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java#L662

Added line #L662 was not covered by tests
}
fireProtocolNegotiationEvent(ctx);
}
}
Expand Down Expand Up @@ -683,8 +695,8 @@ static HostPort parseAuthority(String authority) {
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
*/
public static ProtocolNegotiator tls(SslContext sslContext,
ObjectPool<? extends Executor> executorPool) {
return new ClientTlsProtocolNegotiator(sslContext, executorPool);
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable);
}

/**
Expand All @@ -693,7 +705,7 @@ public static ProtocolNegotiator tls(SslContext sslContext,
* may happen immediately, even before the TLS Handshake is complete.
*/
public static ProtocolNegotiator tls(SslContext sslContext) {
return tls(sslContext, null);
return tls(sslContext, null, Optional.empty());
}

public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -766,7 +767,8 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception {
.trustManager(caCert)
.keyManager(clientCert, clientKey)
.build();
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool);
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
Optional.empty());
// after starting the client, the Executor in the client pool should be used
assertEquals(true, clientExecutorPool.isInUse());
final NettyClientTransport transport = newTransport(negotiator);
Expand Down
12 changes: 7 additions & 5 deletions netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -876,7 +877,7 @@ public String applicationProtocol() {
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);

ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger);
"authority", elg, noopLogger, Optional.empty());
pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler);
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
Expand Down Expand Up @@ -914,7 +915,7 @@ public String applicationProtocol() {
.applicationProtocolConfig(apn).build();

ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger);
"authority", elg, noopLogger, Optional.empty());
pipeline.addLast(handler);
pipeline.replace(SslHandler.class, null, goodSslHandler);
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
Expand All @@ -938,7 +939,7 @@ public String applicationProtocol() {
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);

ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", elg, noopLogger);
"authority", elg, noopLogger, Optional.empty());
pipeline.addLast(handler);

final AtomicReference<Throwable> error = new AtomicReference<>();
Expand Down Expand Up @@ -966,7 +967,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
@Test
public void clientTlsHandler_closeDuringNegotiation() throws Exception {
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
"authority", null, noopLogger);
"authority", null, noopLogger, Optional.empty());
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);

Expand Down Expand Up @@ -1228,7 +1229,8 @@ public void clientTlsHandler_firesNegotiation() throws Exception {
serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build();
}
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null);
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
null, Optional.empty());
WriteBufferingAndExceptionHandler clientWbaeh =
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));

Expand Down
14 changes: 13 additions & 1 deletion s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
import io.grpc.s2a.internal.handshaker.S2AIdentity;
import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory;
import io.grpc.s2a.internal.handshaker.S2AStub;
import javax.annotation.concurrent.NotThreadSafe;
import org.checkerframework.checker.nullness.qual.Nullable;

Expand Down Expand Up @@ -59,6 +60,7 @@ public static final class Builder {
private final String s2aAddress;
private final ChannelCredentials s2aChannelCredentials;
private @Nullable S2AIdentity localIdentity = null;
private @Nullable S2AStub stub = null;

Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) {
this.s2aAddress = s2aAddress;
Expand Down Expand Up @@ -104,6 +106,16 @@ public Builder setLocalUid(String localUid) {
return this;
}

/**
* Sets the stub to use to communicate with S2A. This is only used for testing that the
* stream to S2A gets closed.
*/
public Builder setStub(S2AStub stub) {
checkNotNull(stub);
this.stub = stub;
return this;
}

public ChannelCredentials build() {
return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory());
}
Expand All @@ -113,7 +125,7 @@ InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() {
SharedResourcePool.forResource(
S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials));
checkNotNull(s2aChannelPool, "s2aChannelPool");
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool);
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub);
}
}

Expand Down
Loading

0 comments on commit 1698b54

Please sign in to comment.