Skip to content

Commit 191ed07

Browse files
committed
plumb S2AStub close to handshake end + add integration test.
1 parent d9d4317 commit 191ed07

File tree

10 files changed

+129
-64
lines changed

10 files changed

+129
-64
lines changed

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import io.netty.channel.ChannelHandler;
2525
import io.netty.handler.ssl.SslContext;
2626
import io.netty.util.AsciiString;
27+
import java.util.Optional;
2728
import java.util.concurrent.Executor;
2829

2930
/**
@@ -40,9 +41,10 @@ private InternalProtocolNegotiators() {}
4041
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
4142
*/
4243
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
43-
ObjectPool<? extends Executor> executorPool) {
44+
ObjectPool<? extends Executor> executorPool,
45+
Optional<Runnable> handshakeCompleteRunnable) {
4446
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
45-
executorPool);
47+
executorPool, handshakeCompleteRunnable);
4648
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
4749

4850
@Override
@@ -70,7 +72,7 @@ public void close() {
7072
* may happen immediately, even before the TLS Handshake is complete.
7173
*/
7274
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
73-
return tls(sslContext, null);
75+
return tls(sslContext, null, Optional.empty());
7476
}
7577

7678
/**
@@ -167,7 +169,7 @@ public static ChannelHandler grpcNegotiationHandler(GrpcHttp2ConnectionHandler n
167169
public static ChannelHandler clientTlsHandler(
168170
ChannelHandler next, SslContext sslContext, String authority,
169171
ChannelLogger negotiationLogger) {
170-
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger);
172+
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger, Optional.empty());
171173
}
172174

173175
public static class ProtocolNegotiationHandler

netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import java.util.Collections;
6464
import java.util.HashMap;
6565
import java.util.Map;
66+
import java.util.Optional;
6667
import java.util.concurrent.Executor;
6768
import java.util.concurrent.ScheduledExecutorService;
6869
import java.util.concurrent.TimeUnit;
@@ -604,7 +605,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType(
604605
case PLAINTEXT_UPGRADE:
605606
return ProtocolNegotiators.plaintextUpgrade();
606607
case TLS:
607-
return ProtocolNegotiators.tls(sslContext, executorPool);
608+
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.empty());
608609
default:
609610
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
610611
}

netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
import java.nio.channels.ClosedChannelException;
7373
import java.util.Arrays;
7474
import java.util.EnumSet;
75+
import java.util.Optional;
7576
import java.util.Set;
7677
import java.util.concurrent.Executor;
7778
import java.util.logging.Level;
@@ -543,16 +544,18 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws
543544
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
544545

545546
public ClientTlsProtocolNegotiator(SslContext sslContext,
546-
ObjectPool<? extends Executor> executorPool) {
547+
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
547548
this.sslContext = checkNotNull(sslContext, "sslContext");
548549
this.executorPool = executorPool;
549550
if (this.executorPool != null) {
550551
this.executor = this.executorPool.getObject();
551552
}
553+
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
552554
}
553555

554556
private final SslContext sslContext;
555557
private final ObjectPool<? extends Executor> executorPool;
558+
private final Optional<Runnable> handshakeCompleteRunnable;
556559
private Executor executor;
557560

558561
@Override
@@ -565,7 +568,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
565568
ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
566569
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
567570
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
568-
this.executor, negotiationLogger);
571+
this.executor, negotiationLogger, handshakeCompleteRunnable);
569572
return new WaitUntilActiveHandler(cth, negotiationLogger);
570573
}
571574

@@ -583,15 +586,18 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler {
583586
private final String host;
584587
private final int port;
585588
private Executor executor;
589+
private final Optional<Runnable> handshakeCompleteRunnable;
586590

587591
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
588-
Executor executor, ChannelLogger negotiationLogger) {
592+
Executor executor, ChannelLogger negotiationLogger,
593+
Optional<Runnable> handshakeCompleteRunnable) {
589594
super(next, negotiationLogger);
590595
this.sslContext = checkNotNull(sslContext, "sslContext");
591596
HostPort hostPort = parseAuthority(authority);
592597
this.host = hostPort.host;
593598
this.port = hostPort.port;
594599
this.executor = executor;
600+
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
595601
}
596602

597603
@Override
@@ -634,6 +640,9 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws
634640
.withCause(t)
635641
.asRuntimeException();
636642
}
643+
if (handshakeCompleteRunnable.isPresent()) {
644+
handshakeCompleteRunnable.get().run();
645+
}
637646
ctx.fireExceptionCaught(t);
638647
}
639648
} else {
@@ -649,6 +658,9 @@ private void propagateTlsComplete(ChannelHandlerContext ctx, SSLSession session)
649658
.set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
650659
.build();
651660
replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs).withSecurity(security));
661+
if (handshakeCompleteRunnable.isPresent()) {
662+
handshakeCompleteRunnable.get().run();
663+
}
652664
fireProtocolNegotiationEvent(ctx);
653665
}
654666
}
@@ -683,8 +695,8 @@ static HostPort parseAuthority(String authority) {
683695
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
684696
*/
685697
public static ProtocolNegotiator tls(SslContext sslContext,
686-
ObjectPool<? extends Executor> executorPool) {
687-
return new ClientTlsProtocolNegotiator(sslContext, executorPool);
698+
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable) {
699+
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable);
688700
}
689701

690702
/**
@@ -693,7 +705,7 @@ public static ProtocolNegotiator tls(SslContext sslContext,
693705
* may happen immediately, even before the TLS Handshake is complete.
694706
*/
695707
public static ProtocolNegotiator tls(SslContext sslContext) {
696-
return tls(sslContext, null);
708+
return tls(sslContext, null, Optional.empty());
697709
}
698710

699711
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext) {

netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
import java.util.HashMap;
106106
import java.util.List;
107107
import java.util.Map;
108+
import java.util.Optional;
108109
import java.util.concurrent.ExecutionException;
109110
import java.util.concurrent.LinkedBlockingQueue;
110111
import java.util.concurrent.TimeUnit;
@@ -766,7 +767,8 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception {
766767
.trustManager(caCert)
767768
.keyManager(clientCert, clientKey)
768769
.build();
769-
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool);
770+
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
771+
Optional.empty());
770772
// after starting the client, the Executor in the client pool should be used
771773
assertEquals(true, clientExecutorPool.isInUse());
772774
final NettyClientTransport transport = newTransport(negotiator);

netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@
120120
import java.util.Arrays;
121121
import java.util.Collections;
122122
import java.util.List;
123+
import java.util.Optional;
123124
import java.util.Queue;
124125
import java.util.concurrent.CountDownLatch;
125126
import java.util.concurrent.TimeUnit;
@@ -876,7 +877,7 @@ public String applicationProtocol() {
876877
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
877878

878879
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
879-
"authority", elg, noopLogger);
880+
"authority", elg, noopLogger, Optional.empty());
880881
pipeline.addLast(handler);
881882
pipeline.replace(SslHandler.class, null, goodSslHandler);
882883
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@@ -914,7 +915,7 @@ public String applicationProtocol() {
914915
.applicationProtocolConfig(apn).build();
915916

916917
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
917-
"authority", elg, noopLogger);
918+
"authority", elg, noopLogger, Optional.empty());
918919
pipeline.addLast(handler);
919920
pipeline.replace(SslHandler.class, null, goodSslHandler);
920921
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@@ -938,7 +939,7 @@ public String applicationProtocol() {
938939
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
939940

940941
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
941-
"authority", elg, noopLogger);
942+
"authority", elg, noopLogger, Optional.empty());
942943
pipeline.addLast(handler);
943944

944945
final AtomicReference<Throwable> error = new AtomicReference<>();
@@ -966,7 +967,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
966967
@Test
967968
public void clientTlsHandler_closeDuringNegotiation() throws Exception {
968969
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
969-
"authority", null, noopLogger);
970+
"authority", null, noopLogger, Optional.empty());
970971
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
971972
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
972973

@@ -1228,7 +1229,8 @@ public void clientTlsHandler_firesNegotiation() throws Exception {
12281229
serverSslContext = GrpcSslContexts.forServer(server1Chain, server1Key).build();
12291230
}
12301231
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
1231-
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, null);
1232+
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
1233+
null, Optional.empty());
12321234
WriteBufferingAndExceptionHandler clientWbaeh =
12331235
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
12341236

s2a/src/main/java/io/grpc/s2a/S2AChannelCredentials.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import io.grpc.s2a.internal.channel.S2AHandshakerServiceChannel;
3232
import io.grpc.s2a.internal.handshaker.S2AIdentity;
3333
import io.grpc.s2a.internal.handshaker.S2AProtocolNegotiatorFactory;
34+
import io.grpc.s2a.internal.handshaker.S2AStub;
3435
import javax.annotation.concurrent.NotThreadSafe;
3536
import org.checkerframework.checker.nullness.qual.Nullable;
3637

@@ -59,6 +60,7 @@ public static final class Builder {
5960
private final String s2aAddress;
6061
private final ChannelCredentials s2aChannelCredentials;
6162
private @Nullable S2AIdentity localIdentity = null;
63+
private @Nullable S2AStub stub = null;
6264

6365
Builder(String s2aAddress, ChannelCredentials s2aChannelCredentials) {
6466
this.s2aAddress = s2aAddress;
@@ -104,6 +106,16 @@ public Builder setLocalUid(String localUid) {
104106
return this;
105107
}
106108

109+
/**
110+
* Sets the stub to use to communicate with S2A. This is only used for testing that the
111+
* stream to S2A gets closed.
112+
*/
113+
public Builder setStub(S2AStub stub) {
114+
checkNotNull(stub);
115+
this.stub = stub;
116+
return this;
117+
}
118+
107119
public ChannelCredentials build() {
108120
return InternalNettyChannelCredentials.create(buildProtocolNegotiatorFactory());
109121
}
@@ -113,7 +125,7 @@ InternalProtocolNegotiator.ClientFactory buildProtocolNegotiatorFactory() {
113125
SharedResourcePool.forResource(
114126
S2AHandshakerServiceChannel.getChannelResource(s2aAddress, s2aChannelCredentials));
115127
checkNotNull(s2aChannelPool, "s2aChannelPool");
116-
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool);
128+
return S2AProtocolNegotiatorFactory.createClientFactory(localIdentity, s2aChannelPool, stub);
117129
}
118130
}
119131

0 commit comments

Comments
 (0)