Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions adbc_drivers_validation/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,10 @@ class SelectQuery:
bind_schema_path: Path | None = None
#: Data for the bind parameters.
bind_path: Path | None = None
#: Schema of the result set (when not executing a query). For some
#: databases and some situations, this is different than the schema you get
#: when you actually execute the query.
catalog_schema_path: Path | None = None

def setup_query(self) -> str | None:
if not self.setup_query_path:
Expand Down Expand Up @@ -434,6 +438,11 @@ def query(self) -> str:
def expected_schema(self) -> pyarrow.Schema:
return try_txtcase(self.expected_schema_path, query_schema, ["expected_schema"])

def catalog_schema(self) -> pyarrow.Schema:
if not self.catalog_schema_path:
return self.expected_schema()
return try_txtcase(self.catalog_schema_path, query_schema, ["catalog_schema"])

def expected_result(self) -> pyarrow.Table:
return try_txtcase(
self.expected_path, query_table, ["expected"], self.expected_schema()
Expand Down Expand Up @@ -548,6 +557,8 @@ def merge(
"expected_schema_path": parent.query.expected_schema_path,
"expected_path": parent.query.expected_path,
}
if parent.query.catalog_schema_path:
params["catalog_schema_path"] = parent.query.catalog_schema_path
if parent.query.setup_query_path:
params["setup_query_path"] = parent.query.setup_query_path
# TODO: we also want to test with ExecuteQuery so perhaps
Expand Down Expand Up @@ -588,6 +599,7 @@ def merge(
in {
"query_path",
"expected_schema_path",
"catalog_schema_path",
"expected_path",
"setup_query_path",
"bind_query_path",
Expand Down
9 changes: 4 additions & 5 deletions adbc_drivers_validation/tests/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,8 @@ def test_execute_schema(
) -> None:
subquery = query.query
assert isinstance(subquery, model.SelectQuery)

sql = subquery.query()
expected_schema = subquery.expected_schema()
expected_schema = subquery.catalog_schema()

_setup_query(driver, conn, query)

Expand All @@ -142,16 +141,16 @@ def test_get_table_schema(
query: model.Query,
) -> None:
subquery = query.query

expected_schema = subquery.expected_schema()
assert isinstance(subquery, model.SelectQuery)
expected_schema = subquery.catalog_schema()

with setup_connection(query, conn):
_setup_query(driver, conn, query)

table_name = None
md = query.metadata()
table_name = md.setup.drop
if not table_name and isinstance(subquery, model.SelectQuery):
if not table_name:
# XXX: rather hacky, but extract the table name from the SELECT query
# that would normally be executed
query_str = subquery.query().split()
Expand Down
2 changes: 1 addition & 1 deletion adbc_drivers_validation/txtcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_part(self, part: str, schema: pyarrow.Schema | None = None) -> typing.An

if part == "metadata":
return tomllib.loads(value)
if part in {"bind_schema", "expected_schema", "input_schema"}:
if part in {"bind_schema", "expected_schema", "catalog_schema", "input_schema"}:
return arrowjson.loads_schema(value)
elif part in {"bind_query", "query", "setup_query"}:
return value
Expand Down
Loading