Skip to content

Commit

Permalink
Bug/502 crash thread safety fix (#95)
Browse files Browse the repository at this point in the history
### Motivation

Fixes apple/swift-openapi-generator#502
- Ensure thread safety of `HTTPBody.collect(upTo)`.
- `makeAsyncIterator()`: Instead of crashing, return AsyncSequence which
throws `TooManyIterationsError` thereby honoring the contract for
`IterationBehavior.single` (HTTPBody, MultipartBody)

### Modifications

- HTTPBody, MultipartBody: `makeAsyncIterator()`: removed `try!`, catch
error and create a sequence which throws the error on iteration.
- This removed the need for `try checkIfCanCreateIterator()` in
`HTTPBody.collect(upTo)`.
**Note**: This creates a small change in behavior: There may be a
`TooManyBytesError` thrown before the check for `iterationBehavior`.
This approach uses the simplest code, IMO. If we want to keep that
`iterationBehavior` is checked first and only after that for the length,
then the code needs to be more complex.
- Removed `try checkIfCanCreateIterator()` in both classes (only used in
`HTTPBody`).

### Result

- No intentional crash in `makeAsyncIterator()` anymore.
- Tests supplied as example in
apple/swift-openapi-generator#502 succeed.

### Test Plan

- Added check in `Test_Body.testIterationBehavior_single()` to ensure
that using `makeAsyncIterator()` directly yields the expected error.
- Added tests to check iteration behavior of `MultipartBody`.

---------

Co-authored-by: Lars Peters <[email protected]>
  • Loading branch information
LarsPetersHH and Lars Peters authored Jan 18, 2024
1 parent 7f86e4a commit 95307ba
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 30 deletions.
25 changes: 8 additions & 17 deletions Sources/OpenAPIRuntime/Interface/HTTPBody.swift
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,6 @@ public final class HTTPBody: @unchecked Sendable {
return locked_iteratorCreated
}

/// Verifying that creating another iterator is allowed based on
/// the values of `iterationBehavior` and `locked_iteratorCreated`.
/// - Throws: If another iterator is not allowed to be created.
private func checkIfCanCreateIterator() throws {
lock.lock()
defer { lock.unlock() }
guard iterationBehavior == .single else { return }
if locked_iteratorCreated { throw TooManyIterationsError() }
}

/// Tries to mark an iterator as created, verifying that it is allowed
/// based on the values of `iterationBehavior` and `locked_iteratorCreated`.
/// - Throws: If another iterator is not allowed to be created.
Expand Down Expand Up @@ -341,10 +331,12 @@ extension HTTPBody: AsyncSequence {
/// Creates and returns an asynchronous iterator
///
/// - Returns: An asynchronous iterator for byte chunks.
/// - Note: The returned sequence throws an error if no further iterations are allowed. See ``IterationBehavior``.
public func makeAsyncIterator() -> AsyncIterator {
// The crash on error is intentional here.
try! tryToMarkIteratorCreated()
return .init(sequence.makeAsyncIterator())
do {
try tryToMarkIteratorCreated()
return .init(sequence.makeAsyncIterator())
} catch { return .init(throwing: error) }
}
}

Expand Down Expand Up @@ -381,10 +373,6 @@ extension HTTPBody {
/// than `maxBytes`.
/// - Returns: A byte chunk containing all the accumulated bytes.
fileprivate func collect(upTo maxBytes: Int) async throws -> ByteChunk {

// Check that we're allowed to iterate again.
try checkIfCanCreateIterator()

// If the length is known, verify it's within the limit.
if case .known(let knownBytes) = length {
guard knownBytes <= maxBytes else { throw TooManyBytesError(maxBytes: maxBytes) }
Expand Down Expand Up @@ -563,6 +551,9 @@ extension HTTPBody {
var iterator = iterator
self.produceNext = { try await iterator.next() }
}
/// Creates an iterator throwing the given error when iterated.
/// - Parameter error: The error to throw on iteration.
fileprivate init(throwing error: any Error) { self.produceNext = { throw error } }

/// Advances the iterator to the next element and returns it asynchronously.
///
Expand Down
22 changes: 9 additions & 13 deletions Sources/OpenAPIRuntime/Multipart/MultipartPublicTypes.swift
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,6 @@ public final class MultipartBody<Part: Sendable>: @unchecked Sendable {
var errorDescription: String? { description }
}

/// Verifying that creating another iterator is allowed based on the values of `iterationBehavior`
/// and `locked_iteratorCreated`.
/// - Throws: If another iterator is not allowed to be created.
internal func checkIfCanCreateIterator() throws {
lock.lock()
defer { lock.unlock() }
guard iterationBehavior == .single else { return }
if locked_iteratorCreated { throw TooManyIterationsError() }
}

/// Tries to mark an iterator as created, verifying that it is allowed based on the values
/// of `iterationBehavior` and `locked_iteratorCreated`.
/// - Throws: If another iterator is not allowed to be created.
Expand Down Expand Up @@ -331,10 +321,12 @@ extension MultipartBody: AsyncSequence {
/// Creates and returns an asynchronous iterator
///
/// - Returns: An asynchronous iterator for parts.
/// - Note: The returned sequence throws an error if no further iterations are allowed. See ``IterationBehavior``.
public func makeAsyncIterator() -> AsyncIterator {
// The crash on error is intentional here.
try! tryToMarkIteratorCreated()
return .init(sequence.makeAsyncIterator())
do {
try tryToMarkIteratorCreated()
return .init(sequence.makeAsyncIterator())
} catch { return .init(throwing: error) }
}
}

Expand All @@ -355,6 +347,10 @@ extension MultipartBody {
self.produceNext = { try await iterator.next() }
}

/// Creates an iterator throwing the given error when iterated.
/// - Parameter error: The error to throw on iteration.
fileprivate init(throwing error: any Error) { self.produceNext = { throw error } }

/// Advances the iterator to the next element and returns it asynchronously.
///
/// - Returns: The next element in the sequence, or `nil` if there are no more elements.
Expand Down
5 changes: 5 additions & 0 deletions Tests/OpenAPIRuntimeTests/Interface/Test_HTTPBody.swift
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ final class Test_Body: Test_Runtime {
_ = try await String(collecting: body, upTo: .max)
XCTFail("Expected an error to be thrown")
} catch {}

do {
for try await _ in body {}
XCTFail("Expected an error to be thrown")
} catch {}
}

func testIterationBehavior_multiple() async throws {
Expand Down
49 changes: 49 additions & 0 deletions Tests/OpenAPIRuntimeTests/Interface/Test_MultipartBody.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftOpenAPIGenerator open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import XCTest
@_spi(Generated) @testable import OpenAPIRuntime
import Foundation

final class Test_MultipartBody: XCTestCase {

func testIterationBehavior_single() async throws {
let sourceSequence = (0..<Int.random(in: 2..<10)).map { _ in UUID().uuidString }
let body = MultipartBody(sourceSequence, iterationBehavior: .single)

XCTAssertFalse(body.testing_iteratorCreated)

let iterated = try await body.reduce("") { $0 + $1 }
XCTAssertEqual(iterated, sourceSequence.joined())

XCTAssertTrue(body.testing_iteratorCreated)

do {
for try await _ in body {}
XCTFail("Expected an error to be thrown")
} catch {}
}

func testIterationBehavior_multiple() async throws {
let sourceSequence = (0..<Int.random(in: 2..<10)).map { _ in UUID().uuidString }
let body = MultipartBody(sourceSequence, iterationBehavior: .multiple)

XCTAssertFalse(body.testing_iteratorCreated)
for _ in 0..<2 {
let iterated = try await body.reduce("") { $0 + $1 }
XCTAssertEqual(iterated, sourceSequence.joined())
XCTAssertTrue(body.testing_iteratorCreated)
}
}

}

0 comments on commit 95307ba

Please sign in to comment.