Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def __repr__(self) -> str:
def ref(self) -> BoundReference[L]:
return self

def __hash__(self) -> int:
"""Return hash value of the Record class."""
Comment thread
jqin61 marked this conversation as resolved.
Outdated
return hash(str(self))


class UnboundTerm(Term[Any], Unbound[BoundTerm[L]], ABC):
"""Represents an unbound term."""
Expand Down
120 changes: 115 additions & 5 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@

from pyiceberg.conversions import to_bytes
from pyiceberg.exceptions import ResolveError
from pyiceberg.expressions import (
AlwaysTrue,
BooleanExpression,
BoundTerm,
)
from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNull, BoundReference, BoundTerm, Not, Or
from pyiceberg.expressions.literals import Literal
from pyiceberg.expressions.visitors import (
BoundBooleanExpressionVisitor,
Expand Down Expand Up @@ -638,10 +634,124 @@ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> p
return left_result | right_result


class _CollectIsValidPredicatesFromExpression(BoundBooleanExpressionVisitor[Any]):
def __init__(self) -> None:
# BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr.
self.is_valid_or_not_bound_terms: set[BoundTerm[Any]] = set()
# The remaining BoundTerms appearing in the boolean expr.
self.null_unmentioned_bound_terms: set[BoundTerm[Any]] = set()
super().__init__()
Comment thread
jqin61 marked this conversation as resolved.
Outdated

def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where either is_null or is_not_null is included."""
if term in self.null_unmentioned_bound_terms:
self.null_unmentioned_bound_terms.remove(term)
self.is_valid_or_not_bound_terms.add(term)

def _handle_skipped(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where neither is_null or is_not_null is included."""
if term not in self.is_valid_or_not_bound_terms:
self.null_unmentioned_bound_terms.add(term)

def visit_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> None:
self._handle_skipped(term)

def visit_not_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> None:
self._handle_skipped(term)

# todo: do I have to modify this as well
def visit_is_nan(self, term: BoundTerm[Any]) -> None:
self._handle_skipped(term)

# todo: do I have to modify this as well, might need 2 self.xx sets for mentioned_nan and none-mentioned-nan
def visit_not_nan(self, term: BoundTerm[Any]) -> None:
self._handle_skipped(term)

def visit_is_null(self, term: BoundTerm[Any]) -> None:
self._handle_explicit_is_null_or_not(term)

def visit_not_null(self, term: BoundTerm[Any]) -> None:
self._handle_explicit_is_null_or_not(term)

def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_skipped(term)

def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_skipped(term)

def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_skipped(term)

def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_skipped(term)

def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_skipped(term)

def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_skipped(term)

def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_skipped(term)

def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_skipped(term)

def visit_true(self) -> None:
return

def visit_false(self) -> None:
return

def visit_not(self, child_result: pc.Expression) -> None:
return

def visit_and(self, left_result: pc.Expression, right_result: pc.Expression) -> None:
return

def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> None:
Comment thread
jqin61 marked this conversation as resolved.
Outdated
return


def _get_is_valid_or_not_bound_refs(expr: BooleanExpression) -> tuple[Set[BoundReference[Any]], Set[BoundReference[Any]]]:
"""Collect the bound terms catogorized by having at least one is_null or is_not_null in the expr and the remaining."""
Comment thread
jqin61 marked this conversation as resolved.
Outdated
collector = _CollectIsValidPredicatesFromExpression()
boolean_expression_visit(expr, collector)
null_unmentioned_bound_terms = collector.null_unmentioned_bound_terms
is_valid_or_not_bound_terms = collector.is_valid_or_not_bound_terms

null_unmentioned_bound_refs: Set[BoundReference[Any]] = set()
is_valid_or_not_bound_refs: Set[BoundReference[Any]] = set()
for t in null_unmentioned_bound_terms:
if not isinstance(t, BoundReference):
raise ValueError("Collected Bound Term that is not reference.")
else:
null_unmentioned_bound_refs.add(t)
for t in is_valid_or_not_bound_terms:
if not isinstance(t, BoundReference):
raise ValueError("Collected Bound Term that is not reference.")
else:
is_valid_or_not_bound_refs.add(t)
return null_unmentioned_bound_refs, is_valid_or_not_bound_refs


def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
return boolean_expression_visit(expr, _ConvertToArrowExpression())


def expression_to_reverted_pyarrow(expr: BooleanExpression) -> pc.Expression:
"""Complimentary filter convertion function of expression_to_pyarrow.
Comment thread
jqin61 marked this conversation as resolved.
Outdated

Could not use expression_to_pyarrow(Not(expr)) to achieve this effect because ~ in pc.Expression does not handle null.
"""
null_unmentioned_bound_terms: set[BoundReference[Any]] = _get_is_valid_or_not_bound_refs(expr)[0]
preserver_expr: BooleanExpression = Not(expr)

for term in null_unmentioned_bound_terms:
preserver_expr = Or(preserver_expr, BoundIsNull(term=term))
return expression_to_pyarrow(preserver_expr)


@lru_cache
def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
if file_format == FileFormat.PARQUET:
Expand Down
5 changes: 2 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
And,
BooleanExpression,
EqualTo,
Not,
Or,
Reference,
)
Expand Down Expand Up @@ -576,7 +575,7 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
delete_filter: A boolean expression to delete rows from a table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_reverted_pyarrow, project_table

if (
self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT)
Expand All @@ -593,7 +592,7 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
# Check if there are any files that require an actual rewrite of a data file
if delete_snapshot.rewrites_needed is True:
bound_delete_filter = bind(self._table.schema(), delete_filter, case_sensitive=True)
preserve_row_filter = expression_to_pyarrow(Not(bound_delete_filter))
preserve_row_filter = expression_to_reverted_pyarrow(bound_delete_filter)

files = self._scan(row_filter=delete_filter).plan_files()

Expand Down
75 changes: 75 additions & 0 deletions tests/integration/test_deletes.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,41 @@ def test_partitioned_table_rewrite(spark: SparkSession, session_catalog: RestCat
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [30, 30]}


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_partitioned_table_rewrite_with_null(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
identifier = "default.table_partitioned_delete"

run_spark_commands(
spark,
[
f"DROP TABLE IF EXISTS {identifier}",
f"""
CREATE TABLE {identifier} (
number_partitioned int,
number int
)
USING iceberg
PARTITIONED BY (number_partitioned)
TBLPROPERTIES('format-version' = {format_version})
""",
f"""
INSERT INTO {identifier} VALUES (10, 20), (10, 30)
""",
f"""
INSERT INTO {identifier} VALUES (11, 20), (11, NULL)
""",
],
)

tbl = session_catalog.load_table(identifier)
tbl.delete(EqualTo("number", 20))

# We don't delete a whole partition, so there is only a overwrite
assert [snapshot.summary.operation.value for snapshot in tbl.snapshots()] == ["append", "append", "overwrite"]
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [None, 30]}


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_partitioned_table_no_match(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
Expand Down Expand Up @@ -417,3 +452,43 @@ def test_delete_truncate(session_catalog: RestCatalog) -> None:
assert len(entries) == 1

assert entries[0].status == ManifestEntryStatus.DELETED


@pytest.mark.integration
def test_delete_overwrite_with_null(session_catalog: RestCatalog) -> None:
arrow_schema = pa.schema([pa.field("ints", pa.int32())])
arrow_tbl = pa.Table.from_pylist(
[{"ints": 1}, {"ints": 2}, {"ints": None}],
schema=arrow_schema,
)

iceberg_schema = Schema(NestedField(1, "ints", IntegerType()))

tbl_identifier = "default.test_delete_overwrite_with_null"

try:
session_catalog.drop_table(tbl_identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(tbl_identifier, iceberg_schema)
tbl.append(arrow_tbl)
Comment thread
jqin61 marked this conversation as resolved.

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [Operation.APPEND]

arrow_tbl_overwrite = pa.Table.from_pylist(
[
{"ints": 3},
{"ints": 4},
],
schema=arrow_schema,
)
tbl.overwrite(arrow_tbl_overwrite, "ints == 2") # Should rewrite one file

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [
Operation.APPEND,
Operation.OVERWRITE,
Operation.APPEND,
]

assert tbl.scan().to_arrow()["ints"].to_pylist() == [3, 4, 1, None]
25 changes: 23 additions & 2 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=protected-access,unused-argument,redefined-outer-name

Comment thread
jqin61 marked this conversation as resolved.
import os
import tempfile
import uuid
Expand Down Expand Up @@ -69,12 +68,13 @@
_to_requested_schema,
bin_pack_arrow_table,
expression_to_pyarrow,
expression_to_reverted_pyarrow,
project_table,
schema_to_pyarrow,
)
from pyiceberg.manifest import DataFile, DataFileContent, FileFormat
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema, make_compatible_name, visit
from pyiceberg.schema import Accessor, Schema, make_compatible_name, visit
from pyiceberg.table import FileScanTask, TableProperties
from pyiceberg.table.metadata import TableMetadataV2
from pyiceberg.transforms import IdentityTransform
Expand Down Expand Up @@ -725,6 +725,27 @@ def test_always_false_to_pyarrow(bound_reference: BoundReference[str]) -> None:
assert repr(expression_to_pyarrow(AlwaysFalse())) == "<pyarrow.compute.Expression false>"


def test_revert_expression_to_pyarrow() -> None:
bound_reference_str = BoundReference(
field=NestedField(1, "field_str", StringType(), required=False), accessor=Accessor(position=0, inner=None)
)
bound_eq_str_field = BoundEqualTo(term=bound_reference_str, literal=literal("hello"))

bound_reference_long = BoundReference(
field=NestedField(1, "field_long", LongType(), required=False), accessor=Accessor(position=1, inner=None)
)
bound_larger_than_long_field = BoundGreaterThan(term=bound_reference_long, literal=literal(100)) # type: ignore

bound_is_null_long_field = BoundIsNull(bound_reference_long)

bound_expr = Or(And(bound_eq_str_field, bound_larger_than_long_field), bound_is_null_long_field)
result = expression_to_reverted_pyarrow(bound_expr)
assert (
repr(result)
== """<pyarrow.compute.Expression (invert((((field_str == "hello") and (field_long > 100)) or is_null(field_long, {nan_is_null=false}))) or is_null(field_str, {nan_is_null=false}))>"""
)


@pytest.fixture
def schema_int() -> Schema:
return Schema(NestedField(1, "id", IntegerType(), required=False))
Expand Down
66 changes: 65 additions & 1 deletion tests/io/test_pyarrow_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,27 @@
import pyarrow as pa
import pytest

from pyiceberg.expressions import (
And,
BoundEqualTo,
BoundGreaterThan,
BoundIsNull,
BoundReference,
Or,
)
from pyiceberg.expressions.literals import LongLiteral, literal
from pyiceberg.io.pyarrow import (
_ConvertToArrowSchema,
_ConvertToIceberg,
_ConvertToIcebergWithoutIDs,
_get_is_valid_or_not_bound_refs,
_HasIds,
_pyarrow_schema_ensure_large_types,
pyarrow_to_schema,
schema_to_pyarrow,
visit_pyarrow,
)
from pyiceberg.schema import Schema, visit
from pyiceberg.schema import Accessor, Schema, visit
from pyiceberg.table.name_mapping import MappedField, NameMapping
from pyiceberg.types import (
BinaryType,
Expand Down Expand Up @@ -580,3 +590,57 @@ def test_pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids: pa
),
])
assert _pyarrow_schema_ensure_large_types(pyarrow_schema_nested_without_ids) == expected_schema


@pytest.fixture
def bound_reference_long() -> BoundReference[int]:
return BoundReference(field=NestedField(1, "field", LongType(), required=False), accessor=Accessor(position=0, inner=None))


def test_collect_null_mentioned_terms() -> None:
bound_reference_str = BoundReference(
field=NestedField(1, "field_str", StringType(), required=False), accessor=Accessor(position=0, inner=None)
)
bound_eq_str_field = BoundEqualTo(term=bound_reference_str, literal=literal("hello"))

bound_reference_long = BoundReference(
field=NestedField(2, "field_long", LongType(), required=False), accessor=Accessor(position=1, inner=None)
)
bound_larger_than_long_field = BoundGreaterThan(term=bound_reference_long, literal=literal(100)) # type: ignore

bound_reference_bool = BoundReference(
field=NestedField(3, "field_bool", BooleanType(), required=False), accessor=Accessor(position=2, inner=None)
)
bound_is_null_bool_field = BoundIsNull(bound_reference_bool)

bound_expr = Or(And(bound_eq_str_field, bound_larger_than_long_field), bound_is_null_bool_field)

categorized_terms = _get_is_valid_or_not_bound_refs(bound_expr)
assert {"field_long", "field_str"} == {f.field.name for f in categorized_terms[0]}
assert {
"field_bool",
} == {f.field.name for f in categorized_terms[1]}


def test_collect_null_mentioned_terms_with_multiple_predicates_on_the_same_term() -> None:
"""Test a single term appears multiple places in the expression tree"""
bound_reference_str = BoundReference(
field=NestedField(1, "field_str", StringType(), required=False), accessor=Accessor(position=0, inner=None)
)
bound_eq_str_field = BoundEqualTo(term=bound_reference_str, literal=literal("hello"))

bound_reference_long = BoundReference(
field=NestedField(1, "field_long", LongType(), required=False), accessor=Accessor(position=1, inner=None)
)
bound_larger_than_long_field = BoundGreaterThan(term=bound_reference_long, literal=literal(100)) # type: ignore

bound_is_null_long_field = BoundIsNull(bound_reference_long)

bound_expr = Or(
And(Or(And(bound_eq_str_field, bound_larger_than_long_field), bound_is_null_long_field), bound_larger_than_long_field),
bound_eq_str_field,
)

categorized_terms = _get_is_valid_or_not_bound_refs(bound_expr)
assert {"field_str"} == set({f.field.name for f in categorized_terms[0]})
Comment thread
jqin61 marked this conversation as resolved.
Outdated
assert {"field_long"} == set({f.field.name for f in categorized_terms[1]})