Skip to content

Commit

Permalink
Propagate ChannelShouldQuiesceEvent to child channels (#464)
Browse files Browse the repository at this point in the history
Motivation:

NIO has a 'ChannelShouldQuiesceEvent' which channels can listen for in
order to know when they should quiesce. This is typically used to
initiate a graceful shutdown of an HTTP/2 server. However, child
channels aren't notified of this event so HTTP/2 servers must keep track
of streams separately in order to notify them when the server is
quiescing.

Modifications:

- Propagate the `ChannelShouldQuiesceEvent` to child channels

Result:

Child channels can watch for `ChannelShouldQuiesceEvent`s
  • Loading branch information
glbrntt authored Sep 18, 2024
1 parent 8e667f8 commit 6693a60
Show file tree
Hide file tree
Showing 8 changed files with 149 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ extension NIOHTTP2Handler.InboundStreamMultiplexer {
}
}

func userInboundEventReceived(_ event: Any) {
switch self {
case .inline(let multiplexer):
multiplexer.receivedUserInboundEvent(event)
case .legacy:
() // No-op: already sent down the pipeline by the `NIOHTTP2Handler`.
}
}

func channelWritabilityChangedReceived() {
switch self {
case .inline(let inlineStreamMultiplexer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ extension InlineStreamMultiplexer: HTTP2InboundStreamMultiplexer {
self._commonStreamMultiplexer.receivedFrame(frame, context: self.context, multiplexer: .inline(self))
}

func receivedUserInboundEvent(_ event: Any) {
self._commonStreamMultiplexer.selectivelyPropagateUserInboundEvent(context: self.context, event: event)
}

func streamError(streamID: HTTP2StreamID, error: Error) {
let streamError = NIOHTTP2Errors.streamError(streamID: streamID, baseError: error)
self._commonStreamMultiplexer.streamError(context: self.context, streamError)
Expand Down
5 changes: 5 additions & 0 deletions Sources/NIOHTTP2/HTTP2ChannelHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,11 @@ public final class NIOHTTP2Handler: ChannelDuplexHandler {
}
}

public func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
self.inboundStreamMultiplexer?.userInboundEventReceived(event)
context.fireUserInboundEventTriggered(event)
}

public func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let data = self.unwrapInboundIn(data)
self.frameDecoder.append(bytes: data)
Expand Down
18 changes: 18 additions & 0 deletions Sources/NIOHTTP2/HTTP2CommonInboundStreamMultiplexer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,24 @@ extension HTTP2CommonInboundStreamMultiplexer {
self.streamChannelContinuation?.finish()
}

internal func selectivelyPropagateUserInboundEvent(context: ChannelHandlerContext, event: Any) {
func propagateEvent(_ event: Any) {
for channel in self.streams.values {
channel.baseChannel.pipeline.fireUserInboundEventTriggered(event)
}
for channel in self._pendingStreams.values {
channel.baseChannel.pipeline.fireUserInboundEventTriggered(event)
}
}

switch event {
case is ChannelShouldQuiesceEvent:
propagateEvent(event)
default:
()
}
}

internal func propagateChannelWritabilityChanged(context: ChannelHandlerContext) {
for channel in self.streams.values {
channel.parentChannelWritabilityChanged(newValue: context.channel.isWritable)
Expand Down
2 changes: 1 addition & 1 deletion Sources/NIOHTTP2/HTTP2StreamMultiplexer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ public final class HTTP2StreamMultiplexer: ChannelInboundHandler, ChannelOutboun
case let evt as NIOHTTP2StreamCreatedEvent:
_ = self.commonStreamMultiplexer.streamCreated(event: evt)
default:
break
self.commonStreamMultiplexer.selectivelyPropagateUserInboundEvent(context: context, event: event)
}

context.fireUserInboundEventTriggered(event)
Expand Down
19 changes: 19 additions & 0 deletions Tests/NIOHTTP2Tests/HTTP2FramePayloadStreamMultiplexerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2322,3 +2322,22 @@ private final class ReadAndFrameConsumer: ChannelInboundHandler, ChannelOutbound
}
}
}

final class UserInboundEventRecorder: ChannelInboundHandler {
typealias InboundIn = Any

private let receivedEvents: NIOLockedValueBox<[Any]>

var events: [Any] {
self.receivedEvents.withLockedValue { $0 }
}

init() {
self.receivedEvents = NIOLockedValueBox([])
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
self.receivedEvents.withLockedValue { $0.append(event) }
context.fireUserInboundEventTriggered(event)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2185,4 +2185,62 @@ class SimpleClientServerFramePayloadStreamTests: XCTestCase {
XCTAssertNoThrow(try self.clientChannel.finish())
XCTAssertNoThrow(try self.serverChannel.finish())
}

func testChannelShouldQuiesceIsPropagated() throws {
// Setup the connection.
let receivedShouldQuiesceEvent = self.clientChannel.eventLoop.makePromise(of: Void.self)
try self.basicHTTP2Connection { stream in
stream.pipeline.addHandler(ShouldQuiesceEventWaiter(promise: receivedShouldQuiesceEvent))
}

let connectionReceivedShouldQuiesceEvent = self.clientChannel.eventLoop.makePromise(of: Void.self)
try self.serverChannel.pipeline.addHandler(ShouldQuiesceEventWaiter(promise: connectionReceivedShouldQuiesceEvent)).wait()

// Create the stream channel.
let multiplexer = try self.clientChannel.pipeline.handler(type: HTTP2StreamMultiplexer.self).wait()
let streamPromise = self.clientChannel.eventLoop.makePromise(of: Channel.self)
multiplexer.createStreamChannel(promise: streamPromise) {
$0.eventLoop.makeSucceededVoidFuture()
}
self.clientChannel.embeddedEventLoop.run()
let stream = try streamPromise.futureResult.wait()

// Initiate request to open the stream on the server.
let headers = HPACKHeaders([(":path", "/"), (":method", "POST"), (":scheme", "http")])
let frame: HTTP2Frame.FramePayload = .headers(.init(headers: headers))
stream.writeAndFlush(frame, promise: nil)
self.interactInMemory(self.clientChannel, self.serverChannel)

// Fire the event on the server pipeline, this should propagate to the stream channel and
// the connection channel.
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
XCTAssertNoThrow(try receivedShouldQuiesceEvent.futureResult.wait())
XCTAssertNoThrow(try connectionReceivedShouldQuiesceEvent.futureResult.wait())

XCTAssertNoThrow(try self.clientChannel.finish())
XCTAssertNoThrow(try self.serverChannel.finish())
}
}


final class ShouldQuiesceEventWaiter: ChannelInboundHandler {
typealias InboundIn = Never

private let promise: EventLoopPromise<Void>

init(promise: EventLoopPromise<Void>) {
self.promise = promise
}

func channelInactive(context: ChannelHandlerContext) {
self.promise.fail(ChannelError.eof)
context.fireChannelInactive()
}

func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
if event is ChannelShouldQuiesceEvent {
self.promise.succeed(())
}
context.fireUserInboundEventTriggered(event)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,39 @@ class SimpleClientServerInlineStreamMultiplexerTests: XCTestCase {
XCTAssertNoThrow(try self.clientChannel.finish())
XCTAssertNoThrow(try self.serverChannel.finish())
}

func testChannelShouldQuiesceIsPropagated() throws {
// Setup the connection.
let receivedShouldQuiesceEvent = self.clientChannel.eventLoop.makePromise(of: Void.self)
try self.basicHTTP2Connection { stream in
stream.pipeline.addHandler(ShouldQuiesceEventWaiter(promise: receivedShouldQuiesceEvent))
}

let connectionReceivedShouldQuiesceEvent = self.clientChannel.eventLoop.makePromise(of: Void.self)
try self.serverChannel.pipeline.addHandler(ShouldQuiesceEventWaiter(promise: connectionReceivedShouldQuiesceEvent)).wait()

// Create the stream channel.
let multiplexer = try self.clientChannel.pipeline.handler(type: NIOHTTP2Handler.self).flatMap { $0.multiplexer }.wait()
let streamPromise = self.clientChannel.eventLoop.makePromise(of: Channel.self)
multiplexer.createStreamChannel(promise: streamPromise) {
$0.eventLoop.makeSucceededVoidFuture()
}
self.clientChannel.embeddedEventLoop.run()
let stream = try streamPromise.futureResult.wait()

// Initiate request to open the stream on the server.
let headers = HPACKHeaders([(":path", "/"), (":method", "POST"), (":scheme", "http")])
let frame: HTTP2Frame.FramePayload = .headers(.init(headers: headers))
stream.writeAndFlush(frame, promise: nil)
self.interactInMemory(self.clientChannel, self.serverChannel)

// Fire the event on the server pipeline, this should propagate to the stream channel and
// the connection channel.
self.serverChannel.pipeline.fireUserInboundEventTriggered(ChannelShouldQuiesceEvent())
XCTAssertNoThrow(try receivedShouldQuiesceEvent.futureResult.wait())
XCTAssertNoThrow(try connectionReceivedShouldQuiesceEvent.futureResult.wait())

XCTAssertNoThrow(try self.clientChannel.finish())
XCTAssertNoThrow(try self.serverChannel.finish())
}
}

0 comments on commit 6693a60

Please sign in to comment.