Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static java.util.concurrent.TimeUnit.SECONDS;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Maps;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ChannelCredentials;
Expand All @@ -30,16 +29,15 @@
import io.grpc.internal.SharedResourceHolder.Resource;
import io.grpc.netty.NettyChannelBuilder;
import java.time.Duration;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.concurrent.ThreadSafe;

/**
* Provides APIs for managing gRPC channels to S2A servers. Each channel is local and plaintext. If
* credentials are provided, they are used to secure the channel.
* Provides APIs for managing gRPC channels to an S2A server. Each channel is local and plaintext.
* If credentials are provided, they are used to secure the channel.
*
* <p>This is done as follows: for each S2A server, provides an implementation of gRPC's {@link
* <p>This is done as follows: for an S2A server, provides an implementation of gRPC's {@link
* SharedResourceHolder.Resource} interface called a {@code Resource<Channel>}. A {@code
* Resource<Channel>} is a factory for creating gRPC channels to the S2A server at a given address,
* and a channel must be returned to the {@code Resource<Channel>} when it is no longer needed.
Expand All @@ -56,8 +54,6 @@
*/
@ThreadSafe
public final class S2AHandshakerServiceChannel {
private static final ConcurrentMap<String, Resource<Channel>> SHARED_RESOURCE_CHANNELS =
Maps.newConcurrentMap();
private static final Duration CHANNEL_SHUTDOWN_TIMEOUT = Duration.ofSeconds(10);

/**
Expand All @@ -72,9 +68,7 @@ public final class S2AHandshakerServiceChannel {
public static Resource<Channel> getChannelResource(
String s2aAddress, ChannelCredentials s2aChannelCredentials) {
checkNotNull(s2aAddress);
checkNotNull(s2aChannelCredentials);
return SHARED_RESOURCE_CHANNELS.computeIfAbsent(
s2aAddress, channelResource -> new ChannelResource(s2aAddress, s2aChannelCredentials));
return new ChannelResource(s2aAddress, s2aChannelCredentials);
}

/**
Expand Down
28 changes: 2 additions & 26 deletions s2a/src/main/java/io/grpc/s2a/handshaker/ProtoUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,11 @@

package io.grpc.s2a.handshaker;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;

/** Converts proto messages to Netty strings. */
final class ProtoUtil {
/**
* Converts {@link Ciphersuite} to its {@link String} representation.
*
* @param ciphersuite the {@link Ciphersuite} to be converted.
* @return a {@link String} representing the ciphersuite.
* @throws AssertionError if the {@link Ciphersuite} is not one of the supported ciphersuites.
*/
static String convertCiphersuite(Ciphersuite ciphersuite) {
switch (ciphersuite) {
case CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256";
case CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
return "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384";
case CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
return "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256";
case CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
return "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256";
case CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
return "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384";
case CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256:
return "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256";
default:
throw new AssertionError(
String.format("Ciphersuite %d is not supported.", ciphersuite.getNumber()));
}
}

/**
* Converts a {@link TLSVersion} object to its {@link String} representation.
Expand All @@ -54,6 +29,7 @@ static String convertCiphersuite(Ciphersuite ciphersuite) {
* @return a {@link String} representation of the TLS version.
* @throws AssertionError if the {@code tlsVersion} is not one of the supported TLS versions.
*/
@VisibleForTesting
static String convertTlsProtocolVersion(TLSVersion tlsVersion) {
switch (tlsVersion) {
case TLS_VERSION_1_3:
Expand Down
15 changes: 8 additions & 7 deletions s2a/src/main/java/io/grpc/s2a/handshaker/S2AStub.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ BlockingQueue<Result> getResponses() {
* @throws IOException if an unexpected response is received, or if the {@code reader} or {@code
* writer} calls their {@code onError} method.
*/
@SuppressWarnings("CheckReturnValue")
public SessionResp send(SessionReq req) throws IOException, InterruptedException {
if (doneWriting && doneReading) {
logger.log(Level.INFO, "Stream to the S2A is closed.");
Expand All @@ -92,9 +93,8 @@ public SessionResp send(SessionReq req) throws IOException, InterruptedException
createWriterIfNull();
if (!responses.isEmpty()) {
IOException exception = null;
SessionResp resp = null;
try {
resp = responses.take().getResultOrThrow();
responses.take().getResultOrThrow();
} catch (IOException e) {
exception = e;
}
Expand All @@ -104,14 +104,15 @@ public SessionResp send(SessionReq req) throws IOException, InterruptedException
"Received an unexpected response from a host at the S2A's address. The S2A might be"
+ " unavailable."
+ exception.getMessage());
} else {
throw new IOException("Received an unexpected response from a host at the S2A's address.");
}
return resp;
}
try {
writer.onNext(req);
} catch (RuntimeException e) {
writer.onError(e);
responses.offer(Result.createWithThrowable(e));
responses.add(Result.createWithThrowable(e));
}
try {
return responses.take().getResultOrThrow();
Expand Down Expand Up @@ -159,7 +160,7 @@ private class Reader implements StreamObserver<SessionResp> {
@Override
public void onNext(SessionResp resp) {
verify(!doneReading);
responses.offer(Result.createWithResponse(resp));
responses.add(Result.createWithResponse(resp));
}

/**
Expand All @@ -169,7 +170,7 @@ public void onNext(SessionResp resp) {
*/
@Override
public void onError(Throwable t) {
responses.offer(Result.createWithThrowable(t));
responses.add(Result.createWithThrowable(t));
}

/**
Expand All @@ -180,7 +181,7 @@ public void onError(Throwable t) {
public void onCompleted() {
logger.log(Level.INFO, "Reading from the S2A is complete.");
doneReading = true;
responses.offer(
responses.add(
Result.createWithThrowable(
new ConnectionClosedException("Reading from the S2A is complete.")));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ public final class AccessTokenManager {
private final TokenFetcher tokenFetcher;

/** Creates an {@code AccessTokenManager} based on the environment where the application runs. */
@SuppressWarnings("RethrowReflectiveOperationExceptionAsLinkageError")
public static Optional<AccessTokenManager> create() {
Optional<?> tokenFetcher;
try {
Expand All @@ -38,7 +37,7 @@ public static Optional<AccessTokenManager> create() {
} catch (ClassNotFoundException e) {
tokenFetcher = Optional.empty();
} catch (ReflectiveOperationException e) {
throw new AssertionError(e);
throw new LinkageError(e.getMessage(), e);
}
return tokenFetcher.isPresent()
? Optional.of(new AccessTokenManager((TokenFetcher) tokenFetcher.get()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ public void getChannelResource_mtlsSuccess() throws Exception {

/**
* Creates two {@code Resoure<Channel>}s for the same target address and verifies that they are
* equal.
* distinct.
*/
@Test
public void getChannelResource_twoEqualChannels() {
public void getChannelResource_twoUnEqualChannels() {
Resource<Channel> resource =
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + plaintextServer.getPort(),
Expand All @@ -101,19 +101,19 @@ public void getChannelResource_twoEqualChannels() {
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + plaintextServer.getPort(),
InsecureChannelCredentials.create());
assertThat(resource).isEqualTo(resourceTwo);
assertThat(resource).isNotEqualTo(resourceTwo);
}

/** Same as getChannelResource_twoEqualChannels, but use mTLS. */
/** Same as getChannelResource_twoUnEqualChannels, but use mTLS. */
@Test
public void getChannelResource_mtlsTwoEqualChannels() throws Exception {
public void getChannelResource_mtlsTwoUnEqualChannels() throws Exception {
Resource<Channel> resource =
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + mtlsServer.getPort(), getTlsChannelCredentials());
Resource<Channel> resourceTwo =
S2AHandshakerServiceChannel.getChannelResource(
"localhost:" + mtlsServer.getPort(), getTlsChannelCredentials());
assertThat(resource).isEqualTo(resourceTwo);
assertThat(resource).isNotEqualTo(resourceTwo);
}

/**
Expand Down
7 changes: 6 additions & 1 deletion s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.grpc.s2a.handshaker;

import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.security.spec.InvalidKeySpecException;
import java.util.logging.Logger;
Expand All @@ -38,7 +39,11 @@ public StreamObserver<SessionReq> setUpSession(StreamObserver<SessionResp> respo
@Override
public void onNext(SessionReq req) {
logger.info("Received a request from client.");
responseObserver.onNext(writer.handleResponse(req));
try {
responseObserver.onNext(writer.handleResponse(req));
} catch (IOException e) {
responseObserver.onError(e);
}
}

@Override
Expand Down
18 changes: 11 additions & 7 deletions s2a/src/test/java/io/grpc/s2a/handshaker/FakeS2AServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
import io.grpc.benchmarks.Utils;
import io.grpc.s2a.handshaker.ValidatePeerCertificateChainReq.VerificationMode;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand All @@ -45,9 +48,7 @@ public final class FakeS2AServerTest {
private static final Logger logger = Logger.getLogger(FakeS2AServerTest.class.getName());

private static final ImmutableList<ByteString> FAKE_CERT_DER_CHAIN =
ImmutableList.of(
ByteString.copyFrom(
new byte[] {'f', 'a', 'k', 'e', '-', 'd', 'e', 'r', '-', 'c', 'h', 'a', 'i', 'n'}));
ImmutableList.of(ByteString.copyFrom("fake-der-chain".getBytes(StandardCharsets.US_ASCII)));
private int port;
private String serverAddress;
private SessionResp response = null;
Expand All @@ -68,7 +69,7 @@ public void tearDown() {

@Test
public void callS2AServerOnce_getTlsConfiguration_returnsValidResult()
throws InterruptedException {
throws InterruptedException, IOException {
ExecutorService executor = Executors.newSingleThreadExecutor();
logger.info("Client connecting to: " + serverAddress);
ManagedChannel channel =
Expand Down Expand Up @@ -122,9 +123,12 @@ public void onCompleted() {}
GetTlsConfigurationResp.newBuilder()
.setClientTlsConfiguration(
GetTlsConfigurationResp.ClientTlsConfiguration.newBuilder()
.addCertificateChain(FakeWriter.LEAF_CERT)
.addCertificateChain(FakeWriter.INTERMEDIATE_CERT_2)
.addCertificateChain(FakeWriter.INTERMEDIATE_CERT_1)
.addCertificateChain(new String(Files.readAllBytes(
FakeWriter.leafCertFile.toPath()), StandardCharsets.UTF_8))
.addCertificateChain(new String(Files.readAllBytes(
FakeWriter.cert1File.toPath()), StandardCharsets.UTF_8))
.addCertificateChain(new String(Files.readAllBytes(
FakeWriter.cert2File.toPath()), StandardCharsets.UTF_8))
.setMinTlsVersion(TLSVersion.TLS_VERSION_1_3)
.setMaxTlsVersion(TLSVersion.TLS_VERSION_1_3)
.addCiphersuites(
Expand Down
Loading