From 9bf4d831de646c0ca31948e483221ee89a221166 Mon Sep 17 00:00:00 2001 From: "Chris Ross (ASP.NET)" Date: Wed, 21 Jun 2023 21:55:42 +0000 Subject: [PATCH] Merged PR 31986: [2.0] Forwarder blame improvements Passive health checks rely on the HttpForwarder to categorize failures so the check can understand if it was caused by the client, destination, or some other reason. Incorrect attribution can allow the client to induce an error and cause the destination to be marked unhealthy. Scenarios: 1. Blame the request body (client/destination) for timeouts. Today if the ActivityTimeout triggers while uploading the request body that will be reported as RequestTimeout and blamed on the destination. However, if the proxy was waiting for the client to send more data, that should be blamed on the client instead (RequestBodyClient). 2. A malicious client may pretend to be a gRPC request by setting the expected Content-Type, but then getting proxied to a HTTP/1.1 destination. This causes the proxy request data rate and size limit to be disabled, but then you can induce an error against the destination's size limit that would be blamed on the destination. Fix: Delay disabling the limits until we're sure we were assigned to an HTTP/2 connection. 3. I've included https://github.com/microsoft/reverse-proxy/pull/2119. This works around a Kestrel bug which caused a body length exception that we blamed on the destination. This fix is already in main. --- src/ReverseProxy/Forwarder/HttpForwarder.cs | 126 ++++++------------ src/ReverseProxy/Forwarder/StreamCopier.cs | 8 +- .../Forwarder/StreamCopyHttpContent.cs | 79 +++++++++-- .../Forwarder/HttpForwarderTests.cs | 34 +++-- .../Forwarder/StreamCopierTests.cs | 5 +- .../Forwarder/StreamCopyHttpContentTests.cs | 39 +++--- 6 files changed, 162 insertions(+), 129 deletions(-) diff --git a/src/ReverseProxy/Forwarder/HttpForwarder.cs b/src/ReverseProxy/Forwarder/HttpForwarder.cs index d290c334b..01046f1fc 100644 --- a/src/ReverseProxy/Forwarder/HttpForwarder.cs +++ b/src/ReverseProxy/Forwarder/HttpForwarder.cs @@ -416,7 +416,7 @@ public async ValueTask SendAsync( // :: Step 2: Setup copy of request body (background) Client --► Proxy --► Destination // Note that we must do this before step (3) because step (3) may also add headers to the HttpContent that we set up here. - var requestContent = SetupRequestBodyCopy(context.Request, isStreamingRequest, activityToken); + var requestContent = SetupRequestBodyCopy(context, isStreamingRequest, activityToken); destinationRequest.Content = requestContent; // :: Step 3: Copy request headers Client --► Proxy --► Destination @@ -496,12 +496,13 @@ private void FixupUpgradeRequestHeaders(HttpContext context, HttpRequestMessage // else not an upgrade, or H2->H2, no changes needed } - private StreamCopyHttpContent? SetupRequestBodyCopy(HttpRequest request, bool isStreamingRequest, ActivityCancellationTokenSource activityToken) + private StreamCopyHttpContent? SetupRequestBodyCopy(HttpContext context, bool isStreamingRequest, ActivityCancellationTokenSource activityToken) { // If we generate an HttpContent without a Content-Length then for HTTP/1.1 HttpClient will add a Transfer-Encoding: chunked header // even if it's a GET request. Some servers reject requests containing a Transfer-Encoding header if they're not expecting a body. // Try to be as specific as possible about the client's intent to send a body. The one thing we don't want to do is to start // reading the body early because that has side-effects like 100-continue. + var request = context.Request; var hasBody = true; var contentLength = request.Headers.ContentLength; var method = request.Method; @@ -512,10 +513,11 @@ private void FixupUpgradeRequestHeaders(HttpContext context, HttpRequestMessage // 5.0 servers provide a definitive answer for us. hasBody = canHaveBodyFeature.CanHaveBody; - // TODO: Kestrel bug, this shouldn't be true for ExtendedConnect. -#if NET7_0_OR_GREATER +#if NET7_0 + // TODO: Kestrel 7.0 bug only, hasBody shouldn't be true for ExtendedConnect. + // https://github.com/dotnet/aspnetcore/issues/46002 Fixed in 8.0 var connectFeature = request.HttpContext.Features.Get(); - if (connectFeature?.Protocol != null) + if (connectFeature?.IsExtendedConnect == true) { hasBody = false; } @@ -560,31 +562,13 @@ private void FixupUpgradeRequestHeaders(HttpContext context, HttpRequestMessage if (hasBody) { - if (isStreamingRequest) - { - DisableMinRequestBodyDataRateAndMaxRequestBodySize(request.HttpContext); - } - - // Note on `autoFlushHttpClientOutgoingStream: isStreamingRequest`: - // The.NET Core HttpClient stack keeps its own buffers on top of the underlying outgoing connection socket. - // We flush those buffers down to the socket on every write when this is set, - // but it does NOT result in calls to flush on the underlying socket. - // This is necessary because we proxy http2 transparently, - // and we are deliberately unaware of packet structure used e.g. in gRPC duplex channels. - // Because the sockets aren't flushed, the perf impact of this choice is expected to be small. - // Future: It may be wise to set this to true for *all* http2 incoming requests, - // but for now, out of an abundance of caution, we only do it for requests that look like gRPC. - return new StreamCopyHttpContent( - request: request, - autoFlushHttpClientOutgoingStream: isStreamingRequest, - clock: _clock, - activityToken); + return new StreamCopyHttpContent(context, isStreamingRequest, _clock, _logger, activityToken); } return null; } - private ForwarderError HandleRequestBodyFailure(HttpContext context, StreamCopyResult requestBodyCopyResult, Exception requestBodyException, Exception additionalException) + private ForwarderError HandleRequestBodyFailure(HttpContext context, StreamCopyResult requestBodyCopyResult, Exception requestBodyException, Exception additionalException, bool timedOut) { ForwarderError requestBodyError; int statusCode; @@ -593,19 +577,12 @@ private ForwarderError HandleRequestBodyFailure(HttpContext context, StreamCopyR // Failed while trying to copy the request body from the client. It's ambiguous if the request or response failed first. case StreamCopyResult.InputError: requestBodyError = ForwarderError.RequestBodyClient; - statusCode = StatusCodes.Status400BadRequest; + statusCode = timedOut ? StatusCodes.Status408RequestTimeout : StatusCodes.Status400BadRequest; break; // Failed while trying to copy the request body to the destination. It's ambiguous if the request or response failed first. case StreamCopyResult.OutputError: requestBodyError = ForwarderError.RequestBodyDestination; - statusCode = StatusCodes.Status502BadGateway; - break; - // Canceled while trying to copy the request body, either due to a client disconnect or a timeout. This probably caused the response to fail as a secondary error. - case StreamCopyResult.Canceled: - requestBodyError = ForwarderError.RequestBodyCanceled; - // Timeouts (504s) are handled at the SendAsync call site. - // The request body should only be canceled by the RequestAborted token. - statusCode = StatusCodes.Status502BadGateway; + statusCode = timedOut ? StatusCodes.Status504GatewayTimeout : StatusCodes.Status502BadGateway; break; default: throw new NotImplementedException(requestBodyCopyResult.ToString()); @@ -630,33 +607,46 @@ private ForwarderError HandleRequestBodyFailure(HttpContext context, StreamCopyR private async ValueTask HandleRequestFailureAsync(HttpContext context, StreamCopyHttpContent? requestContent, Exception requestException, HttpTransformer transformer, ActivityCancellationTokenSource requestCancellationSource, bool failedDuringRequestCreation) { - if (requestException is OperationCanceledException) + var triedRequestBody = requestContent?.ConsumptionTask.IsCompleted == true; + + if (requestCancellationSource.CancelledByLinkedToken) { - if (requestCancellationSource.CancelledByLinkedToken) + var requestBodyCanceled = false; + if (triedRequestBody) { - // Either the client went away (HttpContext.RequestAborted) or the CancellationToken provided to SendAsync was signaled. - return await ReportErrorAsync(ForwarderError.RequestCanceled, StatusCodes.Status502BadGateway); - } - else - { - Debug.Assert(requestCancellationSource.IsCancellationRequested || requestException.ToString().Contains("ConnectTimeout"), requestException.ToString()); - return await ReportErrorAsync(ForwarderError.RequestTimedOut, StatusCodes.Status504GatewayTimeout); + var (requestBodyCopyResult, requestBodyException) = requestContent!.ConsumptionTask.Result; + requestBodyCanceled = requestBodyCopyResult == StreamCopyResult.Canceled; + if (requestBodyCanceled) + { + requestException = new AggregateException(requestException, requestBodyException!); + } } + // Either the client went away (HttpContext.RequestAborted) or the CancellationToken provided to SendAsync was signaled. + return await ReportErrorAsync(requestBodyCanceled ? ForwarderError.RequestBodyCanceled : ForwarderError.RequestCanceled, + context.RequestAborted.IsCancellationRequested ? StatusCodes.Status400BadRequest : StatusCodes.Status502BadGateway); } // Check for request body errors, these may have triggered the response error. - if (requestContent?.ConsumptionTask.IsCompleted == true) + if (triedRequestBody) { - var (requestBodyCopyResult, requestBodyException) = requestContent.ConsumptionTask.Result; + var (requestBodyCopyResult, requestBodyException) = requestContent!.ConsumptionTask.Result; - if (requestBodyCopyResult != StreamCopyResult.Success) + if (requestBodyCopyResult is StreamCopyResult.InputError or StreamCopyResult.OutputError) { - var error = HandleRequestBodyFailure(context, requestBodyCopyResult, requestBodyException!, requestException); + var error = HandleRequestBodyFailure(context, requestBodyCopyResult, requestBodyException!, requestException, + timedOut: requestCancellationSource.IsCancellationRequested); await transformer.TransformResponseAsync(context, proxyResponse: null, requestCancellationSource.Token); return error; } } + if (requestException is OperationCanceledException) + { + Debug.Assert(requestCancellationSource.IsCancellationRequested || requestException.ToString().Contains("ConnectTimeout"), requestException.ToString()); + + return await ReportErrorAsync(ForwarderError.RequestTimedOut, StatusCodes.Status504GatewayTimeout); + } + // We couldn't communicate with the destination. return await ReportErrorAsync(failedDuringRequestCreation ? ForwarderError.RequestCreation : ForwarderError.Request, StatusCodes.Status502BadGateway); @@ -870,7 +860,7 @@ private ForwarderError FixupUpgradeResponseHeaders(HttpContext context, HttpResp return (StreamCopyResult.Success, null); } - private async ValueTask HandleResponseBodyErrorAsync(HttpContext context, StreamCopyHttpContent? requestContent, StreamCopyResult responseBodyCopyResult, Exception responseBodyException, CancellationTokenSource requestCancellationSource) + private async ValueTask HandleResponseBodyErrorAsync(HttpContext context, StreamCopyHttpContent? requestContent, StreamCopyResult responseBodyCopyResult, Exception responseBodyException, ActivityCancellationTokenSource requestCancellationSource) { if (requestContent is not null && requestContent.Started) { @@ -884,9 +874,10 @@ private async ValueTask HandleResponseBodyErrorAsync(HttpContext var (requestBodyCopyResult, requestBodyError) = await requestContent.ConsumptionTask; // Check for request body errors, these may have triggered the response error. - if (alreadyFinished && requestBodyCopyResult != StreamCopyResult.Success) + if (alreadyFinished && requestBodyCopyResult is StreamCopyResult.InputError or StreamCopyResult.OutputError) { - return HandleRequestBodyFailure(context, requestBodyCopyResult, requestBodyError!, responseBodyException); + return HandleRequestBodyFailure(context, requestBodyCopyResult, requestBodyError!, responseBodyException, + timedOut: requestCancellationSource.IsCancellationRequested && !requestCancellationSource.CancelledByLinkedToken); } } @@ -920,41 +911,6 @@ private static ValueTask CopyResponseTrailingHeadersAsync(HttpResponseMessage so return transformer.TransformResponseTrailersAsync(context, source, cancellationToken); } - - /// - /// Disable some ASP .NET Core server limits so that we can handle long-running gRPC requests unconstrained. - /// Note that the gRPC server implementation on ASP .NET Core does the same for client-streaming and duplex methods. - /// Since in Gateway we have no way to determine if the current request requires client-streaming or duplex comm, - /// we do this for *all* incoming requests that look like they might be gRPC. - /// - /// - /// Inspired on - /// . - /// - private void DisableMinRequestBodyDataRateAndMaxRequestBodySize(HttpContext httpContext) - { - var minRequestBodyDataRateFeature = httpContext.Features.Get(); - if (minRequestBodyDataRateFeature is not null) - { - minRequestBodyDataRateFeature.MinDataRate = null; - } - - var maxRequestBodySizeFeature = httpContext.Features.Get(); - if (maxRequestBodySizeFeature is not null) - { - if (!maxRequestBodySizeFeature.IsReadOnly) - { - maxRequestBodySizeFeature.MaxRequestBodySize = null; - } - else - { - // IsReadOnly could be true if middleware has already started reading the request body - // In that case we can't disable the max request body size for the request stream - _logger.LogWarning("Unable to disable max request body size."); - } - } - } - private void ReportProxyError(HttpContext context, ForwarderError error, Exception ex) { context.Features.Set(new ForwarderErrorFeature(error, ex)); diff --git a/src/ReverseProxy/Forwarder/StreamCopier.cs b/src/ReverseProxy/Forwarder/StreamCopier.cs index f34a0ff8e..0bea0b021 100644 --- a/src/ReverseProxy/Forwarder/StreamCopier.cs +++ b/src/ReverseProxy/Forwarder/StreamCopier.cs @@ -124,9 +124,13 @@ internal static class StreamCopier telemetry?.AfterWrite(); } - var result = ex is OperationCanceledException ? StreamCopyResult.Canceled : - (read == 0 ? StreamCopyResult.InputError : StreamCopyResult.OutputError); + if (activityToken.CancelledByLinkedToken) + { + return (StreamCopyResult.Canceled, ex); + } + // If the activity timeout triggered while reading or writing, blame the sender or receiver. + var result = read == 0 ? StreamCopyResult.InputError : StreamCopyResult.OutputError; return (result, ex); } finally diff --git a/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs b/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs index 7fd1a541d..12f433f94 100644 --- a/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs +++ b/src/ReverseProxy/Forwarder/StreamCopyHttpContent.cs @@ -9,6 +9,9 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.Extensions.Logging; using Yarp.ReverseProxy.Utilities; namespace Yarp.ReverseProxy.Forwarder; @@ -39,21 +42,22 @@ namespace Yarp.ReverseProxy.Forwarder; /// internal sealed class StreamCopyHttpContent : HttpContent { - private readonly HttpRequest _request; + private readonly HttpContext _context; // HttpClient's machinery keeps an internal buffer that doesn't get flushed to the socket on every write. // Some protocols (e.g. gRPC) may rely on specific bytes being sent, and HttpClient's buffering would prevent it. - private readonly bool _autoFlushHttpClientOutgoingStream; + private bool _isStreamingRequest; private readonly IClock _clock; + private readonly ILogger _logger; private readonly ActivityCancellationTokenSource _activityToken; private readonly TaskCompletionSource<(StreamCopyResult, Exception?)> _tcs = new(TaskCreationOptions.RunContinuationsAsynchronously); private int _started; - public StreamCopyHttpContent(HttpRequest request, bool autoFlushHttpClientOutgoingStream, IClock clock, ActivityCancellationTokenSource activityToken) + public StreamCopyHttpContent(HttpContext context, bool isStreamingRequest, IClock clock, ILogger logger, ActivityCancellationTokenSource activityToken) { - _request = request ?? throw new ArgumentNullException(nameof(request)); - _autoFlushHttpClientOutgoingStream = autoFlushHttpClientOutgoingStream; + _context = context ?? throw new ArgumentNullException(nameof(context)); + _isStreamingRequest = isStreamingRequest; _clock = clock ?? throw new ArgumentNullException(nameof(clock)); - + _logger = logger; _activityToken = activityToken; } @@ -137,11 +141,22 @@ protected override async Task SerializeToStreamAsync(Stream stream, TransportCon // _cancellation will be the same as cancellationToken for HTTP/1.1, so we can avoid the overhead of linking them CancellationTokenSource? linkedCts = null; - if (_activityToken.Token != cancellationToken) + if (_activityToken.Token == cancellationToken) + { + // We're talking to the destination via HTTP/1.1, so this can't be a streaming gRPC request. + _isStreamingRequest = false; + // TODO: Log if _isStreamingRequest is true? Something went wrong with protocol selection. + } + else { Debug.Assert(cancellationToken.CanBeCanceled); linkedCts = CancellationTokenSource.CreateLinkedTokenSource(_activityToken.Token, cancellationToken); cancellationToken = linkedCts.Token; + + if (_isStreamingRequest) + { + DisableMinRequestBodyDataRateAndMaxRequestBodySize(_context); + } } try @@ -163,8 +178,20 @@ protected override async Task SerializeToStreamAsync(Stream stream, TransportCon return; } - // Check that the content-length matches the request body size. This can be removed in .NET 7 now that SocketsHttpHandler enforces this: https://github.com/dotnet/runtime/issues/62258. - var (result, error) = await StreamCopier.CopyAsync(isRequest: true, _request.Body, stream, Headers.ContentLength ?? StreamCopier.UnknownLength, _clock, _activityToken, _autoFlushHttpClientOutgoingStream, cancellationToken); + // Check that the content-length matches the request body size. This can be removed in .NET 7 now that SocketsHttpHandler + // enforces this: https://github.com/dotnet/runtime/issues/62258. + // + // Note on `_isStreamingRequest`: + // The.NET Core HttpClient stack keeps its own buffers on top of the underlying outgoing connection socket. + // We flush those buffers down to the socket on every write when this is set, + // but it does NOT result in calls to flush on the underlying socket. + // This is necessary because we proxy http2 transparently, + // and we are deliberately unaware of packet structure used e.g. in gRPC duplex channels. + // Because the sockets aren't flushed, the perf impact of this choice is expected to be small. + // Future: It may be wise to set this to true for *all* http2 incoming requests, + // but for now, out of an abundance of caution, we only do it for requests that look like gRPC. + var (result, error) = await StreamCopier.CopyAsync(isRequest: true, _context.Request.Body, stream, + Headers.ContentLength ?? StreamCopier.UnknownLength, _clock, _activityToken, _isStreamingRequest, cancellationToken); _tcs.TrySetResult((result, error)); // Check for errors that weren't the result of the destination failing. @@ -199,4 +226,38 @@ protected override bool TryComputeLength(out long length) length = -1; return false; } + + /// + /// Disable some ASP .NET Core server limits so that we can handle long-running gRPC requests unconstrained. + /// Note that the gRPC server implementation on ASP .NET Core does the same for client-streaming and duplex methods. + /// Since in Gateway we have no way to determine if the current request requires client-streaming or duplex comm, + /// we do this for *all* incoming requests that look like they might be gRPC. + /// + /// + /// Inspired on + /// . + /// + private void DisableMinRequestBodyDataRateAndMaxRequestBodySize(HttpContext httpContext) + { + var minRequestBodyDataRateFeature = httpContext.Features.Get(); + if (minRequestBodyDataRateFeature is not null) + { + minRequestBodyDataRateFeature.MinDataRate = null; + } + + var maxRequestBodySizeFeature = httpContext.Features.Get(); + if (maxRequestBodySizeFeature is not null) + { + if (!maxRequestBodySizeFeature.IsReadOnly) + { + maxRequestBodySizeFeature.MaxRequestBodySize = null; + } + else + { + // IsReadOnly could be true if middleware has already started reading the request body + // In that case we can't disable the max request body size for the request stream + _logger.LogWarning("Unable to disable max request body size."); + } + } + } } diff --git a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs index ed4e0861d..2fd5c5e43 100644 --- a/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs @@ -699,8 +699,8 @@ public async Task UpgradableRequest_CancelsIfIdle() Assert.Equal(StatusCodes.Status101SwitchingProtocols, httpContext.Response.StatusCode); // When both are idle it's a race which gets reported as canceled first. - Assert.True(ForwarderError.UpgradeRequestCanceled == result - || ForwarderError.UpgradeResponseCanceled == result); + Assert.True(ForwarderError.UpgradeRequestClient == result + || ForwarderError.UpgradeResponseDestination == result); events.AssertContainProxyStages(upgrade: true); } @@ -1479,7 +1479,7 @@ public async Task RequestConnectTimedOut_Returns504() } [Fact] - public async Task RequestCanceled_Returns502() + public async Task RequestCanceled_Returns400() { var events = TestEventListener.Collect(); @@ -1503,7 +1503,7 @@ public async Task RequestCanceled_Returns502() var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client); Assert.Equal(ForwarderError.RequestCanceled, proxyError); - Assert.Equal(StatusCodes.Status502BadGateway, httpContext.Response.StatusCode); + Assert.Equal(StatusCodes.Status400BadRequest, httpContext.Response.StatusCode); Assert.Equal(0, proxyResponseStream.Length); var errorFeature = httpContext.Features.Get(); Assert.Equal(ForwarderError.RequestCanceled, errorFeature.Error); @@ -1614,7 +1614,7 @@ public async Task RequestWithBody_KeptAliveByActivity() } [Fact] - public async Task RequestWithBodyCanceled_Returns502() + public async Task RequestWithBodyCanceled_Returns400() { var events = TestEventListener.Collect(); @@ -1640,7 +1640,7 @@ public async Task RequestWithBodyCanceled_Returns502() var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client); Assert.Equal(ForwarderError.RequestCanceled, proxyError); - Assert.Equal(StatusCodes.Status502BadGateway, httpContext.Response.StatusCode); + Assert.Equal(StatusCodes.Status400BadRequest, httpContext.Response.StatusCode); Assert.Equal(0, proxyResponseStream.Length); var errorFeature = httpContext.Features.Get(); Assert.Equal(ForwarderError.RequestCanceled, errorFeature.Error); @@ -1762,7 +1762,7 @@ public async Task RequestBodyCanceledBeforeResponseError_Returns502() var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client); Assert.Equal(ForwarderError.RequestBodyCanceled, proxyError); - Assert.Equal(StatusCodes.Status502BadGateway, httpContext.Response.StatusCode); + Assert.Equal(StatusCodes.Status400BadRequest, httpContext.Response.StatusCode); Assert.Equal(0, proxyResponseStream.Length); var errorFeature = httpContext.Features.Get(); Assert.Equal(ForwarderError.RequestBodyCanceled, errorFeature.Error); @@ -1993,6 +1993,7 @@ public async Task ResponseBodyCancelledAfterStart_Aborted() httpContext.Features.Set(responseBody); var destinationPrefix = "https://localhost:123/"; + var cts = new CancellationTokenSource(); var sut = CreateProxy(); var client = MockHttpHandler.CreateClient( (HttpRequestMessage request, CancellationToken cancellationToken) => @@ -2002,14 +2003,17 @@ public async Task ResponseBodyCancelledAfterStart_Aborted() Content = new StreamContent(new CallbackReadStream((_, _) => { responseBody.HasStarted = true; - throw new TaskCanceledException(); + cts.Cancel(); + cts.Token.ThrowIfCancellationRequested(); + throw new NotImplementedException(); })) }; message.Headers.AcceptRanges.Add("bytes"); return Task.FromResult(message); }); - var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client); + var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, + HttpTransformer.Empty, cts.Token); Assert.Equal(ForwarderError.ResponseBodyCanceled, proxyError); Assert.Equal(StatusCodes.Status200OK, httpContext.Response.StatusCode); @@ -2017,7 +2021,7 @@ public async Task ResponseBodyCancelledAfterStart_Aborted() Assert.Equal("bytes", httpContext.Response.Headers[HeaderNames.AcceptRanges]); var errorFeature = httpContext.Features.Get(); Assert.Equal(ForwarderError.ResponseBodyCanceled, errorFeature.Error); - Assert.IsType(errorFeature.Exception); + Assert.IsType(errorFeature.Exception); AssertProxyStartFailedStop(events, destinationPrefix, httpContext.Response.StatusCode, errorFeature.Error); events.AssertContainProxyStages(hasRequestContent: false); @@ -2732,9 +2736,13 @@ public async Task ForwarderCancellations_CancellationsAreVisibleInTransforms(Can ? ForwarderError.RequestTimedOut : ForwarderError.RequestCanceled; - var expectedStatusCode = cancellationScenario == CancellationScenario.ActivityTimeout - ? StatusCodes.Status504GatewayTimeout - : StatusCodes.Status502BadGateway; + var expectedStatusCode = cancellationScenario switch + { + CancellationScenario.ActivityTimeout => StatusCodes.Status504GatewayTimeout, + CancellationScenario.RequestAborted => StatusCodes.Status400BadRequest, + CancellationScenario.ManualCancellationToken => StatusCodes.Status502BadGateway, + _ => throw new NotImplementedException(cancellationScenario.ToString()), + }; Assert.Equal(expectedError, proxyError); Assert.Equal(expectedStatusCode, httpContext.Response.StatusCode); diff --git a/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs b/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs index af2eb16cd..4c0b4d039 100644 --- a/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/StreamCopierTests.cs @@ -102,8 +102,9 @@ public async Task Cancelled_Reported(bool isRequest) var source = new MemoryStream(new byte[10]); var destination = new MemoryStream(); - using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); - cts.Cancel(); + var requestCts = new CancellationTokenSource(); + using var cts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), requestCts.Token); + requestCts.Cancel(); var (result, error) = await StreamCopier.CopyAsync(isRequest, source, destination, StreamCopier.UnknownLength, new ManualClock(), cts, cts.Token); Assert.Equal(StreamCopyResult.Canceled, result); Assert.IsAssignableFrom(error); diff --git a/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs b/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs index 0c6e88bc0..3c5e8852d 100644 --- a/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs +++ b/test/ReverseProxy.Tests/Forwarder/StreamCopyHttpContentTests.cs @@ -13,19 +13,22 @@ using Yarp.Tests.Common; using Yarp.ReverseProxy.Utilities; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using System.Xml.Linq; +using Microsoft.Extensions.Logging.Abstractions; namespace Yarp.ReverseProxy.Forwarder.Tests; public class StreamCopyHttpContentTests { - private static StreamCopyHttpContent CreateContent(HttpRequest request = null, bool autoFlushHttpClientOutgoingStream = false, IClock clock = null, ActivityCancellationTokenSource contentCancellation = null) + private static StreamCopyHttpContent CreateContent(HttpContext context = null, bool isStreamingRequest = false, IClock clock = null, ActivityCancellationTokenSource contentCancellation = null) { - request ??= new DefaultHttpContext().Request; + context ??= new DefaultHttpContext(); clock ??= new Clock(); contentCancellation ??= ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); - return new StreamCopyHttpContent(request, autoFlushHttpClientOutgoingStream, clock, contentCancellation); + return new StreamCopyHttpContent(context, isStreamingRequest, clock, NullLogger.Instance, contentCancellation); } [Fact] @@ -34,11 +37,11 @@ public async Task CopyToAsync_InvokesStreamCopier() const int SourceSize = (128 * 1024) - 3; var sourceBytes = Enumerable.Range(0, SourceSize).Select(i => (byte)(i % 256)).ToArray(); - var request = new DefaultHttpContext().Request; - request.Body = new MemoryStream(sourceBytes); + var context = new DefaultHttpContext(); + context.Request.Body = new MemoryStream(sourceBytes); var destination = new MemoryStream(); - var sut = CreateContent(request); + var sut = CreateContent(context); Assert.False(sut.ConsumptionTask.IsCompleted); Assert.False(sut.Started); @@ -68,12 +71,12 @@ public async Task CopyToAsync_AutoFlushing(bool autoFlush) expectedFlushes++; var sourceBytes = Enumerable.Range(0, SourceSize).Select(i => (byte)(i % 256)).ToArray(); - var request = new DefaultHttpContext().Request; - request.Body = new MemoryStream(sourceBytes); + var context = new DefaultHttpContext(); + context.Request.Body = new MemoryStream(sourceBytes); var destination = new MemoryStream(); var flushCountingDestination = new FlushCountingStream(destination); - var sut = CreateContent(request, autoFlushHttpClientOutgoingStream: autoFlush); + var sut = CreateContent(context, autoFlush); Assert.False(sut.ConsumptionTask.IsCompleted); Assert.False(sut.Started); @@ -91,11 +94,11 @@ public async Task CopyToAsync_AsyncSequencing() var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var source = new Mock(); source.Setup(s => s.ReadAsync(It.IsAny>(), It.IsAny())).Returns(() => new ValueTask(tcs.Task)); - var request = new DefaultHttpContext().Request; - request.Body = source.Object; + var context = new DefaultHttpContext(); + context.Request.Body = source.Object; var destination = new MemoryStream(); - var sut = CreateContent(request); + var sut = CreateContent(context); Assert.False(sut.ConsumptionTask.IsCompleted); Assert.False(sut.Started); @@ -151,12 +154,12 @@ public async Task SerializeToStreamAsync_RespectsContentCancellation() return 0; }); - var request = new DefaultHttpContext().Request; - request.Body = source; + var context = new DefaultHttpContext(); + context.Request.Body = source; using var contentCts = ActivityCancellationTokenSource.Rent(TimeSpan.FromSeconds(10), CancellationToken.None); - var sut = CreateContent(request, contentCancellation: contentCts); + var sut = CreateContent(context, contentCancellation: contentCts); var copyToTask = sut.CopyToWithCancellationAsync(new MemoryStream()); contentCts.Cancel(); @@ -183,10 +186,10 @@ public async Task SerializeToStreamAsync_CanBeCanceledExternally() return 0; }); - var request = new DefaultHttpContext().Request; - request.Body = source; + var context = new DefaultHttpContext(); + context.Request.Body = source; - var sut = CreateContent(request); + var sut = CreateContent(context); using var cts = new CancellationTokenSource(); var copyToTask = sut.CopyToAsync(new MemoryStream(), cts.Token);