Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conform client/server connection handlers to h2 stream delegate #1875

Merged
merged 3 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
/// Resets once `channelReadComplete` returns.
private var inReadLoop: Bool

/// The context of the channel this handler is in.
private var context: ChannelHandlerContext?

/// Creates a new handler which manages the lifecycle of a connection.
///
/// - Parameters:
Expand Down Expand Up @@ -118,6 +121,11 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl

func handlerAdded(context: ChannelHandlerContext) {
assert(context.eventLoop === self.eventLoop)
self.context = context
}

func handlerRemoved(context: ChannelHandlerContext) {
self.context = nil
}

func channelActive(context: ChannelHandlerContext) {
Expand All @@ -144,31 +152,10 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case let event as NIOHTTP2StreamCreatedEvent:
// Stream created, so the connection isn't idle.
self.maxIdleTimer?.cancel()
self.state.streamOpened(event.streamID)
self.streamCreated(event.streamID, channel: context.channel)

case let event as StreamClosedEvent:
switch self.state.streamClosed(event.streamID) {
case .startIdleTimer(let cancelKeepalive):
// All streams are closed, restart the idle timer, and stop the keep-alive timer (it may
// not stop if keep-alive is allowed when there are no active calls).
self.maxIdleTimer?.schedule(on: context.eventLoop) {
self.maxIdleTimerFired(context: context)
}

if cancelKeepalive {
self.keepaliveTimer?.cancel()
}

case .close:
// Connection was closing but waiting for all streams to close. They must all be closed
// now so close the connection.
context.close(promise: nil)

case .none:
()
}
self.streamClosed(event.streamID, channel: context.channel)

default:
()
Expand Down Expand Up @@ -263,6 +250,42 @@ final class ClientConnectionHandler: ChannelInboundHandler, ChannelOutboundHandl
}
}

extension ClientConnectionHandler: NIOHTTP2StreamDelegate {
func streamCreated(_ id: HTTP2StreamID, channel: any Channel) {
self.eventLoop.assertInEventLoop()

// Stream created, so the connection isn't idle.
self.maxIdleTimer?.cancel()
self.state.streamOpened(id)
}

func streamClosed(_ id: HTTP2StreamID, channel: any Channel) {
guard let context = self.context else { return }
self.eventLoop.assertInEventLoop()

switch self.state.streamClosed(id) {
case .startIdleTimer(let cancelKeepalive):
// All streams are closed, restart the idle timer, and stop the keep-alive timer (it may
// not stop if keep-alive is allowed when there are no active calls).
self.maxIdleTimer?.schedule(on: context.eventLoop) {
self.maxIdleTimerFired(context: context)
}

if cancelKeepalive {
self.keepaliveTimer?.cancel()
}

case .close:
// Connection was closing but waiting for all streams to close. They must all be closed
// now so close the connection.
context.close(promise: nil)

case .none:
()
}
}
}

extension ClientConnectionHandler {
private func maybeFlush(context: ChannelHandlerContext) {
if self.inReadLoop {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
/// Resets once `channelReadComplete` returns.
private var inReadLoop: Bool

/// The context of the channel this handler is in.
private var context: ChannelHandlerContext?

/// The current state of the connection.
private var state: StateMachine

Expand Down Expand Up @@ -236,6 +239,11 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {

func handlerAdded(context: ChannelHandlerContext) {
assert(context.eventLoop === self.eventLoop)
self.context = context
}

func handlerRemoved(context: ChannelHandlerContext) {
self.context = nil
}

func channelActive(context: ChannelHandlerContext) {
Expand Down Expand Up @@ -266,23 +274,10 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
switch event {
case let event as NIOHTTP2StreamCreatedEvent:
// The connection isn't idle if a stream is open.
self.maxIdleTimer?.cancel()
self.state.streamOpened(event.streamID)
self.streamCreated(event.streamID, channel: context.channel)

case let event as StreamClosedEvent:
switch self.state.streamClosed(event.streamID) {
case .startIdleTimer:
self.maxIdleTimer?.schedule(on: context.eventLoop) {
self.initiateGracefulShutdown(context: context)
}

case .close:
context.close(mode: .all, promise: nil)

case .none:
()
}
self.streamClosed(event.streamID, channel: context.channel)

default:
()
Expand Down Expand Up @@ -335,6 +330,31 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler {
}
}

extension ServerConnectionManagementHandler: NIOHTTP2StreamDelegate {
func streamCreated(_ id: HTTP2StreamID, channel: any Channel) {
// The connection isn't idle if a stream is open.
self.maxIdleTimer?.cancel()
self.state.streamOpened(id)
}

func streamClosed(_ id: HTTP2StreamID, channel: any Channel) {
guard let context = self.context else { return }

switch self.state.streamClosed(id) {
case .startIdleTimer:
self.maxIdleTimer?.schedule(on: context.eventLoop) {
self.initiateGracefulShutdown(context: context)
}

case .close:
context.close(mode: .all, promise: nil)

case .none:
()
}
}
}

extension ServerConnectionManagementHandler {
private func maybeFlush(context: ChannelHandlerContext) {
if self.inReadLoop {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ final class ClientConnectionHandlerTests: XCTestCase {
extension ClientConnectionHandlerTests {
struct Connection {
let channel: EmbeddedChannel
let streamDelegate: any NIOHTTP2StreamDelegate
var loop: EmbeddedEventLoop {
self.channel.embeddedEventLoop
}
Expand All @@ -245,6 +246,7 @@ extension ClientConnectionHandlerTests {
keepaliveWithoutCalls: allowKeepaliveWithoutCalls
)

self.streamDelegate = handler
self.channel = EmbeddedChannel(handler: handler, loop: loop)
}

Expand All @@ -253,17 +255,11 @@ extension ClientConnectionHandlerTests {
}

func streamOpened(_ id: HTTP2StreamID) {
let event = NIOHTTP2StreamCreatedEvent(
streamID: id,
localInitialWindowSize: nil,
remoteInitialWindowSize: nil
)
self.channel.pipeline.fireUserInboundEventTriggered(event)
self.streamDelegate.streamCreated(id, channel: self.channel)
}

func streamClosed(_ id: HTTP2StreamID) {
let event = StreamClosedEvent(streamID: id, reason: nil)
self.channel.pipeline.fireUserInboundEventTriggered(event)
self.streamDelegate.streamClosed(id, channel: self.channel)
}

func goAway(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ extension ServerConnectionManagementHandlerTests {
extension ServerConnectionManagementHandlerTests {
struct Connection {
let channel: EmbeddedChannel
let streamDelegate: any NIOHTTP2StreamDelegate
let syncView: ServerConnectionManagementHandler.SyncView

var loop: EmbeddedEventLoop {
Expand Down Expand Up @@ -378,6 +379,7 @@ extension ServerConnectionManagementHandlerTests {
clock: self.clock
)

self.streamDelegate = handler
self.syncView = handler.syncView
self.channel = EmbeddedChannel(handler: handler, loop: loop)
}
Expand All @@ -398,17 +400,11 @@ extension ServerConnectionManagementHandlerTests {
}

func streamOpened(_ id: HTTP2StreamID) {
let event = NIOHTTP2StreamCreatedEvent(
streamID: id,
localInitialWindowSize: nil,
remoteInitialWindowSize: nil
)
self.channel.pipeline.fireUserInboundEventTriggered(event)
self.streamDelegate.streamCreated(id, channel: self.channel)
}

func streamClosed(_ id: HTTP2StreamID) {
let event = StreamClosedEvent(streamID: id, reason: nil)
self.channel.pipeline.fireUserInboundEventTriggered(event)
self.streamDelegate.streamClosed(id, channel: self.channel)
}

func ping(data: HTTP2PingData, ack: Bool) throws {
Expand Down