From 13209ec4e6a444be211029077edaf0137fae4d75 Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Fri, 5 Dec 2025 00:50:34 +0100 Subject: [PATCH] Throw error when `PostgresCopyFromWriter` is escaped MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When `PostgresCopyFromWriter` is escaped from the `copyFrom` closure (which users really shouldn’t do) and the user tries to write data to it after `copyFrom` has finished, throw an error instead of hitting a `preconditionFailure`. --- .../PostgresConnection+CopyFrom.swift | 10 +++++++ .../ConnectionStateMachine.swift | 2 +- .../ExtendedQueryStateMachine.swift | 2 +- .../New/PostgresConnectionTests.swift | 30 +++++++++++++++++++ 4 files changed, 42 insertions(+), 2 deletions(-) diff --git a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift index bdfcbd2c..209d53c0 100644 --- a/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift +++ b/Sources/PostgresNIO/Connection/PostgresConnection+CopyFrom.swift @@ -1,3 +1,13 @@ +/// Error thrown when `PostgresCopyFromWriter.write` is called while the backend is not expecting `CopyData` messages. +/// +/// This should only happen if the user explicitly escapes the `PostgresCopyFromWriter` from the `copyFrom` closure and +/// tries to write data to it when tries to write data after the `copyFrom` call has terminated. +struct NotInCopyFromModeError: Error, CustomStringConvertible { + var description: String { + "Writing COPY FROM data when backend is not in COPY mode" + } +} + /// Handle to send data for a `COPY ... FROM STDIN` query to the backend. public struct PostgresCopyFromWriter: Sendable { private let channelHandler: NIOLoopBound diff --git a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift index c4040cc3..3af931b3 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift @@ -827,7 +827,7 @@ struct ConnectionStateMachine { /// should be aborted to avoid unnecessary work. mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise) -> CheckBackendCanReceiveCopyDataAction { guard case .extendedQuery(var queryState, let connectionContext) = self.state else { - preconditionFailure("Copy mode is only supported for extended queries") + return .failPromise(promise, error: NotInCopyFromModeError()) } self.state = .modifying // avoid CoW diff --git a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift index be76b68c..f9ee2d6e 100644 --- a/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift +++ b/Sources/PostgresNIO/New/Connection State Machine/ExtendedQueryStateMachine.swift @@ -430,7 +430,7 @@ struct ExtendedQueryStateMachine { return .failPromise(promise, error: error) } guard case .copyingData(.readyToSend) = self.state else { - preconditionFailure("Not ready to send data") + return .failPromise(promise, error: NotInCopyFromModeError()) } if channelIsWritable { return .succeedPromise(promise) diff --git a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift index 28c593cb..70f2f331 100644 --- a/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift +++ b/Tests/PostgresNIOTests/New/PostgresConnectionTests.swift @@ -936,6 +936,36 @@ import Synchronization } } + @Test func testWritingDataAfterCopyFromHasFinishedThrowsError() async throws { + try await self.withAsyncTestingChannel { connection, channel in + try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in + taskGroup.addTask { + var escapedWriter: PostgresCopyFromWriter? + try await connection.copyFrom(table: "test", logger: .psqlTest) { writer in + escapedWriter = writer + } + let writer = try #require(escapedWriter) + await #expect(throws: (any Error).self) { + try await writer.write(ByteBuffer(string: "oops")) + } + } + + _ = try await channel.waitForUnpreparedRequest() + + try await channel.sendUnpreparedRequestWithNoParametersBindResponse() + try await channel.writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .textual, columnFormats: Array(repeating: .textual, count: 2)))) + + _ = try await channel.waitForCopyData() + try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 0")) + + try await channel.waitForPostgresFrontendMessage(\.sync) + try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle)) + + try await taskGroup.waitForAll() + } + } + } + func withAsyncTestingChannel(_ body: (PostgresConnection, NIOAsyncTestingChannel) async throws -> ()) async throws { let eventLoop = NIOAsyncTestingEventLoop() let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in