Skip to content

Commit 69fb774

Browse files
committed
Flow cancellation to transforms
1 parent 86ea977 commit 69fb774

File tree

7 files changed

+81
-15
lines changed

7 files changed

+81
-15
lines changed

src/ReverseProxy/Forwarder/HttpForwarder.cs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ public async ValueTask<ForwarderError> SendAsync(
202202
{
203203
// :: Step 5: Copy response status line Client ◄-- Proxy ◄-- Destination
204204
// :: Step 6: Copy response headers Client ◄-- Proxy ◄-- Destination
205-
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer);
205+
var copyBody = await CopyResponseStatusAndHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);
206206

207207
if (!copyBody)
208208
{
@@ -269,7 +269,7 @@ public async ValueTask<ForwarderError> SendAsync(
269269
}
270270

271271
// :: Step 8: Copy response trailer headers and finish response Client ◄-- Proxy ◄-- Destination
272-
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer);
272+
await CopyResponseTrailingHeadersAsync(destinationResponse, context, transformer, activityCancellationSource.Token);
273273

274274
if (isStreamingRequest)
275275
{
@@ -411,7 +411,7 @@ public async ValueTask<ForwarderError> SendAsync(
411411
destinationRequest.Content = requestContent;
412412

413413
// :: Step 3: Copy request headers Client --► Proxy --► Destination
414-
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix);
414+
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix, activityToken.Token);
415415

416416
// The transformer generated a response, do not forward.
417417
if (RequestUtilities.IsResponseSet(context.Response))
@@ -662,7 +662,7 @@ async ValueTask<ForwarderError> ReportErrorAsync(ForwarderError error, int statu
662662
}
663663
}
664664

665-
private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
665+
private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
666666
{
667667
context.Response.StatusCode = (int)source.StatusCode;
668668

@@ -676,7 +676,7 @@ private static ValueTask<bool> CopyResponseStatusAndHeadersAsync(HttpResponseMes
676676
}
677677

678678
// Copies headers
679-
return transformer.TransformResponseAsync(context, source);
679+
return transformer.TransformResponseAsync(context, source, cancellationToken);
680680
}
681681

682682
private async ValueTask<ForwarderError> HandleUpgradedResponse(HttpContext context, HttpResponseMessage destinationResponse,
@@ -900,10 +900,10 @@ private async ValueTask<ForwarderError> HandleResponseBodyErrorAsync(HttpContext
900900
return error;
901901
}
902902

903-
private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer)
903+
private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage source, HttpContext context, HttpTransformer transformer, CancellationToken cancellationToken)
904904
{
905905
// Copies trailers
906-
return transformer.TransformResponseTrailersAsync(context, source);
906+
return transformer.TransformResponseTrailersAsync(context, source, cancellationToken);
907907
}
908908

909909

src/ReverseProxy/Forwarder/HttpTransformer.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using System.Net.Http;
88
using System.Net.Http.Headers;
99
using System.Runtime.CompilerServices;
10+
using System.Threading;
1011
using System.Threading.Tasks;
1112
using Microsoft.AspNetCore.Http;
1213
using Microsoft.AspNetCore.Http.Features;
@@ -53,6 +54,23 @@ private static bool IsBodylessStatusCode(HttpStatusCode statusCode) =>
5354
_ => false
5455
};
5556

57+
/// <summary>
58+
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
59+
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
60+
/// See <see cref="RequestUtilities.MakeDestinationAddress(string, PathString, QueryString)"/> for constructing a custom request Uri.
61+
/// The string parameter represents the destination URI prefix that should be used when constructing the RequestUri.
62+
/// The headers are copied by the base implementation, excluding some protocol headers like HTTP/2 pseudo headers (":authority").
63+
/// This method may be overridden to conditionally produce a response, such as for error conditions, and prevent the request from
64+
/// being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`,
65+
/// or writing to the `HttpResponse.Body` or `BodyWriter`.
66+
/// </summary>
67+
/// <param name="httpContext">The incoming request.</param>
68+
/// <param name="proxyRequest">The outgoing proxy request.</param>
69+
/// <param name="destinationPrefix">The uri prefix for the selected destination server which can be used to create the RequestUri.</param>
70+
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
71+
public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
72+
=> TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
73+
5674
/// <summary>
5775
/// A callback that is invoked prior to sending the proxied request. All HttpRequestMessage fields are
5876
/// initialized except RequestUri, which will be initialized after the callback if no value is provided.
@@ -126,9 +144,24 @@ public virtual ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequ
126144
/// </summary>
127145
/// <param name="httpContext">The incoming request.</param>
128146
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
147+
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
129148
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
130149
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
131150
/// etc.</returns>
151+
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
152+
=> TransformResponseAsync(httpContext, proxyResponse);
153+
154+
/// <summary>
155+
/// A callback that is invoked when the proxied response is received. The status code and reason phrase will be copied
156+
/// to the HttpContext.Response before the callback is invoked, but may still be modified there. The headers will be
157+
/// copied to HttpContext.Response.Headers by the base implementation, excludes certain protocol headers like
158+
/// `Transfer-Encoding: chunked`.
159+
/// </summary>
160+
/// <param name="httpContext">The incoming request.</param>
161+
/// <param name="proxyResponse">The response from the destination. This can be null if the destination did not respond.</param>
162+
/// <returns>A bool indicating if the response should be proxied to the client or not. A derived implementation
163+
/// that returns false may send an alternate response inline or return control to the caller for it to retry, respond,
164+
/// etc.</returns>
132165
public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
133166
{
134167
if (proxyResponse is null)
@@ -171,6 +204,16 @@ public virtual ValueTask<bool> TransformResponseAsync(HttpContext httpContext, H
171204
return new ValueTask<bool>(true);
172205
}
173206

207+
/// <summary>
208+
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
209+
/// copied to the HttpContext.Response by the base implementation.
210+
/// </summary>
211+
/// <param name="httpContext">The incoming request.</param>
212+
/// <param name="proxyResponse">The response from the destination.</param>
213+
/// <param name="cancellationToken">Indicates that the request is being canceled.</param>
214+
public virtual ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
215+
=> TransformResponseTrailersAsync(httpContext, proxyResponse);
216+
174217
/// <summary>
175218
/// A callback that is invoked after the response body to modify trailers, if supported. The trailers will be
176219
/// copied to the HttpContext.Response by the base implementation.

src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Generic;
66
using System.Linq;
77
using System.Net.Http;
8+
using System.Threading;
89
using System.Threading.Tasks;
910
using Microsoft.AspNetCore.Http;
1011
using Microsoft.AspNetCore.Http.Features;
@@ -63,11 +64,11 @@ internal StructuredTransformer(bool? copyRequestHeaders, bool? copyResponseHeade
6364
/// </summary>
6465
internal ResponseTrailersTransform[] ResponseTrailerTransforms { get; }
6566

66-
public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix)
67+
public override async ValueTask TransformRequestAsync(HttpContext httpContext, HttpRequestMessage proxyRequest, string destinationPrefix, CancellationToken cancellationToken)
6768
{
6869
if (ShouldCopyRequestHeaders.GetValueOrDefault(true))
6970
{
70-
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
71+
await base.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, cancellationToken);
7172
}
7273

7374
if (RequestTransforms.Length == 0)
@@ -83,6 +84,7 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
8384
Path = httpContext.Request.Path,
8485
Query = new QueryTransformContext(httpContext.Request),
8586
HeadersCopied = ShouldCopyRequestHeaders.GetValueOrDefault(true),
87+
CancellationToken = cancellationToken,
8688
};
8789

8890
foreach (var requestTransform in RequestTransforms)
@@ -101,11 +103,11 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
101103
transformContext.DestinationPrefix, transformContext.Path, transformContext.Query.QueryString);
102104
}
103105

104-
public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse)
106+
public override async ValueTask<bool> TransformResponseAsync(HttpContext httpContext, HttpResponseMessage? proxyResponse, CancellationToken cancellationToken)
105107
{
106108
if (ShouldCopyResponseHeaders.GetValueOrDefault(true))
107109
{
108-
await base.TransformResponseAsync(httpContext, proxyResponse);
110+
await base.TransformResponseAsync(httpContext, proxyResponse, cancellationToken);
109111
}
110112

111113
if (ResponseTransforms.Length == 0)
@@ -118,6 +120,7 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
118120
HttpContext = httpContext,
119121
ProxyResponse = proxyResponse,
120122
HeadersCopied = ShouldCopyResponseHeaders.GetValueOrDefault(true),
123+
CancellationToken = cancellationToken,
121124
};
122125

123126
foreach (var responseTransform in ResponseTransforms)
@@ -128,11 +131,11 @@ public override async ValueTask<bool> TransformResponseAsync(HttpContext httpCon
128131
return !transformContext.SuppressResponseBody;
129132
}
130133

131-
public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse)
134+
public override async ValueTask TransformResponseTrailersAsync(HttpContext httpContext, HttpResponseMessage proxyResponse, CancellationToken cancellationToken)
132135
{
133136
if (ShouldCopyResponseTrailers.GetValueOrDefault(true))
134137
{
135-
await base.TransformResponseTrailersAsync(httpContext, proxyResponse);
138+
await base.TransformResponseTrailersAsync(httpContext, proxyResponse, cancellationToken);
136139
}
137140

138141
if (ResponseTrailerTransforms.Length == 0)
@@ -150,6 +153,7 @@ public override async ValueTask TransformResponseTrailersAsync(HttpContext httpC
150153
HttpContext = httpContext,
151154
ProxyResponse = proxyResponse,
152155
HeadersCopied = ShouldCopyResponseTrailers.GetValueOrDefault(true),
156+
CancellationToken = cancellationToken,
153157
};
154158

155159
foreach (var responseTrailerTransform in ResponseTrailerTransforms)

src/ReverseProxy/Transforms/RequestTransformContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System.Net.Http;
5+
using System.Threading;
56
using Microsoft.AspNetCore.Http;
67

78
namespace Yarp.ReverseProxy.Transforms;
@@ -48,4 +49,9 @@ public class RequestTransformContext
4849
/// port and path base. The 'Path' and 'Query' properties will be appended to this after the transforms have run.
4950
/// </summary>
5051
public string DestinationPrefix { get; init; } = default!;
52+
53+
/// <summary>
54+
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
55+
/// </summary>
56+
public CancellationToken CancellationToken { get; set; }
5157
}

src/ReverseProxy/Transforms/ResponseTrailersTransformContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System.Net.Http;
5+
using System.Threading;
56
using Microsoft.AspNetCore.Http;
67

78
namespace Yarp.ReverseProxy.Transforms;
@@ -27,4 +28,9 @@ public class ResponseTrailersTransformContext
2728
/// should operate on.
2829
/// </summary>
2930
public bool HeadersCopied { get; set; }
31+
32+
/// <summary>
33+
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
34+
/// </summary>
35+
public CancellationToken CancellationToken { get; set; }
3036
}

src/ReverseProxy/Transforms/ResponseTransformContext.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
using System.Net.Http;
5+
using System.Threading;
56
using Microsoft.AspNetCore.Http;
67

78
namespace Yarp.ReverseProxy.Transforms;
@@ -33,4 +34,9 @@ public class ResponseTransformContext
3334
/// Defaults to false.
3435
/// </summary>
3536
public bool SuppressResponseBody { get; set; }
37+
38+
/// <summary>
39+
/// A <see cref="CancellationToken"/> indicating that the request is being aborted.
40+
/// </summary>
41+
public CancellationToken CancellationToken { get; set; }
3642
}

test/ReverseProxy.Tests/Transforms/Builder/TransformBuilderTests.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Generic;
66
using System.Linq;
77
using System.Net.Http;
8+
using System.Threading;
89
using System.Threading.Tasks;
910
using Microsoft.AspNetCore.Http;
1011
using Microsoft.AspNetCore.HttpOverrides;
@@ -307,7 +308,7 @@ public async Task UseOriginalHost(bool? useOriginalHost, bool? copyHeaders)
307308
httpContext.Request.Host = new HostString("StartHost");
308309
var proxyRequest = new HttpRequestMessage();
309310
var destinationPrefix = "http://destinationhost:9090/path";
310-
await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
311+
await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None);
311312

312313
if (useOriginalHost.GetValueOrDefault(false))
313314
{
@@ -373,7 +374,7 @@ public async Task UseCustomHost(bool? useOriginalHost, bool? copyHeaders)
373374
var proxyRequest = new HttpRequestMessage();
374375
var destinationPrefix = "http://destinationhost:9090/path";
375376

376-
await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix);
377+
await results.TransformRequestAsync(httpContext, proxyRequest, destinationPrefix, CancellationToken.None);
377378

378379
Assert.Equal("CustomHost", proxyRequest.Headers.Host);
379380
}

0 commit comments

Comments
 (0)