Skip to content

Commit d628396

Browse files
authored
s2a: Add S2AStub cleanup handler. (#11600)
* Add S2AStub cleanup handler. * Give TLS and Cleanup handlers name + update comment. * Don't add TLS handler twice. * Don't remove explicitly, since done by fireProtocolNegotiationEvent. * plumb S2AStub close to handshake end + add integration test. * close stub when TLS negotiation fails.
1 parent 2129078 commit d628396

File tree

10 files changed

+134
-42
lines changed

10 files changed

+134
-42
lines changed

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import io.netty.channel.ChannelHandler;
2525
import io.netty.handler.ssl.SslContext;
2626
import io.netty.util.AsciiString;
27+
import java.util.Optional;
2728
import java.util.concurrent.Executor;
2829

2930
/**
@@ -40,9 +41,10 @@ private InternalProtocolNegotiators() {}
4041
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
4142
*/
4243
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
43-
ObjectPool<? extends Executor> executorPool) {
44+
ObjectPool<? extends Executor> executorPool,
45+
Optional<Runnable> handshakeCompleteRunnable) {
4446
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
45-
executorPool);
47+
executorPool, handshakeCompleteRunnable);
4648
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
4749

4850
@Override
@@ -70,7 +72,7 @@ public void close() {
7072
* may happen immediately, even before the TLS Handshake is complete.
7173
*/
7274
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
73-
return tls(sslContext, null);
75+
return tls(sslContext, null, Optional.empty());
7476
}
7577

7678
/**
@@ -167,7 +169,8 @@ public static ChannelHandler grpcNegotiationHandler(GrpcHttp2ConnectionHandler n
167169
public static ChannelHandler clientTlsHandler(
168170
ChannelHandler next, SslContext sslContext, String authority,
169171
ChannelLogger negotiationLogger) {
170-
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger);
172+
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger,
173+
Optional.empty());
171174
}
172175

173176
public static class ProtocolNegotiationHandler

netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import java.util.Collections;
6464
import java.util.HashMap;
6565
import java.util.Map;
66+
import java.util.Optional;
6667
import java.util.concurrent.Executor;
6768
import java.util.concurrent.ScheduledExecutorService;
6869
import java.util.concurrent.TimeUnit;
@@ -604,7 +605,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType(
604605
case PLAINTEXT_UPGRADE:
605606
return ProtocolNegotiators.plaintextUpgrade();
606607
case TLS:
607-
return ProtocolNegotiators.tls(sslContext, executorPool);
608+
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.empty());
608609
default:
609610
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
610611
}

netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
import java.nio.channels.ClosedChannelException;
7373
import java.util.Arrays;
7474
import java.util.EnumSet;
75+
import java.util.Optional;
7576
import java.util.Set;
7677
import java.util.concurrent.Executor;
7778
import java.util.logging.Level;
@@ -543,16 +544,18 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws
543544
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
544545

545546
public ClientTlsProtocolNegotiator(SslContext sslContext,
546-
ObjectPool<? extends Executor> executorPool) {
547+
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
547548
this.sslContext = checkNotNull(sslContext, "sslContext");
548549
this.executorPool = executorPool;
549550
if (this.executorPool != null) {
550551
this.executor = this.executorPool.getObject();
551552
}
553+
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
552554
}
553555

554556
private final SslContext sslContext;
555557
private final ObjectPool<? extends Executor> executorPool;
558+
private final Optional<Runnable> handshakeCompleteRunnable;
556559
private Executor executor;
557560

558561
@Override
@@ -565,7 +568,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
565568
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
566569
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
567570
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
568-
this.executor, negotiationLogger);
571+
this.executor, negotiationLogger, handshakeCompleteRunnable);
569572
return new WaitUntilActiveHandler(cth, negotiationLogger);
570573
}
571574

@@ -583,15 +586,18 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler {
583586
private final String host;
584587
private final int port;
585588
private Executor executor;
589+
private final Optional<Runnable> handshakeCompleteRunnable;
586590

587591
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
588-
Executor executor, ChannelLogger negotiationLogger) {
592+
Executor executor, ChannelLogger negotiationLogger,
593+
Optional<Runnable> handshakeCompleteRunnable) {
589594
super(next, negotiationLogger);
590595
this.sslContext = checkNotNull(sslContext, "sslContext");
591596
HostPort hostPort = parseAuthority(authority);
592597
this.host = hostPort.host;
593598
this.port = hostPort.port;
594599
this.executor = executor;
600+
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
595601
}
596602

597603
@Override
@@ -620,6 +626,9 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws
620626
Exception ex =
621627
unavailableException("Failed ALPN negotiation: Unable to find compatible protocol");
622628
logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex);
629+
if (handshakeCompleteRunnable.isPresent()) {
630+
handshakeCompleteRunnable.get().run();
631+
}
623632
ctx.fireExceptionCaught(ex);
624633
}
625634
} else {
@@ -634,6 +643,9 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws
634643
.withCause(t)
635644
.asRuntimeException();
636645
}
646+
if (handshakeCompleteRunnable.isPresent()) {
647+
handshakeCompleteRunnable.get().run();
648+
}
637649
ctx.fireExceptionCaught(t);
638650
}
639651
} else {
@@ -649,6 +661,9 @@ private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session)
649661
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
650662
.build();
651663
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
664+
if (handshakeCompleteRunnable.isPresent()) {
665+
handshakeCompleteRunnable.get().run();
666+
}
652667
fireProtocolNegotiationEvent(ctx);
653668
}
654669
}
@@ -683,8 +698,8 @@ static HostPort parseAuthority(String authority) {
683698
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
684699
*/
685700
public static ProtocolNegotiator tls(SslContext sslContext,
686-
ObjectPool<? extends Executor> executorPool) {
687-
return new ClientTlsProtocolNegotiator(sslContext, executorPool);
701+
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
702+
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable);
688703
}
689704

690705
/**
@@ -693,7 +708,7 @@ public static ProtocolNegotiator tls(SslContext sslContext,
693708
* may happen immediately, even before the TLS Handshake is complete.
694709
*/
695710
public static ProtocolNegotiator tls(SslContext sslContext) {
696-
return tls(sslContext, null);
711+
return tls(sslContext, null, Optional.empty());
697712
}
698713

699714
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {

netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
import java.util.HashMap;
106106
import java.util.List;
107107
import java.util.Map;
108+
import java.util.Optional;
108109
import java.util.concurrent.ExecutionException;
109110
import java.util.concurrent.LinkedBlockingQueue;
110111
import java.util.concurrent.TimeUnit;
@@ -766,7 +767,8 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception {
766767
.trustManager(caCert)
767768
.keyManager(clientCert, clientKey)
768769
.build();
769-
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool);
770+
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
771+
Optional.empty());
770772
// after starting the client, the Executor in the client pool should be used
771773
assertEquals(true, clientExecutorPool.isInUse());
772774
final NettyClientTransport transport = newTransport(negotiator);

netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
import java.util.Arrays;
121121
import java.util.Collections;
122122
import java.util.List;
123+
import java.util.Optional;
123124
import java.util.Queue;
124125
import java.util.concurrent.CountDownLatch;
125126
import java.util.concurrent.TimeUnit;
@@ -876,7 +877,7 @@ public String applicationProtocol() {
876877
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
877878

878879
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
879-
"authority", elg, noopLogger);
880+
"authority", elg, noopLogger, Optional.empty());
880881
pipeline.addLast(handler);
881882
pipeline.replace(SslHandler.class, null, goodSslHandler);
882883
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@@ -914,7 +915,7 @@ public String applicationProtocol() {
914915
.applicationProtocolConfig(apn).build();
915916

916917
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
917-
"authority", elg, noopLogger);
918+
"authority", elg, noopLogger, Optional.empty());
918919
pipeline.addLast(handler);
919920
pipeline.replace(SslHandler.class, null, goodSslHandler);
920921
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@@ -938,7 +939,7 @@ public String applicationProtocol() {
938939
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
939940

940941
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
941-
"authority", elg, noopLogger);
942+
"authority", elg, noopLogger, Optional.empty());
942943
pipeline.addLast(handler);
943944

944945
final AtomicReference<Throwable> error = new AtomicReference<>();
@@ -966,7 +967,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
966967
@Test
967968
public void clientTlsHandler_closeDuringNegotiation() throws Exception {
968969
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
969-
"authority", null, noopLogger);
970+
"authority", null, noopLogger, Optional.empty());
970971
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
971972
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
972973

@@ -1228,7 +1229,8 @@ public void clientTlsHandler_firesNegotiation() throws Exception {
12281229
serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build();
12291230
}
12301231
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
1231-
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null);
1232+
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
1233+
null, Optional.empty());
12321234
WriteBufferingAndExceptionHandler clientWbaeh =
12331235
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
12341236

s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
3232
import io.grpc.s2a.internal.handshaker.S2AIdentity;
3333
import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory;
34+
import io.grpc.s2a.internal.handshaker.S2AStub;
3435
import javax.annotation.concurrent.NotThreadSafe;
3536
import org.checkerframework.checker.nullness.qual.Nullable;
3637

@@ -59,6 +60,7 @@ public static final class Builder {
5960
private final String s2aAddress;
6061
private final ChannelCredentials s2aChannelCredentials;
6162
private @Nullable S2AIdentity localIdentity = null;
63+
private @Nullable S2AStub stub = null;
6264

6365
Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) {
6466
this.s2aAddress = s2aAddress;
@@ -104,6 +106,16 @@ public Builder setLocalUid(String localUid) {
104106
return this;
105107
}
106108

109+
/**
110+
* Sets the stub to use to communicate with S2A. This is only used for testing that the
111+
* stream to S2A gets closed.
112+
*/
113+
public Builder setStub(S2AStub stub) {
114+
checkNotNull(stub);
115+
this.stub = stub;
116+
return this;
117+
}
118+
107119
public ChannelCredentials build() {
108120
return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory());
109121
}
@@ -113,7 +125,7 @@ InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() {
113125
SharedResourcePool.forResource(
114126
S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials));
115127
checkNotNull(s2aChannelPool, "s2aChannelPool");
116-
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool);
128+
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub);
117129
}
118130
}
119131

0 commit comments

Comments
 (0)