diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 71925c27cd..f53c9e3bca 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1820,10 +1820,7 @@ def write_parquet(task: WriteTask) -> DataFile: file_format=FileFormat.PARQUET, partition=task.partition_key.partition if task.partition_key else Record(), file_size_in_bytes=len(fo), - # After this has been fixed: - # https://github.com/apache/iceberg-python/issues/271 - # sort_order_id=task.sort_order_id, - sort_order_id=None, + sort_order_id=task.sort_order_id, # Just copy these from the table for now spec_id=table_metadata.default_spec_id, equality_ids=None, diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2d4b342461..f65833ecce 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -115,7 +115,7 @@ Summary, update_snapshot_summaries, ) -from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder +from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortField, SortOrder from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform, VoidTransform from pyiceberg.typedef import ( EMPTY_DICT, @@ -136,6 +136,7 @@ StructType, transform_dict_value_to_str, ) +from pyiceberg.utils.arrow_sorting import PyArrowSortOptions from pyiceberg.utils.concurrent import ExecutorFactory from pyiceberg.utils.datetime import datetime_to_millis from pyiceberg.utils.singleton import _convert_to_hashable_type @@ -2721,9 +2722,29 @@ def _dataframe_to_data_files( property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES, default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT, ) + sort_order: Optional[SortOrder] = table_metadata.sort_order_by_id(table_metadata.default_sort_order_id) if len(table_metadata.spec().fields) > 0: partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df) + + if sort_order and not sort_order.is_unsorted: + try: + write_partitions = [ + TablePartition( + partition_key=partition.partition_key, + arrow_table_partition=_sort_table_by_sort_order( + arrow_table=partition.arrow_table_partition, schema=table_metadata.schema(), sort_order=sort_order + ), + ) + for partition in partitions + ] + except Exception as exc: + warnings.warn(f"Failed to sort table with error: {exc}") + sort_order = UNSORTED_SORT_ORDER + write_partitions = partitions + else: + write_partitions = partitions + yield from write_file( io=io, table_metadata=table_metadata, @@ -2734,18 +2755,35 @@ def _dataframe_to_data_files( record_batches=batches, partition_key=partition.partition_key, schema=table_metadata.schema(), + sort_order_id=sort_order.order_id if sort_order else None, ) - for partition in partitions + for partition in write_partitions for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size) ]), ) else: + if sort_order and not sort_order.is_unsorted: + try: + write_df = _sort_table_by_sort_order(arrow_table=df, schema=table_metadata.schema(), sort_order=sort_order) + except Exception as exc: + warnings.warn(f"Failed to sort table with error: {exc}") + sort_order = UNSORTED_SORT_ORDER + write_df = df + else: + write_df = df + yield from write_file( io=io, table_metadata=table_metadata, tasks=iter([ - WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema()) - for batches in bin_pack_arrow_table(df, target_file_size) + WriteTask( + write_uuid=write_uuid, + task_id=next(counter), + record_batches=batches, + schema=table_metadata.schema(), + sort_order_id=sort_order.order_id if sort_order else None, + ) + for batches in bin_pack_arrow_table(write_df, target_file_size) ]), ) @@ -3747,3 +3785,45 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions) return table_partitions + + +def _sort_table_by_sort_order(arrow_table: pa.Table, schema: Schema, sort_order: SortOrder) -> pa.Table: + """ + Sorts an Arrow Table using Iceberg Sort Order. + + Args: + arrow_table (pyarrow.Table): Input Arrow table needs to be sorted + schema (Schema): Iceberg Schema of the Arrow Table + sort_order (SortOrder): Sort Order that needs to implemented + + Returns: + pyarrow.Table:Sorted Arrow Table + """ + import pyarrow as pa + + from pyiceberg.utils.arrow_sorting import convert_sort_field_to_pyarrow_sort_options, get_sort_indices_arrow_table + + if unsupported_sort_transforms := [field for field in sort_order.fields if not field.transform.supports_pyarrow_transform]: + raise ValueError( + f"Not all sort transforms are supported for writes. Following sort orders cannot be written using pyarrow: {unsupported_sort_transforms}." + ) + + sort_columns: List[Tuple[SortField, NestedField]] = [ + (sort_field, schema.find_field(sort_field.source_id)) for sort_field in sort_order.fields + ] + + sort_values_generated = pa.table({ + str(sort_spec.source_id): sort_spec.transform.pyarrow_transform(field.field_type)(arrow_table[field.name]) + for sort_spec, field in sort_columns + }) + + arrow_sort_options: list[Tuple[str, PyArrowSortOptions]] = [ + ( + str(sort_field.source_id), + convert_sort_field_to_pyarrow_sort_options(sort_field), + ) + for sort_field in sort_order.fields + ] + + sort_indices = get_sort_indices_arrow_table(arrow_table=sort_values_generated, sort_seq=arrow_sort_options) + return arrow_table.take(sort_indices) diff --git a/pyiceberg/utils/arrow_sorting.py b/pyiceberg/utils/arrow_sorting.py new file mode 100644 index 0000000000..759feba52c --- /dev/null +++ b/pyiceberg/utils/arrow_sorting.py @@ -0,0 +1,110 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +from typing import List, Tuple + +import pyarrow as pa + +from pyiceberg.table.sorting import NullOrder, SortDirection, SortField + + +class PyArrowSortOptions: + sort_direction: str + null_order: str + + def __init__(self, sort_direction: str = "ascending", null_order: str = "at_end"): + if sort_direction not in ("ascending", "descending"): + raise ValueError('Sort Direction should be one of ["ascending","descending"]') + if null_order not in ("at_start", "at_end"): + raise ValueError('Sort Null Order should be one of ["at_start","at_end"]') + + self.sort_direction = sort_direction + self.null_order = null_order + + +def convert_sort_field_to_pyarrow_sort_options(sort_field: SortField) -> PyArrowSortOptions: + """ + Convert an Iceberg Table Sort Field to Arrow Sort Options. + + Args: + sort_field (SortField): Source Iceberg Sort Field to be converted + + Returns: + PyArrowOptions: PyArrowOptions format for the input Sort Field + """ + pyarrow_sort_direction = {SortDirection.ASC: "ascending", SortDirection.DESC: "descending"} + pyarrow_null_ordering = {NullOrder.NULLS_LAST: "at_end", NullOrder.NULLS_FIRST: "at_start"} + return PyArrowSortOptions( + pyarrow_sort_direction.get(sort_field.direction, "ascending"), + pyarrow_null_ordering.get(sort_field.null_order, "at_end"), + ) + + +def get_sort_indices_arrow_table(arrow_table: pa.Table, sort_seq: List[Tuple[str, PyArrowSortOptions]]) -> List[int]: + """ + Return the indices that would sort the input arrow table. + + This function computes an array of indices that define a stable sort of the input arrow_table + + Currently, pyarrow sort_indices function doesn't accept different null ordering across multiple keys + To make sure, we are able to sort null orders across multiple keys: + 1. Utilize a stable sort algo (e.g. pyarrow sort indices) + 2. Sort on the last key first and reverse iterate sort to the first key. + + For instance: + If the sorting is defined on age asc and then name desc, the sorting can be decomposed into single key stable + sorts in the following way: + - first sort by name desc + - then sort by age asc + + Using a stable sort, we can guarantee that the output would be same across different order keys. + + Pyarrow sort_indices function is stable as mentioned in the doc: https://arrow.apache.org/docs/python/generated/pyarrow.compute.sort_indices.html + + Args: + arrow_table (pa.Table): Input table to be sorted + sort_seq: Seq of PyArrowOptions to apply sorting + + Returns: + List[int]: Indices of the arrow table for sorting + """ + import pyarrow as pa + + index_column_name = "__idx__pyarrow_sort__" + cols = set(arrow_table.column_names) + + while index_column_name in cols: + index_column_name = f"{index_column_name}_1" + + sorted_table: pa.Table = arrow_table.add_column(0, index_column_name, [list(range(len(arrow_table)))]) + + for col_name, _ in sort_seq: + if col_name not in cols: + raise ValueError( + f"{col_name} not found in arrow table. Expected one of [{','.join([col_name for col_name, _ in cols])}]" + ) + + for col_name, sort_options in sort_seq[::-1]: + sorted_table = sorted_table.take( + # This function works because pyarrow sort_indices function is stable. + # As mentioned in the docs: https://arrow.apache.org/docs/python/generated/pyarrow.compute.sort_indices.html + pa.compute.sort_indices( + sorted_table, sort_keys=[(col_name, sort_options.sort_direction)], null_placement=sort_options.null_order + ) + ) + + return sorted_table[index_column_name].to_pylist() diff --git a/tests/integration/test_writes/test_sorted_writes.py b/tests/integration/test_writes/test_sorted_writes.py new file mode 100644 index 0000000000..494c4e79bb --- /dev/null +++ b/tests/integration/test_writes/test_sorted_writes.py @@ -0,0 +1,432 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +from typing import Tuple + +import pyarrow as pa +import pytest +from pyspark.sql import SparkSession + +from pyiceberg.catalog import Catalog +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table.sorting import NullOrder, SortDirection, SortField, SortOrder +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + IdentityTransform, + MonthTransform, + TruncateTransform, + YearTransform, +) +from utils import TABLE_SCHEMA, _create_table + +######################################################################################################################## +# Spark engine works here because the read is stable for one file. +# For these tests to run, it is necessary that only 1 file is generated per partition during writes for sorted queries. +######################################################################################################################## + + +@pytest.mark.integration +@pytest.mark.parametrize( + "sort_col", ["int", "bool", "string", "string_long", "long", "float", "double", "date", "timestamp", "timestamptz", "binary"] +) +@pytest.mark.parametrize("sort_direction", [SortDirection.ASC, SortDirection.DESC]) +@pytest.mark.parametrize("sort_null_ordering", [NullOrder.NULLS_FIRST, NullOrder.NULLS_LAST]) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_null_append_sort( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_with_null: pa.Table, + sort_col: str, + sort_direction: SortDirection, + sort_null_ordering: NullOrder, + format_version: int, +) -> None: + table_identifier = f"default.arrow_table_v{format_version}_with_null_sorted_on_col_{sort_col}_in_direction_{sort_direction}_with_null_ordering_{str(sort_null_ordering).replace(' ', '_')}" + nested_field = TABLE_SCHEMA.find_field(sort_col) + + sort_order = SortOrder( + SortField( + source_id=nested_field.field_id, + direction=sort_direction, + transform=IdentityTransform(), + null_order=sort_null_ordering, + ) + ) + + tbl = _create_table( + session_catalog=session_catalog, + identifier=table_identifier, + properties={"format-version": str(format_version)}, + data=[], + sort_order=sort_order, + ) + + tbl.append(arrow_table_with_null) + + query_sorted_df = spark.sql( + f"SELECT * FROM {table_identifier} ORDER BY {sort_col} {sort_direction} {sort_null_ordering}" + ).toPandas() + + append_sorted_df = spark.table(table_identifier).toPandas() + + assert len(tbl.metadata.sort_orders) == 1, f"Expected no sort order for {tbl}" + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + assert append_sorted_df.shape[0] == 3, f"Expected 3 total rows for {table_identifier}" + assert append_sorted_df.equals( + query_sorted_df + ), f"Expected sorted dataframe for v{format_version} on col: {sort_col} in direction {sort_direction} with null ordering as {sort_null_ordering}, got {append_sorted_df[sort_col]}" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "sort_col_tuple_3", [("int", "bool", "string"), ("long", "float", "double"), ("date", "timestamp", "timestamptz")] +) +@pytest.mark.parametrize("sort_direction_tuple_3", [(SortDirection.ASC, SortDirection.DESC, SortDirection.DESC)]) +@pytest.mark.parametrize( + "sort_null_ordering_tuple_3", + [ + (NullOrder.NULLS_FIRST, NullOrder.NULLS_FIRST, NullOrder.NULLS_LAST), + (NullOrder.NULLS_FIRST, NullOrder.NULLS_FIRST, NullOrder.NULLS_FIRST), + ], +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_null_append_multi_sort( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_with_null: pa.Table, + sort_col_tuple_3: Tuple[str, str, str], + sort_direction_tuple_3: Tuple[SortDirection, SortDirection, SortDirection], + sort_null_ordering_tuple_3: Tuple[NullOrder, NullOrder, NullOrder], + format_version: int, +) -> None: + table_identifier = f"default.arrow_table_v{format_version}_with_null_multi_sorted_on_cols_{'_'.join(sort_col_tuple_3)}" + + sort_options_list = list(zip(sort_col_tuple_3, sort_direction_tuple_3, sort_null_ordering_tuple_3)) + + sort_order = SortOrder(*[ + SortField( + source_id=TABLE_SCHEMA.find_field(sort_col).field_id, + direction=sort_direction, + transform=IdentityTransform(), + null_order=sort_null_ordering, + ) + for sort_col, sort_direction, sort_null_ordering in sort_options_list + ]) + + tbl = _create_table( + session_catalog=session_catalog, + identifier=table_identifier, + properties={"format-version": str(format_version)}, + data=[], + sort_order=sort_order, + ) + + tbl.append(arrow_table_with_null) + + query_sorted_df = spark.sql( + f"SELECT * FROM {table_identifier} ORDER BY {','.join([f' {sort_col} {sort_direction} {sort_null_ordering}' for sort_col, sort_direction, sort_null_ordering in sort_options_list])}" + ).toPandas() + + append_sorted_df = spark.table(table_identifier).toPandas() + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + assert append_sorted_df.shape[0] == 3, f"Expected 3 total rows for {table_identifier}" + assert append_sorted_df.equals( + query_sorted_df + ), f"Expected sorted dataframe for v{format_version} on col: {sort_options_list}, got {append_sorted_df}" + + +@pytest.mark.integration +@pytest.mark.parametrize("part_col", ["int", "date", "string"]) +@pytest.mark.parametrize( + "sort_col_tuple_2", [("bool", "string_long"), ("long", "float"), ("double", "timestamp"), ("timestamptz", "binary")] +) +@pytest.mark.parametrize("format_version", [1, 2]) +def test_query_null_append_partitioned_multi_sort( + session_catalog: Catalog, + spark: SparkSession, + arrow_table_with_null: pa.Table, + part_col: str, + sort_col_tuple_2: Tuple[str, str], + format_version: int, +) -> None: + table_identifier = f"default.arrow_table_v{format_version}_with_null_partitioned_on_{part_col}_multi_sorted_on_cols_{'_'.join(sort_col_tuple_2)}" + + partition_spec = PartitionSpec( + PartitionField( + source_id=TABLE_SCHEMA.find_field(part_col).field_id, field_id=1001, transform=IdentityTransform(), name=part_col + ) + ) + + sort_order = SortOrder(*[ + SortField( + source_id=TABLE_SCHEMA.find_field(sort_col).field_id, + transform=IdentityTransform(), + ) + for sort_col in sort_col_tuple_2 + ]) + + tbl = _create_table( + session_catalog=session_catalog, + identifier=table_identifier, + properties={"format-version": str(format_version)}, + data=[], + partition_spec=partition_spec, + sort_order=sort_order, + ) + + tbl.append(arrow_table_with_null) + + query_sorted_df = spark.sql( + f"SELECT * FROM {table_identifier} ORDER BY {part_col} , {','.join([f' {sort_col} ' for sort_col in sort_col_tuple_2])}" + ).toPandas() + + append_sorted_df = spark.sql(f"SELECT * FROM {table_identifier} ORDER BY {part_col}").toPandas() + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + assert append_sorted_df.shape[0] == 3, f"Expected 3 total rows for {table_identifier}" + assert append_sorted_df.equals( + query_sorted_df + ), f"Expected sorted dataframe for v{format_version} on col: {sort_col_tuple_2}, got {append_sorted_df}" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "sort_order", + [ + SortOrder(*[ + SortField(source_id=1, transform=IdentityTransform()), + SortField(source_id=4, transform=BucketTransform(2)), + ]), + SortOrder(SortField(source_id=5, transform=BucketTransform(2))), + SortOrder(SortField(source_id=8, transform=BucketTransform(2))), + SortOrder(SortField(source_id=9, transform=BucketTransform(2))), + SortOrder(SortField(source_id=4, transform=TruncateTransform(2))), + SortOrder(SortField(source_id=5, transform=TruncateTransform(2))), + ], +) +def test_invalid_sort_transform( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, sort_order: SortOrder +) -> None: + import re + + table_identifier = f"""default.arrow_table_invalid_sort_transform_{'_'.join([f"__{re.sub(r'[^A-Za-z0-9_]', '', str(field.transform))}_{field.source_id}_{field.direction}_{str(field.null_order)}__".replace(' ', '') for field in sort_order.fields])}""" + + tbl = _create_table( + session_catalog=session_catalog, + identifier=table_identifier, + properties={"format-version": "1"}, + schema=TABLE_SCHEMA, + sort_order=sort_order, + ) + + with pytest.warns( + UserWarning, + match="Not all sort transforms are supported for writes. Following sort orders cannot be written using pyarrow: *", + ): + tbl.append(arrow_table_with_null) + + files_df = spark.sql( + f""" + SELECT * + FROM {table_identifier}.files + """ + ) + + assert [row.sort_order_id for row in files_df.select("sort_order_id").distinct().collect()] == [ + 0 + ], "Expected Sort Order Id to be set as 0 (Unsorted) in the manifest file" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "sort_order", + [ + SortOrder(*[ + SortField(source_id=1, transform=IdentityTransform()), + SortField(source_id=4, transform=BucketTransform(2)), + ]), + SortOrder(SortField(source_id=5, transform=BucketTransform(2))), + SortOrder(SortField(source_id=8, transform=BucketTransform(2))), + SortOrder(SortField(source_id=9, transform=BucketTransform(2))), + SortOrder(SortField(source_id=4, transform=TruncateTransform(2))), + SortOrder(SortField(source_id=5, transform=TruncateTransform(2))), + ], +) +def test_invalid_sort_transform_partitioned( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, sort_order: SortOrder +) -> None: + import re + + table_identifier = f"""default.arrow_table_invalid_sort_transform_partitioned_{'_'.join([f"__{re.sub(r'[^A-Za-z0-9_]', '', str(field.transform))}_{field.source_id}_{field.direction}_{str(field.null_order)}__".replace(' ', '') for field in sort_order.fields])}""" + + tbl = _create_table( + session_catalog=session_catalog, + identifier=table_identifier, + properties={"format-version": "1"}, + schema=TABLE_SCHEMA, + sort_order=sort_order, + partition_spec=PartitionSpec( + PartitionField(source_id=10, field_id=1001, transform=IdentityTransform(), name="identity_date") + ), + ) + + with pytest.warns( + UserWarning, + match="Not all sort transforms are supported for writes. Following sort orders cannot be written using pyarrow: *", + ): + tbl.append(arrow_table_with_null) + + files_df = spark.sql( + f""" + SELECT * + FROM {table_identifier}.files + """ + ) + + assert [row.sort_order_id for row in files_df.select("sort_order_id").distinct().collect()] == [ + 0 + ], "Expected Sort Order Id to be set as 0 (Unsorted) in the manifest file" + + +@pytest.mark.integration +@pytest.mark.parametrize( + "sort_order", + [ + SortOrder(*[ + SortField(source_id=8, transform=YearTransform()), + SortField(source_id=4, transform=IdentityTransform()), + ]), + SortOrder(*[ + SortField(source_id=10, transform=YearTransform()), + SortField(source_id=9, transform=DayTransform()), + ]), + ], +) +def test_valid_sort_transform( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, sort_order: SortOrder +) -> None: + table_identifier = f"default.arrow_table_invalid_sort_transform_{'_'.join([f'__{field.transform}_{field.source_id}_{field.direction}_{str(field.null_order)}__'.replace(' ', '') for field in sort_order.fields])}" + + _ = _create_table( + session_catalog=session_catalog, + identifier=table_identifier, + properties={"format-version": "1"}, + schema=TABLE_SCHEMA, + data=[arrow_table_with_null], + sort_order=sort_order, + ) + + def _get_sort_order_clause_spark_query(table_schema: Schema, sort_field: SortField) -> str: + if isinstance(sort_field.transform, YearTransform): + return f" YEAR({table_schema.find_field(sort_field.source_id).name}) {sort_field.direction} {sort_field.null_order} " + elif isinstance(sort_field.transform, MonthTransform): + return f" MONTH({table_schema.find_field(sort_field.source_id).name}) {sort_field.direction} {sort_field.null_order} " + elif isinstance(sort_field.transform, DayTransform): + return f" DAY({table_schema.find_field(sort_field.source_id).name}) {sort_field.direction} {sort_field.null_order} " + elif isinstance(sort_field.transform, IdentityTransform): + return f" {table_schema.find_field(sort_field.source_id).name} {sort_field.direction} {sort_field.null_order} " + else: + raise ValueError("Not Supported Transform for Test") + + query_sorted_df = spark.sql( + f"SELECT * FROM {table_identifier} ORDER BY {','.join([_get_sort_order_clause_spark_query(TABLE_SCHEMA, field) for field in sort_order.fields])}" + ).toPandas() + + append_sorted_df = spark.table(table_identifier).toPandas() + + assert append_sorted_df.shape[0] == 3, f"Expected 3 total rows for {table_identifier}" + assert append_sorted_df.equals( + query_sorted_df + ), f"Expected sorted dataframe on col: {','.join([f'[{field}]' for field in sort_order.fields])}, got {append_sorted_df}" + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_manifest_for_sort_order( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, format_version: int +) -> None: + table_identifier = f"default.arrow_table_v{format_version}_manifest_for_sort_order" + + tbl = _create_table( + session_catalog=session_catalog, + identifier=table_identifier, + properties={"format-version": str(format_version)}, + schema=TABLE_SCHEMA, + data=[arrow_table_with_null], + sort_order=SortOrder( + SortField( + source_id=4, + transform=IdentityTransform(), + ) + ), + ) + + files_df = spark.sql( + f""" + SELECT * + FROM {table_identifier}.files + """ + ) + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + assert files_df.count() == 1, f"Expected 1 file in {table_identifier}.files, got: {files_df.count()}" + assert [row.sort_order_id for row in files_df.select("sort_order_id").collect()] == [ + 1 + ], "Expected Sort Order Id to be set as 1 in the manifest file" + + +@pytest.mark.integration +@pytest.mark.parametrize("format_version", [1, 2]) +def test_manifest_partitioned_for_sort_order( + session_catalog: Catalog, spark: SparkSession, arrow_table_with_null: pa.Table, format_version: int +) -> None: + table_identifier = f"default.arrow_table_v{format_version}_manifest_partitioned_for_sort_order" + + tbl = _create_table( + session_catalog=session_catalog, + identifier=table_identifier, + properties={"format-version": str(format_version)}, + schema=TABLE_SCHEMA, + data=[arrow_table_with_null], + sort_order=SortOrder( + SortField( + source_id=4, + transform=IdentityTransform(), + ) + ), + partition_spec=PartitionSpec( + PartitionField(source_id=10, field_id=1001, transform=IdentityTransform(), name="identity_date") + ), + ) + + files_df = spark.sql( + f""" + SELECT * + FROM {table_identifier}.files + """ + ) + + assert tbl.format_version == format_version, f"Expected v{format_version}, got: v{tbl.format_version}" + assert files_df.count() == 3, f"Expected 3 files in {table_identifier}.files, got: {files_df.count()}" + assert [row.sort_order_id for row in files_df.select("sort_order_id").collect()] == [ + 1, + 1, + 1, + ], "Expected Sort Order Id to be set as 1 in the manifest file" diff --git a/tests/integration/test_writes/utils.py b/tests/integration/test_writes/utils.py index 9f1f6df043..c12f3983fc 100644 --- a/tests/integration/test_writes/utils.py +++ b/tests/integration/test_writes/utils.py @@ -24,6 +24,7 @@ from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import Table +from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT, Properties from pyiceberg.types import ( BinaryType, @@ -66,13 +67,16 @@ def _create_table( data: Optional[List[pa.Table]] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, schema: Union[Schema, "pa.Schema"] = TABLE_SCHEMA, + sort_order: SortOrder = UNSORTED_SORT_ORDER, ) -> Table: try: session_catalog.drop_table(identifier=identifier) except NoSuchTableError: pass - tbl = session_catalog.create_table(identifier=identifier, schema=schema, properties=properties, partition_spec=partition_spec) + tbl = session_catalog.create_table( + identifier=identifier, schema=schema, properties=properties, partition_spec=partition_spec, sort_order=sort_order + ) if data is not None: for d in data: diff --git a/tests/utils/test_arrow_sorting.py b/tests/utils/test_arrow_sorting.py new file mode 100644 index 0000000000..5b7447efaf --- /dev/null +++ b/tests/utils/test_arrow_sorting.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name +from typing import List, Tuple + +import pyarrow as pa +import pytest + +from pyiceberg.utils.arrow_sorting import PyArrowSortOptions, get_sort_indices_arrow_table + + +@pytest.fixture +def example_arrow_table_for_sort() -> pa.Table: + return pa.table({ + "column1": [5, None, 3, 1, 1, None, 3], + "column2": ["b", "a", None, "c", "c", "d", "m"], + "column3": [10.5, None, 5.1, None, 2.5, 7.3, 3.3], + }) + + +@pytest.mark.parametrize( + "sort_keys, expected", + [ + ( + [ + ("column1", PyArrowSortOptions("ascending", "at_end")), + ("column2", PyArrowSortOptions("ascending", "at_start")), + ("column3", PyArrowSortOptions("descending", "at_end")), + ], + [4, 3, 2, 6, 0, 1, 5], + ) + ], +) +def test_get_sort_indices_arrow_table( + example_arrow_table_for_sort: pa.Table, sort_keys: List[Tuple[str, PyArrowSortOptions]], expected: List[int] +) -> None: + sorted_indices = get_sort_indices_arrow_table(example_arrow_table_for_sort, sort_keys) + assert sorted_indices == expected, "Table sort not in expected form" + + +@pytest.mark.parametrize("sort_keys, expected", [([("column1", PyArrowSortOptions())], [3, 4, 2, 6, 0, 1, 5])]) +def test_stability_get_sort_indices_arrow_table( + example_arrow_table_for_sort: pa.Table, sort_keys: List[Tuple[str, PyArrowSortOptions]], expected: pa.Table +) -> None: + sorted_indices = get_sort_indices_arrow_table(example_arrow_table_for_sort, sort_keys) + assert sorted_indices == expected, "Arrow Table sort is not stable"