Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions adbc_drivers_validation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ class DriverFeatures(BaseModel):
_secondary_catalog: str | FromEnv | None = PrivateAttr(default=None)
_secondary_catalog_schema: str | FromEnv | None = PrivateAttr(default=None)
supported_xdbc_fields: list[str] = Field(default_factory=list)
# Some databases support temporary tables, but they exist in the same
# namespace as regular tables, so we need to change how we test them.
quirk_bulk_ingest_temporary_shares_namespace: bool = Field(default=False)
# Some vendors sort the columns, so declaring FOREIGN KEY(b, a) REFERENCES
# foo(d, c) still gets returned in the order (a, c), (b, d)
quirk_get_objects_constraints_foreign_normalized: bool = Field(default=False)
Expand Down Expand Up @@ -331,8 +334,10 @@ def qualify_temp_table(
"""
raise NotImplementedError

def quote_identifier(self, *identifiers: str) -> str:
return ".".join(self.quote_one_identifier(ident) for ident in identifiers)
def quote_identifier(self, *identifiers: str | None) -> str:
return ".".join(
self.quote_one_identifier(ident) for ident in identifiers if ident
)

def quote_one_identifier(self, identifier: str) -> str:
"""Quote an identifier (e.g. table or column name)."""
Expand Down Expand Up @@ -546,6 +551,9 @@ def merge(
# absent a specific bind query, we should bind the
# parameters to the regular query
if parent.query.bind_query_path:
# If one is defined, all of them must be
assert parent.query.bind_schema_path is not None
assert parent.query.bind_path is not None
params["bind_query_path"] = parent.query.bind_query_path
params["bind_schema_path"] = parent.query.bind_schema_path
params["bind_path"] = parent.query.bind_path
Expand Down
100 changes: 59 additions & 41 deletions adbc_drivers_validation/tests/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def generate_tests(
driver_param = f"{quirks.name}:{quirks.short_version}"
enabled = {
"test_not_null": quirks.features.statement_bulk_ingest,
"test_temporary": quirks.features.statement_bulk_ingest_temporary,
"test_schema": quirks.features.statement_bulk_ingest_schema,
"test_catalog": quirks.features.statement_bulk_ingest_catalog,
"test_many_columns": quirks.features.statement_bulk_ingest,
Expand All @@ -70,6 +69,7 @@ def generate_tests(
enabled = {
"test_replace_catalog": quirks.features.statement_bulk_ingest_catalog,
"test_replace_schema": quirks.features.statement_bulk_ingest_schema,
"test_temporary": quirks.features.statement_bulk_ingest_temporary,
}.get(metafunc.definition.name, None)
for query in quirks.query_set.queries.values():
marks = []
Expand Down Expand Up @@ -422,51 +422,69 @@ def test_temporary(
self,
driver: model.DriverQuirks,
conn_factory: typing.Callable[[], adbc_driver_manager.dbapi.Connection],
query: Query,
) -> None:
data1 = pyarrow.Table.from_pydict(
{
"idx": [1, 2, 3],
"value": ["foo", "bar", "baz"],
}
)
data2 = pyarrow.Table.from_pydict(
{
"idx": [4, 5, 6],
"value": ["qux", "quux", "spam"],
}
)
subquery = query.query
assert isinstance(subquery, model.IngestQuery)
data = subquery.input()
expected = subquery.expected()

table_name = "test_ingest_temporary"

idx = driver.quote_identifier("idx")
value = driver.quote_identifier("value")

with conn_factory() as conn:
with conn.cursor() as cursor:
driver.try_drop_table(cursor, table_name=table_name)
cursor.adbc_ingest(table_name, data1, temporary=True)
cursor.adbc_ingest(table_name, data2, temporary=False)

with conn.cursor() as cursor:
assert driver.features.current_schema is not None
normal_table = driver.quote_identifier(
driver.features.current_schema, table_name
)
temp_table = driver.qualify_temp_table(cursor, table_name)
select_normal = (
f"SELECT {idx}, {value} FROM {normal_table} ORDER BY {idx} ASC"
)
select_temporary = (
f"SELECT {idx}, {value} FROM {temp_table} ORDER BY {idx} ASC"
)

result_normal = execute_query_without_prepare(cursor, select_normal)

result_temporary = execute_query_without_prepare(
cursor, select_temporary
)

compare.compare_tables(data1, result_temporary)
compare.compare_tables(data2, result_normal)
if driver.features.quirk_bulk_ingest_temporary_shares_namespace:
with conn_factory() as conn:
with conn.cursor() as cursor:
driver.try_drop_table(cursor, table_name=table_name)
cursor.adbc_ingest(table_name, data, temporary=True)
temp_table = driver.qualify_temp_table(cursor, table_name)
select_temporary = (
f"SELECT {idx}, {value} FROM {temp_table} ORDER BY {idx} ASC"
)
result_temporary = execute_query_without_prepare(
cursor, select_temporary
)
compare.compare_tables(expected, result_temporary)

with conn_factory() as conn:
with conn.cursor() as cursor:
temp_table = driver.qualify_temp_table(cursor, table_name)
select_temporary = (
f"SELECT {idx}, {value} FROM {temp_table} ORDER BY {idx} ASC"
)
with pytest.raises(Exception) as excinfo:
execute_query_without_prepare(cursor, select_temporary)
assert driver.is_table_not_found(table_name, excinfo.value)
else:
with conn_factory() as conn:
with conn.cursor() as cursor:
driver.try_drop_table(cursor, table_name=table_name)
cursor.adbc_ingest(table_name, data.slice(0, 1), temporary=True)
cursor.adbc_ingest(table_name, data.slice(1), temporary=False)

with conn.cursor() as cursor:
assert driver.features.current_schema is not None
normal_table = driver.quote_identifier(
driver.features.current_catalog,
driver.features.current_schema,
table_name,
)
temp_table = driver.qualify_temp_table(cursor, table_name)
select_normal = (
f"SELECT {idx}, {value} FROM {normal_table} ORDER BY {idx} ASC"
)
select_temporary = (
f"SELECT {idx}, {value} FROM {temp_table} ORDER BY {idx} ASC"
)

result_normal = execute_query_without_prepare(cursor, select_normal)
result_temporary = execute_query_without_prepare(
cursor, select_temporary
)

compare.compare_tables(expected.slice(0, 1), result_temporary)
compare.compare_tables(expected.slice(1), result_normal)

def test_schema(
self,
Expand Down
12 changes: 4 additions & 8 deletions adbc_drivers_validation/tests/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,17 @@ def generate_tests(all_quirks: list[model.DriverQuirks], metafunc) -> None:
):
marks.append(pytest.mark.skip(reason="bind not supported"))

if metafunc.definition.name == "test_execute_schema":
if metafunc.definition.name in {
"test_execute_schema",
"test_get_table_schema",
}:
if not isinstance(query.query, model.SelectQuery):
continue
if not query.name.startswith("type/select/"):
# There's no need to repeat this test multiple times per type
continue
if not quirks.features.statement_execute_schema:
marks.append(pytest.mark.skip(reason="not implemented"))
elif metafunc.definition.name == "test_get_table_schema":
if not isinstance(query.query, model.SelectQuery):
continue
elif not query.name.startswith("type/select/"):
continue
elif not quirks.features.connection_get_table_schema:
marks.append(pytest.mark.skip(reason="not implemented"))
elif metafunc.definition.name == "test_query":
if not isinstance(query.query, model.SelectQuery):
continue
Expand Down
10 changes: 7 additions & 3 deletions adbc_drivers_validation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,13 @@ def execute_query_without_prepare(
The result of the query.
"""
cursor.adbc_statement.set_sql_query(query)
handle, _ = cursor.adbc_statement.execute_query()
with pyarrow.RecordBatchReader._import_from_c(handle.address) as reader:
return reader.read_all()
try:
handle, _ = cursor.adbc_statement.execute_query()
with pyarrow.RecordBatchReader._import_from_c(handle.address) as reader:
return reader.read_all()
except Exception as e:
e.add_note(f"Query: {query}")
raise


def arrow_type_name(arrow_type, metadata=None, show_type_parameters=False):
Expand Down
Loading