Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1820,10 +1820,7 @@ def write_parquet(task: WriteTask) -> DataFile:
file_format=FileFormat.PARQUET,
partition=task.partition_key.partition if task.partition_key else Record(),
file_size_in_bytes=len(fo),
# After this has been fixed:
# https://github.com/apache/iceberg-python/issues/271
# sort_order_id=task.sort_order_id,
sort_order_id=None,
sort_order_id=task.sort_order_id,
# Just copy these from the table for now
spec_id=table_metadata.default_spec_id,
equality_ids=None,
Expand Down
88 changes: 84 additions & 4 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@
Summary,
update_snapshot_summaries,
)
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortField, SortOrder
from pyiceberg.transforms import IdentityTransform, TimeTransform, Transform, VoidTransform
from pyiceberg.typedef import (
EMPTY_DICT,
Expand All @@ -136,6 +136,7 @@
StructType,
transform_dict_value_to_str,
)
from pyiceberg.utils.arrow_sorting import PyArrowSortOptions
from pyiceberg.utils.concurrent import ExecutorFactory
from pyiceberg.utils.datetime import datetime_to_millis
from pyiceberg.utils.singleton import _convert_to_hashable_type
Expand Down Expand Up @@ -2721,9 +2722,29 @@ def _dataframe_to_data_files(
property_name=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES,
default=TableProperties.WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT,
)
sort_order: Optional[SortOrder] = table_metadata.sort_order_by_id(table_metadata.default_sort_order_id)

if len(table_metadata.spec().fields) > 0:
partitions = _determine_partitions(spec=table_metadata.spec(), schema=table_metadata.schema(), arrow_table=df)

if sort_order and not sort_order.is_unsorted:
try:
write_partitions = [
TablePartition(
partition_key=partition.partition_key,
arrow_table_partition=_sort_table_by_sort_order(
arrow_table=partition.arrow_table_partition, schema=table_metadata.schema(), sort_order=sort_order
),
)
for partition in partitions
]
except Exception as exc:
warnings.warn(f"Failed to sort table with error: {exc}")
sort_order = UNSORTED_SORT_ORDER
write_partitions = partitions
else:
write_partitions = partitions

yield from write_file(
io=io,
table_metadata=table_metadata,
Expand All @@ -2734,18 +2755,35 @@ def _dataframe_to_data_files(
record_batches=batches,
partition_key=partition.partition_key,
schema=table_metadata.schema(),
sort_order_id=sort_order.order_id if sort_order else None,
)
for partition in partitions
for partition in write_partitions
for batches in bin_pack_arrow_table(partition.arrow_table_partition, target_file_size)
]),
)
else:
if sort_order and not sort_order.is_unsorted:
try:
write_df = _sort_table_by_sort_order(arrow_table=df, schema=table_metadata.schema(), sort_order=sort_order)
except Exception as exc:
warnings.warn(f"Failed to sort table with error: {exc}")
sort_order = UNSORTED_SORT_ORDER
write_df = df
else:
write_df = df

yield from write_file(
io=io,
table_metadata=table_metadata,
tasks=iter([
WriteTask(write_uuid=write_uuid, task_id=next(counter), record_batches=batches, schema=table_metadata.schema())
for batches in bin_pack_arrow_table(df, target_file_size)
WriteTask(
write_uuid=write_uuid,
task_id=next(counter),
record_batches=batches,
schema=table_metadata.schema(),
sort_order_id=sort_order.order_id if sort_order else None,
)
for batches in bin_pack_arrow_table(write_df, target_file_size)
]),
)

Expand Down Expand Up @@ -3747,3 +3785,45 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions


def _sort_table_by_sort_order(arrow_table: pa.Table, schema: Schema, sort_order: SortOrder) -> pa.Table:
"""
Sorts an Arrow Table using Iceberg Sort Order.

Args:
arrow_table (pyarrow.Table): Input Arrow table needs to be sorted
schema (Schema): Iceberg Schema of the Arrow Table
sort_order (SortOrder): Sort Order that needs to implemented

Returns:
pyarrow.Table:Sorted Arrow Table
"""
import pyarrow as pa

from pyiceberg.utils.arrow_sorting import convert_sort_field_to_pyarrow_sort_options, get_sort_indices_arrow_table

if unsupported_sort_transforms := [field for field in sort_order.fields if not field.transform.supports_pyarrow_transform]:
raise ValueError(
f"Not all sort transforms are supported for writes. Following sort orders cannot be written using pyarrow: {unsupported_sort_transforms}."
)

sort_columns: List[Tuple[SortField, NestedField]] = [
(sort_field, schema.find_field(sort_field.source_id)) for sort_field in sort_order.fields
]

sort_values_generated = pa.table({
str(sort_spec.source_id): sort_spec.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for sort_spec, field in sort_columns
})

arrow_sort_options: list[Tuple[str, PyArrowSortOptions]] = [
(
str(sort_field.source_id),
convert_sort_field_to_pyarrow_sort_options(sort_field),
)
for sort_field in sort_order.fields
]

sort_indices = get_sort_indices_arrow_table(arrow_table=sort_values_generated, sort_seq=arrow_sort_options)
return arrow_table.take(sort_indices)
110 changes: 110 additions & 0 deletions pyiceberg/utils/arrow_sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
from typing import List, Tuple

import pyarrow as pa

from pyiceberg.table.sorting import NullOrder, SortDirection, SortField


class PyArrowSortOptions:
sort_direction: str
null_order: str

def __init__(self, sort_direction: str = "ascending", null_order: str = "at_end"):
if sort_direction not in ("ascending", "descending"):
raise ValueError('Sort Direction should be one of ["ascending","descending"]')
if null_order not in ("at_start", "at_end"):
raise ValueError('Sort Null Order should be one of ["at_start","at_end"]')

self.sort_direction = sort_direction
self.null_order = null_order


def convert_sort_field_to_pyarrow_sort_options(sort_field: SortField) -> PyArrowSortOptions:
"""
Convert an Iceberg Table Sort Field to Arrow Sort Options.

Args:
sort_field (SortField): Source Iceberg Sort Field to be converted

Returns:
PyArrowOptions: PyArrowOptions format for the input Sort Field
"""
pyarrow_sort_direction = {SortDirection.ASC: "ascending", SortDirection.DESC: "descending"}
pyarrow_null_ordering = {NullOrder.NULLS_LAST: "at_end", NullOrder.NULLS_FIRST: "at_start"}
return PyArrowSortOptions(
pyarrow_sort_direction.get(sort_field.direction, "ascending"),
pyarrow_null_ordering.get(sort_field.null_order, "at_end"),
)


def get_sort_indices_arrow_table(arrow_table: pa.Table, sort_seq: List[Tuple[str, PyArrowSortOptions]]) -> List[int]:
Copy link
Contributor Author

@vinjai vinjai Jul 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wanted to clarify on the separate implementation for sort_indices other than the one provided by pyarrow.
This is because pyarrow sort_indices or Sort Options only supports one order for null placement across keys.
More details here:

While, the iceberg spec doesn't discriminate of having different null ordering across keys: https://iceberg.apache.org/spec/#sort-orders

This function specifically helps to implement the above functionality by sorting across keys and utilizing the stable nature of the sort_indices algo from pyarrow.


We can raise another issue to improve the performance of this function.


In future, if pyarrow sort_indices does support different null ordering across, we can mark this function as obsolete and keep the implementation clean in the iceberg table append and overwrite methods.

"""
Return the indices that would sort the input arrow table.

This function computes an array of indices that define a stable sort of the input arrow_table

Currently, pyarrow sort_indices function doesn't accept different null ordering across multiple keys
To make sure, we are able to sort null orders across multiple keys:
1. Utilize a stable sort algo (e.g. pyarrow sort indices)
2. Sort on the last key first and reverse iterate sort to the first key.

For instance:
If the sorting is defined on age asc and then name desc, the sorting can be decomposed into single key stable
sorts in the following way:
- first sort by name desc
- then sort by age asc

Using a stable sort, we can guarantee that the output would be same across different order keys.

Pyarrow sort_indices function is stable as mentioned in the doc: https://arrow.apache.org/docs/python/generated/pyarrow.compute.sort_indices.html

Args:
arrow_table (pa.Table): Input table to be sorted
sort_seq: Seq of PyArrowOptions to apply sorting

Returns:
List[int]: Indices of the arrow table for sorting
"""
import pyarrow as pa

index_column_name = "__idx__pyarrow_sort__"
cols = set(arrow_table.column_names)

while index_column_name in cols:
index_column_name = f"{index_column_name}_1"

sorted_table: pa.Table = arrow_table.add_column(0, index_column_name, [list(range(len(arrow_table)))])

for col_name, _ in sort_seq:
if col_name not in cols:
raise ValueError(
f"{col_name} not found in arrow table. Expected one of [{','.join([col_name for col_name, _ in cols])}]"
)

for col_name, sort_options in sort_seq[::-1]:
sorted_table = sorted_table.take(
# This function works because pyarrow sort_indices function is stable.
# As mentioned in the docs: https://arrow.apache.org/docs/python/generated/pyarrow.compute.sort_indices.html
pa.compute.sort_indices(
sorted_table, sort_keys=[(col_name, sort_options.sort_direction)], null_placement=sort_options.null_order
)
)

return sorted_table[index_column_name].to_pylist()
Loading