Skip to content

Commit b5192aa

Browse files
authored
Enable direct cancellation for IHttpForwarder, transforms #1542 (#1985)
1 parent c253f4a commit b5192aa

File tree

11 files changed

+153
-27
lines changed

11 files changed

+153
-27
lines changed

src/ReverseProxy/Forwarder/ForwarderRequestConfig.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
using System;
55
using System.Net.Http;
6+
using System.Threading;
7+
using Microsoft.AspNetCore.Http;
68

79
namespace Yarp.ReverseProxy.Forwarder;
810

911
/// <summary>
10-
/// Config for <see cref="IHttpForwarder.SendAsync"/>
12+
/// Config for <see cref="IHttpForwarder.SendAsync(HttpContext, string, HttpMessageInvoker, ForwarderRequestConfig, HttpTransformer, CancellationToken)"/>
1113
/// </summary>
1214
public sealed record ForwarderRequestConfig
1315
{

src/ReverseProxy/Forwarder/HttpForwarder.cs

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,21 @@ public HttpForwarder(ILogger<HttpForwarder> logger, IClock clock)
8383
/// ASP .NET Core (Kestrel) will finally send response trailers (if any)
8484
/// after we complete the steps above and relinquish control.
8585
/// </remarks>
86-
public async ValueTask<ForwarderError> SendAsync(
86+
public ValueTask<ForwarderError> SendAsync(
8787
HttpContext context,
8888
string destinationPrefix,
8989
HttpMessageInvoker httpClient,
9090
ForwarderRequestConfig requestConfig,
9191
HttpTransformer transformer)
92+
=> SendAsync(context, destinationPrefix, httpClient, requestConfig, transformer, CancellationToken.None);
93+
94+
public async ValueTask<ForwarderError> SendAsync(
95+
HttpContext context,
96+
string destinationPrefix,
97+
HttpMessageInvoker httpClient,
98+
ForwarderRequestConfig requestConfig,
99+
HttpTransformer transformer,
100+
CancellationToken cancellationToken)
92101
{
93102
_ = context ?? throw new ArgumentNullException(nameof(context));
94103
_ = destinationPrefix ?? throw new ArgumentNullException(nameof(destinationPrefix));
@@ -110,7 +119,7 @@ public async ValueTask<ForwarderError> SendAsync(
110119

111120
ForwarderTelemetry.Log.ForwarderStart(destinationPrefix);
112121

113-
var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted);
122+
var activityCancellationSource = ActivityCancellationTokenSource.Rent(requestConfig?.ActivityTimeout ?? DefaultTimeout, context.RequestAborted, cancellationToken);
114123
try
115124
{
116125
var isClientHttp2OrGreater = ProtocolHelper.IsHttp2OrGreater(context.Request.Protocol);
@@ -193,7 +202,7 @@ public async ValueTask<ForwarderError> SendAsync(
193202
{
194203
// :: Step 5: Copy response status line Client ◄-- Proxy ◄-- Destination
195204
// :: Step 6: Copy response headers Client ◄-- Proxy ◄-- Destination
196-
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer);
205+
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);
197206

198207
if (!copyBody)
199208
{
@@ -260,7 +269,7 @@ public async ValueTask<ForwarderError> SendAsync(
260269
}
261270

262271
// :: Step 8: Copy response trailer headers and finish response Client ◄-- Proxy ◄-- Destination
263-
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer);
272+
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);
264273

265274
if (isStreamingRequest)
266275
{
@@ -402,14 +411,17 @@ public async ValueTask<ForwarderError> SendAsync(
402411
destinationRequest.Content = requestContent;
403412

404413
// :: Step 3: Copy request headers Client --► Proxy --► Destination
405-
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix);
414+
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix, activityToken.Token);
406415

407416
// The transformer generated a response, do not forward.
408417
if (RequestUtilities.IsResponseSet(context.Response))
409418
{
410419
return (destinationRequest, requestContent, false);
411420
}
412421

422+
// Transforms may have taken a while, especially if they buffered the body, they count as forward progress.
423+
activityToken.ResetTimeout();
424+
413425
FixupUpgradeRequestHeaders(context, destinationRequest, outgoingUpgrade, outgoingConnect);
414426

415427
// Allow someone to custom build the request uri, otherwise provide a default for them.
@@ -653,7 +665,7 @@ async ValueTask<ForwarderError> ReportErrorAsync(ForwarderError error, int statu
653665
}
654666
}
655667

656-
private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
668+
private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
657669
{
658670
context.Response.StatusCode = (int)source.StatusCode;
659671

@@ -667,7 +679,7 @@ private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMes
667679
}
668680

669681
// Copies headers
670-
return transformer.TransformResponseAsync(context, source);
682+
return transformer.TransformResponseAsync(context, source, cancellationToken);
671683
}
672684

673685
private async ValueTask<ForwarderError> HandleUpgradedResponse(HttpContext context, HttpResponseMessage destinationResponse,
@@ -891,10 +903,10 @@ private async ValueTask<ForwarderError> HandleResponseBodyErrorAsync(HttpContext
891903
return error;
892904
}
893905

894-
private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
906+
private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
895907
{
896908
// Copies trailers
897-
return transformer.TransformResponseTrailersAsync(context, source);
909+
return transformer.TransformResponseTrailersAsync(context, source, cancellationToken);
898910
}
899911

900912

src/ReverseProxy/Forwarder/HttpTransformer.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Net.Http;
88
using System.Net.Http.Headers;
99
using System.Runtime.CompilerServices;
10+
using System.Threading;
1011
using System.Threading.Tasks;
1112
using Microsoft.AspNetCore.Http;
1213
using Microsoft.AspNetCore.Http.Features;
@@ -53,6 +54,23 @@ private static bool IsBodylessStatusCode(HttpStatusCode statusCode) =>
5354
_ => false
5455
};
5556

57+
/// <summary>
58+
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
59+
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
60+
/// See <see cref="RequestUtilities.MakeDestinationAddress(string, PathString, QueryString)"/> for constructing a custom request Uri.
61+
/// The string parameter represents the destination URI prefix that should be used when constructing the RequestUri.
62+
/// The headers are copied by the base implementation, excluding some protocol headers like HTTP/2 pseudo headers (":authority").
63+
/// This method may be overridden to conditionally produce a response, such as for error conditions, and prevent the request from
64+
/// being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`,
65+
/// or writing to the `HttpResponse.Body` or `BodyWriter`.
66+
/// </summary>
67+
/// <param name="httpContext">The incoming request.</param>
68+
/// <param name="proxyRequest">The outgoing proxy request.</param>
69+
/// <param name="destinationPrefix">The uri prefix for the selected destination server which can be used to create the RequestUri.</param>
70+
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
71+
public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
72+
=> TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
73+
5674
/// <summary>
5775
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
5876
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
@@ -126,9 +144,24 @@ public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequ
126144
/// </summary>
127145
/// <param name="httpContext">The incoming request.</param>
128146
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
147+
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
129148
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
130149
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
131150
/// etc.</returns>
151+
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
152+
=> TransformResponseAsync(httpContext, proxyResponse);
153+
154+
/// <summary>
155+
/// A callback that is invoked when the proxied response is received. The status code and reason phrase will be copied
156+
/// to the HttpContext.Response before the callback is invoked, but may still be modified there. The headers will be
157+
/// copied to HttpContext.Response.Headers by the base implementation, excludes certain protocol headers like
158+
/// `Transfer-Encoding: chunked`.
159+
/// </summary>
160+
/// <param name="httpContext">The incoming request.</param>
161+
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
162+
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
163+
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
164+
/// etc.</returns>
132165
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
133166
{
134167
if (proxyResponse is null)
@@ -171,6 +204,16 @@ public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, H
171204
return new ValueTask<bool>(true);
172205
}
173206

207+
/// <summary>
208+
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
209+
/// copied to the HttpContext.Response by the base implementation.
210+
/// </summary>
211+
/// <param name="httpContext">The incoming request.</param>
212+
/// <param name="proxyResponse">The response from the destination.</param>
213+
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
214+
public virtual ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
215+
=> TransformResponseTrailersAsync(httpContext, proxyResponse);
216+
174217
/// <summary>
175218
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
176219
/// copied to the HttpContext.Response by the base implementation.

src/ReverseProxy/Forwarder/IHttpForwarder.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System.Net.Http;
5+
using System.Threading;
56
using System.Threading.Tasks;
67
using Microsoft.AspNetCore.Http;
78

@@ -24,4 +25,19 @@ public interface IHttpForwarder
2425
/// <returns>The result of forwarding the request and response.</returns>
2526
ValueTask<ForwarderError> SendAsync(HttpContext context, string destinationPrefix, HttpMessageInvoker httpClient,
2627
ForwarderRequestConfig requestConfig, HttpTransformer transformer);
28+
29+
/// <summary>
30+
/// Forwards the incoming request to the destination server, and the response back to the client.
31+
/// </summary>
32+
/// <param name="context">The HttpContext to forward.</param>
33+
/// <param name="destinationPrefix">The url prefix for where to forward the request to.</param>
34+
/// <param name="httpClient">The HTTP client used to forward the request.</param>
35+
/// <param name="requestConfig">Config for the outgoing request.</param>
36+
/// <param name="transformer">Request and response transforms. Use <see cref="HttpTransformer.Default"/> if
37+
/// custom transformations are not needed.</param>
38+
/// <param name="cancellationToken">A cancellation token that can be used to abort the request.</param>
39+
/// <returns>The result of forwarding the request and response.</returns>
40+
ValueTask<ForwarderError> SendAsync(HttpContext context, string destinationPrefix, HttpMessageInvoker httpClient,
41+
ForwarderRequestConfig requestConfig, HttpTransformer transformer, CancellationToken cancellationToken)
42+
=> SendAsync(context, destinationPrefix, httpClient, requestConfig, transformer);
2743
}

src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Generic;
66
using System.Linq;
77
using System.Net.Http;
8+
using System.Threading;
89
using System.Threading.Tasks;
910
using Microsoft.AspNetCore.Http;
1011
using Microsoft.AspNetCore.Http.Features;
@@ -63,11 +64,11 @@ internal StructuredTransformer(bool? copyRequestHeaders, bool? copyResponseHeade
6364
/// </summary>
6465
internal ResponseTrailersTransform[] ResponseTrailerTransforms { get; }
6566

66-
public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix)
67+
public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
6768
{
6869
if (ShouldCopyRequestHeaders.GetValueOrDefault(true))
6970
{
70-
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
71+
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, cancellationToken);
7172
}
7273

7374
if (RequestTransforms.Length == 0)
@@ -83,6 +84,7 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
8384
Path = httpContext.Request.Path,
8485
Query = new QueryTransformContext(httpContext.Request),
8586
HeadersCopied = ShouldCopyRequestHeaders.GetValueOrDefault(true),
87+
CancellationToken = cancellationToken,
8688
};
8789

8890
foreach (var requestTransform in RequestTransforms)
@@ -101,11 +103,11 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
101103
transformContext.DestinationPrefix, transformContext.Path, transformContext.Query.QueryString);
102104
}
103105

104-
public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
106+
public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
105107
{
106108
if (ShouldCopyResponseHeaders.GetValueOrDefault(true))
107109
{
108-
await base.TransformResponseAsync(httpContext, proxyResponse);
110+
await base.TransformResponseAsync(httpContext, proxyResponse, cancellationToken);
109111
}
110112

111113
if (ResponseTransforms.Length == 0)
@@ -118,6 +120,7 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
118120
HttpContext = httpContext,
119121
ProxyResponse = proxyResponse,
120122
HeadersCopied = ShouldCopyResponseHeaders.GetValueOrDefault(true),
123+
CancellationToken = cancellationToken,
121124
};
122125

123126
foreach (var responseTransform in ResponseTransforms)
@@ -128,11 +131,11 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
128131
return !transformContext.SuppressResponseBody;
129132
}
130133

131-
public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse)
134+
public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
132135
{
133136
if (ShouldCopyResponseTrailers.GetValueOrDefault(true))
134137
{
135-
await base.TransformResponseTrailersAsync(httpContext, proxyResponse);
138+
await base.TransformResponseTrailersAsync(httpContext, proxyResponse, cancellationToken);
136139
}
137140

138141
if (ResponseTrailerTransforms.Length == 0)
@@ -150,6 +153,7 @@ public override async ValueTask TransformResponseTrailersAsync(HttpContext httpC
150153
HttpContext = httpContext,
151154
ProxyResponse = proxyResponse,
152155
HeadersCopied = ShouldCopyResponseTrailers.GetValueOrDefault(true),
156+
CancellationToken = cancellationToken,
153157
};
154158

155159
foreach (var responseTrailerTransform in ResponseTrailerTransforms)

src/ReverseProxy/Transforms/RequestTransformContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System.Net.Http;
5+
using System.Threading;
56
using Microsoft.AspNetCore.Http;
67

78
namespace Yarp.ReverseProxy.Transforms;
@@ -48,4 +49,9 @@ public class RequestTransformContext
4849
/// port and path base. The 'Path' and 'Query' properties will be appended to this after the transforms have run.
4950
/// </summary>
5051
public string DestinationPrefix { get; init; } = default!;
52+
53+
/// <summary>
54+
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
55+
/// </summary>
56+
public CancellationToken CancellationToken { get; set; }
5157
}

src/ReverseProxy/Transforms/ResponseTrailersTransformContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System.Net.Http;
5+
using System.Threading;
56
using Microsoft.AspNetCore.Http;
67

78
namespace Yarp.ReverseProxy.Transforms;
@@ -27,4 +28,9 @@ public class ResponseTrailersTransformContext
2728
/// should operate on.
2829
/// </summary>
2930
public bool HeadersCopied { get; set; }
31+
32+
/// <summary>
33+
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
34+
/// </summary>
35+
public CancellationToken CancellationToken { get; set; }
3036
}

src/ReverseProxy/Transforms/ResponseTransformContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System.Net.Http;
5+
using System.Threading;
56
using Microsoft.AspNetCore.Http;
67

78
namespace Yarp.ReverseProxy.Transforms;
@@ -33,4 +34,9 @@ public class ResponseTransformContext
3334
/// Defaults to false.
3435
/// </summary>
3536
public bool SuppressResponseBody { get; set; }
37+
38+
/// <summary>
39+
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
40+
/// </summary>
41+
public CancellationToken CancellationToken { get; set; }
3642
}

src/ReverseProxy/Utilities/ActivityCancellationTokenSource.cs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ internal sealed class ActivityCancellationTokenSource : CancellationTokenSource
1919
};
2020

2121
private int _activityTimeoutMs;
22-
private CancellationTokenRegistration _linkedRegistration;
22+
private CancellationTokenRegistration _linkedRegistration1;
23+
private CancellationTokenRegistration _linkedRegistration2;
2324

2425
private ActivityCancellationTokenSource() { }
2526

@@ -28,7 +29,7 @@ public void ResetTimeout()
2829
CancelAfter(_activityTimeoutMs);
2930
}
3031

31-
public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken)
32+
public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, CancellationToken linkedToken1 = default, CancellationToken linkedToken2 = default)
3233
{
3334
if (_sharedSources.TryDequeue(out var cts))
3435
{
@@ -40,16 +41,19 @@ public static ActivityCancellationTokenSource Rent(TimeSpan activityTimeout, Can
4041
}
4142

4243
cts._activityTimeoutMs = (int)activityTimeout.TotalMilliseconds;
43-
cts._linkedRegistration = linkedToken.UnsafeRegister(_linkedTokenCancelDelegate, cts);
44+
cts._linkedRegistration1 = linkedToken1.UnsafeRegister(_linkedTokenCancelDelegate, cts);
45+
cts._linkedRegistration2 = linkedToken2.UnsafeRegister(_linkedTokenCancelDelegate, cts);
4446
cts.ResetTimeout();
4547

4648
return cts;
4749
}
4850

4951
public void Return()
5052
{
51-
_linkedRegistration.Dispose();
52-
_linkedRegistration = default;
53+
_linkedRegistration1.Dispose();
54+
_linkedRegistration1 = default;
55+
_linkedRegistration2.Dispose();
56+
_linkedRegistration2 = default;
5357

5458
if (TryReset())
5559
{

0 commit comments

Comments
 (0)