|  | 
| 20 | 20 | import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.STREAMING_ECDSA_SIGNED_PAYLOAD; | 
| 21 | 21 | import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.STREAMING_ECDSA_SIGNED_PAYLOAD_TRAILER; | 
| 22 | 22 | import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.STREAMING_UNSIGNED_PAYLOAD_TRAILER; | 
|  | 23 | +import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_DECODED_CONTENT_LENGTH; | 
| 23 | 24 | import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerConstant.X_AMZ_TRAILER; | 
|  | 25 | +import static software.amazon.awssdk.http.auth.aws.internal.signer.util.SignerUtils.computeAndMoveContentLength; | 
| 24 | 26 | 
 | 
| 25 | 27 | import java.io.InputStream; | 
|  | 28 | +import java.nio.ByteBuffer; | 
| 26 | 29 | import java.nio.charset.StandardCharsets; | 
| 27 | 30 | import java.util.ArrayList; | 
| 28 | 31 | import java.util.Collections; | 
| 29 | 32 | import java.util.List; | 
|  | 33 | +import java.util.Optional; | 
|  | 34 | +import java.util.concurrent.CompletableFuture; | 
|  | 35 | +import org.reactivestreams.Publisher; | 
| 30 | 36 | import software.amazon.awssdk.annotations.SdkInternalApi; | 
| 31 | 37 | import software.amazon.awssdk.checksums.SdkChecksum; | 
| 32 | 38 | import software.amazon.awssdk.checksums.spi.ChecksumAlgorithm; | 
|  | 
| 35 | 41 | import software.amazon.awssdk.http.SdkHttpRequest; | 
| 36 | 42 | import software.amazon.awssdk.http.auth.aws.internal.signer.CredentialScope; | 
| 37 | 43 | import software.amazon.awssdk.http.auth.aws.internal.signer.NoOpPayloadChecksumStore; | 
|  | 44 | +import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.AsyncChunkEncodedPayload; | 
| 38 | 45 | import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.ChecksumTrailerProvider; | 
| 39 | 46 | import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.ChunkedEncodedInputStream; | 
|  | 47 | +import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.ChunkedEncodedPayload; | 
|  | 48 | +import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.ChunkedEncodedPublisher; | 
|  | 49 | +import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.SyncChunkEncodedPayload; | 
| 40 | 50 | import software.amazon.awssdk.http.auth.aws.internal.signer.chunkedencoding.TrailerProvider; | 
| 41 | 51 | import software.amazon.awssdk.http.auth.aws.internal.signer.io.ChecksumInputStream; | 
| 42 | 52 | import software.amazon.awssdk.http.auth.aws.internal.signer.io.ResettableContentStreamProvider; | 
| @@ -83,39 +93,73 @@ public ContentStreamProvider sign(ContentStreamProvider payload, V4aRequestSigni | 
| 83 | 93 |             .chunkSize(chunkSize) | 
| 84 | 94 |             .header(chunk -> Integer.toHexString(chunk.remaining()).getBytes(StandardCharsets.UTF_8)); | 
| 85 | 95 | 
 | 
| 86 |  | -        preExistingTrailers.forEach(trailer -> chunkedEncodedInputStreamBuilder.addTrailer(() -> trailer)); | 
|  | 96 | +        SyncChunkEncodedPayload chunkedPayload = new SyncChunkEncodedPayload(chunkedEncodedInputStreamBuilder); | 
|  | 97 | + | 
|  | 98 | +        signCommon(chunkedPayload, requestSigningResult); | 
|  | 99 | + | 
|  | 100 | +        return new ResettableContentStreamProvider(chunkedEncodedInputStreamBuilder::build); | 
|  | 101 | +    } | 
|  | 102 | + | 
|  | 103 | +    /** | 
|  | 104 | +     * Given a payload and result of request signing, sign the payload via the SigV4 process. | 
|  | 105 | +     */ | 
|  | 106 | +    @Override | 
|  | 107 | +    public Publisher<ByteBuffer> signAsync(Publisher<ByteBuffer> payload, V4aRequestSigningResult requestSigningResult) { | 
|  | 108 | +        ChunkedEncodedPublisher.Builder chunkedStreamBuilder = ChunkedEncodedPublisher.builder() | 
|  | 109 | +                                                                                      .publisher(payload) | 
|  | 110 | +                                                                                      .chunkSize(chunkSize) | 
|  | 111 | +                                                                                      .addEmptyTrailingChunk(true); | 
|  | 112 | +        AsyncChunkEncodedPayload chunkedPayload = new AsyncChunkEncodedPayload(chunkedStreamBuilder); | 
|  | 113 | + | 
|  | 114 | +        signCommon(chunkedPayload, requestSigningResult); | 
|  | 115 | + | 
|  | 116 | +        return chunkedStreamBuilder.build(); | 
|  | 117 | +    } | 
|  | 118 | + | 
|  | 119 | +    private ChunkedEncodedPayload signCommon(ChunkedEncodedPayload payload, V4aRequestSigningResult requestSigningResult) { | 
|  | 120 | +        SdkHttpRequest.Builder request = requestSigningResult.getSignedRequest(); | 
|  | 121 | + | 
|  | 122 | +        payload.decodedContentLength(request.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH) | 
|  | 123 | +                                            .map(Long::parseLong) | 
|  | 124 | +                                            .orElseThrow(() -> { | 
|  | 125 | +                                                String msg = String.format("Expected header '%s' to be present", | 
|  | 126 | +                                                                           X_AMZ_DECODED_CONTENT_LENGTH); | 
|  | 127 | +                                                return new RuntimeException(msg); | 
|  | 128 | +                                            })); | 
|  | 129 | + | 
|  | 130 | +        preExistingTrailers.forEach(trailer -> payload.addTrailer(() -> trailer)); | 
| 87 | 131 | 
 | 
| 88 | 132 |         switch (requestSigningResult.getSigningConfig().getSignedBodyValue()) { | 
| 89 | 133 |             case STREAMING_ECDSA_SIGNED_PAYLOAD: { | 
| 90 | 134 |                 RollingSigner rollingSigner = new RollingSigner(requestSigningResult.getSignature(), | 
| 91 | 135 |                                                                 requestSigningResult.getSigningConfig()); | 
| 92 |  | -                chunkedEncodedInputStreamBuilder.addExtension(new SigV4aChunkExtensionProvider(rollingSigner, credentialScope)); | 
|  | 136 | +                payload.addExtension(new SigV4aChunkExtensionProvider(rollingSigner, credentialScope)); | 
| 93 | 137 |                 break; | 
| 94 | 138 |             } | 
| 95 | 139 |             case STREAMING_UNSIGNED_PAYLOAD_TRAILER: | 
| 96 |  | -                setupChecksumTrailerIfNeeded(chunkedEncodedInputStreamBuilder); | 
|  | 140 | +                setupChecksumTrailerIfNeeded(payload); | 
| 97 | 141 |                 break; | 
| 98 | 142 |             case STREAMING_ECDSA_SIGNED_PAYLOAD_TRAILER: { | 
| 99 | 143 |                 RollingSigner rollingSigner = new RollingSigner(requestSigningResult.getSignature(), | 
| 100 | 144 |                                                                 requestSigningResult.getSigningConfig()); | 
| 101 |  | -                chunkedEncodedInputStreamBuilder.addExtension(new SigV4aChunkExtensionProvider(rollingSigner, credentialScope)); | 
| 102 |  | -                setupChecksumTrailerIfNeeded(chunkedEncodedInputStreamBuilder); | 
| 103 |  | -                chunkedEncodedInputStreamBuilder.addTrailer( | 
| 104 |  | -                    new SigV4aTrailerProvider(chunkedEncodedInputStreamBuilder.trailers(), rollingSigner, credentialScope) | 
|  | 145 | +                payload.addExtension(new SigV4aChunkExtensionProvider(rollingSigner, credentialScope)); | 
|  | 146 | +                setupChecksumTrailerIfNeeded(payload); | 
|  | 147 | +                payload.addTrailer( | 
|  | 148 | +                    new SigV4aTrailerProvider(payload.trailers(), rollingSigner, credentialScope) | 
| 105 | 149 |                 ); | 
| 106 | 150 |                 break; | 
| 107 | 151 |             } | 
| 108 | 152 |             default: | 
| 109 | 153 |                 throw new UnsupportedOperationException(); | 
| 110 | 154 |         } | 
| 111 | 155 | 
 | 
| 112 |  | -        return new ResettableContentStreamProvider(chunkedEncodedInputStreamBuilder::build); | 
|  | 156 | +        return payload; | 
| 113 | 157 |     } | 
| 114 | 158 | 
 | 
| 115 | 159 |     @Override | 
| 116 | 160 |     public void beforeSigning(SdkHttpRequest.Builder request, ContentStreamProvider payload, String checksum) { | 
| 117 | 161 |         long encodedContentLength = 0; | 
| 118 |  | -        long contentLength = SignerUtils.computeAndMoveContentLength(request, payload); | 
|  | 162 | +        long contentLength = computeAndMoveContentLength(request, payload); | 
| 119 | 163 |         setupPreExistingTrailers(request); | 
| 120 | 164 | 
 | 
| 121 | 165 |         // pre-existing trailers | 
| @@ -157,6 +201,72 @@ public void beforeSigning(SdkHttpRequest.Builder request, ContentStreamProvider | 
| 157 | 201 |         // CRT-signed request doesn't expect 'aws-chunked' Content-Encoding, so we don't add it | 
| 158 | 202 |     } | 
| 159 | 203 | 
 | 
|  | 204 | +    @Override | 
|  | 205 | +    public CompletableFuture<Pair<SdkHttpRequest.Builder, Optional<Publisher<ByteBuffer>>>> beforeSigningAsync( | 
|  | 206 | +        SdkHttpRequest.Builder request, Publisher<ByteBuffer> payload, String checksum) { | 
|  | 207 | + | 
|  | 208 | +        return SignerUtils.moveContentLength(request, payload) | 
|  | 209 | +                          .thenApply(p -> { | 
|  | 210 | +                              SdkHttpRequest.Builder requestBuilder = p.left(); | 
|  | 211 | +                              setupPreExistingTrailers(requestBuilder); | 
|  | 212 | + | 
|  | 213 | +                              long decodedContentLength = | 
|  | 214 | +                                  requestBuilder.firstMatchingHeader(X_AMZ_DECODED_CONTENT_LENGTH) | 
|  | 215 | +                                                .map(Long::parseLong) | 
|  | 216 | +                                                // should not happen, this header is added by | 
|  | 217 | +                                                // moveContentLength | 
|  | 218 | +                                                .orElseThrow(() -> new RuntimeException( | 
|  | 219 | +                                                    X_AMZ_DECODED_CONTENT_LENGTH + " header not present")); | 
|  | 220 | + | 
|  | 221 | +                              long encodedContentLength = calculateEncodedContentLength(request, decodedContentLength, checksum); | 
|  | 222 | + | 
|  | 223 | +                              if (checksumAlgorithm != null) { | 
|  | 224 | +                                  String checksumHeaderName = checksumHeaderName(checksumAlgorithm); | 
|  | 225 | +                                  request.appendHeader(X_AMZ_TRAILER, checksumHeaderName); | 
|  | 226 | +                              } | 
|  | 227 | +                              request.putHeader(Header.CONTENT_LENGTH, Long.toString(encodedContentLength)); | 
|  | 228 | + | 
|  | 229 | +                              return Pair.of(requestBuilder, p.right()); | 
|  | 230 | +                          }); | 
|  | 231 | +    } | 
|  | 232 | + | 
|  | 233 | +    private long calculateEncodedContentLength(SdkHttpRequest.Builder requestBuilder, long decodedContentLength, | 
|  | 234 | +                                               String checksum) { | 
|  | 235 | +        long encodedContentLength = 0; | 
|  | 236 | + | 
|  | 237 | +        encodedContentLength += calculateExistingTrailersLength(); | 
|  | 238 | + | 
|  | 239 | +        switch (checksum) { | 
|  | 240 | +            case STREAMING_ECDSA_SIGNED_PAYLOAD: { | 
|  | 241 | +                long extensionsLength = 161; // ;chunk-signature:<sigv4a-ecsda hex signature, 144 bytes> | 
|  | 242 | +                encodedContentLength += calculateChunksLength(decodedContentLength, extensionsLength); | 
|  | 243 | +                break; | 
|  | 244 | +            } | 
|  | 245 | +            case STREAMING_UNSIGNED_PAYLOAD_TRAILER: | 
|  | 246 | +                if (checksumAlgorithm != null) { | 
|  | 247 | +                    encodedContentLength += calculateChecksumTrailerLength(checksumHeaderName(checksumAlgorithm)); | 
|  | 248 | +                } | 
|  | 249 | +                encodedContentLength += calculateChunksLength(decodedContentLength, 0); | 
|  | 250 | +                break; | 
|  | 251 | +            case STREAMING_ECDSA_SIGNED_PAYLOAD_TRAILER: { | 
|  | 252 | +                long extensionsLength = 161; // ;chunk-signature:<sigv4a-ecsda hex signature, 144 bytes> | 
|  | 253 | +                encodedContentLength += calculateChunksLength(decodedContentLength, extensionsLength); | 
|  | 254 | +                if (checksumAlgorithm != null) { | 
|  | 255 | +                    encodedContentLength += calculateChecksumTrailerLength(checksumHeaderName(checksumAlgorithm)); | 
|  | 256 | +                } | 
|  | 257 | +                encodedContentLength += 170; // x-amz-trailer-signature:<sigv4a-ecsda hex signature, 144 bytes>\r\n | 
|  | 258 | +                break; | 
|  | 259 | +            } | 
|  | 260 | +            default: | 
|  | 261 | +                throw new UnsupportedOperationException(); | 
|  | 262 | +        } | 
|  | 263 | + | 
|  | 264 | +        // terminating \r\n | 
|  | 265 | +        encodedContentLength += 2; | 
|  | 266 | + | 
|  | 267 | +        return encodedContentLength; | 
|  | 268 | +    } | 
|  | 269 | + | 
| 160 | 270 |     /** | 
| 161 | 271 |      * Set up a map of pre-existing trailer (headers) for the given request to be used when chunk-encoding the payload. | 
| 162 | 272 |      * <p> | 
| @@ -270,6 +380,30 @@ private void setupChecksumTrailerIfNeeded(ChunkedEncodedInputStream.Builder buil | 
| 270 | 380 |         builder.inputStream(checksumInputStream).addTrailer(checksumTrailer); | 
| 271 | 381 |     } | 
| 272 | 382 | 
 | 
|  | 383 | +    private void setupChecksumTrailerIfNeeded(ChunkedEncodedPayload payload) { | 
|  | 384 | +        if (checksumAlgorithm == null) { | 
|  | 385 | +            return; | 
|  | 386 | +        } | 
|  | 387 | +        String checksumHeaderName = checksumHeaderName(checksumAlgorithm); | 
|  | 388 | + | 
|  | 389 | +        String cachedChecksum = getCachedChecksum(); | 
|  | 390 | + | 
|  | 391 | +        if (cachedChecksum != null) { | 
|  | 392 | +            LOG.debug(() -> String.format("Cached payload checksum available for algorithm %s: %s. Using cached value", | 
|  | 393 | +                                          checksumAlgorithm.algorithmId(), checksumHeaderName)); | 
|  | 394 | +            payload.addTrailer(() -> Pair.of(checksumHeaderName, Collections.singletonList(cachedChecksum))); | 
|  | 395 | +            return; | 
|  | 396 | +        } | 
|  | 397 | + | 
|  | 398 | +        SdkChecksum sdkChecksum = fromChecksumAlgorithm(checksumAlgorithm); | 
|  | 399 | +        payload.checksumPayload(sdkChecksum); | 
|  | 400 | + | 
|  | 401 | +        TrailerProvider checksumTrailer = | 
|  | 402 | +            new ChecksumTrailerProvider(sdkChecksum, checksumHeaderName, checksumAlgorithm, payloadChecksumStore); | 
|  | 403 | + | 
|  | 404 | +        payload.addTrailer(checksumTrailer); | 
|  | 405 | +    } | 
|  | 406 | + | 
| 273 | 407 |     private String getCachedChecksum() { | 
| 274 | 408 |         byte[] checksumBytes = payloadChecksumStore.getChecksumValue(checksumAlgorithm); | 
| 275 | 409 |         if (checksumBytes != null) { | 
|  | 
0 commit comments