5656import com .google .common .io .Files ;
5757import io .grpc .CallCredentials ;
5858import io .grpc .ChannelCredentials ;
59+ import io .grpc .CompositeChannelCredentials ;
5960import io .grpc .Grpc ;
6061import io .grpc .InsecureChannelCredentials ;
6162import io .grpc .ManagedChannel ;
6970import java .nio .charset .StandardCharsets ;
7071import java .security .GeneralSecurityException ;
7172import java .security .KeyStore ;
73+ import java .util .ArrayList ;
7274import java .util .HashMap ;
7375import java .util .List ;
7476import java .util .Map ;
@@ -139,14 +141,15 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
139141 @ Nullable private final Boolean keepAliveWithoutCalls ;
140142 private final ChannelPoolSettings channelPoolSettings ;
141143 @ Nullable private final Credentials credentials ;
144+ @ Nullable private final CallCredentials mtlsS2ACallCredentials ;
142145 @ Nullable private final ChannelPrimer channelPrimer ;
143146 @ Nullable private final Boolean attemptDirectPath ;
144147 @ Nullable private final Boolean attemptDirectPathXds ;
145148 @ Nullable private final Boolean allowNonDefaultServiceAccount ;
146149 @ VisibleForTesting final ImmutableMap <String , ?> directPathServiceConfig ;
147150 @ Nullable private final MtlsProvider mtlsProvider ;
148151 @ Nullable private final SecureSessionAgent s2aConfigProvider ;
149- @ Nullable private final List <HardBoundTokenTypes > allowedHardBoundTokenTypes ;
152+ private final List <HardBoundTokenTypes > allowedHardBoundTokenTypes ;
150153 @ VisibleForTesting final Map <String , String > headersWithDuplicatesRemoved = new HashMap <>();
151154
152155 @ Nullable
@@ -188,6 +191,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
188191 this .channelPoolSettings = builder .channelPoolSettings ;
189192 this .channelConfigurator = builder .channelConfigurator ;
190193 this .credentials = builder .credentials ;
194+ this .mtlsS2ACallCredentials = builder .mtlsS2ACallCredentials ;
191195 this .channelPrimer = builder .channelPrimer ;
192196 this .attemptDirectPath = builder .attemptDirectPath ;
193197 this .attemptDirectPathXds = builder .attemptDirectPathXds ;
@@ -648,6 +652,12 @@ private ManagedChannel createSingleChannel() throws IOException {
648652 }
649653 if (channelCredentials != null ) {
650654 // Create the channel using S2A-secured channel credentials.
655+ if (mtlsS2ACallCredentials != null ) {
656+ // Set {@code mtlsS2ACallCredentials} to be per-RPC call credentials,
657+ // which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
658+ channelCredentials =
659+ CompositeChannelCredentials .create (channelCredentials , mtlsS2ACallCredentials );
660+ }
651661 builder = Grpc .newChannelBuilder (endpoint , channelCredentials );
652662 } else {
653663 // Use default if we cannot initialize channel credentials via DCA or S2A.
@@ -812,18 +822,20 @@ public static final class Builder {
812822 @ Nullable private Boolean keepAliveWithoutCalls ;
813823 @ Nullable private ApiFunction <ManagedChannelBuilder , ManagedChannelBuilder > channelConfigurator ;
814824 @ Nullable private Credentials credentials ;
825+ @ Nullable private CallCredentials mtlsS2ACallCredentials ;
815826 @ Nullable private ChannelPrimer channelPrimer ;
816827 private ChannelPoolSettings channelPoolSettings ;
817828 @ Nullable private Boolean attemptDirectPath ;
818829 @ Nullable private Boolean attemptDirectPathXds ;
819830 @ Nullable private Boolean allowNonDefaultServiceAccount ;
820831 @ Nullable private ImmutableMap <String , ?> directPathServiceConfig ;
821- @ Nullable private List <HardBoundTokenTypes > allowedHardBoundTokenTypes ;
832+ private List <HardBoundTokenTypes > allowedHardBoundTokenTypes ;
822833
823834 private Builder () {
824835 processorCount = Runtime .getRuntime ().availableProcessors ();
825836 envProvider = System ::getenv ;
826837 channelPoolSettings = ChannelPoolSettings .staticallySized (1 );
838+ allowedHardBoundTokenTypes = new ArrayList <>();
827839 }
828840
829841 private Builder (InstantiatingGrpcChannelProvider provider ) {
@@ -841,11 +853,13 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
841853 this .keepAliveWithoutCalls = provider .keepAliveWithoutCalls ;
842854 this .channelConfigurator = provider .channelConfigurator ;
843855 this .credentials = provider .credentials ;
856+ this .mtlsS2ACallCredentials = provider .mtlsS2ACallCredentials ;
844857 this .channelPrimer = provider .channelPrimer ;
845858 this .channelPoolSettings = provider .channelPoolSettings ;
846859 this .attemptDirectPath = provider .attemptDirectPath ;
847860 this .attemptDirectPathXds = provider .attemptDirectPathXds ;
848861 this .allowNonDefaultServiceAccount = provider .allowNonDefaultServiceAccount ;
862+ this .allowedHardBoundTokenTypes = provider .allowedHardBoundTokenTypes ;
849863 this .directPathServiceConfig = provider .directPathServiceConfig ;
850864 this .mtlsProvider = provider .mtlsProvider ;
851865 this .s2aConfigProvider = provider .s2aConfigProvider ;
@@ -914,7 +928,10 @@ Builder setUseS2A(boolean useS2A) {
914928 */
915929 @ InternalApi
916930 public Builder setAllowHardBoundTokenTypes (List <HardBoundTokenTypes > allowedValues ) {
917- this .allowedHardBoundTokenTypes = allowedValues ;
931+ this .allowedHardBoundTokenTypes =
932+ Preconditions .checkNotNull (
933+ allowedValues , "List of allowed HardBoundTokenTypes cannot be null" );
934+ ;
918935 return this ;
919936 }
920937
@@ -1133,7 +1150,50 @@ public Builder setDirectPathServiceConfig(Map<String, ?> serviceConfig) {
11331150 return this ;
11341151 }
11351152
1153+ boolean isMtlsS2AHardBoundTokensEnabled () {
1154+ // If S2A cannot be used, the list of allowed hard bound token types is empty or doesn't
1155+ // contain
1156+ // {@code HardBoundTokenTypes.MTLS_S2A}, the {@code credentials} are null or not of type
1157+ // {@code
1158+ // ComputeEngineCredentials} then {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens
1159+ // should
1160+ // not
1161+ // be used. {@code HardBoundTokenTypes.MTLS_S2A} hard bound tokens can only be used on MTLS
1162+ // channels established using S2A and when tokens from MDS (i.e {@code
1163+ // ComputeEngineCredentials}
1164+ // are being used.
1165+ if (!this .useS2A
1166+ || this .allowedHardBoundTokenTypes .isEmpty ()
1167+ || this .credentials == null
1168+ || !(this .credentials instanceof ComputeEngineCredentials )) {
1169+ return false ;
1170+ }
1171+ return allowedHardBoundTokenTypes .stream ()
1172+ .anyMatch (val -> val .equals (HardBoundTokenTypes .MTLS_S2A ));
1173+ }
1174+
1175+ CallCredentials createHardBoundTokensCallCredentials (
1176+ ComputeEngineCredentials .GoogleAuthTransport googleAuthTransport ,
1177+ ComputeEngineCredentials .BindingEnforcement bindingEnforcement ) {
1178+ // We only set scopes and HTTP transport factory from the original credentials because
1179+ // only those are used in gRPC CallCredentials to fetch request metadata.
1180+ return MoreCallCredentials .from (
1181+ ((ComputeEngineCredentials ) this .credentials )
1182+ .toBuilder ()
1183+ .setGoogleAuthTransport (googleAuthTransport )
1184+ .setBindingEnforcement (bindingEnforcement )
1185+ .build ());
1186+ }
1187+
11361188 public InstantiatingGrpcChannelProvider build () {
1189+ if (isMtlsS2AHardBoundTokensEnabled ()) {
1190+ // Set a {@code ComputeEngineCredentials} instance to be per-RPC call credentials,
1191+ // which will be used to fetch MTLS_S2A hard bound tokens from the metdata server.
1192+ this .mtlsS2ACallCredentials =
1193+ createHardBoundTokensCallCredentials (
1194+ ComputeEngineCredentials .GoogleAuthTransport .MTLS ,
1195+ ComputeEngineCredentials .BindingEnforcement .ON );
1196+ }
11371197 InstantiatingGrpcChannelProvider instantiatingGrpcChannelProvider =
11381198 new InstantiatingGrpcChannelProvider (this );
11391199 instantiatingGrpcChannelProvider .removeApiKeyCredentialDuplicateHeaders ();
0 commit comments