Skip to content

Commit 6d0c7d5

Browse files
committed
UDP: support for multihoming with unbound sockets
1 parent 0df6cf7 commit 6d0c7d5

File tree

10 files changed

+481
-26
lines changed

10 files changed

+481
-26
lines changed

FlyingSocks/Sources/AsyncSocket.swift

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,25 @@ public extension AsyncSocketPool where Self == SocketPool<Poll> {
6262

6363
public struct AsyncSocket: Sendable {
6464

65+
public struct Message: Sendable {
66+
public let peerAddress: sockaddr_storage
67+
public let bytes: [UInt8]
68+
public let interfaceIndex: UInt32?
69+
public let localAddress: sockaddr_storage?
70+
71+
public init(
72+
peerAddress: sockaddr_storage,
73+
bytes: [UInt8],
74+
interfaceIndex: UInt32? = nil,
75+
localAddress: sockaddr_storage? = nil
76+
) {
77+
self.peerAddress = peerAddress
78+
self.bytes = bytes
79+
self.interfaceIndex = interfaceIndex
80+
self.localAddress = localAddress
81+
}
82+
}
83+
6584
public let socket: Socket
6685
let pool: any AsyncSocketPool
6786

@@ -143,6 +162,23 @@ public struct AsyncSocket: Sendable {
143162
} while true
144163
}
145164

165+
#if !canImport(WinSDK)
166+
public func receive(atMost length: Int) async throws -> Message {
167+
try Task.checkCancellation()
168+
169+
repeat {
170+
do {
171+
let (peerAddress, bytes, interfaceIndex, localAddress) = try socket.receive(length: length)
172+
return Message(peerAddress: peerAddress, bytes: bytes, interfaceIndex: interfaceIndex, localAddress: localAddress)
173+
} catch SocketError.blocked {
174+
try await pool.suspendSocket(socket, untilReadyFor: .read)
175+
} catch {
176+
throw error
177+
}
178+
} while true
179+
}
180+
#endif
181+
146182
/// Reads bytes from the socket up to by not over/
147183
/// - Parameter bytes: The max number of bytes to read
148184
/// - Returns: an array of the read bytes capped to the number of bytes provided.
@@ -190,6 +226,31 @@ public struct AsyncSocket: Sendable {
190226
try await send(Array(data), to: address)
191227
}
192228

229+
#if !canImport(WinSDK)
230+
public func send(
231+
message: [UInt8],
232+
to peerAddress: some SocketAddress,
233+
interfaceIndex: UInt32? = nil,
234+
from localAddress: (some SocketAddress)? = nil
235+
) async throws {
236+
let sent = try await pool.loopUntilReady(for: .write, on: socket) {
237+
try socket.send(message: message, to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
238+
}
239+
guard sent == message.count else {
240+
throw SocketError.disconnected
241+
}
242+
}
243+
244+
public func send(
245+
message: Data,
246+
to peerAddress: some SocketAddress,
247+
interfaceIndex: UInt32? = nil,
248+
from localAddress: (some SocketAddress)? = nil
249+
) async throws {
250+
try await send(message: Array(message), to: peerAddress, interfaceIndex: interfaceIndex, from: localAddress)
251+
}
252+
#endif
253+
193254
public func close() throws {
194255
try socket.close()
195256
}
@@ -275,7 +336,8 @@ public struct AsyncSocketSequence: AsyncSequence, AsyncIteratorProtocol, Sendabl
275336
public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol, Sendable {
276337
public static let DefaultMaxMessageLength: Int = 1500
277338

278-
public typealias Element = (sockaddr_storage, [UInt8])
339+
// Windows has a different recvmsg() API signature which is presently unsupported
340+
public typealias Element = AsyncSocket.Message
279341

280342
private let socket: AsyncSocket
281343
private let maxMessageLength: Int
@@ -288,7 +350,15 @@ public struct AsyncSocketMessageSequence: AsyncSequence, AsyncIteratorProtocol,
288350
}
289351

290352
public mutating func next() async throws -> Element? {
291-
return try await socket.receive(atMost: maxMessageLength)
353+
#if !canImport(WinSDK)
354+
try await socket.receive(atMost: maxMessageLength)
355+
#else
356+
let peerAddress: sockaddr_storage
357+
let bytes: [UInt8]
358+
359+
(peerAddress, bytes) = try await socket.receive(atMost: maxMessageLength)
360+
return AsyncSocket.Message(peerAddress: peerAddress, bytes: bytes)
361+
#endif
292362
}
293363
}
294364

FlyingSocks/Sources/Socket+Android.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ let EPOLLET: UInt32 = 1 << 31;
3737

3838
public extension Socket {
3939
typealias FileDescriptorType = Int32
40+
typealias IovLengthType = UInt
41+
typealias ControlMessageHeaderLengthType = Int
42+
typealias IPv4InterfaceIndexType = Int32
43+
typealias IPv6InterfaceIndexType = Int32
4044
}
4145

4246
extension Socket.FileDescriptor {
@@ -47,6 +51,10 @@ extension Socket {
4751
static let stream = Int32(SOCK_STREAM)
4852
static let datagram = Int32(SOCK_DGRAM)
4953
static let in_addr_any = Android.in_addr(s_addr: Android.in_addr_t(0))
54+
static let ipproto_ip = Int32(IPPROTO_IP)
55+
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
56+
static let ip_pktinfo = Int32(IP_PKTINFO)
57+
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)
5058

5159
static func makeAddressINET(port: UInt16) -> Android.sockaddr_in {
5260
Android.sockaddr_in(
@@ -184,6 +192,14 @@ extension Socket {
184192
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
185193
Android.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
186194
}
195+
196+
static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
197+
Android.recvmsg(fd, message, flags)
198+
}
199+
200+
static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
201+
Android.sendmsg(fd, message, flags)
202+
}
187203
}
188204

189205
#endif

FlyingSocks/Sources/Socket+Darwin.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ import Darwin
3434

3535
public extension Socket {
3636
typealias FileDescriptorType = Int32
37+
typealias IovLengthType = Int
38+
typealias ControlMessageHeaderLengthType = UInt32
39+
typealias IPv4InterfaceIndexType = UInt32
40+
typealias IPv6InterfaceIndexType = UInt32
3741
}
3842

3943
extension Socket.FileDescriptor {
@@ -44,6 +48,10 @@ extension Socket {
4448
static let stream = Int32(SOCK_STREAM)
4549
static let datagram = Int32(SOCK_DGRAM)
4650
static let in_addr_any = Darwin.in_addr(s_addr: Darwin.in_addr_t(0))
51+
static let ipproto_ip = Int32(IPPROTO_IP)
52+
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
53+
static let ip_pktinfo = Int32(IP_PKTINFO)
54+
static let ipv6_pktinfo = Int32(50) // __APPLE_USE_RFC_2292
4755

4856
static func makeAddressINET(port: UInt16) -> Darwin.sockaddr_in {
4957
Darwin.sockaddr_in(
@@ -185,6 +193,14 @@ extension Socket {
185193
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
186194
Darwin.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
187195
}
196+
197+
static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
198+
Darwin.recvmsg(fd, message, flags)
199+
}
200+
201+
static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
202+
Darwin.sendmsg(fd, message, flags)
203+
}
188204
}
189205

190206
#endif

FlyingSocks/Sources/Socket+Glibc.swift

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ import Glibc
3434

3535
public extension Socket {
3636
typealias FileDescriptorType = Int32
37+
typealias IovLengthType = Int
38+
typealias ControlMessageHeaderLengthType = Int
39+
typealias IPv4InterfaceIndexType = Int32
40+
typealias IPv6InterfaceIndexType = UInt32
3741
}
3842

3943
extension Socket.FileDescriptor {
@@ -44,6 +48,10 @@ extension Socket {
4448
static let stream = Int32(SOCK_STREAM.rawValue)
4549
static let datagram = Int32(SOCK_DGRAM.rawValue)
4650
static let in_addr_any = Glibc.in_addr(s_addr: Glibc.in_addr_t(0))
51+
static let ipproto_ip = Int32(IPPROTO_IP)
52+
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
53+
static let ip_pktinfo = Int32(IP_PKTINFO)
54+
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)
4755

4856
static func makeAddressINET(port: UInt16) -> Glibc.sockaddr_in {
4957
Glibc.sockaddr_in(
@@ -181,6 +189,19 @@ extension Socket {
181189
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
182190
Glibc.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
183191
}
192+
193+
static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
194+
Glibc.recvmsg(fd, message, flags)
195+
}
196+
197+
static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
198+
Glibc.sendmsg(fd, message, flags)
199+
}
200+
}
201+
202+
struct in6_pktinfo {
203+
var ipi6_addr: in6_addr
204+
var ipi6_ifindex: CUnsignedInt
184205
}
185206

186207
#endif

FlyingSocks/Sources/Socket+Musl.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ import Musl
3434

3535
public extension Socket {
3636
typealias FileDescriptorType = Int32
37+
typealias IovLengthType = Int
38+
typealias ControlMessageHeaderLengthType = UInt32
39+
typealias IPv4InterfaceIndexType = Int32
40+
typealias IPv6InterfaceIndexType = UInt32
3741
}
3842

3943
extension Socket.FileDescriptor {
@@ -44,6 +48,10 @@ extension Socket {
4448
static let stream = Int32(SOCK_STREAM)
4549
static let datagram = Int32(SOCK_DGRAM)
4650
static let in_addr_any = Musl.in_addr(s_addr: Musl.in_addr_t(0))
51+
static let ipproto_ip = Int32(IPPROTO_IP)
52+
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
53+
static let ip_pktinfo = Int32(IP_PKTINFO)
54+
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)
4755

4856
static func makeAddressINET(port: UInt16) -> Musl.sockaddr_in {
4957
Musl.sockaddr_in(
@@ -181,6 +189,14 @@ extension Socket {
181189
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
182190
Musl.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
183191
}
192+
193+
static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
194+
Musl.recvmsg(fd, message, flags)
195+
}
196+
197+
static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
198+
Musl.sendmsg(fd, message, flags)
199+
}
184200
}
185201

186202
#endif

FlyingSocks/Sources/Socket+WinSock2.swift

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ public typealias sa_family_t = UInt8
4444

4545
public extension Socket {
4646
typealias FileDescriptorType = UInt64
47+
typealias IovLengthType = UInt
48+
typealias ControlMessageHeaderLengthType = DWORD
49+
typealias IPv4InterfaceIndexType = ULONG
50+
typealias IPv6InterfaceIndexType = ULONG
4751
}
4852

4953
extension Socket.FileDescriptor {
@@ -54,6 +58,10 @@ extension Socket {
5458
static let stream = Int32(SOCK_STREAM)
5559
static let datagram = Int32(SOCK_DGRAM)
5660
static let in_addr_any = WinSDK.in_addr()
61+
static let ipproto_ip = Int32(IPPROTO_IP)
62+
static let ipproto_ipv6 = Int32(IPPROTO_IPV6)
63+
static let ip_pktinfo = Int32(IP_PKTINFO)
64+
static let ipv6_pktinfo = Int32(IPV6_PKTINFO)
5765

5866
static func makeAddressINET(port: UInt16) -> WinSDK.sockaddr_in {
5967
WinSDK.sockaddr_in(
@@ -193,6 +201,14 @@ extension Socket {
193201
static func sendto(_ fd: FileDescriptorType, _ buffer: UnsafeRawPointer!, _ nbyte: Int, _ flags: Int32, _ destaddr: UnsafePointer<sockaddr>!, _ destlen: socklen_t) -> Int {
194202
WinSDK.sendto(fd, buffer, nbyte, flags, destaddr, destlen)
195203
}
204+
205+
static func recvmsg(_ fd: FileDescriptorType, _ message: UnsafeMutablePointer<msghdr>, _ flags: Int32) -> Int {
206+
WinSDK.recvmsg(fd, message, flags)
207+
}
208+
209+
static func sendmsg(_ fd: FileDescriptorType, _ message: UnsafePointer<msghdr>, _ flags: Int32) -> Int {
210+
WinSDK.sendmsg(fd, message, flags)
211+
}
196212
}
197213

198214
#endif

0 commit comments

Comments
 (0)