Skip to content

Commit

Permalink
[HTTP/3] Abort response stream on dispose if content not finished (#5…
Browse files Browse the repository at this point in the history
…7156)

* Sends abort read/write if H/3 stream is disposed before respective contents are finsihed

* Minor tweaks in abort conditions

* Prevent reverting SendState from Aborted/ConnectionClosed back to sending state within Send* methods.
  • Loading branch information
ManickaP authored Aug 19, 2021
1 parent 9a55354 commit b75e55b
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ internal sealed class Http3LoopbackStream : IDisposable
private const int MaximumVarIntBytes = 8;
private const long VarIntMax = (1L << 62) - 1;

private const long DataFrame = 0x0;
private const long HeadersFrame = 0x1;
private const long SettingsFrame = 0x4;
private const long GoAwayFrame = 0x7;
public const long DataFrame = 0x0;
public const long HeadersFrame = 0x1;
public const long SettingsFrame = 0x4;
public const long GoAwayFrame = 0x7;

public const long ControlStream = 0x0;
public const long PushStream = 0x1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public void Dispose()
if (!_disposed)
{
_disposed = true;
AbortStream();
_stream.Dispose();
DisposeSyncHelper();
}
Expand All @@ -94,6 +95,7 @@ public async ValueTask DisposeAsync()
if (!_disposed)
{
_disposed = true;
AbortStream();
await _stream.DisposeAsync().ConfigureAwait(false);
DisposeSyncHelper();
}
Expand Down Expand Up @@ -358,6 +360,9 @@ private async Task SendContentAsync(HttpContent content, CancellationToken cance
await content.CopyToAsync(writeStream, null, cancellationToken).ConfigureAwait(false);
}

// Set to 0 to recognize that the whole request body has been sent and therefore there's no need to abort write side in case of a premature disposal.
_requestContentLengthRemaining = 0;

if (_sendBuffer.ActiveLength != 0)
{
// Our initial send buffer, which has our headers, is normally sent out on the first write to the Http3WriteStream.
Expand Down Expand Up @@ -1210,6 +1215,20 @@ private async ValueTask<bool> ReadNextDataFrameAsync(HttpResponseMessage respons
public void Trace(string message, [CallerMemberName] string? memberName = null) =>
_connection.Trace(StreamId, message, memberName);

private void AbortStream()
{
// If the request body isn't completed, cancel it now.
if (_requestContentLengthRemaining != 0) // 0 is used for the end of content writing, -1 is used for unknown Content-Length
{
_stream.AbortWrite((long)Http3ErrorCode.RequestCancelled);
}
// If the response body isn't completed, cancel it now.
if (_responseDataPayloadRemaining != -1) // -1 is used for EOF, 0 for consumed DATA frame payload before the next read
{
_stream.AbortRead((long)Http3ErrorCode.RequestCancelled);
}
}

// TODO: it may be possible for Http3RequestStream to implement Stream directly and avoid this allocation.
private sealed class Http3ReadStream : HttpBaseStream
{
Expand All @@ -1233,36 +1252,42 @@ public Http3ReadStream(Http3RequestStream stream)

protected override void Dispose(bool disposing)
{
if (_stream != null)
Http3RequestStream? stream = Interlocked.Exchange(ref _stream, null);
if (stream is null)
{
if (disposing)
{
// This will remove the stream from the connection properly.
_stream.Dispose();
}
else
{
// We shouldn't be using a managed instance here, but don't have much choice -- we
// need to remove the stream from the connection's GOAWAY collection.
_stream._connection.RemoveStream(_stream._stream);
_stream._connection = null!;
}
return;
}

_stream = null;
_response = null;
if (disposing)
{
// This will remove the stream from the connection properly.
stream.Dispose();
}
else
{
// We shouldn't be using a managed instance here, but don't have much choice -- we
// need to remove the stream from the connection's GOAWAY collection.
stream._connection.RemoveStream(stream._stream);
stream._connection = null!;
}

_response = null;

base.Dispose(disposing);
}

public override async ValueTask DisposeAsync()
{
if (_stream != null)
Http3RequestStream? stream = Interlocked.Exchange(ref _stream, null);
if (stream is null)
{
await _stream.DisposeAsync().ConfigureAwait(false);
_stream = null!;
return;
}

await stream.DisposeAsync().ConfigureAwait(false);

_response = null;

await base.DisposeAsync().ConfigureAwait(false);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,119 @@ public async Task ReservedFrameType_Throws()
await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}

[Fact]
public async Task RequestSentResponseDisposed_ThrowsOnServer()
{
byte[] data = Encoding.UTF8.GetBytes(new string('a', 1024));

using Http3LoopbackServer server = CreateHttp3LoopbackServer();

Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
HttpRequestData request = await stream.ReadRequestDataAsync();
await stream.SendResponseHeadersAsync();

Stopwatch sw = Stopwatch.StartNew();
bool hasFailed = false;
while (sw.Elapsed < TimeSpan.FromSeconds(15))
{
try
{
await stream.SendResponseBodyAsync(data, isFinal: false);
}
catch (QuicStreamAbortedException)
{
hasFailed = true;
break;
}
}
Assert.True(hasFailed, $"Expected {nameof(QuicStreamAbortedException)}, instead ran successfully for {sw.Elapsed}");
});

Task clientTask = Task.Run(async () =>
{
using HttpClient client = CreateHttpClient();
using HttpRequestMessage request = new()
{
Method = HttpMethod.Get,
RequestUri = server.Address,
Version = HttpVersion30,
VersionPolicy = HttpVersionPolicy.RequestVersionExact
};

var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
var stream = await response.Content.ReadAsStreamAsync();
byte[] buffer = new byte[512];
for (int i = 0; i < 5; ++i)
{
var count = await stream.ReadAsync(buffer);
}

// We haven't finished reading the whole respose, but we're disposing it, which should turn into an exception on the server-side.
response.Dispose();
await serverTask;
});

await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}

[Fact]
public async Task RequestSendingResponseDisposed_ThrowsOnServer()
{
byte[] data = Encoding.UTF8.GetBytes(new string('a', 1024));

using Http3LoopbackServer server = CreateHttp3LoopbackServer();

Task serverTask = Task.Run(async () =>
{
using Http3LoopbackConnection connection = (Http3LoopbackConnection)await server.EstablishGenericConnectionAsync();
using Http3LoopbackStream stream = await connection.AcceptRequestStreamAsync();
HttpRequestData request = await stream.ReadRequestDataAsync(false);
await stream.SendResponseHeadersAsync();

Stopwatch sw = Stopwatch.StartNew();
bool hasFailed = false;
while (sw.Elapsed < TimeSpan.FromSeconds(15))
{
try
{
var (frameType, payload) = await stream.ReadFrameAsync();
Assert.Equal(Http3LoopbackStream.DataFrame, frameType);
}
catch (QuicStreamAbortedException)
{
hasFailed = true;
break;
}
}
Assert.True(hasFailed, $"Expected {nameof(QuicStreamAbortedException)}, instead ran successfully for {sw.Elapsed}");
});

Task clientTask = Task.Run(async () =>
{
using HttpClient client = CreateHttpClient();
using HttpRequestMessage request = new()
{
Method = HttpMethod.Get,
RequestUri = server.Address,
Version = HttpVersion30,
VersionPolicy = HttpVersionPolicy.RequestVersionExact,
Content = new ByteAtATimeContent(60*4, Task.CompletedTask, new TaskCompletionSource<bool>(), 250)
};

var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
var stream = await response.Content.ReadAsStreamAsync();

// We haven't finished sending the whole request, but we're disposing the response, which should turn into an exception on the server-side.
response.Dispose();
await serverTask;
});

await new[] { clientTask, serverTask }.WhenAllOrAnyFailed(20_000);
}

[Fact]
public async Task ServerCertificateCustomValidationCallback_Succeeds()
{
Expand Down Expand Up @@ -885,7 +998,7 @@ public async Task StatusCodes_ReceiveSuccess(HttpStatusCode statusCode, bool qpa
VersionPolicy = HttpVersionPolicy.RequestVersionExact
};
HttpResponseMessage response = await client.SendAsync(request).WaitAsync(TimeSpan.FromSeconds(10));

Assert.Equal(statusCode, response.StatusCode);

await serverTask;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ internal override async ValueTask WriteAsync(ReadOnlySequence<byte> buffers, boo
{
ThrowIfDisposed();

using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken);
using CancellationTokenRegistration registration = HandleWriteStartState(buffers.IsEmpty, cancellationToken);

await SendReadOnlySequenceAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);

Expand All @@ -281,7 +281,7 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<ReadOnlyMemory<byte>
{
ThrowIfDisposed();

using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken);
using CancellationTokenRegistration registration = HandleWriteStartState(buffers.IsEmpty, cancellationToken);

await SendReadOnlyMemoryListAsync(buffers, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);

Expand All @@ -292,20 +292,20 @@ internal override async ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, bool e
{
ThrowIfDisposed();

using CancellationTokenRegistration registration = HandleWriteStartState(cancellationToken);
using CancellationTokenRegistration registration = HandleWriteStartState(buffer.IsEmpty, cancellationToken);

await SendReadOnlyMemoryAsync(buffer, endStream ? QUIC_SEND_FLAGS.FIN : QUIC_SEND_FLAGS.NONE).ConfigureAwait(false);

HandleWriteCompletedState();
}

private CancellationTokenRegistration HandleWriteStartState(CancellationToken cancellationToken)
private CancellationTokenRegistration HandleWriteStartState(bool emptyBuffer, CancellationToken cancellationToken)
{
if (_state.SendState == SendState.Closed)
{
throw new InvalidOperationException(SR.net_quic_writing_notallowed);
}
else if ( _state.SendState == SendState.Aborted)
if (_state.SendState == SendState.Aborted)
{
if (_state.SendErrorCode != -1)
{
Expand Down Expand Up @@ -363,10 +363,14 @@ private CancellationTokenRegistration HandleWriteStartState(CancellationToken ca

throw new OperationCanceledException(SR.net_quic_sending_aborted);
}
else if (_state.SendState == SendState.ConnectionClosed)
if (_state.SendState == SendState.ConnectionClosed)
{
throw GetConnectionAbortedException(_state);
}

// Change the state in the same lock where we check for final states to prevent coming back from Aborted/ConnectionClosed.
Debug.Assert(_state.SendState != SendState.Pending);
_state.SendState = emptyBuffer ? SendState.Finished : SendState.Pending;
}

return registration;
Expand Down Expand Up @@ -632,7 +636,10 @@ internal override void Shutdown()

lock (_state)
{
_state.SendState = SendState.Finished;
if (_state.SendState < SendState.Finished)
{
_state.SendState = SendState.Finished;
}
}

// it is ok to send shutdown several times, MsQuic will ignore it
Expand Down Expand Up @@ -1157,12 +1164,6 @@ private unsafe ValueTask SendReadOnlyMemoryAsync(
ReadOnlyMemory<byte> buffer,
QUIC_SEND_FLAGS flags)
{
lock (_state)
{
Debug.Assert(_state.SendState != SendState.Pending);
_state.SendState = buffer.IsEmpty ? SendState.Finished : SendState.Pending;
}

if (buffer.IsEmpty)
{
if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
Expand Down Expand Up @@ -1211,13 +1212,6 @@ private unsafe ValueTask SendReadOnlySequenceAsync(
ReadOnlySequence<byte> buffers,
QUIC_SEND_FLAGS flags)
{

lock (_state)
{
Debug.Assert(_state.SendState != SendState.Pending);
_state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending;
}

if (buffers.IsEmpty)
{
if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
Expand Down Expand Up @@ -1281,12 +1275,6 @@ private unsafe ValueTask SendReadOnlyMemoryListAsync(
ReadOnlyMemory<ReadOnlyMemory<byte>> buffers,
QUIC_SEND_FLAGS flags)
{
lock (_state)
{
Debug.Assert(_state.SendState != SendState.Pending);
_state.SendState = buffers.IsEmpty ? SendState.Finished : SendState.Pending;
}

if (buffers.IsEmpty)
{
if ((flags & QUIC_SEND_FLAGS.FIN) == QUIC_SEND_FLAGS.FIN)
Expand Down

0 comments on commit b75e55b

Please sign in to comment.