Skip to content

Commit 18bc95e

Browse files
More checks for task cancellation and tests
1 parent 144464e commit 18bc95e

File tree

4 files changed

+282
-19
lines changed

4 files changed

+282
-19
lines changed

Sources/OpenAPIURLSession/URLSessionBidirectionalStreaming/URLSession+Extensions.swift

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import Foundation
3232
task = dataTask(with: urlRequest)
3333
}
3434
return try await withTaskCancellationHandler {
35+
try Task.checkCancellation()
3536
let delegate = BidirectionalStreamingURLSessionDelegate(
3637
requestBody: requestBody,
3738
requestStreamBufferSize: requestStreamBufferSize,
@@ -47,8 +48,10 @@ import Foundation
4748
length: .init(from: response),
4849
iterationBehavior: .single
4950
)
51+
try Task.checkCancellation()
5052
return (try HTTPResponse(response), responseBody)
5153
} onCancel: {
54+
debug("Concurrency task cancelled, cancelling URLSession task.")
5255
task.cancel()
5356
}
5457
}

Sources/OpenAPIURLSession/URLSessionTransport.swift

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import class Foundation.FileHandle
2424
#if canImport(FoundationNetworking)
2525
@preconcurrency import struct FoundationNetworking.URLRequest
2626
import class FoundationNetworking.URLSession
27+
import class FoundationNetworking.URLSessionTask
2728
import class FoundationNetworking.URLResponse
2829
import class FoundationNetworking.HTTPURLResponse
2930
#endif
@@ -243,31 +244,50 @@ extension URLSession {
243244
func bufferedRequest(for request: HTTPRequest, baseURL: URL, requestBody: HTTPBody?) async throws -> (
244245
HTTPResponse, HTTPBody?
245246
) {
247+
try Task.checkCancellation()
246248
var urlRequest = try URLRequest(request, baseURL: baseURL)
247249
if let requestBody { urlRequest.httpBody = try await Data(collecting: requestBody, upTo: .max) }
250+
try Task.checkCancellation()
248251

249252
/// Use `dataTask(with:completionHandler:)` here because `data(for:[delegate:]) async` is only available on
250253
/// Darwin platforms newer than our minimum deployment target, and not at all on Linux.
251-
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
252-
continuation in
253-
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
254-
if let error {
255-
continuation.resume(throwing: error)
256-
return
254+
let taskBox: LockedValueBox<URLSessionTask?> = .init(nil)
255+
return try await withTaskCancellationHandler {
256+
let (response, maybeResponseBodyData): (URLResponse, Data?) = try await withCheckedThrowingContinuation {
257+
continuation in
258+
let task = self.dataTask(with: urlRequest) { [urlRequest] data, response, error in
259+
if let error {
260+
continuation.resume(throwing: error)
261+
return
262+
}
263+
guard let response else {
264+
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
265+
return
266+
}
267+
continuation.resume(with: .success((response, data)))
257268
}
258-
guard let response else {
259-
continuation.resume(throwing: URLSessionTransportError.noResponse(url: urlRequest.url))
260-
return
269+
// Swift concurrency task cancelled here.
270+
taskBox.withLockedValue { boxedTask in
271+
guard task.state == .suspended else {
272+
debug("URLSession task cannot be resumed, probably because it was cancelled by onCancel.")
273+
return
274+
}
275+
task.resume()
276+
boxedTask = task
261277
}
262-
continuation.resume(with: .success((response, data)))
263278
}
264-
task.resume()
265-
}
266279

267-
let maybeResponseBody = maybeResponseBodyData.map { data in
268-
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
280+
let maybeResponseBody = maybeResponseBodyData.map { data in
281+
HTTPBody(data, length: HTTPBody.Length(from: response), iterationBehavior: .multiple)
282+
}
283+
return (try HTTPResponse(response), maybeResponseBody)
284+
} onCancel: {
285+
taskBox.withLockedValue { boxedTask in
286+
debug("Concurrency task cancelled, cancelling URLSession task.")
287+
boxedTask?.cancel()
288+
boxedTask = nil
289+
}
269290
}
270-
return (try HTTPResponse(response), maybeResponseBody)
271291
}
272292
}
273293

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the SwiftOpenAPIGenerator open source project
4+
//
5+
// Copyright (c) 2023 Apple Inc. and the SwiftOpenAPIGenerator project authors
6+
// Licensed under Apache License v2.0
7+
//
8+
// See LICENSE.txt for license information
9+
// See CONTRIBUTORS.txt for the list of SwiftOpenAPIGenerator project authors
10+
//
11+
// SPDX-License-Identifier: Apache-2.0
12+
//
13+
//===----------------------------------------------------------------------===//
14+
#if canImport(Darwin)
15+
16+
import Foundation
17+
import HTTPTypes
18+
import NIO
19+
import OpenAPIRuntime
20+
import XCTest
21+
@testable import OpenAPIURLSession
22+
23+
enum CancellationPoint: CaseIterable {
24+
case beforeSendingHead
25+
case beforeSendingRequestBody
26+
case partwayThroughSendingRequestBody
27+
case beforeConsumingResponseBody
28+
case partwayThroughConsumingResponseBody
29+
case afterConsumingResponseBody
30+
}
31+
32+
func testTaskCancelled(_ cancellationPoint: CancellationPoint, transport: URLSessionTransport) async throws {
33+
let requestPath = "/hello/world"
34+
let requestBodyElements = ["Hello,", "world!"]
35+
let requestBodySequence = MockAsyncSequence(elementsToVend: requestBodyElements, gatingProduction: true)
36+
let requestBody = HTTPBody(
37+
requestBodySequence,
38+
length: .known(Int64(requestBodyElements.joined().lengthOfBytes(using: .utf8))),
39+
iterationBehavior: .single
40+
)
41+
42+
let responseBodyMessage = "Hey!"
43+
44+
let taskShouldCancel = XCTestExpectation(description: "Concurrency task cancelled")
45+
let taskCancelled = XCTestExpectation(description: "Concurrency task cancelled")
46+
47+
try await withThrowingTaskGroup(of: Void.self) { group in
48+
let serverPort = try await AsyncTestHTTP1Server.start(connectionTaskGroup: &group) { connectionChannel in
49+
try await connectionChannel.executeThenClose { inbound, outbound in
50+
var requestPartIterator = inbound.makeAsyncIterator()
51+
var accumulatedBody = ByteBuffer()
52+
while let requestPart = try await requestPartIterator.next() {
53+
switch requestPart {
54+
case .head(let head):
55+
XCTAssertEqual(head.uri, requestPath)
56+
XCTAssertEqual(head.method, .POST)
57+
case .body(let buffer): accumulatedBody.writeImmutableBuffer(buffer)
58+
case .end:
59+
switch cancellationPoint {
60+
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody,
61+
.afterConsumingResponseBody:
62+
XCTAssertEqual(
63+
String(decoding: accumulatedBody.readableBytesView, as: UTF8.self),
64+
requestBodyElements.joined()
65+
)
66+
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody: break
67+
}
68+
try await outbound.write(.head(.init(version: .http1_1, status: .ok)))
69+
try await outbound.write(.body(ByteBuffer(string: responseBodyMessage)))
70+
try await outbound.write(.end(nil))
71+
}
72+
}
73+
}
74+
}
75+
debug("Server running on 127.0.0.1:\(serverPort)")
76+
77+
let task = Task {
78+
if case .beforeSendingHead = cancellationPoint {
79+
taskShouldCancel.fulfill()
80+
await fulfillment(of: [taskCancelled])
81+
}
82+
debug("Client starting request")
83+
async let (asyncResponse, asyncResponseBody) = try await transport.send(
84+
HTTPRequest(method: .post, scheme: nil, authority: nil, path: requestPath),
85+
body: requestBody,
86+
baseURL: URL(string: "http://127.0.0.1:\(serverPort)")!,
87+
operationID: "unused"
88+
)
89+
90+
if case .beforeSendingRequestBody = cancellationPoint {
91+
taskShouldCancel.fulfill()
92+
await fulfillment(of: [taskCancelled])
93+
}
94+
95+
requestBodySequence.openGate(for: 1)
96+
97+
if case .partwayThroughSendingRequestBody = cancellationPoint {
98+
taskShouldCancel.fulfill()
99+
await fulfillment(of: [taskCancelled])
100+
}
101+
102+
requestBodySequence.openGate()
103+
104+
let (response, maybeResponseBody) = try await (asyncResponse, asyncResponseBody)
105+
106+
debug("Client received response head: \(response)")
107+
XCTAssertEqual(response.status, .ok)
108+
let responseBody = try XCTUnwrap(maybeResponseBody)
109+
110+
if case .beforeConsumingResponseBody = cancellationPoint {
111+
taskShouldCancel.fulfill()
112+
await fulfillment(of: [taskCancelled])
113+
}
114+
115+
var iterator = responseBody.makeAsyncIterator()
116+
117+
_ = try await iterator.next()
118+
119+
if case .partwayThroughConsumingResponseBody = cancellationPoint {
120+
taskShouldCancel.fulfill()
121+
await fulfillment(of: [taskCancelled])
122+
}
123+
124+
while try await iterator.next() != nil {
125+
126+
}
127+
128+
if case .afterConsumingResponseBody = cancellationPoint {
129+
taskShouldCancel.fulfill()
130+
await fulfillment(of: [taskCancelled])
131+
}
132+
133+
}
134+
135+
await fulfillment(of: [taskShouldCancel])
136+
task.cancel()
137+
taskCancelled.fulfill()
138+
139+
switch transport.configuration.implementation {
140+
case .buffering:
141+
switch cancellationPoint {
142+
case .beforeSendingHead, .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
143+
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
144+
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
145+
try await task.value
146+
}
147+
case .streaming:
148+
switch cancellationPoint {
149+
case .beforeSendingHead:
150+
await XCTAssertThrowsError(try await task.value) { error in XCTAssertTrue(error is CancellationError) }
151+
case .beforeSendingRequestBody, .partwayThroughSendingRequestBody:
152+
await XCTAssertThrowsError(try await task.value) { error in
153+
guard let urlError = error as? URLError else {
154+
XCTFail()
155+
return
156+
}
157+
XCTAssertEqual(urlError.code, .cancelled)
158+
}
159+
case .beforeConsumingResponseBody, .partwayThroughConsumingResponseBody, .afterConsumingResponseBody:
160+
try await task.value
161+
}
162+
}
163+
164+
group.cancelAll()
165+
}
166+
167+
}
168+
169+
func fulfillment(
170+
of expectations: [XCTestExpectation],
171+
timeout seconds: TimeInterval = .infinity,
172+
enforceOrder enforceOrderOfFulfillment: Bool = false,
173+
file: StaticString = #file,
174+
line: UInt = #line
175+
) async {
176+
guard
177+
case .completed = await XCTWaiter.fulfillment(
178+
of: expectations,
179+
timeout: seconds,
180+
enforceOrder: enforceOrderOfFulfillment
181+
)
182+
else {
183+
XCTFail("Expectation was not fulfilled", file: file, line: line)
184+
return
185+
}
186+
}
187+
188+
extension URLSessionTransportBufferedTests {
189+
func testCancellation_beforeSendingHead() async throws {
190+
try await testTaskCancelled(.beforeSendingHead, transport: transport)
191+
}
192+
193+
func testCancellation_beforeSendingRequestBody() async throws {
194+
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
195+
}
196+
197+
func testCancellation_partwayThroughSendingRequestBody() async throws {
198+
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
199+
}
200+
201+
func testCancellation_beforeConsumingResponseBody() async throws {
202+
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
203+
}
204+
205+
func testCancellation_partwayThroughConsumingResponseBody() async throws {
206+
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
207+
}
208+
209+
func testCancellation_afterConsumingResponseBody() async throws {
210+
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
211+
}
212+
}
213+
214+
extension URLSessionTransportStreamingTests {
215+
func testCancellation_beforeSendingHead() async throws {
216+
try await testTaskCancelled(.beforeSendingHead, transport: transport)
217+
}
218+
219+
func testCancellation_beforeSendingRequestBody() async throws {
220+
try await testTaskCancelled(.beforeSendingRequestBody, transport: transport)
221+
}
222+
223+
func testCancellation_partwayThroughSendingRequestBody() async throws {
224+
try await testTaskCancelled(.partwayThroughSendingRequestBody, transport: transport)
225+
}
226+
227+
func testCancellation_beforeConsumingResponseBody() async throws {
228+
try await testTaskCancelled(.beforeConsumingResponseBody, transport: transport)
229+
}
230+
231+
func testCancellation_partwayThroughConsumingResponseBody() async throws {
232+
try await testTaskCancelled(.partwayThroughConsumingResponseBody, transport: transport)
233+
}
234+
235+
func testCancellation_afterConsumingResponseBody() async throws {
236+
try await testTaskCancelled(.afterConsumingResponseBody, transport: transport)
237+
}
238+
}
239+
240+
#endif // canImport(Darwin)

Tests/OpenAPIURLSessionTests/URLSessionTransportTests.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class URLSessionTransportConverterTests: XCTestCase {
5656

5757
// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
5858
class URLSessionTransportBufferedTests: XCTestCase {
59-
var transport: (any ClientTransport)!
59+
var transport: URLSessionTransport!
6060

6161
static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }
6262

@@ -66,7 +66,7 @@ class URLSessionTransportBufferedTests: XCTestCase {
6666

6767
func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }
6868

69-
func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
69+
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }
7070

7171
#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
7272
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {
@@ -89,7 +89,7 @@ class URLSessionTransportBufferedTests: XCTestCase {
8989

9090
// swift-format-ignore: AllPublicDeclarationsHaveDocumentation
9191
class URLSessionTransportStreamingTests: XCTestCase {
92-
var transport: (any ClientTransport)!
92+
var transport: URLSessionTransport!
9393

9494
static override func setUp() { OpenAPIURLSession.debugLoggingEnabled = false }
9595

@@ -107,7 +107,7 @@ class URLSessionTransportStreamingTests: XCTestCase {
107107

108108
func testBasicGet() async throws { try await testHTTPBasicGet(transport: transport) }
109109

110-
func testBasicPost() async throws { try await testHTTPBasicGet(transport: transport) }
110+
func testBasicPost() async throws { try await testHTTPBasicPost(transport: transport) }
111111

112112
#if canImport(Darwin) // Only passes on Darwin because Linux doesn't replay the request body on 307.
113113
func testHTTPRedirect_multipleIterationBehavior_succeeds() async throws {

0 commit comments

Comments
 (0)