From a9a73a159d35c131f04ba9cff0d8996319067012 Mon Sep 17 00:00:00 2001 From: Simon Brockmann Date: Tue, 7 May 2019 16:31:03 +0200 Subject: [PATCH] Add cancellation token --- BoltsSwift.xcodeproj/project.pbxproj | 39 +++ Sources/BoltsSwift/CancellationToken.swift | 132 ++++++++++ .../CancellationTokenRegistration.swift | 52 ++++ .../BoltsSwift/CancellationTokenSource.swift | 36 +++ Sources/BoltsSwift/Errors.swift | 8 + Sources/BoltsSwift/Executor.swift | 4 +- Sources/BoltsSwift/Task+ContinueWith.swift | 137 ++++++++-- Sources/BoltsSwift/Task+Delay.swift | 26 ++ Sources/BoltsSwift/Task.swift | 8 +- Tests/CancellationTests.swift | 158 ++++++++++++ Tests/TaskTests.swift | 237 ++++++++++++------ 11 files changed, 730 insertions(+), 107 deletions(-) create mode 100644 Sources/BoltsSwift/CancellationToken.swift create mode 100644 Sources/BoltsSwift/CancellationTokenRegistration.swift create mode 100644 Sources/BoltsSwift/CancellationTokenSource.swift create mode 100644 Tests/CancellationTests.swift diff --git a/BoltsSwift.xcodeproj/project.pbxproj b/BoltsSwift.xcodeproj/project.pbxproj index 4ea1df8..20b73bd 100644 --- a/BoltsSwift.xcodeproj/project.pbxproj +++ b/BoltsSwift.xcodeproj/project.pbxproj @@ -38,6 +38,21 @@ 81D3007F1C93AF9F00E1A1ED /* TaskTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 81D300781C93AF9F00E1A1ED /* TaskTests.swift */; }; 81D300801C93AF9F00E1A1ED /* TaskTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 81D300781C93AF9F00E1A1ED /* TaskTests.swift */; }; 87FEF3721A9085FA00C60678 /* BoltsSwift.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 87FEF3661A9085FA00C60678 /* BoltsSwift.framework */; }; + EA34D7492281A8D60024A0C3 /* CancellationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA34D7482281A8D60024A0C3 /* CancellationTests.swift */; }; + EA34D74A2281A8D60024A0C3 /* CancellationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA34D7482281A8D60024A0C3 /* CancellationTests.swift */; }; + EA34D74B2281A8D60024A0C3 /* CancellationTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA34D7482281A8D60024A0C3 /* CancellationTests.swift */; }; + EA6E8600227C7F10009A18B7 /* CancellationToken.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E85FF227C7F10009A18B7 /* CancellationToken.swift */; }; + EA6E8601227C7F10009A18B7 /* CancellationToken.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E85FF227C7F10009A18B7 /* CancellationToken.swift */; }; + EA6E8602227C7F10009A18B7 /* CancellationToken.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E85FF227C7F10009A18B7 /* CancellationToken.swift */; }; + EA6E8603227C7F10009A18B7 /* CancellationToken.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E85FF227C7F10009A18B7 /* CancellationToken.swift */; }; + EA6E8605227C7F26009A18B7 /* CancellationTokenRegistration.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E8604227C7F26009A18B7 /* CancellationTokenRegistration.swift */; }; + EA6E8606227C7F26009A18B7 /* CancellationTokenRegistration.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E8604227C7F26009A18B7 /* CancellationTokenRegistration.swift */; }; + EA6E8607227C7F26009A18B7 /* CancellationTokenRegistration.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E8604227C7F26009A18B7 /* CancellationTokenRegistration.swift */; }; + EA6E8608227C7F26009A18B7 /* CancellationTokenRegistration.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E8604227C7F26009A18B7 /* CancellationTokenRegistration.swift */; }; + EA6E860A227C7F32009A18B7 /* CancellationTokenSource.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E8609227C7F32009A18B7 /* CancellationTokenSource.swift */; }; + EA6E860B227C7F32009A18B7 /* CancellationTokenSource.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E8609227C7F32009A18B7 /* CancellationTokenSource.swift */; }; + EA6E860C227C7F32009A18B7 /* CancellationTokenSource.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E8609227C7F32009A18B7 /* CancellationTokenSource.swift */; }; + EA6E860D227C7F32009A18B7 /* CancellationTokenSource.swift in Sources */ = {isa = PBXBuildFile; fileRef = EA6E8609227C7F32009A18B7 /* CancellationTokenSource.swift */; }; F569C0C11CFF6A07000749B6 /* Task+ContinueWith.swift in Sources */ = {isa = PBXBuildFile; fileRef = F569C0C01CFF6A07000749B6 /* Task+ContinueWith.swift */; }; F569C0C21CFF6A07000749B6 /* Task+ContinueWith.swift in Sources */ = {isa = PBXBuildFile; fileRef = F569C0C01CFF6A07000749B6 /* Task+ContinueWith.swift */; }; F569C0C31CFF6A07000749B6 /* Task+ContinueWith.swift in Sources */ = {isa = PBXBuildFile; fileRef = F569C0C01CFF6A07000749B6 /* Task+ContinueWith.swift */; }; @@ -118,6 +133,10 @@ 81D300781C93AF9F00E1A1ED /* TaskTests.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = TaskTests.swift; sourceTree = ""; }; 87FEF3661A9085FA00C60678 /* BoltsSwift.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = BoltsSwift.framework; sourceTree = BUILT_PRODUCTS_DIR; }; 87FEF3711A9085FA00C60678 /* BoltsSwiftTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = BoltsSwiftTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; + EA34D7482281A8D60024A0C3 /* CancellationTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CancellationTests.swift; sourceTree = ""; }; + EA6E85FF227C7F10009A18B7 /* CancellationToken.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CancellationToken.swift; sourceTree = ""; }; + EA6E8604227C7F26009A18B7 /* CancellationTokenRegistration.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CancellationTokenRegistration.swift; sourceTree = ""; }; + EA6E8609227C7F32009A18B7 /* CancellationTokenSource.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CancellationTokenSource.swift; sourceTree = ""; }; F569C0C01CFF6A07000749B6 /* Task+ContinueWith.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Task+ContinueWith.swift"; sourceTree = ""; }; F569C0CB1CFF6AEE000749B6 /* Task+Delay.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Task+Delay.swift"; sourceTree = ""; }; F569C0D61CFF6B18000749B6 /* Task+WhenAll.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = "Task+WhenAll.swift"; sourceTree = ""; }; @@ -259,6 +278,9 @@ F569C0E01CFF6B1F000749B6 /* Task+WhenAny.swift */, 81D300641C93AF7300E1A1ED /* Executor.swift */, 81D300631C93AF7300E1A1ED /* Errors.swift */, + EA6E85FF227C7F10009A18B7 /* CancellationToken.swift */, + EA6E8604227C7F26009A18B7 /* CancellationTokenRegistration.swift */, + EA6E8609227C7F32009A18B7 /* CancellationTokenSource.swift */, ); path = BoltsSwift; sourceTree = ""; @@ -274,6 +296,7 @@ 81D300741C93AF9F00E1A1ED /* Tests */ = { isa = PBXGroup; children = ( + EA34D7482281A8D60024A0C3 /* CancellationTests.swift */, 81D300781C93AF9F00E1A1ED /* TaskTests.swift */, 81D300771C93AF9F00E1A1ED /* TaskCompletionSourceTests.swift */, 81D300751C93AF9F00E1A1ED /* ExecutorTests.swift */, @@ -512,6 +535,7 @@ developmentRegion = English; hasScannedForEncodings = 0; knownRegions = ( + English, en, ); mainGroup = 87FEF35C1A9085FA00C60678; @@ -592,9 +616,12 @@ 065894EF1C9A9391000FDDA6 /* Task.swift in Sources */, F569C0CF1CFF6AEE000749B6 /* Task+Delay.swift in Sources */, F569C0C41CFF6A07000749B6 /* Task+ContinueWith.swift in Sources */, + EA6E8603227C7F10009A18B7 /* CancellationToken.swift in Sources */, + EA6E860D227C7F32009A18B7 /* CancellationTokenSource.swift in Sources */, F569C0E41CFF6B1F000749B6 /* Task+WhenAny.swift in Sources */, 065894F11C9A9391000FDDA6 /* Executor.swift in Sources */, 065894F01C9A9391000FDDA6 /* TaskCompletionSource.swift in Sources */, + EA6E8608227C7F26009A18B7 /* CancellationTokenRegistration.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -607,9 +634,12 @@ 065894F61C9A93B7000FDDA6 /* Task.swift in Sources */, F569C0CE1CFF6AEE000749B6 /* Task+Delay.swift in Sources */, F569C0C31CFF6A07000749B6 /* Task+ContinueWith.swift in Sources */, + EA6E8602227C7F10009A18B7 /* CancellationToken.swift in Sources */, + EA6E860C227C7F32009A18B7 /* CancellationTokenSource.swift in Sources */, F569C0E31CFF6B1F000749B6 /* Task+WhenAny.swift in Sources */, 065894F71C9A93B7000FDDA6 /* Errors.swift in Sources */, 065894F81C9A93B7000FDDA6 /* TaskCompletionSource.swift in Sources */, + EA6E8607227C7F26009A18B7 /* CancellationTokenRegistration.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -619,6 +649,7 @@ files = ( 065895121C9A947B000FDDA6 /* ExecutorTests.swift in Sources */, 065895131C9A947B000FDDA6 /* TaskTests.swift in Sources */, + EA34D74B2281A8D60024A0C3 /* CancellationTests.swift in Sources */, 065895141C9A947B000FDDA6 /* TaskCompletionSourceTests.swift in Sources */, 810AB3221C9B1AC3005B6184 /* XCTestCase+TestName.swift in Sources */, ); @@ -633,9 +664,12 @@ 81D3006D1C93AF7300E1A1ED /* Task.swift in Sources */, F569C0CD1CFF6AEE000749B6 /* Task+Delay.swift in Sources */, F569C0C21CFF6A07000749B6 /* Task+ContinueWith.swift in Sources */, + EA6E8601227C7F10009A18B7 /* CancellationToken.swift in Sources */, + EA6E860B227C7F32009A18B7 /* CancellationTokenSource.swift in Sources */, F569C0E21CFF6B1F000749B6 /* Task+WhenAny.swift in Sources */, 81D300691C93AF7300E1A1ED /* Errors.swift in Sources */, 81D3006F1C93AF7300E1A1ED /* TaskCompletionSource.swift in Sources */, + EA6E8606227C7F26009A18B7 /* CancellationTokenRegistration.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -645,6 +679,7 @@ files = ( 81D3007A1C93AF9F00E1A1ED /* ExecutorTests.swift in Sources */, 81D300801C93AF9F00E1A1ED /* TaskTests.swift in Sources */, + EA34D74A2281A8D60024A0C3 /* CancellationTests.swift in Sources */, 81D3007E1C93AF9F00E1A1ED /* TaskCompletionSourceTests.swift in Sources */, 810AB3211C9B1AC3005B6184 /* XCTestCase+TestName.swift in Sources */, ); @@ -659,9 +694,12 @@ 81D3006C1C93AF7300E1A1ED /* Task.swift in Sources */, F569C0CC1CFF6AEE000749B6 /* Task+Delay.swift in Sources */, F569C0C11CFF6A07000749B6 /* Task+ContinueWith.swift in Sources */, + EA6E8600227C7F10009A18B7 /* CancellationToken.swift in Sources */, + EA6E860A227C7F32009A18B7 /* CancellationTokenSource.swift in Sources */, F569C0E11CFF6B1F000749B6 /* Task+WhenAny.swift in Sources */, 81D300681C93AF7300E1A1ED /* Errors.swift in Sources */, 81D3006E1C93AF7300E1A1ED /* TaskCompletionSource.swift in Sources */, + EA6E8605227C7F26009A18B7 /* CancellationTokenRegistration.swift in Sources */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -671,6 +709,7 @@ files = ( 81D300791C93AF9F00E1A1ED /* ExecutorTests.swift in Sources */, 81D3007F1C93AF9F00E1A1ED /* TaskTests.swift in Sources */, + EA34D7492281A8D60024A0C3 /* CancellationTests.swift in Sources */, 81D3007D1C93AF9F00E1A1ED /* TaskCompletionSourceTests.swift in Sources */, 810AB3201C9B1AC3005B6184 /* XCTestCase+TestName.swift in Sources */, ); diff --git a/Sources/BoltsSwift/CancellationToken.swift b/Sources/BoltsSwift/CancellationToken.swift new file mode 100644 index 0000000..73f1b05 --- /dev/null +++ b/Sources/BoltsSwift/CancellationToken.swift @@ -0,0 +1,132 @@ +// +// CancellationToken.swift +// BoltsSwift +// +// Copyright © 2019 Facebook. All rights reserved. +// + +import Foundation + +public typealias CancellationObserver = () -> Void + +public final class CancellationToken { + public var cancellationRequested: Bool { + get { + return synchronizationQueue.sync(flags: .barrier) { () -> Bool in + return _cancellationRequested + + } + } + } + + private let synchronizationQueue = DispatchQueue(label: "com.bolts.cancellationToken", attributes: DispatchQueue.Attributes.concurrent) + private var _cancellationRequested: Bool + private var _regestrations = [CancellationTokenRegistration]() + private var _disposed: Bool + private var _cancelDelayedWorkItem: DispatchWorkItem? + + public init() { + _cancellationRequested = false + _disposed = false + } + + @discardableResult + public func registerCancellationObserver(observer: @escaping CancellationObserver) -> CancellationTokenRegistration? { + return synchronizationQueue.sync(flags: .barrier) { () -> CancellationTokenRegistration? in + if _disposed { + return nil + } + let registration = CancellationTokenRegistration.registrationWithToken(token: self, observer: observer) + _regestrations.append(registration) + return registration + } + } + + internal func unrigsterRegistration(registration: CancellationTokenRegistration) throws { + try synchronizationQueue.sync(flags: .barrier) { () -> Void in + try throwIfDisposed() + if let index = _regestrations.firstIndex(where: { $0 === registration }) { + _regestrations.remove(at: index) + } + } + } + + internal func cancel() throws { + var registrations: [CancellationTokenRegistration]? + try synchronizationQueue.sync(flags: .barrier) { () -> Void in + try throwIfDisposed() + if _cancellationRequested { + return + } + if let cancelWorkItem = _cancelDelayedWorkItem { + cancelWorkItem.cancel() + _cancelDelayedWorkItem = nil + } + _cancellationRequested = true + registrations = [CancellationTokenRegistration](_regestrations) + } + if registrations == nil { + return + } + try notifyCancellation(registrations: registrations!) + } + + private func notifyCancellation(registrations: [CancellationTokenRegistration]) throws { + for registration in registrations { + try registration.notifyDelegate() + } + } + + internal func cancelAfterInterval(interval: TimeInterval) throws { + try throwIfDisposed() + + if interval < 0 { + throw IntervalError() + } + + if interval == 0 { + try cancel() + return + } + + try synchronizationQueue.sync(flags: .barrier) { () -> Void in + try throwIfDisposed() + if let cancelWorkItem = _cancelDelayedWorkItem { + cancelWorkItem.cancel() + _cancelDelayedWorkItem = nil + } + if _cancellationRequested { + return + } + _cancelDelayedWorkItem = DispatchWorkItem { [weak self] in + try? self?.cancel() + } + DispatchQueue.global().asyncAfter(deadline: .now() + interval, execute: _cancelDelayedWorkItem!) + } + } + + internal func dispose() throws { + var registrations: [CancellationTokenRegistration]? + synchronizationQueue.sync(flags: .barrier) { () -> Void in + if _disposed { + return + } + registrations = [CancellationTokenRegistration](_regestrations) + _regestrations.removeAll() + } + if registrations != nil { + try registrations!.forEach { + try $0.dispose() + } + } + synchronizationQueue.sync(flags: .barrier, execute: { () -> Void in + _disposed = true + }) + } + + private func throwIfDisposed() throws { + if _disposed { + throw DisposedError() + } + } +} diff --git a/Sources/BoltsSwift/CancellationTokenRegistration.swift b/Sources/BoltsSwift/CancellationTokenRegistration.swift new file mode 100644 index 0000000..2d72467 --- /dev/null +++ b/Sources/BoltsSwift/CancellationTokenRegistration.swift @@ -0,0 +1,52 @@ +// +// CancellationTokenRegistration.swift +// BoltsSwift +// +// Copyright © 2019 Facebook. All rights reserved. +// + +import Foundation + +public final class CancellationTokenRegistration { + private var _disposed: Bool + private let _synchronizationQueue = DispatchQueue(label: "com.bolts.cancellationTokenRegistration", attributes: DispatchQueue.Attributes.concurrent) + private var _observer: CancellationObserver? + private weak var _token: CancellationToken? + + private init(token: CancellationToken, observer: @escaping CancellationObserver) { + _disposed = false + _observer = observer + _token = token + } + + public class func registrationWithToken(token: CancellationToken, observer: @escaping CancellationObserver) -> CancellationTokenRegistration { + return CancellationTokenRegistration(token: token, observer: observer) + } + + public func dispose() throws { + _synchronizationQueue.sync(flags: .barrier) { () -> Void in + if _disposed { + return + } + _disposed = true + } + if let token = _token { + try token.unrigsterRegistration(registration: self) + _token = nil + } + _observer = nil + } + + internal func notifyDelegate() throws { + try _synchronizationQueue.sync(flags: .barrier) { () -> Void in + try throwIfDisposed() + _observer?() + } + } + + private func throwIfDisposed() throws { + if _disposed { + throw DisposedError() + } + } +} diff --git a/Sources/BoltsSwift/CancellationTokenSource.swift b/Sources/BoltsSwift/CancellationTokenSource.swift new file mode 100644 index 0000000..1954b3c --- /dev/null +++ b/Sources/BoltsSwift/CancellationTokenSource.swift @@ -0,0 +1,36 @@ +// +// CancellationTokenSource.swift +// BoltsSwift +// +// Copyright © 2019 Facebook. All rights reserved. +// + +import Foundation + +public final class CancellationTokenSource { + public private(set) var token: CancellationToken + + public init() { + token = CancellationToken() + } + + public class func cancellationTokenSource() -> CancellationTokenSource { + return CancellationTokenSource() + } + + public func isCancellationRequested() -> Bool { + return token.cancellationRequested + } + + public func cancel() throws { + try token.cancel() + } + + public func cancelAfterInterval(interval: TimeInterval) throws { + try token.cancelAfterInterval(interval: interval) + } + + public func dispose() throws { + try token.dispose() + } +} diff --git a/Sources/BoltsSwift/Errors.swift b/Sources/BoltsSwift/Errors.swift index ec3db5b..d9434bc 100644 --- a/Sources/BoltsSwift/Errors.swift +++ b/Sources/BoltsSwift/Errors.swift @@ -32,3 +32,11 @@ public struct CancelledError: Error { */ public init() { } } + +public struct DisposedError: Error { + public init() {} +} + +public struct IntervalError: Error { + public init() {} +} diff --git a/Sources/BoltsSwift/Executor.swift b/Sources/BoltsSwift/Executor.swift index 462483a..34edc43 100644 --- a/Sources/BoltsSwift/Executor.swift +++ b/Sources/BoltsSwift/Executor.swift @@ -128,9 +128,9 @@ extension Executor : CustomStringConvertible, CustomDebugStringConvertible { case .operationQueue(let queue): return "\(description): \(queue)" case .closure(let closure): - return "\(description): \(closure)" + return "\(description): \(String(describing: closure))" case .escapingClosure(let closure): - return "\(description): \(closure)" + return "\(description): \(String(describing: closure))" default: return description } diff --git a/Sources/BoltsSwift/Task+ContinueWith.swift b/Sources/BoltsSwift/Task+ContinueWith.swift index 8fbc2b7..f2bdc98 100644 --- a/Sources/BoltsSwift/Task+ContinueWith.swift +++ b/Sources/BoltsSwift/Task+ContinueWith.swift @@ -20,25 +20,36 @@ extension Task { - parameter executor: The executor to invoke the closure on. - parameter options: The options to run the closure with - parameter continuation: The closure to execute. + - parameter cancellationToken: The cancellationToken to cancel the task queue. - returns: The task resulting from the continuation */ fileprivate func continueWithTask(_ executor: Executor, - options: TaskContinuationOptions, - continuation: @escaping ((Task) throws -> Task) - ) -> Task { + options: TaskContinuationOptions, + cancellationToken: CancellationToken? = nil, + continuation: @escaping ((Task) throws -> Task) + ) -> Task { let taskCompletionSource = TaskCompletionSource() let wrapperContinuation = { + if (cancellationToken?.cancellationRequested ?? false) { + taskCompletionSource.cancel() + return + } switch self.state { case .success where options.contains(.RunOnSuccess): fallthrough case .error where options.contains(.RunOnError): fallthrough case .cancelled where options.contains(.RunOnCancelled): executor.execute { + let wrappedState = TaskState>.fromClosure { try continuation(self) } switch wrappedState { case .success(let nextTask): + if (cancellationToken?.cancellationRequested ?? false) { + taskCompletionSource.cancel() + return + } switch nextTask.state { case .pending: nextTask.continueWith { nextTask in @@ -76,14 +87,15 @@ extension Task { /** Enqueues a given closure to be run once this task is complete. - - parameter executor: Determines how the the closure is called. The default is to call the closure immediately. - - parameter continuation: The closure that returns the result of the task. + - parameter executor: Determines how the the closure is called. The default is to call the closure immediately. + - parameter cancellationToken: The cancellationToken to cancel the task queue. + - parameter continuation: The closure that returns the result of the task. - returns: A task that will be completed with a result from a given closure. */ @discardableResult - public func continueWith(_ executor: Executor = .default, continuation: @escaping ((Task) throws -> S)) -> Task { - return continueWithTask(executor) { task in + public func continueWith(_ executor: Executor = .default, _ cancelationToken: CancellationToken? = nil, continuation: @escaping ((Task) throws -> S)) -> Task { + return continueWithTask(executor, cancelationToken) { task in let state = TaskState.fromClosure({ try continuation(task) }) @@ -93,15 +105,47 @@ extension Task { /** Enqueues a given closure to be run once this task is complete. + + - parameter cancellationToken: The cancellationToken to cancel the task queue. + - parameter continuation: The closure that returns the result of the task. + + - returns: A task that will be completed with a result from a given closure. + */ + @discardableResult + public func continueWith(_ cancelationToken: CancellationToken, continuation: @escaping ((Task) throws -> S)) -> Task { + return continueWithTask(.default, cancelationToken) { task in + let state = TaskState.fromClosure({ + try continuation(task) + }) + return Task(state: state) + } + } - - parameter executor: Determines how the the closure is called. The default is to call the closure immediately. - - parameter continuation: The closure that returns a task to chain on. + /** + Enqueues a given closure to be run once this task is complete. + + - parameter cancellationToken: The cancellationToken to cancel the task queue. + - parameter continuation: The closure that returns a task to chain on. + + - returns: A task that will be completed when a task returned from a closure is completed. + */ + @discardableResult + public func continueWithTask(_ cancellationToken: CancellationToken, continuation: @escaping ((Task) throws -> Task)) -> Task { + return continueWithTask(.default, options: .RunAlways, cancellationToken: cancellationToken, continuation: continuation) + } + /** + Enqueues a given closure to be run once this task is complete. + + - parameter executor: Determines how the the closure is called. The default is to call the closure immediately. + - parameter cancellationToken: The cancellationToken to cancel the task queue. + - parameter continuation: The closure that returns a task to chain on. + - returns: A task that will be completed when a task returned from a closure is completed. */ @discardableResult - public func continueWithTask(_ executor: Executor = .default, continuation: @escaping ((Task) throws -> Task)) -> Task { - return continueWithTask(executor, options: .RunAlways, continuation: continuation) + public func continueWithTask(_ executor: Executor = .default, _ cancellationToken: CancellationToken? = nil, continuation: @escaping ((Task) throws -> Task)) -> Task { + return continueWithTask(executor, options: .RunAlways, cancellationToken: cancellationToken, continuation: continuation) } } @@ -113,15 +157,17 @@ extension Task { /** Enqueues a given closure to be run once this task completes with success (has intended result). - - parameter executor: Determines how the the closure is called. The default is to call the closure immediately. - - parameter continuation: The closure that returns a task to chain on. + - parameter executor: Determines how the the closure is called. The default is to call the closure immediately. + - parameter cancellationToken: The cancellationToken to cancel the task queue. + - parameter continuation: The closure that returns a task to chain on. - returns: A task that will be completed when a task returned from a closure is completed. */ @discardableResult public func continueOnSuccessWith(_ executor: Executor = .default, - continuation: @escaping ((TResult) throws -> S)) -> Task { - return continueOnSuccessWithTask(executor) { taskResult in + _ cancellationToken: CancellationToken? = nil, + continuation: @escaping ((TResult) throws -> S)) -> Task { + return continueOnSuccessWithTask(executor, cancellationToken) { taskResult in let state = TaskState.fromClosure({ try continuation(taskResult) }) @@ -131,16 +177,53 @@ extension Task { /** Enqueues a given closure to be run once this task completes with success (has intended result). + + - parameter cancellationToken: The cancellationToken to cancel the task queue. + - parameter continuation: The closure that returns a task to chain on. + + - returns: A task that will be completed when a task returned from a closure is completed. + */ + @discardableResult + public func continueOnSuccessWith(_ cancellationToken: CancellationToken, + continuation: @escaping ((TResult) throws -> S)) -> Task { + return continueOnSuccessWithTask(.default, cancellationToken) { taskResult in + let state = TaskState.fromClosure({ + try continuation(taskResult) + }) + return Task(state: state) + } + } - - parameter executor: Determines how the the closure is called. The default is to call the closure immediately. - - parameter continuation: The closure that returns a task to chain on. + /** + Enqueues a given closure to be run once this task completes with success (has intended result). + + - parameter cancellationToken: The cancellationToken to cancel the task queue. + - parameter continuation: The closure that returns a task to chain on. + + - returns: A task that will be completed when a task returned from a closure is completed. + */ + @discardableResult + public func continueOnSuccessWithTask(_ cancellationToken: CancellationToken, + continuation: @escaping ((TResult) throws -> Task)) -> Task { + return continueWithTask(.default, options: .RunOnSuccess, cancellationToken: cancellationToken) { task in + return try continuation(task.result!) + } + } + + /** + Enqueues a given closure to be run once this task completes with success (has intended result). + + - parameter executor: Determines how the the closure is called. The default is to call the closure immediately. + - parameter cancellationToken: The cancellationToken to cancel the task queue. + - parameter continuation: The closure that returns a task to chain on. - returns: A task that will be completed when a task returned from a closure is completed. */ @discardableResult public func continueOnSuccessWithTask(_ executor: Executor = .default, - continuation: @escaping ((TResult) throws -> Task)) -> Task { - return continueWithTask(executor, options: .RunOnSuccess) { task in + _ cancellationToken: CancellationToken? = nil, + continuation: @escaping ((TResult) throws -> Task)) -> Task { + return continueWithTask(executor, options: .RunOnSuccess, cancellationToken: cancellationToken) { task in return try continuation(task.result!) } } @@ -160,8 +243,8 @@ extension Task { - returns: A task that will be completed when a task returned from a closure is completed. */ @discardableResult - public func continueOnErrorWith(_ executor: Executor = .default, continuation: @escaping ((E) throws -> TResult)) -> Task { - return continueOnErrorWithTask(executor) { (error: E) in + public func continueOnErrorWith(_ executor: Executor = .default, _ cancellationToken: CancellationToken? = nil, continuation: @escaping ((E) throws -> TResult)) -> Task { + return continueOnErrorWithTask(executor, cancellationToken) { (error: E) in let state = TaskState.fromClosure({ try continuation(error) }) @@ -178,8 +261,8 @@ extension Task { - returns: A task that will be completed when a task returned from a closure is completed. */ @discardableResult - public func continueOnErrorWith(_ executor: Executor = .default, continuation: @escaping ((Error) throws -> TResult)) -> Task { - return continueOnErrorWithTask(executor) { (error: Error) in + public func continueOnErrorWith(_ executor: Executor = .default, _ cancellationToken: CancellationToken? = nil, continuation: @escaping ((Error) throws -> TResult)) -> Task { + return continueOnErrorWithTask(executor, cancellationToken) { (error: Error) in let state = TaskState.fromClosure({ try continuation(error) }) @@ -196,8 +279,8 @@ extension Task { - returns: A task that will be completed when a task returned from a closure is completed. */ @discardableResult - public func continueOnErrorWithTask(_ executor: Executor = .default, continuation: @escaping ((E) throws -> Task)) -> Task { - return continueOnErrorWithTask(executor) { (error: Error) in + public func continueOnErrorWithTask(_ executor: Executor = .default, _ cancellationToken: CancellationToken? = nil, continuation: @escaping ((E) throws -> Task)) -> Task { + return continueOnErrorWithTask(executor, cancellationToken) { (error: Error) in if let error = error as? E { return try continuation(error) } @@ -214,8 +297,8 @@ extension Task { - returns: A task that will be completed when a task returned from a closure is completed. */ @discardableResult - public func continueOnErrorWithTask(_ executor: Executor = .default, continuation: @escaping ((Error) throws -> Task)) -> Task { - return continueWithTask(executor, options: .RunOnError) { task in + public func continueOnErrorWithTask(_ executor: Executor = .default, _ cancellationToken: CancellationToken? = nil, continuation: @escaping ((Error) throws -> Task)) -> Task { + return continueWithTask(executor, options: .RunOnError, cancellationToken: cancellationToken) { task in return try continuation(task.error!) } } diff --git a/Sources/BoltsSwift/Task+Delay.swift b/Sources/BoltsSwift/Task+Delay.swift index 1c8feb6..8b0aa6f 100644 --- a/Sources/BoltsSwift/Task+Delay.swift +++ b/Sources/BoltsSwift/Task+Delay.swift @@ -29,4 +29,30 @@ extension Task { } return taskCompletionSource.task } + + public class func withDelay(_ delay: TimeInterval, _ cancellationToken: CancellationToken) -> Task { + let taskCompletionSource = TaskCompletionSource() + + if cancellationToken.cancellationRequested { + taskCompletionSource.cancel() + return taskCompletionSource.task + } + + let dispatchItem = DispatchWorkItem { + if cancellationToken.cancellationRequested { + taskCompletionSource.cancel() + return + } + taskCompletionSource.trySet(result: ()) + } + + let time = DispatchTime.now() + delay + DispatchQueue.global(qos: .default).asyncAfter(deadline: time, execute: dispatchItem) + + cancellationToken.registerCancellationObserver { + dispatchItem.cancel() + taskCompletionSource.tryCancel() + } + return taskCompletionSource.task + } } diff --git a/Sources/BoltsSwift/Task.swift b/Sources/BoltsSwift/Task.swift index 95a6091..2c5a4b8 100644 --- a/Sources/BoltsSwift/Task.swift +++ b/Sources/BoltsSwift/Task.swift @@ -10,7 +10,7 @@ import Foundation enum TaskState { - case pending() + case pending case success(TResult) case error(Error) case cancelled @@ -53,7 +53,7 @@ public final class Task { fileprivate let synchronizationQueue = DispatchQueue(label: "com.bolts.task", attributes: DispatchQueue.Attributes.concurrent) fileprivate var _completedCondition: NSCondition? - fileprivate var _state: TaskState = .pending() + fileprivate var _state: TaskState = .pending fileprivate var _continuations: [Continuation] = Array() // MARK: Initializers @@ -108,7 +108,7 @@ public final class Task { The returned task will complete when the closure completes. */ public convenience init(_ executor: Executor = .default, closure: @escaping (() throws -> TResult)) { - self.init(state: .pending()) + self.init(state: .pending) executor.execute { self.trySet(state: TaskState.fromClosure(closure)) } @@ -242,7 +242,7 @@ public final class Task { var completedCondition: NSCondition? synchronizationQueue.sync(flags: .barrier, execute: { switch self._state { - case .pending(): + case .pending: stateChanged = true self._state = state continuations = self._continuations diff --git a/Tests/CancellationTests.swift b/Tests/CancellationTests.swift new file mode 100644 index 0000000..1e29ad6 --- /dev/null +++ b/Tests/CancellationTests.swift @@ -0,0 +1,158 @@ +// +// CancellationTests.swift +// BoltsSwift +// +// Created by Simon Brockmann on 07.05.19. +// Copyright © 2019 Facebook. All rights reserved. +// + +import XCTest +import BoltsSwift + +class CancellationTests: XCTestCase { + func testCancel() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + XCTAssertFalse(cts.isCancellationRequested(), "Source should not be cancelled") + XCTAssertFalse(cts.token.cancellationRequested, "Token should not be cancelled") + + try cts.cancel() + + XCTAssertTrue(cts.isCancellationRequested(), "Source should be cancelled") + XCTAssertTrue(cts.token.cancellationRequested, "Token should be cancelled") + } + + func testCancelMultipleTimes() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + XCTAssertFalse(cts.isCancellationRequested()) + XCTAssertFalse(cts.token.cancellationRequested) + + try cts.cancel() + XCTAssertTrue(cts.isCancellationRequested()); + XCTAssertTrue(cts.token.cancellationRequested); + + try cts.cancel() + XCTAssertTrue(cts.isCancellationRequested()); + XCTAssertTrue(cts.token.cancellationRequested); + } + + func testCancellationBlock() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + var cancelled = false + cts.token.registerCancellationObserver { + cancelled = true + } + XCTAssertFalse(cts.isCancellationRequested(), "Source should not be cancelled"); + XCTAssertFalse(cts.token.cancellationRequested, "Token should not be cancelled"); + + try cts.cancel() + + XCTAssertTrue(cancelled, "Source should be cancelled"); + } + + func testCancellationAfterDelay() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + XCTAssertFalse(cts.isCancellationRequested(), "Source should not be cancelled"); + XCTAssertFalse(cts.token.cancellationRequested, "Token should not be cancelled"); + + try cts.cancelAfterInterval(interval: 0.2) + XCTAssertFalse(cts.isCancellationRequested(), "Source should be cancelled") + XCTAssertFalse(cts.token.cancellationRequested, "Token should be cancelled") + + // Spin the run loop for half a second, since `delay` is in milliseconds, not seconds. + RunLoop.current.run(until: Date(timeIntervalSinceNow: 0.5)) + + XCTAssertTrue(cts.isCancellationRequested(), "Source should be cancelled") + XCTAssertTrue(cts.token.cancellationRequested, "Token should be cancelled") + } + + func testCancellationAfterDelayValidation() { + let cts = CancellationTokenSource.cancellationTokenSource() + XCTAssertFalse(cts.isCancellationRequested()) + XCTAssertFalse(cts.token.cancellationRequested) + + XCTAssertThrowsError(try cts.cancelAfterInterval(interval: -1)) + } + + func testCancellationAfterZeroDelay() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + XCTAssertFalse(cts.isCancellationRequested()) + XCTAssertFalse(cts.token.cancellationRequested) + + try cts.cancelAfterInterval(interval: 0) + + XCTAssertTrue(cts.isCancellationRequested()); + XCTAssertTrue(cts.token.cancellationRequested); + } + + func testCancellationAfterDelayOnCancelled() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + try cts.cancel() + XCTAssertTrue(cts.isCancellationRequested()) + XCTAssertTrue(cts.token.cancellationRequested) + + try cts.cancelAfterInterval(interval: 1) + + XCTAssertTrue(cts.isCancellationRequested()) + XCTAssertTrue(cts.token.cancellationRequested) + } + + func testDispose() throws { + var cts = CancellationTokenSource.cancellationTokenSource() + try cts.dispose() + + XCTAssertThrowsError(try cts.cancel()) + + cts = CancellationTokenSource.cancellationTokenSource() + try cts.cancel() + + XCTAssertTrue(cts.isCancellationRequested(), "Source should be cancelled") + XCTAssertTrue(cts.token.cancellationRequested, "Token should be cancelled") + + try cts.dispose() + XCTAssertThrowsError(try cts.cancel()) + } + + func testDisposeMultipleTimes() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + try cts.dispose() + XCTAssertNoThrow(try cts.dispose()) + } + + func testDisposeRegistration() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + let regestration = cts.token.registerCancellationObserver { + XCTFail() + }! + XCTAssertNoThrow(try regestration.dispose()) + try cts.cancel() + } + + func testDisposeRegistrationMultipleTimes() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + let regestration = cts.token.registerCancellationObserver { + XCTFail() + }! + XCTAssertNoThrow(try regestration.dispose()) + XCTAssertNoThrow(try regestration.dispose()) + + try cts.cancel() + } + + func testDisposeRegistrationAfterCancellationToken() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + let regestration = cts.token.registerCancellationObserver { + } + + try regestration!.dispose() + try cts.cancel() + } + + func testDisposeRegistrationBeforeCancellationToken() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + let registration = cts.token.registerCancellationObserver { + }! + + try cts.dispose() + XCTAssertNoThrow(try registration.dispose()) + } +} diff --git a/Tests/TaskTests.swift b/Tests/TaskTests.swift index bf20440..216d478 100644 --- a/Tests/TaskTests.swift +++ b/Tests/TaskTests.swift @@ -294,23 +294,23 @@ class TaskTests: XCTestCase { count += 1 XCTAssertEqual(count, 1) return nil - }.continueWith { _ -> String? in - count += 1 - XCTAssertEqual(count, 2) - return nil - }.continueWith { _ -> String? in - count += 1 - XCTAssertEqual(count, 3) - return nil - }.continueWith { _ -> String? in - count += 1 - XCTAssertEqual(count, 4) - return nil - }.continueWith { _ -> String? in - count += 1 - XCTAssertEqual(count, 5) - expectation.fulfill() - return nil + }.continueWith { _ -> String? in + count += 1 + XCTAssertEqual(count, 2) + return nil + }.continueWith { _ -> String? in + count += 1 + XCTAssertEqual(count, 3) + return nil + }.continueWith { _ -> String? in + count += 1 + XCTAssertEqual(count, 4) + return nil + }.continueWith { _ -> String? in + count += 1 + XCTAssertEqual(count, 5) + expectation.fulfill() + return nil } waitForTestExpectations() @@ -324,19 +324,19 @@ class TaskTests: XCTestCase { Task.cancelledTask().continueWith(executor) { _ in count += 1 XCTAssertEqual(count, 1) - }.continueWith(executor) { _ in - count += 1 - XCTAssertEqual(count, 2) - }.continueWith(executor) { _ in - count += 1 - XCTAssertEqual(count, 3) - }.continueWith(executor) { _ in - count += 1 - XCTAssertEqual(count, 4) - }.continueWith(executor) { _ in - count += 1 - XCTAssertEqual(count, 5) - expectation.fulfill() + }.continueWith(executor) { _ in + count += 1 + XCTAssertEqual(count, 2) + }.continueWith(executor) { _ in + count += 1 + XCTAssertEqual(count, 3) + }.continueWith(executor) { _ in + count += 1 + XCTAssertEqual(count, 4) + }.continueWith(executor) { _ in + count += 1 + XCTAssertEqual(count, 5) + expectation.fulfill() } waitForTestExpectations() @@ -401,10 +401,10 @@ class TaskTests: XCTestCase { for i in 1...20 { let task = Task.withDelay(0.5) - .continueWith(continuation: { task -> Int in - OSAtomicIncrement32(&count) - return i - }) + .continueWith(continuation: { task -> Int in + OSAtomicIncrement32(&count) + return i + }) tasks.append(task) } @@ -431,10 +431,10 @@ class TaskTests: XCTestCase { for i in 1...20 { let task = Task.withDelay(0.5) - .continueWith(executor, continuation: { task -> Int in - OSAtomicIncrement32(&count) - return i - }) + .continueWith(executor, continuation: { task -> Int in + OSAtomicIncrement32(&count) + return i + }) tasks.append(task) } @@ -462,13 +462,13 @@ class TaskTests: XCTestCase { for i in 1...20 { let task = Task.withDelay(0.5) - .continueWithTask(executor, continuation: { task -> Task in - OSAtomicIncrement32(&count) - if i == 20 { - return Task.cancelledTask() - } - return Task(i) - }) + .continueWithTask(executor, continuation: { task -> Task in + OSAtomicIncrement32(&count) + if i == 20 { + return Task.cancelledTask() + } + return Task(i) + }) tasks.append(task) } @@ -494,10 +494,10 @@ class TaskTests: XCTestCase { for i in 1...20 { let task = Task.withDelay(0.5) - .continueWith(continuation: { task in - OSAtomicIncrement32(&count) - throw NSError(domain: "bolts", code: i, userInfo: nil) - }) + .continueWith(continuation: { task in + OSAtomicIncrement32(&count) + throw NSError(domain: "bolts", code: i, userInfo: nil) + }) tasks.append(task) } @@ -536,10 +536,10 @@ class TaskTests: XCTestCase { }) for i in 1...20 { let task = Task.withDelay(0.5) - .continueWith(executor, continuation: { task -> Int in - OSAtomicIncrement32(&count) - return i - }) + .continueWith(executor, continuation: { task -> Int in + OSAtomicIncrement32(&count) + return i + }) tasks.append(task) } @@ -568,10 +568,10 @@ class TaskTests: XCTestCase { for i in 1...20 { let task = Task.withDelay(Double(i) * 0.5) - .continueWithTask(executor, continuation: { task -> Task in - OSAtomicIncrement32(&count) - return Task(error: error) - }) + .continueWithTask(executor, continuation: { task -> Task in + OSAtomicIncrement32(&count) + return Task(error: error) + }) tasks.append(task) } @@ -599,10 +599,10 @@ class TaskTests: XCTestCase { for i in 1...20 { let task = Task.withDelay(Double(i) * 0.5) - .continueWithTask(executor, continuation: { task -> Task in - OSAtomicIncrement32(&count) - return Task.cancelledTask() - }) + .continueWithTask(executor, continuation: { task -> Task in + OSAtomicIncrement32(&count) + return Task.cancelledTask() + }) tasks.append(task) } @@ -638,19 +638,108 @@ class TaskTests: XCTestCase { Task.cancelledTask().continueWith { _ in count += 1 XCTAssertEqual(count, 1) - }.continueWith { _ in - count += 1 - XCTAssertEqual(count, 2) - }.continueWith { _ in - count += 1 - XCTAssertEqual(count, 3) - }.continueWith { _ in - count += 1 - XCTAssertEqual(count, 4) - }.continueWith { _ in - count += 1 - XCTAssertEqual(count, 5) - }.waitUntilCompleted() + }.continueWith { _ in + count += 1 + XCTAssertEqual(count, 2) + }.continueWith { _ in + count += 1 + XCTAssertEqual(count, 3) + }.continueWith { _ in + count += 1 + XCTAssertEqual(count, 4) + }.continueWith { _ in + count += 1 + XCTAssertEqual(count, 5) + }.waitUntilCompleted() XCTAssertEqual(count, 5) } + + // MARK: Cancellation + + func testOnSuccessWithToken() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + var task = Task.withDelay(0.1) + + task = task.continueOnSuccessWith(.immediate, cts.token) { + XCTFail("Success block should not be triggered"); + } + + try cts.cancel() + task.waitUntilCompleted() + XCTAssertTrue(task.cancelled); + } + + func testOnContinueWithToken() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + var task = Task.withDelay(0.1) + + task = task.continueWith(.immediate, cts.token) { _ in + XCTFail("Success block should not be triggered"); + } + + try cts.cancel() + task.waitUntilCompleted() + XCTAssertTrue(task.cancelled); + } + + + func testOnSuccessWithCancellationToken() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + var task = Task(name) + + try cts.cancel() + + task = task.continueOnSuccessWith(.immediate, cts.token) { _ in + XCTFail("Success block should not be triggered") + return "" + } + + XCTAssertTrue(task.cancelled); + } + + func testOnContinueWithCancellationToken() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + var task = Task(name) + + try cts.cancel() + + task = task.continueWith(.immediate, cts.token) { _ in + XCTFail("Success block should not be triggered") + return "" + } + + XCTAssertTrue(task.cancelled); + } + + func testDelayWithToken() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + let task = Task.withDelay(0.1, cts.token) + + try cts.cancel() + task.waitUntilCompleted() + XCTAssertTrue(task.cancelled, "Task should be cancelled immediately") + } + + func testDelayWithCancelledToken() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + try cts.cancel() + + let task = Task.withDelay(0.1, cts.token) + XCTAssertTrue(task.cancelled, "Task should be cancelled immediately") + } + + func testReturnTaskFromContinuationWithCancellation() throws { + let cts = CancellationTokenSource.cancellationTokenSource() + let expectation = self.expectation(description: "task") + let task = Task.withDelay(1) + + task.continueWith(cts.token) { task -> Task in + try cts.cancel() + return Task.withDelay(10) + }.continueWith { t in + XCTAssertTrue(t.cancelled); + expectation.fulfill() + } + self.waitForExpectations(timeout: 10.0, handler: nil) + } }