Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyiceberg/expressions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def __repr__(self) -> str:
def ref(self) -> BoundReference[L]:
return self

def __hash__(self) -> int:
"""Return hash value of the BoundReference class."""
return hash(str(self))


class UnboundTerm(Term[Any], Unbound[BoundTerm[L]], ABC):
"""Represents an unbound term."""
Expand Down
152 changes: 145 additions & 7 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@

from pyiceberg.conversions import to_bytes
from pyiceberg.exceptions import ResolveError
from pyiceberg.expressions import (
AlwaysTrue,
BooleanExpression,
BoundTerm,
)
from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNaN, BoundIsNull, BoundTerm, Not, Or
from pyiceberg.expressions.literals import Literal
from pyiceberg.expressions.visitors import (
BoundBooleanExpressionVisitor,
Expand Down Expand Up @@ -576,11 +572,11 @@ def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar:


class _ConvertToArrowExpression(BoundBooleanExpressionVisitor[pc.Expression]):
def visit_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression:
def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return pc.field(term.ref().field.name).isin(pyarrow_literals)

def visit_not_in(self, term: BoundTerm[pc.Expression], literals: Set[Any]) -> pc.Expression:
def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> pc.Expression:
pyarrow_literals = pa.array(literals, type=schema_to_pyarrow(term.ref().field.field_type))
return ~pc.field(term.ref().field.name).isin(pyarrow_literals)

Expand Down Expand Up @@ -638,10 +634,152 @@ def visit_or(self, left_result: pc.Expression, right_result: pc.Expression) -> p
return left_result | right_result


class _NullNaNUnmentionedTermsCollector(BoundBooleanExpressionVisitor[None]):
# BoundTerms which have either is_null or is_not_null appearing at least once in the boolean expr.
is_null_or_not_bound_terms: set[BoundTerm[Any]]
# The remaining BoundTerms appearing in the boolean expr.
null_unmentioned_bound_terms: set[BoundTerm[Any]]
# BoundTerms which have either is_nan or is_not_nan appearing at least once in the boolean expr.
is_nan_or_not_bound_terms: set[BoundTerm[Any]]
# The remaining BoundTerms appearing in the boolean expr.
nan_unmentioned_bound_terms: set[BoundTerm[Any]]

def __init__(self) -> None:
super().__init__()
self.is_null_or_not_bound_terms = set()
self.null_unmentioned_bound_terms = set()
self.is_nan_or_not_bound_terms = set()
self.nan_unmentioned_bound_terms = set()

def _handle_explicit_is_null_or_not(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where either is_null or is_not_null is included."""
if term in self.null_unmentioned_bound_terms:
self.null_unmentioned_bound_terms.remove(term)
self.is_null_or_not_bound_terms.add(term)

def _handle_null_unmentioned(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where neither is_null or is_not_null is included."""
if term not in self.is_null_or_not_bound_terms:
self.null_unmentioned_bound_terms.add(term)

def _handle_explicit_is_nan_or_not(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where either is_nan or is_not_nan is included."""
if term in self.nan_unmentioned_bound_terms:
self.nan_unmentioned_bound_terms.remove(term)
self.is_nan_or_not_bound_terms.add(term)

def _handle_nan_unmentioned(self, term: BoundTerm[Any]) -> None:
"""Handle the predicate case where neither is_nan or is_not_nan is included."""
if term not in self.is_nan_or_not_bound_terms:
self.nan_unmentioned_bound_terms.add(term)

def visit_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_not_in(self, term: BoundTerm[Any], literals: Set[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_is_nan(self, term: BoundTerm[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_explicit_is_nan_or_not(term)

def visit_not_nan(self, term: BoundTerm[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_explicit_is_nan_or_not(term)

def visit_is_null(self, term: BoundTerm[Any]) -> None:
self._handle_explicit_is_null_or_not(term)
self._handle_nan_unmentioned(term)

def visit_not_null(self, term: BoundTerm[Any]) -> None:
self._handle_explicit_is_null_or_not(term)
self._handle_nan_unmentioned(term)

def visit_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_not_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_greater_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_greater_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_less_than(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_less_than_or_equal(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_not_starts_with(self, term: BoundTerm[Any], literal: Literal[Any]) -> None:
self._handle_null_unmentioned(term)
self._handle_nan_unmentioned(term)

def visit_true(self) -> None:
return

def visit_false(self) -> None:
return

def visit_not(self, child_result: None) -> None:
return

def visit_and(self, left_result: None, right_result: None) -> None:
return

def visit_or(self, left_result: None, right_result: None) -> None:
return

def collect(
self,
expr: BooleanExpression,
) -> None:
"""Collect the bound references categorized by having at least one is_null or is_not_null in the expr and the remaining."""
boolean_expression_visit(expr, self)


def expression_to_pyarrow(expr: BooleanExpression) -> pc.Expression:
return boolean_expression_visit(expr, _ConvertToArrowExpression())


def _expression_to_complementary_pyarrow(expr: BooleanExpression) -> pc.Expression:
"""Complementary filter conversion function of expression_to_pyarrow.

Could not use expression_to_pyarrow(Not(expr)) to achieve this complementary effect because ~ in pyarrow.compute.Expression does not handle null.
"""
collector = _NullNaNUnmentionedTermsCollector()
collector.collect(expr)

# Convert the set of terms to a sorted list so that layout of the expression to build is deterministic.
null_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted(
collector.null_unmentioned_bound_terms, key=lambda term: term.ref().field.name
)
nan_unmentioned_bound_terms: List[BoundTerm[Any]] = sorted(
collector.nan_unmentioned_bound_terms, key=lambda term: term.ref().field.name
)

preserve_expr: BooleanExpression = Not(expr)
for term in null_unmentioned_bound_terms:
preserve_expr = Or(preserve_expr, BoundIsNull(term=term))
for term in nan_unmentioned_bound_terms:
preserve_expr = Or(preserve_expr, BoundIsNaN(term=term))
return expression_to_pyarrow(preserve_expr)


@lru_cache
def _get_file_format(file_format: FileFormat, **kwargs: Dict[str, Any]) -> ds.FileFormat:
if file_format == FileFormat.PARQUET:
Expand Down
9 changes: 6 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
And,
BooleanExpression,
EqualTo,
Not,
Or,
Reference,
)
Expand Down Expand Up @@ -576,7 +575,11 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
delete_filter: A boolean expression to delete rows from a table
snapshot_properties: Custom properties to be added to the snapshot summary
"""
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.io.pyarrow import (
_dataframe_to_data_files,
_expression_to_complementary_pyarrow,
project_table,
)

if (
self.table_metadata.properties.get(TableProperties.DELETE_MODE, TableProperties.DELETE_MODE_DEFAULT)
Expand All @@ -593,7 +596,7 @@ def delete(self, delete_filter: Union[str, BooleanExpression], snapshot_properti
# Check if there are any files that require an actual rewrite of a data file
if delete_snapshot.rewrites_needed is True:
bound_delete_filter = bind(self._table.schema(), delete_filter, case_sensitive=True)
preserve_row_filter = expression_to_pyarrow(Not(bound_delete_filter))
preserve_row_filter = _expression_to_complementary_pyarrow(bound_delete_filter)

files = self._scan(row_filter=delete_filter).plan_files()

Expand Down
138 changes: 137 additions & 1 deletion tests/integration/test_deletes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pyiceberg.manifest import ManifestEntryStatus
from pyiceberg.schema import Schema
from pyiceberg.table.snapshots import Operation, Summary
from pyiceberg.types import IntegerType, NestedField
from pyiceberg.types import FloatType, IntegerType, NestedField


def run_spark_commands(spark: SparkSession, sqls: List[str]) -> None:
Expand Down Expand Up @@ -105,6 +105,40 @@ def test_partitioned_table_rewrite(spark: SparkSession, session_catalog: RestCat
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [30, 30]}


@pytest.mark.parametrize("format_version", [1, 2])
def test_rewrite_partitioned_table_with_null(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
identifier = "default.table_partitioned_delete"

run_spark_commands(
spark,
[
f"DROP TABLE IF EXISTS {identifier}",
f"""
CREATE TABLE {identifier} (
number_partitioned int,
number int
)
USING iceberg
PARTITIONED BY (number_partitioned)
TBLPROPERTIES('format-version' = {format_version})
""",
f"""
INSERT INTO {identifier} VALUES (10, 20), (10, 30)
""",
f"""
INSERT INTO {identifier} VALUES (11, 20), (11, NULL)
""",
],
)

tbl = session_catalog.load_table(identifier)
tbl.delete(EqualTo("number", 20))

# We don't delete a whole partition, so there is only a overwrite
assert [snapshot.summary.operation.value for snapshot in tbl.snapshots()] == ["append", "append", "overwrite"]
assert tbl.scan().to_arrow().to_pydict() == {"number_partitioned": [11, 10], "number": [None, 30]}


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [1, 2])
def test_partitioned_table_no_match(spark: SparkSession, session_catalog: RestCatalog, format_version: int) -> None:
Expand Down Expand Up @@ -417,3 +451,105 @@ def test_delete_truncate(session_catalog: RestCatalog) -> None:
assert len(entries) == 1

assert entries[0].status == ManifestEntryStatus.DELETED


def test_delete_overwrite_table_with_null(session_catalog: RestCatalog) -> None:
arrow_schema = pa.schema([pa.field("ints", pa.int32())])
arrow_tbl = pa.Table.from_pylist(
[{"ints": 1}, {"ints": 2}, {"ints": None}],
schema=arrow_schema,
)

iceberg_schema = Schema(NestedField(1, "ints", IntegerType()))

tbl_identifier = "default.test_delete_overwrite_with_null"

try:
session_catalog.drop_table(tbl_identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(tbl_identifier, iceberg_schema)
tbl.append(arrow_tbl)

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [Operation.APPEND]

arrow_tbl_overwrite = pa.Table.from_pylist(
[
{"ints": 3},
{"ints": 4},
],
schema=arrow_schema,
)
tbl.overwrite(arrow_tbl_overwrite, "ints == 2") # Should rewrite one file

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [
Operation.APPEND,
Operation.OVERWRITE,
Operation.APPEND,
]

assert tbl.scan().to_arrow()["ints"].to_pylist() == [3, 4, 1, None]


def test_delete_overwrite_table_with_nan(session_catalog: RestCatalog) -> None:
arrow_schema = pa.schema([pa.field("floats", pa.float32())])

# Create Arrow Table with NaN values
data = [pa.array([1.0, float("nan"), 2.0], type=pa.float32())]
arrow_tbl = pa.Table.from_arrays(
data,
schema=arrow_schema,
)

iceberg_schema = Schema(NestedField(1, "floats", FloatType()))

tbl_identifier = "default.test_delete_overwrite_with_nan"

try:
session_catalog.drop_table(tbl_identifier)
except NoSuchTableError:
pass

tbl = session_catalog.create_table(tbl_identifier, iceberg_schema)
tbl.append(arrow_tbl)

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [Operation.APPEND]

arrow_tbl_overwrite = pa.Table.from_pylist(
[
{"floats": 3.0},
{"floats": 4.0},
],
schema=arrow_schema,
)
"""
We want to test the _expression_to_complementary_pyarrow function can generate a correct complimentary filter
for selecting records to remain in the new overwritten file.
Compared with test_delete_overwrite_table_with_null which tests rows with null cells,
nan testing is faced with a more tricky issue:
A filter of (field == value) will not include cells of nan but col != val will.
(Interestingly, neither == or != will include null)

This means if we set the test case as floats == 2.0 (equal predicate as in test_delete_overwrite_table_with_null),
test will pass even without the logic under test
in _NullNaNUnmentionedTermsCollector (a helper of _expression_to_complementary_pyarrow
to handle revert of iceberg expression of is_null/not_null/is_nan/not_nan).
Instead, we test the filter of !=, so that the revert is == which exposes the issue.
"""
tbl.overwrite(arrow_tbl_overwrite, "floats != 2.0") # Should rewrite one file

assert [snapshot.summary.operation for snapshot in tbl.snapshots()] == [
Operation.APPEND,
Operation.OVERWRITE,
Operation.APPEND,
]

result = tbl.scan().to_arrow()["floats"].to_pylist()

from math import isnan

assert any(isnan(e) for e in result)
assert 2.0 in result
assert 3.0 in result
assert 4.0 in result
Loading