Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions c/driver/postgresql/bind_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ struct BindStream {
bool autocommit = false;
std::string tz_setting;

// Expected types from PostgreSQL (after DESCRIBE); used to resolve NA params
PostgresType expected_param_types;
bool has_expected_types = false;

struct ArrowError na_error;

BindStream() {
Expand All @@ -71,6 +75,20 @@ struct BindStream {
ArrowArrayStreamMove(stream, &bind.value);
}

Status ReconcileWithExpectedTypes(const PostgresType& expected_types) {
if (bind_schema->release == nullptr) {
return Status::InvalidState("[libpq] Bind stream schema not initialized");
}
if (expected_types.n_children() != bind_schema->n_children) {
return Status::InvalidState("[libpq] Expected ", expected_types.n_children(),
" parameters but bind stream has ",
bind_schema->n_children);
}
expected_param_types = expected_types;
has_expected_types = true;
return Status::Ok();
}

template <typename Callback>
Status Begin(Callback&& callback) {
UNWRAP_NANOARROW(
Expand Down Expand Up @@ -111,9 +129,28 @@ struct BindStream {

for (size_t i = 0; i < bind_field_writers.size(); i++) {
PostgresType type;
UNWRAP_NANOARROW(na_error, Internal,
PostgresType::FromSchema(type_resolver, bind_schema->children[i],
&type, &na_error));

// Handle NA type by using expected parameter type from PostgreSQL
if (has_expected_types && bind_schema_fields[i].type == NANOARROW_TYPE_NA &&
i < static_cast<size_t>(expected_param_types.n_children())) {
const auto& expected_type = expected_param_types.child(i);
// If PostgreSQL couldn't infer a concrete type (e.g., SELECT $1), don't
// force an "expected" type; fall back to Arrow-derived mapping.
if (expected_type.oid() != 0 &&
expected_type.type_id() != PostgresTypeId::kUnknown) {
type = expected_type;
} else {
UNWRAP_NANOARROW(
na_error, Internal,
PostgresType::FromSchema(type_resolver, bind_schema->children[i], &type,
&na_error));
}
} else {
// Normal case: derive type from Arrow schema
UNWRAP_NANOARROW(na_error, Internal,
PostgresType::FromSchema(type_resolver, bind_schema->children[i],
&type, &na_error));
}

// tz-aware timestamps require special handling to set the timezone to UTC
// prior to sending over the binary protocol; must be reset after execute
Expand Down Expand Up @@ -205,6 +242,14 @@ struct BindStream {

for (int64_t col = 0; col < array_view->n_children; col++) {
is_null_param[col] = ArrowArrayViewIsNull(array_view->children[col], current_row);

// Safety check: NA type arrays should only contain nulls
if (bind_schema_fields[col].type == NANOARROW_TYPE_NA && !is_null_param[col]) {
return Status::InvalidArgument(
"Parameter $", col + 1,
" has null type but contains a non-null value at row ", current_row);
}

if (!is_null_param[col]) {
// Note that this Write() call currently writes the (int32_t) byte size of the
// field in addition to the serialized value.
Expand Down
16 changes: 15 additions & 1 deletion c/driver/postgresql/copy/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,17 @@ class PostgresCopyFieldWriter {
std::vector<std::unique_ptr<PostgresCopyFieldWriter>> children_;
};

class PostgresCopyNullFieldWriter : public PostgresCopyFieldWriter {
public:
ArrowErrorCode Write(ArrowBuffer* buffer, int64_t index, ArrowError* error) override {
(void)buffer;
ArrowErrorSet(error,
"[libpq] Unexpected non-null value for Arrow null type at row %" PRId64,
index);
return EINVAL;
}
};

class PostgresCopyFieldTupleWriter : public PostgresCopyFieldWriter {
public:
void AppendChild(std::unique_ptr<PostgresCopyFieldWriter> child) {
Expand All @@ -131,7 +142,7 @@ class PostgresCopyFieldTupleWriter : public PostgresCopyFieldWriter {
constexpr int32_t field_size_bytes = -1;
NANOARROW_RETURN_NOT_OK(WriteChecked<int32_t>(buffer, field_size_bytes, error));
} else {
children_[i]->Write(buffer, index, error);
NANOARROW_RETURN_NOT_OK(children_[i]->Write(buffer, index, error));
}
}

Expand Down Expand Up @@ -743,6 +754,9 @@ static inline ArrowErrorCode MakeCopyFieldWriter(
NANOARROW_RETURN_NOT_OK(ArrowSchemaViewInit(&schema_view, schema, error));

switch (schema_view.type) {
case NANOARROW_TYPE_NA:
*out = PostgresCopyFieldWriter::Create<PostgresCopyNullFieldWriter>(array_view);
return NANOARROW_OK;
case NANOARROW_TYPE_BOOL:
using T = PostgresCopyBooleanFieldWriter;
*out = T::Create<T>(array_view);
Expand Down
6 changes: 6 additions & 0 deletions c/driver/postgresql/postgres_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,12 @@ inline ArrowErrorCode PostgresType::FromSchema(const PostgresTypeResolver& resol
// Dictionary arrays always resolve to the dictionary type when binding or ingesting
return PostgresType::FromSchema(resolver, schema->dictionary, out, error);

case NANOARROW_TYPE_NA:
// NA type - default to TEXT which PostgreSQL can coerce to any type
// This provides a fallback when we don't have expected type information (e.g.,
// COPY)
return resolver.Find(resolver.GetOID(PostgresTypeId::kText), out, error);

default:
ArrowErrorSet(error, "Can't map Arrow type '%s' to Postgres type",
ArrowTypeString(schema_view.type));
Expand Down
16 changes: 9 additions & 7 deletions c/driver/postgresql/result_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,13 +147,15 @@ Status PqResultHelper::ResolveParamTypes(PostgresTypeResolver& type_resolver,
for (int i = 0; i < num_params; i++) {
const Oid pg_oid = PQparamtype(result_, i);
PostgresType pg_type;
if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
std::string param_name = "$" + std::to_string(i + 1);
Status status =
Status::NotImplemented("[libpq] Parameter #", i + 1, " (\"", param_name,
"\") has unknown type code ", pg_oid);
ClearResult();
return status;
if (pg_oid == 0) {
// PostgreSQL didn't infer a type (can happen in ambiguous contexts).
pg_type = PostgresType(PostgresTypeId::kUnknown).WithPgTypeInfo(0, "unknown");
} else if (type_resolver.Find(pg_oid, &pg_type, &na_error) != NANOARROW_OK) {
// Don't fail preparation just because we can't resolve an expected parameter
// type. We'll fall back to Arrow-derived mapping.
pg_type =
PostgresType(PostgresTypeId::kUnknown)
.WithPgTypeInfo(pg_oid, "unknown<oid:" + std::to_string(pg_oid) + ">");
}

std::string param_name = "$" + std::to_string(i + 1);
Expand Down
35 changes: 35 additions & 0 deletions c/driver/postgresql/result_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,29 @@ Status PqResultArrayReader::Initialize(int64_t* rows_affected) {
if (bind_stream_) {
UNWRAP_STATUS(bind_stream_->Begin([] { return Status::Ok(); }));

const bool has_na_params = std::any_of(
bind_stream_->bind_schema_fields.begin(), bind_stream_->bind_schema_fields.end(),
[](const ArrowSchemaView& view) { return view.type == NANOARROW_TYPE_NA; });
if (has_na_params) {
// Prepare WITHOUT parameter types to let PostgreSQL infer them. This is required
// to resolve Arrow NA (all-null) parameters to the expected PostgreSQL types.
UNWRAP_STATUS(helper_.Prepare());

// Get PostgreSQL's expected parameter types
UNWRAP_STATUS(helper_.DescribePrepared());
PostgresType expected_types;
UNWRAP_STATUS(helper_.ResolveParamTypes(*type_resolver_, &expected_types));

// Reconcile Arrow schema with expected types
UNWRAP_STATUS(bind_stream_->ReconcileWithExpectedTypes(expected_types));
}

// Now set parameter types (will use reconciled types for NA fields)
UNWRAP_STATUS(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_));

// Re-prepare with the actual parameter types
UNWRAP_STATUS(helper_.Prepare(bind_stream_->param_types));

UNWRAP_STATUS(BindNextAndExecute(nullptr));

// If there were no arrays in the bind stream, we still need a result
Expand Down Expand Up @@ -252,6 +273,20 @@ Status PqResultArrayReader::ExecuteAll(int64_t* affected_rows) {
// stream (if there is one) or execute the query without binding.
if (bind_stream_) {
UNWRAP_STATUS(bind_stream_->Begin([] { return Status::Ok(); }));

const bool has_na_params = std::any_of(
bind_stream_->bind_schema_fields.begin(), bind_stream_->bind_schema_fields.end(),
[](const ArrowSchemaView& view) { return view.type == NANOARROW_TYPE_NA; });
if (has_na_params) {
// Prepare without parameter types so PostgreSQL can infer them. This is required
// to resolve Arrow NA (all-null) parameters to the expected PostgreSQL types.
UNWRAP_STATUS(helper_.Prepare());
UNWRAP_STATUS(helper_.DescribePrepared());
PostgresType expected_types;
UNWRAP_STATUS(helper_.ResolveParamTypes(*type_resolver_, &expected_types));
UNWRAP_STATUS(bind_stream_->ReconcileWithExpectedTypes(expected_types));
}

UNWRAP_STATUS(bind_stream_->SetParamTypes(conn_, *type_resolver_, autocommit_));
UNWRAP_STATUS(helper_.Prepare(bind_stream_->param_types));

Expand Down
87 changes: 87 additions & 0 deletions python/adbc_driver_postgresql/tests/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,90 @@ def test_server_terminates_connection(postgres_uri: str) -> None:
with pytest.raises(Exception):
with conn2.cursor() as cur:
cur.execute("SELECT 1")


# Tests for issue #3549: Cannot use null values as bound parameters
def test_bind_null_insert(postgres: dbapi.Connection) -> None:
"""Test INSERT with None parameter (issue #3549)."""
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_null_binding")
cur.execute("CREATE TABLE test_null_binding(a TEXT, b INT)")
# This should not raise an error about mapping Arrow type 'na' to Postgres type
cur.execute("INSERT INTO test_null_binding VALUES ($1, $2)", ("hello", None))
postgres.commit()


def test_bind_null_update(postgres: dbapi.Connection) -> None:
"""Test UPDATE with None parameter (issue #3549)."""
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_null_binding")
cur.execute("CREATE TABLE test_null_binding(a TEXT, b INT)")
cur.execute("INSERT INTO test_null_binding VALUES ('hello', 42)")
postgres.commit()

with postgres.cursor() as cur:
# This should not raise an error
cur.execute("UPDATE test_null_binding SET b=$2 WHERE a=$1", ("hello", None))
postgres.commit()

with postgres.cursor() as cur:
cur.execute("SELECT a, b FROM test_null_binding WHERE a='hello'")
result = cur.fetchone()
assert result is not None
assert result[0] == "hello"
assert result[1] is None


def test_executemany_all_nulls(postgres: dbapi.Connection) -> None:
"""Test executemany with all None values (issue #3549)."""
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_null_binding")
cur.execute("CREATE TABLE test_null_binding(a TEXT, b INT)")
postgres.commit()

with postgres.cursor() as cur:
# This is the critical test case from the issue
cur.executemany(
"INSERT INTO test_null_binding VALUES ($1, $2)",
[("hello", None), ("world", None)],
)
postgres.commit()

with postgres.cursor() as cur:
cur.execute("SELECT COUNT(*) FROM test_null_binding WHERE b IS NULL")
row = cur.fetchone()
assert row is not None
count = row[0]
assert count == 2


def test_bind_multiple_null_parameters(postgres: dbapi.Connection) -> None:
"""Test binding multiple None parameters in a single statement (issue #3549)."""
with postgres.cursor() as cur:
cur.execute("DROP TABLE IF EXISTS test_null_binding")
cur.execute("CREATE TABLE test_null_binding(a INT, b TEXT, c FLOAT)")
postgres.commit()

with postgres.cursor() as cur:
# All parameters are None
cur.execute(
"INSERT INTO test_null_binding VALUES ($1, $2, $3)", (None, None, None)
)
postgres.commit()

with postgres.cursor() as cur:
cur.execute("SELECT * FROM test_null_binding")
result = cur.fetchone()
assert result is not None
assert result[0] is None
assert result[1] is None
assert result[2] is None


def test_bind_null_unknown_inference(postgres: dbapi.Connection) -> None:
"""Test binding None where PostgreSQL can't infer a concrete parameter type."""
with postgres.cursor() as cur:
cur.execute("SELECT $1", (None,))
result = cur.fetchone()
assert result is not None
assert result[0] is None
Loading