Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix HttpContext race condition by copying values to reader and writer #2294

Merged
merged 5 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/Grpc.AspNetCore.Server/Internal/GrpcProtocolHelpers.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -234,4 +234,17 @@ internal static bool ShouldSkipHeader(string name)
{
return name.StartsWith(':') || GrpcProtocolConstants.FilteredHeaders.Contains(name);
}

internal static IHttpRequestLifetimeFeature GetRequestLifetimeFeature(HttpContext httpContext)
{
var lifetimeFeature = httpContext.Features.Get<IHttpRequestLifetimeFeature>();
if (lifetimeFeature is null)
{
// This should only run in tests where the HttpContext is manually created.
lifetimeFeature = new HttpRequestLifetimeFeature();
httpContext.Features.Set(lifetimeFeature);
}

return lifetimeFeature;
}
}
14 changes: 12 additions & 2 deletions src/Grpc.AspNetCore.Server/Internal/HttpContextStreamReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#endregion

using System.Diagnostics;
using System.IO.Pipelines;
using Grpc.Core;
using Grpc.Shared;
using Microsoft.AspNetCore.Http.Features;

namespace Grpc.AspNetCore.Server.Internal;

Expand All @@ -28,6 +30,8 @@ internal class HttpContextStreamReader<TRequest> : IAsyncStreamReader<TRequest>
{
private readonly HttpContextServerCallContext _serverCallContext;
private readonly Func<DeserializationContext, TRequest> _deserializer;
private readonly PipeReader _bodyReader;
private readonly IHttpRequestLifetimeFeature _requestLifetimeFeature;
private bool _completed;
private long _readCount;
private bool _endOfStream;
Expand All @@ -36,6 +40,12 @@ public HttpContextStreamReader(HttpContextServerCallContext serverCallContext, F
{
_serverCallContext = serverCallContext;
_deserializer = deserializer;

// Copy HttpContext values.
// This is done to avoid a race condition when reading them from HttpContext later when running in a separate thread.
_bodyReader = _serverCallContext.HttpContext.Request.BodyReader;
// Copy lifetime feature because HttpContext.RequestAborted on .NET 6 doesn't return the real cancellation token.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain both the 6.0 comment as well as this whole PR?

From my understanding; If you're using the cancellation token on a background thread it might be disposed instead of canceled when the HttpContext is reused. And if you access the token property after it's been reset you might be using a token from another request.

Copy link
Member Author

@JamesNK JamesNK Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Root problem: dotnet/aspnetcore#42040
The fix it to access the cancellation token on the feature directly. HttpContext.RequestAborted isn't thread-safe but the feature is.

re: .NET 6. I observed that the cancellation token returned by HttpContext.RequestAborted and cached at the start of the request behaved strangely. It wouldn't immediately be canceled when the request was aborted. The cached HttpContext.RequestAborted value appeared to be different from the value after the abort happened (but while the request was still in progress, so the value isn't coming from another request). I didn't investigate why I got this odd behavior in .NET 6 since there was a workaround - access the token from the feature - and .NET 7 and .NET 8 behaved how I expected.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Root problem: dotnet/aspnetcore#42040
The fix it to access the cancellation token on the feature directly. HttpContext.RequestAborted isn't thread-safe but the feature is.

Oh, so gRPC isn't using the token after the middleware returns? That's what my comment was focused on.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. This is in the context of one request.

_requestLifetimeFeature = GrpcProtocolHelpers.GetRequestLifetimeFeature(_serverCallContext.HttpContext);
}

public TRequest Current { get; private set; } = default!;
Expand All @@ -54,7 +64,7 @@ async Task<bool> MoveNextAsync(ValueTask<TRequest?> readStreamTask)
return Task.FromCanceled<bool>(cancellationToken);
}

if (_completed || _serverCallContext.CancellationToken.IsCancellationRequested)
if (_completed || _requestLifetimeFeature.RequestAborted.IsCancellationRequested)
{
return Task.FromException<bool>(new InvalidOperationException("Can't read messages after the request is complete."));
}
Expand All @@ -63,7 +73,7 @@ async Task<bool> MoveNextAsync(ValueTask<TRequest?> readStreamTask)
// In a long running stream this can allow the previous value to be GCed.
Current = null!;

var request = _serverCallContext.HttpContext.Request.BodyReader.ReadStreamMessageAsync(_serverCallContext, _deserializer, cancellationToken);
var request = _bodyReader.ReadStreamMessageAsync(_serverCallContext, _deserializer, cancellationToken);
if (!request.IsCompletedSuccessfully)
{
return MoveNextAsync(request);
Expand Down
14 changes: 12 additions & 2 deletions src/Grpc.AspNetCore.Server/Internal/HttpContextStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
#endregion

using System.Diagnostics;
using System.IO.Pipelines;
using Grpc.Core;
using Grpc.Shared;
using Microsoft.AspNetCore.Http.Features;

namespace Grpc.AspNetCore.Server.Internal;

Expand All @@ -29,6 +31,8 @@ internal class HttpContextStreamWriter<TResponse> : IServerStreamWriter<TRespons
{
private readonly HttpContextServerCallContext _context;
private readonly Action<TResponse, SerializationContext> _serializer;
private readonly PipeWriter _bodyWriter;
private readonly IHttpRequestLifetimeFeature _requestLifetimeFeature;
private readonly object _writeLock;
private Task? _writeTask;
private bool _completed;
Expand All @@ -39,6 +43,12 @@ public HttpContextStreamWriter(HttpContextServerCallContext context, Action<TRes
_context = context;
_serializer = serializer;
_writeLock = new object();

// Copy HttpContext values.
// This is done to avoid a race condition when reading them from HttpContext later when running in a separate thread.
_bodyWriter = context.HttpContext.Response.BodyWriter;
// Copy lifetime feature because HttpContext.RequestAborted on .NET 6 doesn't return the real cancellation token.
_requestLifetimeFeature = GrpcProtocolHelpers.GetRequestLifetimeFeature(context.HttpContext);
}

public WriteOptions? WriteOptions
Expand Down Expand Up @@ -77,7 +87,7 @@ private async Task WriteCoreAsync(TResponse message, CancellationToken cancellat
{
cancellationToken.ThrowIfCancellationRequested();

if (_completed || _context.CancellationToken.IsCancellationRequested)
if (_completed || _requestLifetimeFeature.RequestAborted.IsCancellationRequested)
{
throw new InvalidOperationException("Can't write the message because the request is complete.");
}
Expand All @@ -91,7 +101,7 @@ private async Task WriteCoreAsync(TResponse message, CancellationToken cancellat
}

// Save write task to track whether it is complete. Must be set inside lock.
_writeTask = _context.HttpContext.Response.BodyWriter.WriteStreamedMessageAsync(message, _context, _serializer, cancellationToken);
_writeTask = _bodyWriter.WriteStreamedMessageAsync(message, _context, _serializer, cancellationToken);
}

await _writeTask;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#endregion

using Grpc.AspNetCore.Server.Internal.CallHandlers;
using Grpc.AspNetCore.Server.Tests.TestObjects;
using Grpc.Core;
using Grpc.Shared.Server;
using Grpc.Tests.Shared;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.Logging.Abstractions;
using NUnit.Framework;

namespace Grpc.AspNetCore.Server.Tests;

[TestFixture]
public class DuplexStreamingServerCallHandlerTests
{
private static readonly Marshaller<TestMessage> _marshaller = new Marshaller<TestMessage>((message, context) => { context.Complete(Array.Empty<byte>()); }, context => new TestMessage());

[Test]
public async Task HandleCallAsync_ConcurrentReadAndWrite_Success()
{
// Arrange
var invoker = new DuplexStreamingServerMethodInvoker<TestService, TestMessage, TestMessage>(
(service, reader, writer, context) =>
{
var message = new TestMessage();
var readTask = Task.Run(() => reader.MoveNext());
var writeTask = Task.Run(() => writer.WriteAsync(message));
return Task.WhenAll(readTask, writeTask);
},
new Method<TestMessage, TestMessage>(MethodType.DuplexStreaming, "test", "test", _marshaller, _marshaller),
HttpContextServerCallContextHelper.CreateMethodOptions(),
new TestGrpcServiceActivator<TestService>());
var handler = new DuplexStreamingServerCallHandler<TestService, TestMessage, TestMessage>(invoker, NullLoggerFactory.Instance);

// Verify there isn't a race condition when reading/writing on seperate threads.
// This test primarily exists to ensure that the stream reader and stream writer aren't accessing non-thread safe APIs on HttpContext.
for (var i = 0; i < 10_000; i++)
{
var httpContext = HttpContextHelpers.CreateContext();

// Act
await handler.HandleCallAsync(httpContext);

// Assert
var trailers = httpContext.Features.Get<IHttpResponseTrailersFeature>()!.Trailers;
Assert.AreEqual("0", trailers["grpc-status"].ToString());
}
}
}