Skip to content
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/PADL/IORingSwift", branch: "main"),
.package(url: "https://github.com/PADL/IORingSwift", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-async-algorithms", from: "1.0.0"),
.package(url: "https://github.com/apple/swift-system", from: "1.2.1"),
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.2.0"),
Expand Down
3 changes: 2 additions & 1 deletion Sources/NetLink/NFNetLink.swift
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ public final class NFNLLog: Sendable {

public init(family: sa_family_t = sa_family_t(AF_BRIDGE), group: UInt16) throws {
_socket = try NLSocket(protocol: NETLINK_NETFILTER)
_log = NLObject(consumingObj: nfnl_log_alloc())
guard let logObj = nfnl_log_alloc() else { throw NLError.noMemory }
_log = NLObject(consumingObj: logObj)
try throwingNLError {
nfnl_log_pf_bind(_socket._sk, UInt8(family))
}
Expand Down
70 changes: 44 additions & 26 deletions Sources/NetLink/NetLink.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,24 @@ Sendable, Equatable, Hashable, CustomStringConvertible {
try withUnsafeMutablePointer(to: &obj) { objRef in
_ = try throwingNLError {
nl_msg_parse(msg, { obj, objRef in
nl_object_get(obj)
objRef!
let ptr = objRef!
.withMemoryRebound(
to: OpaquePointer.self,
to: OpaquePointer?.self,
capacity: 1
) { objRef in
objRef.pointee = obj!
}
) { $0 }
if let existing = ptr.pointee {
nl_object_put(existing)
}
nl_object_get(obj)
ptr.pointee = obj
}, objRef)
}
}

guard obj != nil else {
throw NLError.invalidArgument
}

self.init(obj: obj, constructFromObject: constructFromObject)
nl_object_put(obj)
}
Expand Down Expand Up @@ -222,15 +228,15 @@ private func NLSocket_ErrCB(
let hdr = err.pointee.msg
debugPrint("NLSocket_ErrCB: error \(err.pointee)")
nlSocket.yield(sequence: hdr.nlmsg_seq, with: Result.failure(Errno(rawValue: -err.pointee.error)))
return err.pointee.error
return CInt(NL_SKIP.rawValue)
}

public final class NLSocket: @unchecked
Sendable {
private typealias Continuation = CheckedContinuation<NLObjectConstructible, Error>
private typealias Stream = AsyncThrowingStream<NLObjectConstructible, Error>
private typealias Ack = CheckedContinuation<(), Error>
public typealias Channel = AsyncThrowingChannel<NLObjectConstructible, Error>
public typealias NotificationStream = AsyncThrowingStream<NLObjectConstructible, Error>

private enum _Request {
case continuation(Continuation)
Expand All @@ -248,12 +254,18 @@ Sendable {
}

let _sk: OpaquePointer!
let _queue = DispatchQueue(label: "NLSocket")
private let _readSource: any DispatchSourceRead
private let _requests = Mutex<[UInt32: _Request]>([:])

public let notifications = Channel()
private let _notificationsContinuation: NotificationStream.Continuation
public let notifications: NotificationStream

public init(protocol: Int32) throws {
var continuation: NotificationStream.Continuation!
notifications = NotificationStream(bufferingPolicy: .unbounded) { continuation = $0 }
_notificationsContinuation = continuation

guard let sk = nl_socket_alloc() else { throw NLError.noMemory }
nl_socket_disable_seq_check(sk)
_sk = sk
Expand All @@ -266,7 +278,7 @@ Sendable {
let fd = nl_socket_get_fd(sk)
precondition(fd >= 0)

_readSource = DispatchSource.makeReadSource(fileDescriptor: fd, queue: .main)
_readSource = DispatchSource.makeReadSource(fileDescriptor: fd, queue: _queue)
_readSource.setEventHandler(handler: onReadReady)

nl_socket_modify_cb(
Expand Down Expand Up @@ -302,6 +314,7 @@ Sendable {

deinit {
_readSource.cancel()
_notificationsContinuation.finish()
nl_socket_free(_sk)
}

Expand Down Expand Up @@ -344,13 +357,15 @@ Sendable {
}

public func useNextSequenceNumber() -> UInt32 {
var nextSequenceNumber: UInt32
_queue.sync {
var nextSequenceNumber: UInt32

repeat {
nextSequenceNumber = nl_socket_use_seq(_sk)
} while nextSequenceNumber == 0
repeat {
nextSequenceNumber = nl_socket_use_seq(_sk)
} while nextSequenceNumber == 0

return nextSequenceNumber
return nextSequenceNumber
}
}

private func _lookup(sequence: UInt32, forceRemove: Bool) -> _Request? {
Expand Down Expand Up @@ -405,13 +420,7 @@ Sendable {
}
}
} else {
Task {
do {
try await notifications.send(result.get())
} catch {
notifications.fail(error)
}
}
_notificationsContinuation.yield(with: result)
}
}

Expand Down Expand Up @@ -466,7 +475,12 @@ Sendable {
}
stream = _stream
}
try message.send(on: self)
do {
try message.send(on: self)
} catch {
_requests.withLock { $0.removeValue(forKey: sequence) }
throw error
}
return stream
}
}
Expand Down Expand Up @@ -764,8 +778,10 @@ struct NLMessage: ~Copyable {
}

func append(opaque value: UnsafePointer<some Any>) throws {
_ = try withUnsafeBytes(of: value.pointee) { value in
try append(Array(value))
try withUnsafeBytes(of: value.pointee) { bytes in
try throwingNLError {
nlmsg_append(_msg, UnsafeMutableRawPointer(mutating: bytes.baseAddress), bytes.count, CInt(NLMSG_ALIGNTO))
}
}
}

Expand Down Expand Up @@ -858,7 +874,9 @@ struct NLMessage: ~Copyable {
}

func send(on socket: NLSocket) throws {
try throwingNLError { nl_send_auto(socket._sk, _msg) }
try socket._queue.sync {
try throwingNLError { nl_send_auto(socket._sk, _msg) }
}
}

deinit {
Expand Down
Loading