Skip to content

Commit

Permalink
Improve UseRequestTimeouts validation (#2501)
Browse files Browse the repository at this point in the history
* Improve UseRequestTimeouts validation

* Clarify helper name
  • Loading branch information
MihaZupan authored May 15, 2024
1 parent 270a9cf commit c4b5055
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 19 deletions.
92 changes: 77 additions & 15 deletions src/ReverseProxy/Model/ProxyPipelineInitializerMiddleware.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
#if NET8_0_OR_GREATER
using System.Threading;
using Microsoft.AspNetCore.Http.Timeouts;
using Microsoft.Extensions.Options;
#endif
using Microsoft.Extensions.Logging;
#if NET8_0_OR_GREATER
using Yarp.ReverseProxy.Configuration;
#endif
using Yarp.ReverseProxy.Utilities;

namespace Yarp.ReverseProxy.Model;
Expand All @@ -23,12 +22,23 @@ internal sealed class ProxyPipelineInitializerMiddleware
{
private readonly ILogger _logger;
private readonly RequestDelegate _next;
#if NET8_0_OR_GREATER
private readonly IOptionsMonitor<RequestTimeoutOptions> _timeoutOptions;
#endif

public ProxyPipelineInitializerMiddleware(RequestDelegate next,
ILogger<ProxyPipelineInitializerMiddleware> logger)
ILogger<ProxyPipelineInitializerMiddleware> logger
#if NET8_0_OR_GREATER
, IOptionsMonitor<RequestTimeoutOptions> timeoutOptions
#endif
)
{
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_next = next ?? throw new ArgumentNullException(nameof(next));

#if NET8_0_OR_GREATER
_timeoutOptions = timeoutOptions ?? throw new ArgumentNullException(nameof(timeoutOptions));
#endif
}

public Task Invoke(HttpContext context)
Expand All @@ -47,19 +57,11 @@ public Task Invoke(HttpContext context)
context.Response.StatusCode = StatusCodes.Status503ServiceUnavailable;
return Task.CompletedTask;
}

#if NET8_0_OR_GREATER
// There's no way to detect the presence of the timeout middleware before this, only the options.
if (endpoint.Metadata.GetMetadata<RequestTimeoutAttribute>() != null
&& context.Features.Get<IHttpRequestTimeoutFeature>() == null
// The feature is skipped if the request is already canceled. We'll handle canceled requests later for consistency.
&& !context.RequestAborted.IsCancellationRequested)
{
Log.TimeoutNotApplied(_logger, route.Config.RouteId);
// Out of an abundance of caution, refuse the request rather than allowing it to proceed without the configured timeout.
throw new InvalidOperationException($"The timeout was not applied for route '{route.Config.RouteId}', ensure `IApplicationBuilder.UseRequestTimeouts()`"
+ " is called between `IApplicationBuilder.UseRouting()` and `IApplicationBuilder.UseEndpoints()`.");
}
EnsureRequestTimeoutPolicyIsAppliedCorrectly(context, endpoint, route);
#endif

var destinationsState = cluster.DestinationsState;
context.Features.Set<IReverseProxyFeature>(new ReverseProxyFeature
{
Expand Down Expand Up @@ -91,6 +93,66 @@ private async Task AwaitWithActivity(HttpContext context, Activity activity)
}
}

#if NET8_0_OR_GREATER
private void EnsureRequestTimeoutPolicyIsAppliedCorrectly(HttpContext context, Endpoint endpoint, RouteModel route)
{
// There's no way to detect the presence of the timeout middleware before this, only the options.
if (endpoint.Metadata.GetMetadata<RequestTimeoutAttribute>() is { } requestTimeout &&
context.Features.Get<IHttpRequestTimeoutFeature>() is null &&
// The feature is skipped if the request is already canceled. We'll handle canceled requests later for consistency.
!context.RequestAborted.IsCancellationRequested &&
// The policy may set the timeout to null / infinite.
TimeoutPolicyRequestedATimeoutBeSet(requestTimeout))
{
// A timeout should have been set.
// Out of an abundance of caution, refuse the request rather than allowing it to proceed without the configured timeout.
ThrowIfDebuggerNotAttached(route);
}

void ThrowIfDebuggerNotAttached(RouteModel route)
{
// The feature is skipped if the debugger is attached.
if (!Debugger.IsAttached)
{
Log.TimeoutNotApplied(_logger, route.Config.RouteId);

throw new InvalidOperationException(
$"The timeout was not applied for route '{route.Config.RouteId}', " +
"ensure `IApplicationBuilder.UseRequestTimeouts()` is called between " +
"`IApplicationBuilder.UseRouting()` and `IApplicationBuilder.UseEndpoints()`.");
}
}
}

private bool TimeoutPolicyRequestedATimeoutBeSet(RequestTimeoutAttribute requestTimeout)
{
if (requestTimeout.Timeout is not TimeSpan timeout)
{
if (requestTimeout.PolicyName is not string policyName)
{
Debug.Fail("Either Timeout or PolicyName should have been set.");
return false;
}

if (!_timeoutOptions.CurrentValue.Policies.TryGetValue(policyName, out var policy))
{
// This should only happen if the policy existed at some point, but the options were updated to remove it.
return false;
}

if (policy.Timeout is null)
{
// The policy requested no timeout.
return false;
}

timeout = policy.Timeout.Value;
}

return timeout != Timeout.InfiniteTimeSpan;
}
#endif

private static class Log
{
private static readonly Action<ILogger, string, Exception?> _noClusterFound = LoggerMessage.Define<string>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using Yarp.Tests.Common;
using Yarp.ReverseProxy.Configuration;
using Yarp.ReverseProxy.Forwarder;
using System.Diagnostics;

namespace Yarp.ReverseProxy.Model.Tests;

Expand Down Expand Up @@ -122,9 +123,12 @@ public async Task Invoke_NoHealthyEndpoints_CallsNext()

Assert.Equal(StatusCodes.Status418ImATeapot, httpContext.Response.StatusCode);
}

#if NET8_0_OR_GREATER
[Fact]
public async Task Invoke_MissingTimeoutMiddleware_RefuseRequest()
[Theory]
[InlineData(1)]
[InlineData(Timeout.Infinite)]
public async Task Invoke_MissingTimeoutMiddleware_RefuseRequest(int timeoutMs)
{
var httpClient = new HttpMessageInvoker(new Mock<HttpMessageHandler>().Object);
var cluster1 = new ClusterState(clusterId: "cluster1")
Expand All @@ -140,15 +144,23 @@ public async Task Invoke_MissingTimeoutMiddleware_RefuseRequest()
var aspNetCoreEndpoint = CreateAspNetCoreEndpoint(routeConfig,
builder =>
{
builder.Metadata.Add(new RequestTimeoutAttribute(1));
builder.Metadata.Add(new RequestTimeoutAttribute(timeoutMs));
});
aspNetCoreEndpoints.Add(aspNetCoreEndpoint);
var httpContext = new DefaultHttpContext();
httpContext.SetEndpoint(aspNetCoreEndpoint);

var sut = Create<ProxyPipelineInitializerMiddleware>();

await Assert.ThrowsAsync<InvalidOperationException>(() => sut.Invoke(httpContext));
if (timeoutMs == Timeout.Infinite || Debugger.IsAttached)
{
// If the timeout was infinite or the debugger is attached, we shouldn't refuse the request.
await sut.Invoke(httpContext);
}
else
{
await Assert.ThrowsAsync<InvalidOperationException>(() => sut.Invoke(httpContext));
}
}
#endif

Expand Down

0 comments on commit c4b5055

Please sign in to comment.