diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index 481207db7a..da52d5df8e 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -387,7 +387,7 @@ def partition(self) -> Record: # partition key transformed with iceberg interna for raw_partition_field_value in self.raw_partition_field_values: partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id] if len(partition_fields) != 1: - raise ValueError("partition_fields must contain exactly one field.") + raise ValueError(f"Cannot have redundant partitions: {partition_fields}") partition_field = partition_fields[0] iceberg_typed_key_values[partition_field.name] = partition_record_value( partition_field=partition_field, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index aa108de08b..f160ab2441 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -392,10 +392,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") - supported_transforms = {IdentityTransform} - if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields): + if unsupported_partitions := [ + field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform + ]: raise ValueError( - f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}." + f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}." ) _check_schema_compatible(self._table.schema(), other_schema=df.schema) @@ -3643,33 +3644,6 @@ class TablePartition: arrow_table_partition: pa.Table -def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]: - order = "ascending" if not reverse else "descending" - null_placement = "at_start" if reverse else "at_end" - return {"sort_keys": [(column_name, order) for column_name in partition_columns], "null_placement": null_placement} - - -def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table: - """Given a table, sort it by current partition scheme.""" - # only works for identity for now - sort_options = _get_partition_sort_order(partition_columns, reverse=False) - sorted_arrow_table = arrow_table.sort_by(sorting=sort_options["sort_keys"], null_placement=sort_options["null_placement"]) - return sorted_arrow_table - - -def get_partition_columns( - spec: PartitionSpec, - schema: Schema, -) -> list[str]: - partition_cols = [] - for partition_field in spec.fields: - column_name = schema.find_column_name(partition_field.source_id) - if not column_name: - raise ValueError(f"{partition_field=} could not be found in {schema}.") - partition_cols.append(column_name) - return partition_cols - - def _get_table_partitions( arrow_table: pa.Table, partition_spec: PartitionSpec, @@ -3724,13 +3698,30 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T """ import pyarrow as pa - partition_columns = get_partition_columns(spec=spec, schema=schema) - arrow_table = group_by_partition_scheme(arrow_table, partition_columns) - - reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True) - reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist() - - slice_instructions: list[dict[str, Any]] = [] + partition_columns: List[Tuple[PartitionField, NestedField]] = [ + (partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields + ] + partition_values_table = pa.table({ + str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name]) + for partition, field in partition_columns + }) + + # Sort by partitions + sort_indices = pa.compute.sort_indices( + partition_values_table, + sort_keys=[(col, "ascending") for col in partition_values_table.column_names], + null_placement="at_end", + ).to_pylist() + arrow_table = arrow_table.take(sort_indices) + + # Get slice_instructions to group by partitions + partition_values_table = partition_values_table.take(sort_indices) + reversed_indices = pa.compute.sort_indices( + partition_values_table, + sort_keys=[(col, "descending") for col in partition_values_table.column_names], + null_placement="at_start", + ).to_pylist() + slice_instructions: List[Dict[str, Any]] = [] last = len(reversed_indices) reversed_indices_size = len(reversed_indices) ptr = 0 @@ -3741,6 +3732,6 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T last = reversed_indices[ptr] ptr = ptr + group_size - table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) + table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) return table_partitions diff --git a/pyiceberg/transforms.py b/pyiceberg/transforms.py index 6dcae59e49..38cc6221a2 100644 --- a/pyiceberg/transforms.py +++ b/pyiceberg/transforms.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from enum import IntEnum from functools import singledispatch -from typing import Any, Callable, Generic, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar from typing import Literal as LiteralType from uuid import UUID @@ -82,6 +82,9 @@ from pyiceberg.utils.parsing import ParseNumberFromBrackets from pyiceberg.utils.singleton import Singleton +if TYPE_CHECKING: + import pyarrow as pa + S = TypeVar("S") T = TypeVar("T") @@ -175,6 +178,13 @@ def __eq__(self, other: Any) -> bool: return self.root == other.root return False + @property + def supports_pyarrow_transform(self) -> bool: + return False + + @abstractmethod + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ... + class BucketTransform(Transform[S, int]): """Base Transform class to transform a value into a bucket partition value. @@ -290,6 +300,9 @@ def __repr__(self) -> str: """Return the string representation of the BucketTransform class.""" return f"BucketTransform(num_buckets={self._num_buckets})" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + class TimeResolution(IntEnum): YEAR = 6 @@ -349,6 +362,10 @@ def dedup_name(self) -> str: def preserves_order(self) -> bool: return True + @property + def supports_pyarrow_transform(self) -> bool: + return True + class YearTransform(TimeTransform[S]): """Transforms a datetime value into a year value. @@ -391,6 +408,21 @@ def __repr__(self) -> str: """Return the string representation of the YearTransform class.""" return "YearTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, DateType): + epoch = datetime.EPOCH_DATE + elif isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply year transform for type: {source}") + + return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not None else None + class MonthTransform(TimeTransform[S]): """Transforms a datetime value into a month value. @@ -433,6 +465,27 @@ def __repr__(self) -> str: """Return the string representation of the MonthTransform class.""" return "MonthTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, DateType): + epoch = datetime.EPOCH_DATE + elif isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply month transform for type: {source}") + + def month_func(v: pa.Array) -> pa.Array: + return pc.add( + pc.multiply(pc.years_between(pa.scalar(epoch), v), pa.scalar(12)), + pc.add(pc.month(v), pa.scalar(-1)), + ) + + return lambda v: month_func(v) if v is not None else None + class DayTransform(TimeTransform[S]): """Transforms a datetime value into a day value. @@ -478,6 +531,21 @@ def __repr__(self) -> str: """Return the string representation of the DayTransform class.""" return "DayTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, DateType): + epoch = datetime.EPOCH_DATE + elif isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply day transform for type: {source}") + + return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None else None + class HourTransform(TimeTransform[S]): """Transforms a datetime value into a hour value. @@ -515,6 +583,19 @@ def __repr__(self) -> str: """Return the string representation of the HourTransform class.""" return "HourTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + import pyarrow as pa + import pyarrow.compute as pc + + if isinstance(source, TimestampType): + epoch = datetime.EPOCH_TIMESTAMP + elif isinstance(source, TimestamptzType): + epoch = datetime.EPOCH_TIMESTAMPTZ + else: + raise ValueError(f"Cannot apply hour transform for type: {source}") + + return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None + def _base64encode(buffer: bytes) -> str: """Convert bytes to base64 string.""" @@ -585,6 +666,13 @@ def __repr__(self) -> str: """Return the string representation of the IdentityTransform class.""" return "IdentityTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + return lambda v: v + + @property + def supports_pyarrow_transform(self) -> bool: + return True + class TruncateTransform(Transform[S, S]): """A transform for truncating a value to a specified width. @@ -725,6 +813,9 @@ def __repr__(self) -> str: """Return the string representation of the TruncateTransform class.""" return f"TruncateTransform(width={self._width})" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + @singledispatch def _human_string(value: Any, _type: IcebergType) -> str: @@ -807,6 +898,9 @@ def __repr__(self) -> str: """Return the string representation of the UnknownTransform class.""" return f"UnknownTransform(transform={repr(self._transform)})" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + class VoidTransform(Transform[S, None], Singleton): """A transform that always returns None.""" @@ -835,6 +929,9 @@ def __repr__(self) -> str: """Return the string representation of the VoidTransform class.""" return "VoidTransform()" + def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": + raise NotImplementedError() + def _truncate_number( name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]] diff --git a/tests/conftest.py b/tests/conftest.py index 01915b7d82..d3f23689a2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2158,3 +2158,46 @@ def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table": import pyarrow as pa return pa.Table.from_pylist([{}, {}], schema=pa_schema) + + +@pytest.fixture(scope="session") +def arrow_table_date_timestamps() -> "pa.Table": + """Pyarrow table with only date, timestamp and timestamptz values.""" + import pyarrow as pa + + return pa.Table.from_pydict( + { + "date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None], + "timestamp": [ + datetime(2023, 12, 31, 0, 0, 0), + datetime(2024, 1, 1, 0, 0, 0), + datetime(2024, 1, 31, 0, 0, 0), + datetime(2024, 2, 1, 0, 0, 0), + datetime(2024, 2, 1, 6, 0, 0), + None, + ], + "timestamptz": [ + datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc), + datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc), + None, + ], + }, + schema=pa.schema([ + ("date", pa.date32()), + ("timestamp", pa.timestamp(unit="us")), + ("timestamptz", pa.timestamp(unit="us", tz="UTC")), + ]), + ) + + +@pytest.fixture(scope="session") +def arrow_table_date_timestamps_schema() -> Schema: + """Pyarrow table Schema with only date, timestamp and timestamptz values.""" + return Schema( + NestedField(field_id=1, name="date", field_type=DateType(), required=False), + NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False), + NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False), + ) diff --git a/tests/integration/test_writes/test_partitioned_writes.py b/tests/integration/test_writes/test_partitioned_writes.py index 5cb03e59d8..76d559ca57 100644 --- a/tests/integration/test_writes/test_partitioned_writes.py +++ b/tests/integration/test_writes/test_partitioned_writes.py @@ -16,6 +16,10 @@ # under the License. # pylint:disable=redefined-outer-name + +from datetime import date +from typing import Any, Set + import pyarrow as pa import pytest from pyspark.sql import SparkSession @@ -23,12 +27,14 @@ from pyiceberg.catalog import Catalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema from pyiceberg.transforms import ( BucketTransform, DayTransform, HourTransform, IdentityTransform, MonthTransform, + Transform, TruncateTransform, YearTransform, ) @@ -351,18 +357,6 @@ def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog) -> Non (PartitionSpec(PartitionField(source_id=5, field_id=1001, transform=TruncateTransform(2), name="long_trunc"))), (PartitionSpec(PartitionField(source_id=2, field_id=1001, transform=TruncateTransform(2), name="string_trunc"))), (PartitionSpec(PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(2), name="binary_trunc"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=YearTransform(), name="timestamp_year"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=YearTransform(), name="timestamptz_year"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=YearTransform(), name="date_year"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=MonthTransform(), name="timestamp_month"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=MonthTransform(), name="timestamptz_month"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=MonthTransform(), name="date_month"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=DayTransform(), name="timestamp_day"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=DayTransform(), name="timestamptz_day"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=DayTransform(), name="date_day"))), - (PartitionSpec(PartitionField(source_id=8, field_id=1001, transform=HourTransform(), name="timestamp_hour"))), - (PartitionSpec(PartitionField(source_id=9, field_id=1001, transform=HourTransform(), name="timestamptz_hour"))), - (PartitionSpec(PartitionField(source_id=10, field_id=1001, transform=HourTransform(), name="date_hour"))), ], ) def test_unsupported_transform( @@ -382,5 +376,186 @@ def test_unsupported_transform( properties={"format-version": "1"}, ) - with pytest.raises(ValueError, match="All transforms are not supported.*"): + with pytest.raises( + ValueError, + match="Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: *", + ): tbl.append(arrow_table_with_null) + + +@pytest.mark.integration +@pytest.mark.parametrize( + "transform,expected_rows", + [ + pytest.param(YearTransform(), 2, id="year_transform"), + pytest.param(MonthTransform(), 3, id="month_transform"), + pytest.param(DayTransform(), 3, id="day_transform"), + ], +) +@pytest.mark.parametrize("part_col", ["date", "timestamp", "timestamptz"]) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_ymd_transform_partitioned( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_with_null: pa.Table, + transform: Transform[Any, Any], + expected_rows: int, + part_col: str, + format_version: int, +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_partition_on_col_{part_col}" + nested_field = TABLE_SCHEMA.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col) + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_with_null], + partition_spec=partition_spec, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 3, f"Expected 3 total rows for {identifier}" + for col in TEST_DATA_WITH_NULL.keys(): + assert df.where(f"{col} is not null").count() == 2, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + assert tbl.inspect.partitions().num_rows == expected_rows + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == expected_rows + + +@pytest.mark.integration +@pytest.mark.parametrize( + "transform,expected_partitions", + [ + pytest.param(YearTransform(), {53, 54, None}, id="year_transform"), + pytest.param(MonthTransform(), {647, 648, 649, None}, id="month_transform"), + pytest.param( + DayTransform(), {date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), None}, id="day_transform" + ), + pytest.param(HourTransform(), {473328, 473352, 474072, 474096, 474102, None}, id="hour_transform"), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_transform_partition_verify_partitions_count( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_date_timestamps: pa.Table, + arrow_table_date_timestamps_schema: Schema, + transform: Transform[Any, Any], + expected_partitions: Set[Any], + format_version: int, +) -> None: + # Given + part_col = "timestamptz" + identifier = f"default.arrow_table_v{format_version}_with_{str(transform)}_transform_partitioned_on_col_{part_col}" + nested_field = arrow_table_date_timestamps_schema.find_field(part_col) + partition_spec = PartitionSpec( + PartitionField(source_id=nested_field.field_id, field_id=1001, transform=transform, name=part_col), + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_date_timestamps], + partition_spec=partition_spec, + schema=arrow_table_date_timestamps_schema, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 6, f"Expected 6 total rows for {identifier}" + for col in arrow_table_date_timestamps.column_names: + assert df.where(f"{col} is not null").count() == 5, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + partitions_table = tbl.inspect.partitions() + assert partitions_table.num_rows == len(expected_partitions) + assert {part[part_col] for part in partitions_table["partition"].to_pylist()} == expected_partitions + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == len(expected_partitions) + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_append_multiple_partitions( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_date_timestamps: pa.Table, + arrow_table_date_timestamps_schema: Schema, + format_version: int, +) -> None: + # Given + identifier = f"default.arrow_table_v{format_version}_with_multiple_partitions" + partition_spec = PartitionSpec( + PartitionField( + source_id=arrow_table_date_timestamps_schema.find_field("date").field_id, + field_id=1001, + transform=YearTransform(), + name="date_year", + ), + PartitionField( + source_id=arrow_table_date_timestamps_schema.find_field("timestamptz").field_id, + field_id=1000, + transform=HourTransform(), + name="timestamptz_hour", + ), + ) + + # When + tbl = _create_table( + session_catalog=session_catalog, + identifier=identifier, + properties={"format-version": str(format_version)}, + data=[arrow_table_date_timestamps], + partition_spec=partition_spec, + schema=arrow_table_date_timestamps_schema, + ) + + # Then + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + df = spark.table(identifier) + assert df.count() == 6, f"Expected 6 total rows for {identifier}" + for col in arrow_table_date_timestamps.column_names: + assert df.where(f"{col} is not null").count() == 5, f"Expected 2 non-null rows for {col}" + assert df.where(f"{col} is null").count() == 1, f"Expected 1 null row for {col} is null" + + partitions_table = tbl.inspect.partitions() + assert partitions_table.num_rows == 6 + partitions = partitions_table["partition"].to_pylist() + assert {(part["date_year"], part["timestamptz_hour"]) for part in partitions} == { + (53, 473328), + (54, 473352), + (54, 474072), + (54, 474096), + (54, 474102), + (None, None), + } + files_df = spark.sql( + f""" + SELECT * + FROM {identifier}.files + """ + ) + assert files_df.count() == 6 diff --git a/tests/test_transforms.py b/tests/test_transforms.py index b8bef4b998..3a9ffd6009 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -17,7 +17,7 @@ # pylint: disable=eval-used,protected-access,redefined-outer-name from datetime import date from decimal import Decimal -from typing import Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional from uuid import UUID import mmh3 as mmh3 @@ -69,6 +69,7 @@ TimestampLiteral, literal, ) +from pyiceberg.partitioning import _to_partition_representation from pyiceberg.schema import Accessor from pyiceberg.transforms import ( BucketTransform, @@ -111,6 +112,9 @@ timestamptz_to_micros, ) +if TYPE_CHECKING: + import pyarrow as pa + @pytest.mark.parametrize( "test_input,test_type,expected", @@ -1808,3 +1812,31 @@ def test_strict_binary(bound_reference_binary: BoundReference[str]) -> None: _test_projection( lhs=transform.strict_project(name="name", pred=BoundIn(term=bound_reference_binary, literals=set_of_literals)), rhs=None ) + + +@pytest.mark.parametrize( + "transform", + [ + pytest.param(YearTransform(), id="year_transform"), + pytest.param(MonthTransform(), id="month_transform"), + pytest.param(DayTransform(), id="day_transform"), + pytest.param(HourTransform(), id="hour_transform"), + ], +) +@pytest.mark.parametrize( + "source_col, source_type", [("date", DateType()), ("timestamp", TimestampType()), ("timestamptz", TimestamptzType())] +) +def test_ymd_pyarrow_transforms( + arrow_table_date_timestamps: "pa.Table", + source_col: str, + source_type: PrimitiveType, + transform: Transform[Any, Any], +) -> None: + if transform.can_transform(source_type): + assert transform.pyarrow_transform(source_type)(arrow_table_date_timestamps[source_col]).to_pylist() == [ + transform.transform(source_type)(_to_partition_representation(source_type, v)) + for v in arrow_table_date_timestamps[source_col].to_pylist() + ] + else: + with pytest.raises(ValueError): + transform.pyarrow_transform(DateType())(arrow_table_date_timestamps[source_col])