Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ let package = Package(
.product(name: "NIOFoundationCompat", package: "swift-nio"),
.product(name: "ServiceLifecycle", package: "swift-service-lifecycle"),
],
swiftSettings: swiftSettings
swiftSettings: swiftSettings + [.enableExperimentalFeature("Lifetimes")]
),
.target(
name: "_ConnectionPoolModule",
Expand Down
182 changes: 182 additions & 0 deletions Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,131 @@ public struct PostgresCopyFromWriter: Sendable {
}
}

// PostgresBinaryCopyFromWriter relies on non-Escapable types, which were only introduced in Swift 6.2
#if compiler(>=6.2)
/// Handle to send binary data for a `COPY ... FROM STDIN` query to the backend.
///
/// It takes care of serializing `PostgresEncodable` column types into the binary format that Postgres expects.
public struct PostgresBinaryCopyFromWriter: ~Copyable {
/// Handle to serialize columns into a row that is being written by `PostgresBinaryCopyFromWriter`.
public struct ColumnWriter: ~Escapable, ~Copyable {
/// Pointer to the `PostgresBinaryCopyFromWriter` that is gathering the serialized data.
@usableFromInline
let underlying: UnsafeMutablePointer<PostgresBinaryCopyFromWriter>

/// The number of columns that have been written by this `ColumnWriter`.
@usableFromInline
var columns: UInt16 = 0

/// - Warning: Do not call directly, call `withColumnWriter` instead
@usableFromInline
init(_underlying: UnsafeMutablePointer<PostgresBinaryCopyFromWriter>) {
self.underlying = _underlying
}

@usableFromInline
static func withColumnWriter<T>(
writingTo underlying: inout PostgresBinaryCopyFromWriter,
body: (inout ColumnWriter) throws -> T
) rethrows -> T {
return try withUnsafeMutablePointer(to: &underlying) { pointerToUnderlying in
// We can guarantee that `ColumWriter` never outlives `underlying` because `ColumnWriter` is
// `~Escapable` and thus cannot escape the context of the closure to `withUnsafeMutablePointer`.
// To model this without resorting to unsafe pointers, we would need to be able to declare an `inout`
// reference to `PostgresBinaryCopyFromWriter` as a member of `ColumnWriter`, which isn't possible at
// the moment (https://github.com/swiftlang/swift/issues/85832).
var columnWriter = ColumnWriter(_underlying: pointerToUnderlying)
return try body(&columnWriter)
}
}

/// Serialize a single column to a row.
///
/// - Important: It is critical that that data type encoded here exactly matches the data type in the
/// database. For example, if the database stores an a 4-bit integer the corresponding `writeColumn` must
/// be called with an `Int32`. Serializing an integer of a different width will cause a deserialization
/// failure in the backend.
@inlinable
#if compiler(<6.3)
@_lifetime(&self)
#endif
public mutating func writeColumn(_ column: (some PostgresEncodable)?) throws {
columns += 1
try invokeWriteColumn(on: underlying, column)
}

// Needed to work around https://github.com/swiftlang/swift/issues/83309, copying the implementation into
// `writeColumn` causes an assertion failure when thread sanitizer is enabled.
@inlinable
func invokeWriteColumn(
on writer: UnsafeMutablePointer<PostgresBinaryCopyFromWriter>,
_ column: (some PostgresEncodable)?
) throws {
try writer.pointee.writeColumn(column)
}
}

/// The underlying `PostgresCopyFromWriter` that sends the serialized data to the backend.
@usableFromInline let underlying: PostgresCopyFromWriter

/// The buffer in which we accumulate binary data. Once this buffer exceeds `bufferSize`, we flush it to
/// the backend.
@usableFromInline var buffer = ByteBuffer()

/// Once `buffer` exceeds this size, it gets flushed to the backend.
@usableFromInline let bufferSize: Int

init(underlying: PostgresCopyFromWriter, bufferSize: Int) {
self.underlying = underlying
// Allocate 10% more than the buffer size because we only flush the buffer once it has exceeded `bufferSize`
buffer.reserveCapacity(bufferSize + bufferSize / 10)
self.bufferSize = bufferSize
}

/// Serialize a single row to the backend. Call `writeColumn` on `columnWriter` for every column that should be
/// included in the row.
@inlinable
public mutating func writeRow(_ body: (_ columnWriter: inout ColumnWriter) throws -> Void) async throws {
// Write a placeholder for the number of columns
let columnIndex = buffer.writerIndex
buffer.writeInteger(UInt16(0))

let columns = try ColumnWriter.withColumnWriter(writingTo: &self) { columnWriter in
try body(&columnWriter)
return columnWriter.columns
}

// Fill in the number of columns
buffer.setInteger(columns, at: columnIndex)

if buffer.readableBytes > bufferSize {
try await flush()
}
}

/// Serialize a single column to the buffer. Should only be called by `ColumnWriter`.
@inlinable
mutating func writeColumn(_ column: (some PostgresEncodable)?) throws {
guard let column else {
buffer.writeInteger(Int32(-1))
return
}
try buffer.writeLengthPrefixed(as: Int32.self) { buffer in
let startIndex = buffer.writerIndex
try column.encode(into: &buffer, context: .default)
return buffer.writerIndex - startIndex
}
}

/// Flush any pending data in the buffer to the backend.
@usableFromInline
mutating func flush() async throws {
try await underlying.write(buffer)
buffer.clear()
}
}
#endif

/// Specifies the format in which data is transferred to the backend in a COPY operation.
///
/// See the Postgres documentation at https://www.postgresql.org/docs/current/sql-copy.html for the option's meanings
Expand All @@ -113,15 +238,25 @@ public struct PostgresCopyFromFormat: Sendable {
public init() {}
}

/// Options that can be used to modify the `binary` format of a COPY operation.
public struct BinaryOptions: Sendable {
public init() {}
}

enum Format {
case text(TextOptions)
case binary(BinaryOptions)
}

var format: Format

public static func text(_ options: TextOptions) -> PostgresCopyFromFormat {
return PostgresCopyFromFormat(format: .text(options))
}

public static func binary(_ options: BinaryOptions) -> PostgresCopyFromFormat {
return PostgresCopyFromFormat(format: .binary(options))
}
}

/// Create a `COPY ... FROM STDIN` query based on the given parameters.
Expand Down Expand Up @@ -153,6 +288,8 @@ private func buildCopyFromQuery(
// Set the delimiter as a Unicode code point. This avoids the possibility of SQL injection.
queryOptions.append("DELIMITER U&'\\\(String(format: "%04x", delimiter.value))'")
}
case .binary:
queryOptions.append("FORMAT binary")
}
precondition(!queryOptions.isEmpty)
query += " WITH ("
Expand All @@ -162,6 +299,51 @@ private func buildCopyFromQuery(
}

extension PostgresConnection {
#if compiler(>=6.2)
/// Copy data into a table using a `COPY <table name> FROM STDIN` query, transferring data in a binary format.
///
/// - Parameters:
/// - table: The name of the table into which to copy the data.
/// - columns: The name of the columns to copy. If an empty array is passed, all columns are assumed to be copied.
/// - bufferSize: How many bytes to accumulate a local buffer before flushing it to the database. Can affect
/// performance characteristics of the copy operation.
/// - writeData: Closure that produces the data for the table, to be streamed to the backend. Call `write` on the
/// writer provided by the closure to send data to the backend and return from the closure once all data is sent.
/// Throw an error from the closure to fail the data transfer. The error thrown by the closure will be rethrown
/// by the `copyFromBinary` function.
///
/// - Important: The table and column names are inserted into the `COPY FROM` query as passed and might thus be
/// susceptible to SQL injection. Ensure no untrusted data is contained in these strings.
public func copyFromBinary(
table: String,
columns: [String] = [],
options: PostgresCopyFromFormat.BinaryOptions = .init(),
bufferSize: Int = 100_000,
logger: Logger,
file: String = #fileID,
line: Int = #line,
writeData: (inout PostgresBinaryCopyFromWriter) async throws -> Void
) async throws {
try await copyFrom(table: table, columns: columns, format: .binary(PostgresCopyFromFormat.BinaryOptions()), logger: logger) { writer in
var header = ByteBuffer()
header.writeString("PGCOPY\n")
header.writeInteger(UInt8(0xff))
header.writeString("\r\n\0")

// Flag fields
header.writeInteger(UInt32(0))

// Header extension area length
header.writeInteger(UInt32(0))
try await writer.write(header)

var binaryWriter = PostgresBinaryCopyFromWriter(underlying: writer, bufferSize: bufferSize)
try await writeData(&binaryWriter)
try await binaryWriter.flush()
}
}
#endif

/// Copy data into a table using a `COPY <table name> FROM STDIN` query.
///
/// - Parameters:
Expand Down
36 changes: 36 additions & 0 deletions Tests/IntegrationTests/PSQLIntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -487,4 +487,40 @@ final class IntegrationTests: XCTestCase {
XCTAssertEqual((error as? PSQLError)?.serverInfo?[.sqlState], "42601") // scanner_yyerror
}
}

#if compiler(>=6.2) // copyFromBinary is only available in Swift 6.2+
func testCopyFromBinary() async throws {
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2)
defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) }
let eventLoop = eventLoopGroup.next()

let conn = try await PostgresConnection.test(on: eventLoop).get()
defer { XCTAssertNoThrow(try conn.close().wait()) }

_ = try? await conn.query("DROP TABLE copy_table", logger: .psqlTest).get()
_ = try await conn.query("CREATE TABLE copy_table (id INT, name VARCHAR(100))", logger: .psqlTest).get()

try await conn.copyFromBinary(table: "copy_table", columns: ["id", "name"], logger: .psqlTest) { writer in
let records: [(id: Int, name: String)] = [
(1, "Alice"),
(42, "Bob")
]
for record in records {
try await writer.writeRow { columnWriter in
try columnWriter.writeColumn(Int32(record.id))
try columnWriter.writeColumn(record.name)
}
}
}
let rows = try await conn.query("SELECT id, name FROM copy_table").get().rows.map { try $0.decode((Int, String).self) }
guard rows.count == 2 else {
XCTFail("Expected 2 columns, received \(rows.count)")
return
}
XCTAssertEqual(rows[0].0, 1)
XCTAssertEqual(rows[0].1, "Alice")
XCTAssertEqual(rows[1].0, 42)
XCTAssertEqual(rows[1].1, "Bob")
}
#endif
}
64 changes: 64 additions & 0 deletions Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,70 @@ import Synchronization
}
}

#if compiler(>=6.2) // copyFromBinary is only available in Swift 6.2+
@Test func testCopyFromBinary() async throws {
try await self.withAsyncTestingChannel { connection, channel in
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> Void in
taskGroup.addTask {
try await connection.copyFromBinary(table: "copy_table", logger: .psqlTest) {
writer in
try await writer.writeRow { columnWriter in
try columnWriter.writeColumn(Int32(1))
try columnWriter.writeColumn("Alice")
}
try await writer.writeRow { columnWriter in
try columnWriter.writeColumn(Int32(2))
try columnWriter.writeColumn("Bob")
}
}
}

let copyRequest = try await channel.waitForUnpreparedRequest()
#expect(copyRequest.parse.query == #"COPY "copy_table" FROM STDIN WITH (FORMAT binary)"#)

try await channel.sendUnpreparedRequestWithNoParametersBindResponse()
try await channel.writeInbound(
PostgresBackendMessage.copyInResponse(
.init(format: .binary, columnFormats: [.binary, .binary])))

let copyData = try await channel.waitForCopyData()
#expect(copyData.result == .done)
var data = copyData.data
// Signature
#expect(data.readString(length: 7) == "PGCOPY\n")
#expect(data.readInteger(as: UInt8.self) == 0xff)
#expect(data.readString(length: 3) == "\r\n\0")
// Flags
#expect(data.readInteger(as: UInt32.self) == 0)
// Header extension area length
#expect(data.readInteger(as: UInt32.self) == 0)

struct Row: Equatable {
let id: Int32
let name: String
}
var rows: [Row] = []
while data.readableBytes > 0 {
// Number of columns
#expect(data.readInteger(as: UInt16.self) == 2)
// 'id' column
#expect(data.readInteger(as: UInt32.self) == 4)
let id = data.readInteger(as: Int32.self)
// 'name' column length
let nameLength = data.readInteger(as: UInt32.self)
let name = data.readString(length: Int(try #require(nameLength)))
rows.append(Row(id: try #require(id), name: try #require(name)))
}
#expect(rows == [Row(id: 1, name: "Alice"), Row(id: 2, name: "Bob")])
try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 1"))

try await channel.waitForPostgresFrontendMessage(\.sync)
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))
}
}
}
#endif

func withAsyncTestingChannel(_ body: (PostgresConnection, NIOAsyncTestingChannel) async throws -> ()) async throws {
let eventLoop = NIOAsyncTestingEventLoop()
let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in
Expand Down
Loading