diff --git a/Package.swift b/Package.swift index 0956684..c4b24bc 100644 --- a/Package.swift +++ b/Package.swift @@ -39,7 +39,7 @@ let package = Package( ), ], dependencies: [ - .package(url: "https://github.com/PADL/IORingSwift", branch: "main"), + .package(url: "https://github.com/PADL/IORingSwift", from: "1.0.0"), .package(url: "https://github.com/apple/swift-async-algorithms", from: "1.0.0"), .package(url: "https://github.com/apple/swift-system", from: "1.2.1"), .package(url: "https://github.com/apple/swift-argument-parser", from: "1.2.0"), diff --git a/Sources/NetLink/NFNetLink.swift b/Sources/NetLink/NFNetLink.swift index 55b9323..c210406 100644 --- a/Sources/NetLink/NFNetLink.swift +++ b/Sources/NetLink/NFNetLink.swift @@ -140,7 +140,8 @@ public final class NFNLLog: Sendable { public init(family: sa_family_t = sa_family_t(AF_BRIDGE), group: UInt16) throws { _socket = try NLSocket(protocol: NETLINK_NETFILTER) - _log = NLObject(consumingObj: nfnl_log_alloc()) + guard let logObj = nfnl_log_alloc() else { throw NLError.noMemory } + _log = NLObject(consumingObj: logObj) try throwingNLError { nfnl_log_pf_bind(_socket._sk, UInt8(family)) } diff --git a/Sources/NetLink/NetLink.swift b/Sources/NetLink/NetLink.swift index 0b902ab..6ddbaca 100644 --- a/Sources/NetLink/NetLink.swift +++ b/Sources/NetLink/NetLink.swift @@ -73,18 +73,24 @@ Sendable, Equatable, Hashable, CustomStringConvertible { try withUnsafeMutablePointer(to: &obj) { objRef in _ = try throwingNLError { nl_msg_parse(msg, { obj, objRef in - nl_object_get(obj) - objRef! + let ptr = objRef! .withMemoryRebound( - to: OpaquePointer.self, + to: OpaquePointer?.self, capacity: 1 - ) { objRef in - objRef.pointee = obj! - } + ) { $0 } + if let existing = ptr.pointee { + nl_object_put(existing) + } + nl_object_get(obj) + ptr.pointee = obj }, objRef) } } + guard obj != nil else { + throw NLError.invalidArgument + } + self.init(obj: obj, constructFromObject: constructFromObject) nl_object_put(obj) } @@ -222,7 +228,7 @@ private func NLSocket_ErrCB( let hdr = err.pointee.msg debugPrint("NLSocket_ErrCB: error \(err.pointee)") nlSocket.yield(sequence: hdr.nlmsg_seq, with: Result.failure(Errno(rawValue: -err.pointee.error))) - return err.pointee.error + return CInt(NL_SKIP.rawValue) } public final class NLSocket: @unchecked @@ -230,7 +236,7 @@ Sendable { private typealias Continuation = CheckedContinuation private typealias Stream = AsyncThrowingStream private typealias Ack = CheckedContinuation<(), Error> - public typealias Channel = AsyncThrowingChannel + public typealias NotificationStream = AsyncThrowingStream private enum _Request { case continuation(Continuation) @@ -248,12 +254,18 @@ Sendable { } let _sk: OpaquePointer! + let _queue = DispatchQueue(label: "NLSocket") private let _readSource: any DispatchSourceRead private let _requests = Mutex<[UInt32: _Request]>([:]) - public let notifications = Channel() + private let _notificationsContinuation: NotificationStream.Continuation + public let notifications: NotificationStream public init(protocol: Int32) throws { + var continuation: NotificationStream.Continuation! + notifications = NotificationStream(bufferingPolicy: .unbounded) { continuation = $0 } + _notificationsContinuation = continuation + guard let sk = nl_socket_alloc() else { throw NLError.noMemory } nl_socket_disable_seq_check(sk) _sk = sk @@ -266,7 +278,7 @@ Sendable { let fd = nl_socket_get_fd(sk) precondition(fd >= 0) - _readSource = DispatchSource.makeReadSource(fileDescriptor: fd, queue: .main) + _readSource = DispatchSource.makeReadSource(fileDescriptor: fd, queue: _queue) _readSource.setEventHandler(handler: onReadReady) nl_socket_modify_cb( @@ -302,6 +314,7 @@ Sendable { deinit { _readSource.cancel() + _notificationsContinuation.finish() nl_socket_free(_sk) } @@ -344,13 +357,15 @@ Sendable { } public func useNextSequenceNumber() -> UInt32 { - var nextSequenceNumber: UInt32 + _queue.sync { + var nextSequenceNumber: UInt32 - repeat { - nextSequenceNumber = nl_socket_use_seq(_sk) - } while nextSequenceNumber == 0 + repeat { + nextSequenceNumber = nl_socket_use_seq(_sk) + } while nextSequenceNumber == 0 - return nextSequenceNumber + return nextSequenceNumber + } } private func _lookup(sequence: UInt32, forceRemove: Bool) -> _Request? { @@ -405,13 +420,7 @@ Sendable { } } } else { - Task { - do { - try await notifications.send(result.get()) - } catch { - notifications.fail(error) - } - } + _notificationsContinuation.yield(with: result) } } @@ -466,7 +475,12 @@ Sendable { } stream = _stream } - try message.send(on: self) + do { + try message.send(on: self) + } catch { + _requests.withLock { $0.removeValue(forKey: sequence) } + throw error + } return stream } } @@ -764,8 +778,10 @@ struct NLMessage: ~Copyable { } func append(opaque value: UnsafePointer) throws { - _ = try withUnsafeBytes(of: value.pointee) { value in - try append(Array(value)) + try withUnsafeBytes(of: value.pointee) { bytes in + try throwingNLError { + nlmsg_append(_msg, UnsafeMutableRawPointer(mutating: bytes.baseAddress), bytes.count, CInt(NLMSG_ALIGNTO)) + } } } @@ -858,7 +874,9 @@ struct NLMessage: ~Copyable { } func send(on socket: NLSocket) throws { - try throwingNLError { nl_send_auto(socket._sk, _msg) } + try socket._queue.sync { + try throwingNLError { nl_send_auto(socket._sk, _msg) } + } } deinit {