Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 69 additions & 22 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,16 @@ def __setstate__(self, state: Dict[str, Any]) -> None:


def schema_to_pyarrow(
schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True
schema: Union[Schema, IcebergType],
metadata: Dict[bytes, bytes] = EMPTY_DICT,
include_field_ids: bool = True,
with_large_types: bool = True,
) -> pa.schema:
return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))
pyarrow_schema = visit(schema, _ConvertToArrowSchema(metadata, include_field_ids))
if with_large_types:
return _pyarrow_schema_ensure_large_types(pyarrow_schema)
else:
return pyarrow_schema


class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]):
Expand Down Expand Up @@ -504,7 +511,7 @@ def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field:

def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType:
element_field = self.field(list_type.element_field, element_result)
return pa.large_list(value_type=element_field)
return pa.list_(value_type=element_field)

def map(self, map_type: MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
key_field = self.field(map_type.key_field, key_result)
Expand Down Expand Up @@ -548,13 +555,13 @@ def visit_timestamptz(self, _: TimestamptzType) -> pa.DataType:
return pa.timestamp(unit="us", tz="UTC")

def visit_string(self, _: StringType) -> pa.DataType:
return pa.large_string()
return pa.string()

def visit_uuid(self, _: UUIDType) -> pa.DataType:
return pa.binary(16)

def visit_binary(self, _: BinaryType) -> pa.DataType:
return pa.large_binary()
return pa.binary()


def _convert_scalar(value: Any, iceberg_type: IcebergType) -> pa.scalar:
Expand Down Expand Up @@ -958,19 +965,23 @@ def after_map_value(self, element: pa.Field) -> None:

class _ConvertToLargeTypes(PyArrowSchemaVisitor[Union[pa.DataType, pa.Schema]]):
def schema(self, schema: pa.Schema, struct_result: pa.StructType) -> pa.Schema:
return pa.schema(struct_result)
return pa.schema(list(struct_result))

def struct(self, struct: pa.StructType, field_results: List[pa.Field]) -> pa.StructType:
return pa.struct(field_results)

def field(self, field: pa.Field, field_result: pa.DataType) -> pa.Field:
return field.with_type(field_result)
new_field = field.with_type(field_result)
return new_field

def list(self, list_type: pa.ListType, element_result: pa.DataType) -> pa.DataType:
return pa.large_list(element_result)
element_field = self.field(list_type.value_field, element_result)
return pa.large_list(element_field)

def map(self, map_type: pa.MapType, key_result: pa.DataType, value_result: pa.DataType) -> pa.DataType:
return pa.map_(key_result, value_result)
key_field = self.field(map_type.key_field, key_result)
value_field = self.field(map_type.item_field, value_result)
return pa.map_(key_type=key_field, item_type=value_field)

def primitive(self, primitive: pa.DataType) -> pa.DataType:
if primitive == pa.string():
Expand Down Expand Up @@ -1004,6 +1015,7 @@ def _task_to_record_batches(
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
with_large_types: bool = True,
) -> Iterator[pa.RecordBatch]:
_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
Expand Down Expand Up @@ -1049,7 +1061,7 @@ def _task_to_record_batches(
arrow_table = pa.Table.from_batches([batch])
arrow_table = arrow_table.filter(pyarrow_filter)
batch = arrow_table.to_batches()[0]
yield to_requested_schema(projected_schema, file_project_schema, batch)
yield to_requested_schema(projected_schema, file_project_schema, batch, with_large_types=with_large_types)
current_index += len(batch)


Expand All @@ -1062,11 +1074,22 @@ def _task_to_table(
positional_deletes: Optional[List[ChunkedArray]],
case_sensitive: bool,
name_mapping: Optional[NameMapping] = None,
with_large_types: bool = True,
) -> pa.Table:
batches = _task_to_record_batches(
fs, task, bound_row_filter, projected_schema, projected_field_ids, positional_deletes, case_sensitive, name_mapping
fs,
task,
bound_row_filter,
projected_schema,
projected_field_ids,
positional_deletes,
case_sensitive,
name_mapping,
with_large_types,
)
return pa.Table.from_batches(
batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False, with_large_types=with_large_types)
)
return pa.Table.from_batches(batches, schema=schema_to_pyarrow(projected_schema, include_field_ids=False))


def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
Expand Down Expand Up @@ -1095,6 +1118,7 @@ def project_table(
projected_schema: Schema,
case_sensitive: bool = True,
limit: Optional[int] = None,
with_large_types: bool = True,
) -> pa.Table:
"""Resolve the right columns based on the identifier.

Expand Down Expand Up @@ -1146,6 +1170,7 @@ def project_table(
deletes_per_file.get(task.file.file_path),
case_sensitive,
table_metadata.name_mapping(),
with_large_types,
)
for task in tasks
]
Expand All @@ -1168,7 +1193,9 @@ def project_table(
tables = [f.result() for f in completed_futures if f.result()]

if len(tables) < 1:
return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False))
return pa.Table.from_batches(
[], schema=schema_to_pyarrow(projected_schema, include_field_ids=False, with_large_types=with_large_types)
)

result = pa.concat_tables(tables)

Expand All @@ -1186,6 +1213,7 @@ def project_batches(
projected_schema: Schema,
case_sensitive: bool = True,
limit: Optional[int] = None,
with_large_types: bool = True,
) -> Iterator[pa.RecordBatch]:
"""Resolve the right columns based on the identifier.

Expand Down Expand Up @@ -1238,6 +1266,7 @@ def project_batches(
deletes_per_file.get(task.file.file_path),
case_sensitive,
table_metadata.name_mapping(),
with_large_types,
)
for batch in batches:
if limit is not None:
Expand All @@ -1248,8 +1277,12 @@ def project_batches(
total_row_count += len(batch)


def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch) -> pa.RecordBatch:
struct_array = visit_with_partner(requested_schema, batch, ArrowProjectionVisitor(file_schema), ArrowAccessor(file_schema))
def to_requested_schema(
requested_schema: Schema, file_schema: Schema, batch: pa.RecordBatch, with_large_types: bool = True
) -> pa.RecordBatch:
struct_array = visit_with_partner(
requested_schema, batch, ArrowProjectionVisitor(file_schema, with_large_types), ArrowAccessor(file_schema)
)

arrays = []
fields = []
Expand All @@ -1263,15 +1296,26 @@ def to_requested_schema(requested_schema: Schema, file_schema: Schema, batch: pa
class ArrowProjectionVisitor(SchemaWithPartnerVisitor[pa.Array, Optional[pa.Array]]):
file_schema: Schema

def __init__(self, file_schema: Schema):
def __init__(self, file_schema: Schema, with_large_types: bool = True):
self.file_schema = file_schema
self.with_large_types = with_large_types

def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array:
file_field = self.file_schema.find_field(field.field_id)
if field.field_type.is_primitive:
if field.field_type != file_field.field_type:
return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False))
elif (target_type := schema_to_pyarrow(field.field_type, include_field_ids=False)) != values.type:
return values.cast(
schema_to_pyarrow(
promote(file_field.field_type, field.field_type),
include_field_ids=False,
with_large_types=self.with_large_types,
)
)
elif (
target_type := schema_to_pyarrow(
field.field_type, include_field_ids=False, with_large_types=self.with_large_types
)
) != values.type:
# if file_field and field_type (e.g. String) are the same
# but the pyarrow type of the array is different from the expected type
# (e.g. string vs larger_string), we want to cast the array to the larger type
Expand Down Expand Up @@ -1302,7 +1346,7 @@ def struct(
field_arrays.append(array)
fields.append(self._construct_field(field, array.type))
elif field.optional:
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False)
arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False, with_large_types=self.with_large_types)
field_arrays.append(pa.nulls(len(struct_array), type=arrow_type))
fields.append(self._construct_field(field, arrow_type))
else:
Expand All @@ -1320,7 +1364,10 @@ def list(self, list_type: ListType, list_array: Optional[pa.Array], value_array:
# https://github.com/apache/arrow/issues/38809
list_array = pa.LargeListArray.from_arrays(list_array.offsets, value_array)

arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type))
if self.with_large_types:
arrow_field = pa.large_list(self._construct_field(list_type.element_field, value_array.type))
else:
arrow_field = pa.list_(self._construct_field(list_type.element_field, value_array.type))
return list_array.cast(arrow_field)
else:
return None
Expand Down Expand Up @@ -1919,14 +1966,14 @@ def write_parquet(task: WriteTask) -> DataFile:
file_schema = table_schema

batches = [
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch)
to_requested_schema(requested_schema=file_schema, file_schema=table_schema, batch=batch, with_large_types=False)
for batch in task.record_batches
]
arrow_table = pa.Table.from_batches(batches)
file_path = f'{table_metadata.location}/data/{task.generate_data_file_path("parquet")}'
fo = io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=file_schema.as_arrow(), **parquet_writer_kwargs) as writer:
with pq.ParquetWriter(fos, schema=file_schema.as_arrow(with_large_types=False), **parquet_writer_kwargs) as writer:
writer.write(arrow_table, row_group_size=row_group_size)
statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=writer.writer.metadata,
Expand Down
4 changes: 2 additions & 2 deletions pyiceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,11 @@ def as_struct(self) -> StructType:
"""Return the schema as a struct."""
return StructType(*self.fields)

def as_arrow(self) -> "pa.Schema":
def as_arrow(self, with_large_types: bool = False) -> "pa.Schema":
"""Return the schema as an Arrow schema."""
from pyiceberg.io.pyarrow import schema_to_pyarrow

return schema_to_pyarrow(self)
return schema_to_pyarrow(self, with_large_types=with_large_types)

def find_field(self, name_or_id: Union[str, int], case_sensitive: bool = True) -> NestedField:
"""Find a field using a field name or field ID.
Expand Down
8 changes: 5 additions & 3 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,7 +2010,7 @@ def plan_files(self) -> Iterable[FileScanTask]:
for data_entry in data_entries
]

def to_arrow(self) -> pa.Table:
def to_arrow(self, with_large_types: bool = True) -> pa.Table:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @syun64 Thanks again for jumping on this issue. It is a very nasty one, so thanks for doing the hard work here.

Can I suggest one more direction? My first thoughts are that we should not bother the user with having to set this kind of flags. Instead, I think we can solve it when we concatenate the table:

image

When we do to_requested_schema, we can allow both a normal and a large string when we request a string type. When doing the concatenation of the batches into a table, we let Arrow coerce to a common type. WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for taking the time to review @Fokko .

It’s great you brought this up because I didn’t feel great about introducing a flag either… but I felt like we needed a way for the user to control which type they would be using for their arrow table or RecordBatchReader.

Do you have a preference for which type (large or small) should be the common type for the schema? The reason I’ve introduced a flag here is because we would still need to choose to which type to use in the pyarrow schema we infer based on the Iceberg table schema. As we’ve discussed in this issue, I thought being intentional about which type we are choosing to represent our table or RecordBatchReader would make the behavior feel more consistent and error prone for the end user, than the alternative of rendering the type that PyArrow infers based on the parquet file.

If this does not sound like a great candidate for an API argument, would having a configuration to control this behavior be a better option? I think that was an idea that was discussed in a previous discussion here. Please let me know!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt like we needed a way for the user to control which type they would be using for their arrow table or RecordBatchReader

I don't think we should expose this in the public API. Do people want to control this? In an ideal world:

  • When writing you want to take the type that's being handed to PyIceberg from the user
  • When reading you want to take this information from what comes out of the Parquet files

My first assumption was to go with the large one since that seems what most libraries seem to be using. But unfortunately, that doesn't seem to be the case.

Copy link
Contributor

@Fokko Fokko Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@syun64 @HonahX I've played around with this, and I think we can let Arrow decide on the types: #902

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should expose this in the public API. Do people want to control this? In an ideal world:

  • When writing you want to take the type that's being handed to PyIceberg from the user
  • When reading you want to take this information from what comes out of the Parquet files

I agree with this in the ideal world. However, PyArrow API cannot handle both large_* and normal types in its APIs without the us manually casting the type to one or the other. For example, the RecordBatchReader will fail to produce the next RecordBatch if the schema doesn't align completely, and requires us to choose one and always cast the types.

If the concern is in exposing this option in the public API, I think we can walk back on this change and remove it from:

  • to_arrow_batch_reader()
  • to_arrow_table()
  • to_requested_schema()

But we may still need it in schema_to_pyarrow because here, we are making an opinionated decision about the type we are choosing to represent the data as, for when we write and for when we read.

My first assumption was to go with the large one since that seems what most libraries seem to be using. But unfortunately, that doesn't seem to be the case.

  • I think this is the case for daft and polars:

Daft:

>>> import pyarrow as pa
>>> import pyarrow.parquet as pq
>>> import daft
>>> daft.read_parquet("strings.parquet").to_arrow()
pyarrow.Table                                                                                                                                                                                                                                                                                                         
strings: large_string
----
strings: [["a","b"]]
>>> daft.read_parquet("strings.parquet").to_arrow()
pyarrow.Table                                                                                                                                                                                                                                                                                                         
strings: large_string
----
strings: [["a","b"]]
>>> daft.read_parquet("strings.parquet").to_arrow().cast(pa.schema([("strings", pa.string())])).write_parquet("small-strings.parquet")
Traceback (most recent call last):                                                                                                                                                                                                                                                                                    
  File "<stdin>", line 1, in <module>
AttributeError: 'pyarrow.lib.Table' object has no attribute 'write_parquet'
>>> daft.from_arrow(daft.read_parquet("strings.parquet").to_arrow().cast(pa.schema([("strings", pa.string())]))).write_parquet("small-strings.parquet")
╭────────────────────────────────╮                                                                                                                                                                                                                                                                                    
│ path                           │
│ ---                            │
│ Utf8                           │
╞════════════════════════════════╡
│ small-strings.parquet/74515f6… │
╰────────────────────────────────╯

(Showing first 1 of 1 rows)
>>> daft.read_parquet("small-strings.parquet").to_arrow()
pyarrow.Table                                                                                                                                                                                                                                                                                                         
strings: large_string
----
strings: [["a","b"]]
>>> pq.read_table("small-strings.parquet")
pyarrow.Table
strings: large_string
----
strings: [["a","b"]]

from pyiceberg.io.pyarrow import project_table

return project_table(
Expand All @@ -2021,15 +2021,16 @@ def to_arrow(self) -> pa.Table:
self.projection(),
case_sensitive=self.case_sensitive,
limit=self.limit,
with_large_types=with_large_types,
)

def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
def to_arrow_batch_reader(self, with_large_types: bool = True) -> pa.RecordBatchReader:
import pyarrow as pa

from pyiceberg.io.pyarrow import project_batches, schema_to_pyarrow

return pa.RecordBatchReader.from_batches(
schema_to_pyarrow(self.projection()),
schema_to_pyarrow(self.projection(), include_field_ids=False, with_large_types=with_large_types),
project_batches(
self.plan_files(),
self.table_metadata,
Expand All @@ -2038,6 +2039,7 @@ def to_arrow_batch_reader(self) -> pa.RecordBatchReader:
self.projection(),
case_sensitive=self.case_sensitive,
limit=self.limit,
with_large_types=with_large_types,
),
)

Expand Down
Loading