Skip to content

Commit

Permalink
Merge pull request #48 from outfoxx/fix/event-source-cleanup
Browse files Browse the repository at this point in the history
Cleanup EventSource handling of cancellation errors and logging
  • Loading branch information
kdubb committed Sep 11, 2023
2 parents dfee578 + 28a4c6d commit 7f7df30
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 53 deletions.
60 changes: 40 additions & 20 deletions Sources/Sunday/EventSource.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class EventSource {
public enum Error: Swift.Error {
case invalidState
case eventTimeout
case requestStreamEmpty
}

/// Possible states of the `EventSource`
Expand Down Expand Up @@ -118,7 +119,7 @@ public class EventSource {
///
public private(set) var retryTime = retryTimeDefault

private let dataEventStreamFactory: (HTTP.Headers) async throws -> NetworkSession.DataEventStream
private let dataEventStreamFactory: (HTTP.Headers) async throws -> NetworkSession.DataEventStream?
private var dataEventStreamTask: Task<Void, Swift.Error>?
private var receivedString: String?

Expand Down Expand Up @@ -165,14 +166,14 @@ public class EventSource {
queue: DispatchQueue = .global(qos: .background),
eventTimeoutInterval: DispatchTimeInterval? = eventTimeoutIntervalDefault,
eventTimeoutCheckInterval: DispatchTimeInterval = eventTimeoutCheckIntervalDefault,
dataEventStreamFactory: @escaping (HTTP.Headers) async throws -> NetworkSession.DataEventStream
dataEventStreamFactory: @escaping (HTTP.Headers) async throws -> NetworkSession.DataEventStream?
) {
self.queue = DispatchQueue(label: "io.outfoxx.sunday.EventSource", attributes: [], target: queue)
readyStateValue = StateValue(.closed, queue: queue)
self.readyStateValue = StateValue(.closed, queue: queue)
self.dataEventStreamFactory = dataEventStreamFactory
self.eventTimeoutInterval = eventTimeoutInterval
self.eventTimeoutCheckInterval = eventTimeoutCheckInterval
receivedString = nil
self.receivedString = nil
}

deinit {
Expand Down Expand Up @@ -261,7 +262,7 @@ public class EventSource {
/// List of unique event types that handlers are
/// registered for.
///
public func events() -> [String] {
public func registeredListenerTypes() -> [String] {
queue.sync {
return Array(eventListeners.keys)
}
Expand Down Expand Up @@ -289,7 +290,9 @@ public class EventSource {
private func internalConnect() {

guard readyStateValue.isNotClosed else {
#if EVENT_SOURCE_EXTRA_LOGGING
logger.debug("Skipping connect due to close")
#endif
return
}

Expand All @@ -312,7 +315,12 @@ public class EventSource {
// Create a data stream and
do {

let dataStream = try await dataEventStreamFactory(headers)
guard let dataStream = try await dataEventStreamFactory(headers) else {
logger.debug("Stream factory empty")
fireErrorEvent(error: Error.requestStreamEmpty)
close()
return
}

for try await event in dataStream {

Expand Down Expand Up @@ -413,7 +421,9 @@ public class EventSource {
return
}

#if EVENT_SOURCE_EXTRA_LOGGING
logger.debug("Checking Event Timeout")
#endif

let eventTimeoutDeadline = lastEventReceivedTime + eventTimeoutInterval

Expand All @@ -424,7 +434,7 @@ public class EventSource {

logger.debug("Event Timeout Deadline Expired")

fireErrorEvent(error: .eventTimeout)
fireErrorEvent(error: Error.eventTimeout)

scheduleReconnect()
}
Expand All @@ -437,9 +447,9 @@ public class EventSource {
private func receivedHeaders(_ response: HTTPURLResponse) throws {

guard readyStateValue.ifNotClosed(updateTo: .open) else {
logger.error("invalid state for receiving headers: state=\(self.readyState.rawValue, privacy: .public)")
logger.error("Invalid state for receiving headers: state=\(self.readyState.rawValue, privacy: .public)")

fireErrorEvent(error: .invalidState)
fireErrorEvent(error: Error.invalidState)

scheduleReconnect()
return
Expand All @@ -464,9 +474,9 @@ public class EventSource {
private func receivedData(_ data: Data) throws {

guard readyState == .open else {
logger.error("invalid state for receiving data: state=\(self.readyState.rawValue, privacy: .public)")
logger.error("Invalid state for receiving data: state=\(self.readyState.rawValue, privacy: .public)")

fireErrorEvent(error: .invalidState)
fireErrorEvent(error: Error.invalidState)

scheduleReconnect()
return
Expand All @@ -483,6 +493,13 @@ public class EventSource {
return
}

// Quietly close dure to Task or URLTask cancellation
if isCancellationError(error: error) {
fireErrorEvent(error: error)
close()
return
}

logger.debug("Received Error: \(error.localizedDescription, privacy: .public)")

fireErrorEvent(error: error)
Expand All @@ -503,6 +520,13 @@ public class EventSource {
scheduleReconnect()
}

private func isCancellationError(error: Swift.Error) -> Bool {
switch error {
case let urlError as URLError where urlError.code == .cancelled: return true
case is CancellationError: return true
default: return false
}
}


// MARK: Reconnection
Expand Down Expand Up @@ -591,13 +615,13 @@ public class EventSource {
if let retry = info.retry {

if let retryTime = Int(retry.trimmingCharacters(in: .whitespaces), radix: 10) {
logger.debug("update retry timeout: retryTime=\(retryTime)ms")
logger.debug("Update retry timeout: retryTime=\(retryTime)ms")

self.retryTime = .milliseconds(retryTime)

}
else {
logger.debug("ignoring invalid retry timeout message: retry=\(retry, privacy: .public)")
logger.debug("Ignoring invalid retry timeout message: retry=\(retry, privacy: .public)")
}

}
Expand All @@ -616,14 +640,14 @@ public class EventSource {
lastEventId = eventId
}
else {
logger.debug("event id contains null, unable to use for last-event-id")
logger.debug("Event id contains null, unable to use for last-event-id")
}
}

if let onMessageCallback = onMessageCallback {

logger.debug(
"dispatch onMessage: event=\(info.event ?? "", privacy: .public), id=\(info.id ?? "", privacy: .public)"
"Dispatch onMessage: event=\(info.event ?? "", privacy: .public), id=\(info.id ?? "", privacy: .public)"
)

queue.async {
Expand All @@ -639,7 +663,7 @@ public class EventSource {
for eventHandler in eventHandlers {

logger.debug(
"dispatch listener: event=\(info.event ?? "", privacy: .public), id=\(info.id ?? "", privacy: .public)"
"Dispatch listener: event=\(info.event ?? "", privacy: .public), id=\(info.id ?? "", privacy: .public)"
)

eventHandler.value(event, info.id, info.data)
Expand All @@ -649,10 +673,6 @@ public class EventSource {

}

func fireErrorEvent(error: Error) {
fireErrorEvent(error: error as Swift.Error)
}

func fireErrorEvent(error: Swift.Error) {

if let onErrorCallback = onErrorCallback {
Expand Down
52 changes: 29 additions & 23 deletions Sources/Sunday/NetworkRequestFactory.swift
Original file line number Diff line number Diff line change
Expand Up @@ -433,22 +433,25 @@ public class NetworkRequestFactory: RequestFactory {
body: B?, contentTypes: [MediaType]? = nil, acceptTypes: [MediaType]? = nil, headers: Parameters? = nil
) -> EventSource where B: Encodable {

eventSource(from: { try await self.request(
method: method,
pathTemplate: pathTemplate,
pathParameters: pathParameters,
queryParameters: queryParameters,
body: body,
contentTypes: contentTypes,
acceptTypes: acceptTypes,
headers: headers
)})
eventSource(from: {
if self.session.isClosed { return nil }
return try await self.request(
method: method,
pathTemplate: pathTemplate,
pathParameters: pathParameters,
queryParameters: queryParameters,
body: body,
contentTypes: contentTypes,
acceptTypes: acceptTypes,
headers: headers
)
})
}

public func eventSource(from requestFactory: @escaping () async throws -> URLRequest) -> EventSource {
public func eventSource(from requestFactory: @escaping () async throws -> URLRequest?) -> EventSource {

return EventSource(queue: requestQueue) { headers in
let request = try await requestFactory()
guard let request = try await requestFactory() else { return nil }
return try self.eventSession.dataEventStream(for: request.adding(httpHeaders: headers))
}
}
Expand All @@ -461,22 +464,25 @@ public class NetworkRequestFactory: RequestFactory {

eventStream(
decoder: decoder,
from: { try await self.request(
method: method,
pathTemplate: pathTemplate,
pathParameters: pathParameters,
queryParameters: queryParameters,
body: body,
contentTypes: contentTypes,
acceptTypes: acceptTypes,
headers: headers
)}
from: {
if self.session.isClosed { return nil }
return try await self.request(
method: method,
pathTemplate: pathTemplate,
pathParameters: pathParameters,
queryParameters: queryParameters,
body: body,
contentTypes: contentTypes,
acceptTypes: acceptTypes,
headers: headers
)
}
)
}

public func eventStream<D>(
decoder: @escaping (TextMediaTypeDecoder, String?, String?, String, Logger) throws -> D?,
from requestFactory: @escaping () async throws -> URLRequest
from requestFactory: @escaping () async throws -> URLRequest?
) -> AsyncStream<D> {

guard let jsonDecoder = try? mediaTypeDecoders.find(for: .json) as? TextMediaTypeDecoder else {
Expand Down
4 changes: 2 additions & 2 deletions Sources/Sunday/Patching.swift
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ public enum PatchOp<Value: Codable>: AnyPatchOp, Codable {

extension UpdateOp {

public static func merge<Value: Codable>(_ value: Value) -> UpdateOp<Value> { .set(value) }
public static func merge<NewValue: Codable>(_ value: NewValue) -> UpdateOp<NewValue> { .set(value) }

}

Expand All @@ -187,7 +187,7 @@ extension UpdateOp: CustomStringConvertible {

extension PatchOp {

public static func merge<Value: Codable>(_ value: Value) -> PatchOp<Value> { .set(value) }
public static func merge<NewValue: Codable>(_ value: NewValue) -> PatchOp<NewValue> { .set(value) }

}

Expand Down
12 changes: 6 additions & 6 deletions Sources/SundayServer/RoutableParam.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ public struct Param<T> {
return body(Data.self)
}

public static func body<T>(_ type: T.Type) -> Param<T> where T: Decodable {
return Param<T>(name: bodyParameterName) { _, req, res in
public static func body<B>(_ type: B.Type) -> Param<B> where B: Decodable {
return Param<B>(name: bodyParameterName) { _, req, res in
guard let decoder = res.properties[bodyDecoderPropertyName] as? MediaTypeDecoder else { return nil }
guard let body = req.body else { return nil }
return try decoder.decode(type, from: body)
}
}

public static func body<T>(ref: T.Type) -> Param<T> {
public static func body(ref: T.Type) -> Param<T> {
return body(ref: ref, using: Ref.self)
}

public static func body<T, TKP: TypeKeyProvider, VKP: ValueKeyProvider, TI: TypeIndex>(
public static func body<TKP: TypeKeyProvider, VKP: ValueKeyProvider, TI: TypeIndex>(
ref: T.Type,
using refType: CustomRef<TKP, VKP, TI>.Type
) -> Param<T> {
Expand All @@ -116,11 +116,11 @@ public struct Param<T> {
}
}

public static func body<T>(embebbedRef: T.Type) -> Param<T> {
public static func body(embebbedRef: T.Type) -> Param<T> {
return body(embeddedRef: embebbedRef, using: EmbeddedRef.self)
}

public static func body<T, TKP: TypeKeyProvider, TI: TypeIndex>(
public static func body<TKP: TypeKeyProvider, TI: TypeIndex>(
embeddedRef: T.Type,
using refType: CustomEmbeddedRef<TKP, TI>.Type
) -> Param<T> {
Expand Down
55 changes: 53 additions & 2 deletions Tests/SundayTests/EventSourceTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ class EventSourceTests: XCTestCase {
}

let handlerId = eventSource.addEventListener(for: "test") { _, _, _ in }
XCTAssertTrue(!eventSource.events().isEmpty)
XCTAssertTrue(!eventSource.registeredListenerTypes().isEmpty)

eventSource.removeEventListener(handlerId: handlerId, for: "test")
XCTAssertTrue(eventSource.events().isEmpty)
XCTAssertTrue(eventSource.registeredListenerTypes().isEmpty)
}

func testValidRetryTimeoutUpdate() throws {
Expand Down Expand Up @@ -555,4 +555,55 @@ class EventSourceTests: XCTestCase {
waitForExpectations()
}

func testCloseWhenRequestFactoryReturnsNil() throws {

let server = try! RoutingHTTPServer(port: .any, localOnly: true) {
Path("/simple") {
GET { _, res in
res.start(status: .ok, headers: [
HTTP.StdHeaders.contentType: [MediaType.eventStream.value],
HTTP.StdHeaders.transferEncoding: ["chunked"],
])
res.send(chunk: "event: test\n".data(using: .utf8) ?? Data())
res.send(chunk: "id: 123\n".data(using: .utf8) ?? Data())
res.send(chunk: "data: some test data\n\n".data(using: .utf8) ?? Data())
res.finish(trailers: [:])
}
}
}
guard let serverURL = server.startLocal(timeout: 5.0) else {
XCTFail("could not start local server")
return
}
defer { server.stop() }

let session = NetworkSession(configuration: .default)
defer { session.close(cancelOutstandingTasks: true) }

let url = try XCTUnwrap(URL(string: "/simple", relativeTo: serverURL))
var requestsReturned = 0
let eventSource =
EventSource {
if requestsReturned > 1 {
return nil
}
defer { requestsReturned += 1 }
let request = URLRequest(url: url).adding(httpHeaders: $0)
return try session.dataEventStream(for: request)
}

let closeErrorX = expectation(description: "EventSource Close Error")

eventSource.onError = { error in
guard let error = error as? EventSource.Error, case .requestStreamEmpty = error else {
return
}
closeErrorX.fulfill()
}

eventSource.connect()

waitForExpectations()
}

}

0 comments on commit 7f7df30

Please sign in to comment.