diff --git a/Sources/NIOHTTP2/HTTP2Error.swift b/Sources/NIOHTTP2/HTTP2Error.swift index 402b5c56..b20330e5 100644 --- a/Sources/NIOHTTP2/HTTP2Error.swift +++ b/Sources/NIOHTTP2/HTTP2Error.swift @@ -200,6 +200,16 @@ public enum NIOHTTP2Errors { self.streamID = streamID } } + + /// An attempt was made to send trailers without setting END_STREAM on them. + public struct TrailersWithoutEndStream: NIOHTTP2Error { + /// The affected stream ID. + public var streamID: HTTP2StreamID + + public init(streamID: HTTP2StreamID) { + self.streamID = streamID + } + } } diff --git a/Sources/NIOHTTP2/StreamStateMachine.swift b/Sources/NIOHTTP2/StreamStateMachine.swift index 3401b70f..3f2e31c0 100644 --- a/Sources/NIOHTTP2/StreamStateMachine.swift +++ b/Sources/NIOHTTP2/StreamStateMachine.swift @@ -850,6 +850,12 @@ extension HTTP2StreamStateMachine { /// target final state. private mutating func processTrailers(_ headers: HPACKHeaders, isEndStreamSet endStream: Bool, targetState target: State, targetEffect effect: StreamStateChange?) -> StateMachineResultWithStreamEffect { // TODO(cory): Implement + + // End stream must be set on trailers. + guard endStream else { + return StateMachineResultWithStreamEffect(result: .streamError(streamID: self.streamID, underlyingError: NIOHTTP2Errors.TrailersWithoutEndStream(streamID: self.streamID), type: .protocolError), effect: nil) + } + self.state = target return StateMachineResultWithStreamEffect(result: .succeed, effect: effect) } diff --git a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests+XCTest.swift b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests+XCTest.swift index 770349c0..c11709b6 100644 --- a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests+XCTest.swift +++ b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests+XCTest.swift @@ -81,6 +81,8 @@ extension ConnectionStateMachineTests { ("testDisablingPushPreventsPush", testDisablingPushPreventsPush), ("testRatchetingGoawayEvenWhenFullyQueisced", testRatchetingGoawayEvenWhenFullyQueisced), ("testRatchetingGoawayForBothPeersEvenWhenFullyQuiesced", testRatchetingGoawayForBothPeersEvenWhenFullyQuiesced), + ("testClientTrailersMustHaveEndStreamSet", testClientTrailersMustHaveEndStreamSet), + ("testServerTrailersMustHaveEndStreamSet", testServerTrailersMustHaveEndStreamSet), ] } } diff --git a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift index 5c92ecf2..7ee6f0b0 100644 --- a/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift +++ b/Tests/NIOHTTP2Tests/ConnectionStateMachineTests.swift @@ -1476,8 +1476,8 @@ class ConnectionStateMachineTests: XCTestCase { assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: false)) assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: false)) - assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.trailers, isEndStreamSet: false)) - assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.trailers, isEndStreamSet: false)) + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.trailers, isEndStreamSet: true)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.trailers, isEndStreamSet: true)) assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.responseHeaders, isEndStreamSet: false)) assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.responseHeaders, isEndStreamSet: false)) assertSucceeds(self.server.sendPushPromise(originalStreamID: streamOne, childStreamID: streamTwo, headers: ConnectionStateMachineTests.requestHeaders)) @@ -1721,5 +1721,37 @@ class ConnectionStateMachineTests: XCTestCase { assertGoawaySucceeds(self.client.sendGoaway(lastStreamID: .rootStream), droppingStreams: nil) assertGoawaySucceeds(self.server.receiveGoaway(lastStreamID: .rootStream), droppingStreams: nil) } + + func testClientTrailersMustHaveEndStreamSet() { + let streamOne = HTTP2StreamID(1) + + self.exchangePreamble() + + // Client initiates a stream. + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: false)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: false)) + + // If the client attempts to send trailers without end stream, this fails with a stream error. + assertStreamError(type: .protocolError, self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.trailers, isEndStreamSet: false)) + assertStreamError(type: .protocolError, self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.trailers, isEndStreamSet: false)) + } + + func testServerTrailersMustHaveEndStreamSet() { + let streamOne = HTTP2StreamID(1) + + self.exchangePreamble() + + // Client initiates a stream. + assertSucceeds(self.client.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + assertSucceeds(self.server.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.requestHeaders, isEndStreamSet: true)) + + // Server sends back headers. + assertSucceeds(self.server.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.responseHeaders, isEndStreamSet: false)) + assertSucceeds(self.client.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.responseHeaders, isEndStreamSet: false)) + + // If the server attempts to send trailers without end stream, this fails with a stream error. + assertStreamError(type: .protocolError, self.server.sendHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.trailers, isEndStreamSet: false)) + assertStreamError(type: .protocolError, self.client.receiveHeaders(streamID: streamOne, headers: ConnectionStateMachineTests.trailers, isEndStreamSet: false)) + } }