diff --git a/c/driver/postgresql/bind_stream.h b/c/driver/postgresql/bind_stream.h index 25c55eec7e..4e73477d44 100644 --- a/c/driver/postgresql/bind_stream.h +++ b/c/driver/postgresql/bind_stream.h @@ -59,6 +59,10 @@ struct BindStream { bool autocommit = false; std::string tz_setting; + // Expected types from PostgreSQL (after DESCRIBE); used to resolve NA params + PostgresType expected_param_types; + bool has_expected_types = false; + struct ArrowError na_error; BindStream() { @@ -71,6 +75,20 @@ struct BindStream { ArrowArrayStreamMove(stream, &bind.value); } + Status ReconcileWithExpectedTypes(const PostgresType& expected_types) { + if (bind_schema->release == nullptr) { + return Status::InvalidState("[libpq] Bind stream schema not initialized"); + } + if (expected_types.n_children() != bind_schema->n_children) { + return Status::InvalidState("[libpq] Expected ", expected_types.n_children(), + " parameters but bind stream has ", + bind_schema->n_children); + } + expected_param_types = expected_types; + has_expected_types = true; + return Status::Ok(); + } + template Status Begin(Callback&& callback) { UNWRAP_NANOARROW( @@ -111,9 +129,28 @@ struct BindStream { for (size_t i = 0; i < bind_field_writers.size(); i++) { PostgresType type; - UNWRAP_NANOARROW(na_error, Internal, - PostgresType::FromSchema(type_resolver, bind_schema->children[i], - &type, &na_error)); + + // Handle NA type by using expected parameter type from PostgreSQL + if (has_expected_types && bind_schema_fields[i].type == NANOARROW_TYPE_NA && + i < static_cast(expected_param_types.n_children())) { + const auto& expected_type = expected_param_types.child(i); + // If PostgreSQL couldn't infer a concrete type (e.g., SELECT $1), don't + // force an "expected" type; fall back to Arrow-derived mapping. + if (expected_type.oid() != 0 && + expected_type.type_id() != PostgresTypeId::kUnknown) { + type = expected_type; + } else { + UNWRAP_NANOARROW( + na_error, Internal, + PostgresType::FromSchema(type_resolver, bind_schema->children[i], &type, + &na_error)); + } + } else { + // Normal case: derive type from Arrow schema + UNWRAP_NANOARROW(na_error, Internal, + PostgresType::FromSchema(type_resolver, bind_schema->children[i], + &type, &na_error)); + } // tz-aware timestamps require special handling to set the timezone to UTC // prior to sending over the binary protocol; must be reset after execute @@ -205,6 +242,14 @@ struct BindStream { for (int64_t col = 0; col < array_view->n_children; col++) { is_null_param[col] = ArrowArrayViewIsNull(array_view->children[col], current_row); + + // Safety check: NA type arrays should only contain nulls + if (bind_schema_fields[col].type == NANOARROW_TYPE_NA && !is_null_param[col]) { + return Status::InvalidArgument( + "Parameter $", col + 1, + " has null type but contains a non-null value at row ", current_row); + } + if (!is_null_param[col]) { // Note that this Write() call currently writes the (int32_t) byte size of the // field in addition to the serialized value. diff --git a/c/driver/postgresql/copy/writer.h b/c/driver/postgresql/copy/writer.h index 2b31310e70..512a29afee 100644 --- a/c/driver/postgresql/copy/writer.h +++ b/c/driver/postgresql/copy/writer.h @@ -111,6 +111,17 @@ class PostgresCopyFieldWriter { std::vector> children_; }; +class PostgresCopyNullFieldWriter : public PostgresCopyFieldWriter { + public: + ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override { + (void)buffer; + ArrowErrorSet(error, + "[libpq] Unexpected non-null value for Arrow null type at row %" PRId64, + index); + return EINVAL; + } +}; + class PostgresCopyFieldTupleWriter : public PostgresCopyFieldWriter { public: void AppendChild(std::unique_ptr child) { @@ -131,7 +142,7 @@ class PostgresCopyFieldTupleWriter : public PostgresCopyFieldWriter { constexpr int32_t field_size_bytes = -1; NANOARROW_RETURN_NOT_OK(WriteChecked(buffer, field_size_bytes, error)); } else { - children_[i]->Write(buffer, index, error); + NANOARROW_RETURN_NOT_OK(children_[i]->Write(buffer, index, error)); } } @@ -743,6 +754,9 @@ static inline ArrowErrorCode MakeCopyFieldWriter( NANOARROW_RETURN_NOT_OK(ArrowSchemaViewInit(&schema_view, schema, error)); switch (schema_view.type) { + case NANOARROW_TYPE_NA: + *out = PostgresCopyFieldWriter::Create(array_view); + return NANOARROW_OK; case NANOARROW_TYPE_BOOL: using T = PostgresCopyBooleanFieldWriter; *out = T::Create(array_view); diff --git a/c/driver/postgresql/postgres_type.h b/c/driver/postgresql/postgres_type.h index e8935cc76b..9768bb1c9f 100644 --- a/c/driver/postgresql/postgres_type.h +++ b/c/driver/postgresql/postgres_type.h @@ -653,6 +653,12 @@ inline ArrowErrorCode PostgresType::FromSchema(const PostgresTypeResolver& resol // Dictionary arrays always resolve to the dictionary type when binding or ingesting return PostgresType::FromSchema(resolver, schema->dictionary, out, error); + case NANOARROW_TYPE_NA: + // NA type - default to TEXT which PostgreSQL can coerce to any type + // This provides a fallback when we don't have expected type information (e.g., + // COPY) + return resolver.Find(resolver.GetOID(PostgresTypeId::kText), out, error); + default: ArrowErrorSet(error, "Can't map Arrow type '%s' to Postgres type", ArrowTypeString(schema_view.type)); diff --git a/c/driver/postgresql/result_helper.cc b/c/driver/postgresql/result_helper.cc index e455467bd5..2557e8f944 100644 --- a/c/driver/postgresql/result_helper.cc +++ b/c/driver/postgresql/result_helper.cc @@ -147,13 +147,15 @@ Status PqResultHelper::ResolveParamTypes(PostgresTypeResolver& type_resolver, for (int i = 0; i < num_params; i++) { const Oid pg_oid = PQparamtype(result_, i); PostgresType pg_type; - if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { - std::string param_name = "$" + std::to_string(i + 1); - Status status = - Status::NotImplemented("[libpq] Parameter #", i + 1, " (\"", param_name, - "\") has unknown type code ", pg_oid); - ClearResult(); - return status; + if (pg_oid == 0) { + // PostgreSQL didn't infer a type (can happen in ambiguous contexts). + pg_type = PostgresType(PostgresTypeId::kUnknown).WithPgTypeInfo(0, "unknown"); + } else if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) { + // Don't fail preparation just because we can't resolve an expected parameter + // type. We'll fall back to Arrow-derived mapping. + pg_type = + PostgresType(PostgresTypeId::kUnknown) + .WithPgTypeInfo(pg_oid, "unknown"); } std::string param_name = "$" + std::to_string(i + 1); diff --git a/c/driver/postgresql/result_reader.cc b/c/driver/postgresql/result_reader.cc index ad73d884a3..6552270c7a 100644 --- a/c/driver/postgresql/result_reader.cc +++ b/c/driver/postgresql/result_reader.cc @@ -148,8 +148,29 @@ Status PqResultArrayReader::Initialize(int64_t* rows_affected) { if (bind_stream_) { UNWRAP_STATUS(bind_stream_->Begin([] { return Status::Ok(); })); + const bool has_na_params = std::any_of( + bind_stream_->bind_schema_fields.begin(), bind_stream_->bind_schema_fields.end(), + [](const ArrowSchemaView& view) { return view.type == NANOARROW_TYPE_NA; }); + if (has_na_params) { + // Prepare WITHOUT parameter types to let PostgreSQL infer them. This is required + // to resolve Arrow NA (all-null) parameters to the expected PostgreSQL types. + UNWRAP_STATUS(helper_.Prepare()); + + // Get PostgreSQL's expected parameter types + UNWRAP_STATUS(helper_.DescribePrepared()); + PostgresType expected_types; + UNWRAP_STATUS(helper_.ResolveParamTypes(*type_resolver_, &expected_types)); + + // Reconcile Arrow schema with expected types + UNWRAP_STATUS(bind_stream_->ReconcileWithExpectedTypes(expected_types)); + } + + // Now set parameter types (will use reconciled types for NA fields) UNWRAP_STATUS(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_)); + + // Re-prepare with the actual parameter types UNWRAP_STATUS(helper_.Prepare(bind_stream_->param_types)); + UNWRAP_STATUS(BindNextAndExecute(nullptr)); // If there were no arrays in the bind stream, we still need a result @@ -252,6 +273,20 @@ Status PqResultArrayReader::ExecuteAll(int64_t* affected_rows) { // stream (if there is one) or execute the query without binding. if (bind_stream_) { UNWRAP_STATUS(bind_stream_->Begin([] { return Status::Ok(); })); + + const bool has_na_params = std::any_of( + bind_stream_->bind_schema_fields.begin(), bind_stream_->bind_schema_fields.end(), + [](const ArrowSchemaView& view) { return view.type == NANOARROW_TYPE_NA; }); + if (has_na_params) { + // Prepare without parameter types so PostgreSQL can infer them. This is required + // to resolve Arrow NA (all-null) parameters to the expected PostgreSQL types. + UNWRAP_STATUS(helper_.Prepare()); + UNWRAP_STATUS(helper_.DescribePrepared()); + PostgresType expected_types; + UNWRAP_STATUS(helper_.ResolveParamTypes(*type_resolver_, &expected_types)); + UNWRAP_STATUS(bind_stream_->ReconcileWithExpectedTypes(expected_types)); + } + UNWRAP_STATUS(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_)); UNWRAP_STATUS(helper_.Prepare(bind_stream_->param_types)); diff --git a/python/adbc_driver_postgresql/tests/test_dbapi.py b/python/adbc_driver_postgresql/tests/test_dbapi.py index 952389de27..9d6c617040 100644 --- a/python/adbc_driver_postgresql/tests/test_dbapi.py +++ b/python/adbc_driver_postgresql/tests/test_dbapi.py @@ -586,3 +586,90 @@ def test_server_terminates_connection(postgres_uri: str) -> None: with pytest.raises(Exception): with conn2.cursor() as cur: cur.execute("SELECT 1") + + +# Tests for issue #3549: Cannot use null values as bound parameters +def test_bind_null_insert(postgres: dbapi.Connection) -> None: + """Test INSERT with None parameter (issue #3549).""" + with postgres.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS test_null_binding") + cur.execute("CREATE TABLE test_null_binding(a TEXT, b INT)") + # This should not raise an error about mapping Arrow type 'na' to Postgres type + cur.execute("INSERT INTO test_null_binding VALUES ($1, $2)", ("hello", None)) + postgres.commit() + + +def test_bind_null_update(postgres: dbapi.Connection) -> None: + """Test UPDATE with None parameter (issue #3549).""" + with postgres.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS test_null_binding") + cur.execute("CREATE TABLE test_null_binding(a TEXT, b INT)") + cur.execute("INSERT INTO test_null_binding VALUES ('hello', 42)") + postgres.commit() + + with postgres.cursor() as cur: + # This should not raise an error + cur.execute("UPDATE test_null_binding SET b=$2 WHERE a=$1", ("hello", None)) + postgres.commit() + + with postgres.cursor() as cur: + cur.execute("SELECT a, b FROM test_null_binding WHERE a='hello'") + result = cur.fetchone() + assert result is not None + assert result[0] == "hello" + assert result[1] is None + + +def test_executemany_all_nulls(postgres: dbapi.Connection) -> None: + """Test executemany with all None values (issue #3549).""" + with postgres.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS test_null_binding") + cur.execute("CREATE TABLE test_null_binding(a TEXT, b INT)") + postgres.commit() + + with postgres.cursor() as cur: + # This is the critical test case from the issue + cur.executemany( + "INSERT INTO test_null_binding VALUES ($1, $2)", + [("hello", None), ("world", None)], + ) + postgres.commit() + + with postgres.cursor() as cur: + cur.execute("SELECT COUNT(*) FROM test_null_binding WHERE b IS NULL") + row = cur.fetchone() + assert row is not None + count = row[0] + assert count == 2 + + +def test_bind_multiple_null_parameters(postgres: dbapi.Connection) -> None: + """Test binding multiple None parameters in a single statement (issue #3549).""" + with postgres.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS test_null_binding") + cur.execute("CREATE TABLE test_null_binding(a INT, b TEXT, c FLOAT)") + postgres.commit() + + with postgres.cursor() as cur: + # All parameters are None + cur.execute( + "INSERT INTO test_null_binding VALUES ($1, $2, $3)", (None, None, None) + ) + postgres.commit() + + with postgres.cursor() as cur: + cur.execute("SELECT * FROM test_null_binding") + result = cur.fetchone() + assert result is not None + assert result[0] is None + assert result[1] is None + assert result[2] is None + + +def test_bind_null_unknown_inference(postgres: dbapi.Connection) -> None: + """Test binding None where PostgreSQL can't infer a concrete parameter type.""" + with postgres.cursor() as cur: + cur.execute("SELECT $1", (None,)) + result = cur.fetchone() + assert result is not None + assert result[0] is None