Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def __init__(
rows. Smaller values reduce peak memory but increase per-batch overhead.
None uses the PyArrow default (~131K rows).
"""
if branch is not None and branch.strip() == "":
raise ValueError("branch must not be empty or whitespace")
if branch is not None and snapshot_id is not None:
raise ValueError("Cannot specify both branch and snapshot_id")
self._catalog = catalog
self._table_id = TableIdentifier(database, table, branch)
self._snapshot_id = snapshot_id
Expand All @@ -121,7 +125,17 @@ def table_properties(self) -> Mapping[str, str]:
@cached_property
def snapshot_id(self) -> int | None:
"""Snapshot ID of the loaded table, or None if the table has no snapshots"""
return self._snapshot_id if self._snapshot_id is not None else self._iceberg_table.metadata.current_snapshot_id
if self._snapshot_id is not None:
return self._snapshot_id
if self._table_id.branch:
snapshot = self._iceberg_table.snapshot_by_name(self._table_id.branch)
if snapshot is None:
raise ValueError(
f"Branch '{self._table_id.branch}' not found for table "
f"{self._table_id.database}.{self._table_id.table}"
)
return snapshot.snapshot_id
return self._iceberg_table.metadata.current_snapshot_id

def _verify_snapshot(self, snapshot: Snapshot | None) -> None:
"""Log the resolved snapshot or raise if a user-provided snapshot_id was not found."""
Expand Down
75 changes: 73 additions & 2 deletions integrations/python/dataloader/tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def test_package_imports():
}


def _write_parquet(tmp_path, data: dict) -> str:
def _write_parquet(tmp_path, data: dict, filename: str = "test.parquet") -> str:
"""Write a Parquet file with Iceberg field IDs in column metadata."""
file_path = str(tmp_path / "test.parquet")
file_path = str(tmp_path / filename)
table = pa.table(data)
fields = [field.with_metadata({b"PARQUET:field_id": str(i + 1).encode()}) for i, field in enumerate(table.schema)]
pq.write_table(table.cast(pa.schema(fields)), file_path)
Expand Down Expand Up @@ -324,6 +324,77 @@ def test_snapshot_id_with_columns_and_filters(tmp_path):
assert "row_filter" in scan_kwargs


# --- branch tests ---


def test_branch_and_snapshot_id_raises():
"""ValueError is raised when both branch and snapshot_id are provided."""
catalog = MagicMock()

with pytest.raises(ValueError, match="Cannot specify both branch and snapshot_id"):
OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="b", snapshot_id=42)


def test_branch_snapshot_id_resolves():
"""snapshot_id property resolves via snapshot_by_name when branch is set."""
catalog = MagicMock()
mock_snapshot = MagicMock()
mock_snapshot.snapshot_id = 123
catalog.load_table.return_value.snapshot_by_name.side_effect = (
lambda name: mock_snapshot if name == "my-branch" else None
)

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="my-branch")

assert loader.snapshot_id == 123


def test_branch_snapshot_id_not_found_raises():
"""ValueError is raised when branch does not exist in table metadata."""
catalog = MagicMock()
catalog.load_table.return_value.snapshot_by_name.side_effect = lambda name: None

loader = OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="missing")

with pytest.raises(ValueError, match="Branch 'missing' not found"):
_ = loader.snapshot_id


def test_branch_reads_data_from_branch_snapshot():
"""Branch splits come from the branch snapshot, not the main snapshot."""
catalog = MagicMock()

main_task = MagicMock()
main_task.file.file_path = "main.parquet"
branch_task = MagicMock()
branch_task.file.file_path = "branch.parquet"

branch_snapshot_id = 200

def fake_scan(**kwargs):
task = branch_task if kwargs.get("snapshot_id") == branch_snapshot_id else main_task
scan = MagicMock()
scan.plan_files.return_value = [task]
return scan

mock_snapshot = MagicMock()
mock_snapshot.snapshot_id = branch_snapshot_id

mock_table = catalog.load_table.return_value
mock_table.scan.side_effect = fake_scan
mock_table.snapshot_by_name.side_effect = lambda name: mock_snapshot if name == "my-branch" else None

# Without branch: splits come from main snapshot
main_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl"))
assert len(main_splits) == 1
assert main_splits[0]._file_scan_task.file.file_path == "main.parquet"

# With branch: splits come from branch snapshot
branch_splits = list(OpenHouseDataLoader(catalog=catalog, database="db", table="tbl", branch="my-branch"))
assert len(branch_splits) == 1
assert branch_splits[0]._file_scan_task.file.file_path == "branch.parquet"


# --- batch_size tests ---


Expand Down