diff --git a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py index 23dd603df..926ef9aaf 100644 --- a/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py +++ b/integrations/python/dataloader/src/openhouse/dataloader/data_loader.py @@ -96,6 +96,10 @@ def __init__( rows. Smaller values reduce peak memory but increase per-batch overhead. None uses the PyArrow default (~131K rows). """ + if branch is not None and branch.strip() == "": + raise ValueError("branch must not be empty or whitespace") + if branch is not None and snapshot_id is not None: + raise ValueError("Cannot specify both branch and snapshot_id") self._catalog = catalog self._table_id = TableIdentifier(database, table, branch) self._snapshot_id = snapshot_id @@ -121,7 +125,17 @@ def table_properties(self) -> Mapping[str, str]: @cached_property def snapshot_id(self) -> int | None: """Snapshot ID of the loaded table, or None if the table has no snapshots""" - return self._snapshot_id if self._snapshot_id is not None else self._iceberg_table.metadata.current_snapshot_id + if self._snapshot_id is not None: + return self._snapshot_id + if self._table_id.branch: + snapshot = self._iceberg_table.snapshot_by_name(self._table_id.branch) + if snapshot is None: + raise ValueError( + f"Branch '{self._table_id.branch}' not found for table " + f"{self._table_id.database}.{self._table_id.table}" + ) + return snapshot.snapshot_id + return self._iceberg_table.metadata.current_snapshot_id def _verify_snapshot(self, snapshot: Snapshot | None) -> None: """Log the resolved snapshot or raise if a user-provided snapshot_id was not found.""" diff --git a/integrations/python/dataloader/tests/test_data_loader.py b/integrations/python/dataloader/tests/test_data_loader.py index 7862a7f35..4919c5c4c 100644 --- a/integrations/python/dataloader/tests/test_data_loader.py +++ b/integrations/python/dataloader/tests/test_data_loader.py @@ -52,9 +52,9 @@ def test_package_imports(): } -def _write_parquet(tmp_path, data: dict) -> str: +def _write_parquet(tmp_path, data: dict, filename: str = "test.parquet") -> str: """Write a Parquet file with Iceberg field IDs in column metadata.""" - file_path = str(tmp_path / "test.parquet") + file_path = str(tmp_path / filename) table = pa.table(data) fields = [field.with_metadata({b"PARQUET:field_id": str(i + 1).encode()}) for i, field in enumerate(table.schema)] pq.write_table(table.cast(pa.schema(fields)), file_path) @@ -324,6 +324,77 @@ def test_snapshot_id_with_columns_and_filters(tmp_path): assert "row_filter" in scan_kwargs +# --- branch tests --- + + +def test_branch_and_snapshot_id_raises(): + """ValueError is raised when both branch and snapshot_id are provided.""" + catalog = MagicMock() + + with pytest.raises(ValueError, match="Cannot specify both branch and snapshot_id"): + OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="b", snapshot_id=42) + + +def test_branch_snapshot_id_resolves(): + """snapshot_id property resolves via snapshot_by_name when branch is set.""" + catalog = MagicMock() + mock_snapshot = MagicMock() + mock_snapshot.snapshot_id = 123 + catalog.load_table.return_value.snapshot_by_name.side_effect = ( + lambda name: mock_snapshot if name == "my-branch" else None + ) + + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="my-branch") + + assert loader.snapshot_id == 123 + + +def test_branch_snapshot_id_not_found_raises(): + """ValueError is raised when branch does not exist in table metadata.""" + catalog = MagicMock() + catalog.load_table.return_value.snapshot_by_name.side_effect = lambda name: None + + loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="missing") + + with pytest.raises(ValueError, match="Branch 'missing' not found"): + _ = loader.snapshot_id + + +def test_branch_reads_data_from_branch_snapshot(): + """Branch splits come from the branch snapshot, not the main snapshot.""" + catalog = MagicMock() + + main_task = MagicMock() + main_task.file.file_path = "main.parquet" + branch_task = MagicMock() + branch_task.file.file_path = "branch.parquet" + + branch_snapshot_id = 200 + + def fake_scan(**kwargs): + task = branch_task if kwargs.get("snapshot_id") == branch_snapshot_id else main_task + scan = MagicMock() + scan.plan_files.return_value = [task] + return scan + + mock_snapshot = MagicMock() + mock_snapshot.snapshot_id = branch_snapshot_id + + mock_table = catalog.load_table.return_value + mock_table.scan.side_effect = fake_scan + mock_table.snapshot_by_name.side_effect = lambda name: mock_snapshot if name == "my-branch" else None + + # Without branch: splits come from main snapshot + main_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl")) + assert len(main_splits) == 1 + assert main_splits[0]._file_scan_task.file.file_path == "main.parquet" + + # With branch: splits come from branch snapshot + branch_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="my-branch")) + assert len(branch_splits) == 1 + assert branch_splits[0]._file_scan_task.file.file_path == "branch.parquet" + + # --- batch_size tests ---