Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
from pyiceberg.schema import Schema
from pyiceberg.serializers import ToOutputFile
from pyiceberg.table import (
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE,
CommitTableRequest,
CommitTableResponse,
CreateTableTransaction,
Expand Down Expand Up @@ -674,7 +673,7 @@ def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema:
try:
import pyarrow as pa

from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow
from pyiceberg.io.pyarrow import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, _ConvertToIcebergWithoutIDs, visit_pyarrow

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
if isinstance(schema, pa.Schema):
Expand Down
50 changes: 49 additions & 1 deletion pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@

ONE_MEGABYTE = 1024 * 1024
BUFFER_SIZE = "buffer-size"
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"

ICEBERG_SCHEMA = b"iceberg.schema"
# The PARQUET: in front means that it is Parquet specific, in this case the field_id
PYARROW_PARQUET_FIELD_ID_KEY = b"PARQUET:field_id"
Expand Down Expand Up @@ -1934,7 +1936,7 @@ def data_file_statistics_from_parquet_metadata(


def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteTask]) -> Iterator[DataFile]:
from pyiceberg.table import DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE, PropertyUtil, TableProperties
from pyiceberg.table import PropertyUtil, TableProperties

parquet_writer_kwargs = _get_parquet_writer_kwargs(table_metadata.properties)
row_group_size = PropertyUtil.property_as_int(
Expand Down Expand Up @@ -2015,6 +2017,50 @@ def bin_pack_arrow_table(tbl: pa.Table, target_file_size: int) -> Iterator[List[
return bin_packed_record_batches


def _check_schema_compatible(table_schema: Schema, other_schema: pa.Schema) -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.

Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.

Raises:
ValueError: If the schemas are not compatible.
"""
name_mapping = table_schema.name_mapping
downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
try:
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")

console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")


def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_paths: Iterator[str]) -> Iterator[DataFile]:
for file_path in file_paths:
input_file = io.new_input(file_path)
Expand All @@ -2026,6 +2072,8 @@ def parquet_files_to_data_files(io: FileIO, table_metadata: TableMetadata, file_
f"Cannot add file {file_path} because it has field IDs. `add_files` only supports addition of files without field_ids"
)
schema = table_metadata.schema()
_check_schema_compatible(schema, parquet_metadata.schema.to_arrow_schema())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understand is that now if we enable downcast-ns-timestamp-to-us-on-write, we allow user to add parquet files with TIMESTAMP_NANOS type data. My concern here is that we may add parquet files that not align with spec, which states that timestamp/timstamptz type should map to TIMESTAMP_MICROS. Shall we be more restrictive when checking the parquet file that will be directly added to the table?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the catch @HonahX - I think you are right. Adding a nanosecond timestamp file doesn't correctly allow Spark Iceberg to read the file and instead results in exceptions like:

ValueError: year 53177 is out of range

I will make downcast_ns_timestamp_to_us_on_write an input argument to _check_schema_compatible, so that we can prevent nanoseconds timestamp types from being added through add_files, but can continue to support it being downcast in overwrite/append

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I think it is okay to be able to read more broadly. We do need tests to ensure that it works correctly. Looking at Arrow, there are already some physical types that we don't support (date64, etc). In Java, we do support reading Timestamps that are encoded using INT96, we should not produce them.


statistics = data_file_statistics_from_parquet_metadata(
parquet_metadata=parquet_metadata,
stats_columns=compute_statistics_plan(schema, table_metadata.properties),
Expand Down
50 changes: 1 addition & 49 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
manifest_evaluator,
)
from pyiceberg.io import FileIO, OutputFile, load_file_io
from pyiceberg.io.pyarrow import _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.io.pyarrow import _check_schema_compatible, _dataframe_to_data_files, expression_to_pyarrow, project_table
from pyiceberg.manifest import (
POSITIONAL_DELETE_SCHEMA,
DataFile,
Expand Down Expand Up @@ -151,7 +151,6 @@
)
from pyiceberg.utils.bin_packing import ListPacker
from pyiceberg.utils.concurrent import ExecutorFactory
from pyiceberg.utils.config import Config
from pyiceberg.utils.datetime import datetime_to_millis
from pyiceberg.utils.deprecated import deprecated
from pyiceberg.utils.singleton import _convert_to_hashable_type
Expand All @@ -167,56 +166,9 @@

ALWAYS_TRUE = AlwaysTrue()
TABLE_ROOT_ID = -1
DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE = "downcast-ns-timestamp-to-us-on-write"
_JAVA_LONG_MAX = 9223372036854775807


def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None:
"""
Check if the `table_schema` is compatible with `other_schema`.

Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type.

Raises:
ValueError: If the schemas are not compatible.
"""
from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema

downcast_ns_timestamp_to_us = Config().get_bool(DOWNCAST_NS_TIMESTAMP_TO_US_ON_WRITE) or False
name_mapping = table_schema.name_mapping
try:
task_schema = pyarrow_to_schema(
other_schema, name_mapping=name_mapping, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us
)
except ValueError as e:
other_schema = _pyarrow_to_schema_without_ids(other_schema, downcast_ns_timestamp_to_us=downcast_ns_timestamp_to_us)
additional_names = set(other_schema.column_names) - set(table_schema.column_names)
raise ValueError(
f"PyArrow table contains more columns: {', '.join(sorted(additional_names))}. Update the schema first (hint, use union_by_name)."
) from e

if table_schema.as_struct() != task_schema.as_struct():
from rich.console import Console
from rich.table import Table as RichTable

console = Console(record=True)

rich_table = RichTable(show_header=True, header_style="bold")
rich_table.add_column("")
rich_table.add_column("Table field")
rich_table.add_column("Dataframe field")

for lhs in table_schema.fields:
try:
rhs = task_schema.find_field(lhs.field_id)
rich_table.add_row("✅" if lhs == rhs else "❌", str(lhs), str(rhs))
except ValueError:
rich_table.add_row("❌", str(lhs), "Missing")

console.print(rich_table)
raise ValueError(f"Mismatch in fields:\n{console.export_text()}")


class TableProperties:
PARQUET_ROW_GROUP_SIZE_BYTES = "write.parquet.row-group-size-bytes"
PARQUET_ROW_GROUP_SIZE_BYTES_DEFAULT = 128 * 1024 * 1024 # 128 MB
Expand Down
51 changes: 51 additions & 0 deletions tests/integration/test_add_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,57 @@ def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Cat
assert summary["snapshot_prop_a"] == "test_prop_a"


@pytest.mark.integration
def test_add_files_fails_on_schema_mismatch(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.table_schema_mismatch_fails_v{format_version}"

tbl = _create_table(session_catalog, identifier, format_version)
WRONG_SCHEMA = pa.schema([
("foo", pa.bool_()),
("bar", pa.string()),
("baz", pa.string()), # should be integer
("qux", pa.date32()),
])
file_path = f"s3://warehouse/default/table_schema_mismatch_fails/v{format_version}/test.parquet"
# write parquet files
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=WRONG_SCHEMA) as writer:
writer.write_table(
pa.Table.from_pylist(
[
{
"foo": True,
"bar": "bar_string",
"baz": "123",
"qux": date(2024, 3, 7),
},
{
"foo": True,
"bar": "bar_string",
"baz": "124",
"qux": date(2024, 3, 7),
},
],
schema=WRONG_SCHEMA,
)
)

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional boolean │ 1: foo: optional boolean │
| ✅ │ 2: bar: optional string │ 2: bar: optional string │
│ ❌ │ 3: baz: optional int │ 3: baz: optional string │
│ ✅ │ 4: qux: optional date │ 4: qux: optional date │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
tbl.add_files(file_paths=[file_path])


@pytest.mark.integration
def test_timestamp_tz_ns_downcast_on_read(session_catalog: Catalog, format_version: int, mocker: MockerFixture) -> None:
nanoseconds_schema_iceberg = Schema(NestedField(1, "quux", TimestamptzType()))
Expand Down
91 changes: 91 additions & 0 deletions tests/io/test_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
PyArrowFile,
PyArrowFileIO,
StatsAggregator,
_check_schema_compatible,
_ConvertToArrowSchema,
_primitive_to_physical,
_read_deletes,
Expand Down Expand Up @@ -1718,3 +1719,93 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None:
# and will produce half the number of files if we double the target size
bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2)
assert len(list(bin_packed)) == 5


def test_schema_mismatch_type(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.decimal128(18, 6), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = r"""Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴─────────────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ 2: bar: optional int │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
))

expected = """Mismatch in fields:
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ ┃ Table field ┃ Dataframe field ┃
┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ ✅ │ 1: foo: optional string │ 1: foo: optional string │
│ ❌ │ 2: bar: required int │ Missing │
│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │
└────┴──────────────────────────┴──────────────────────────┘
"""

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None:
other_schema = pa.schema((
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=True),
pa.field("baz", pa.bool_(), nullable=True),
pa.field("new_field", pa.date32(), nullable=True),
))

expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)."

with pytest.raises(ValueError, match=expected):
_check_schema_compatible(table_schema_simple, other_schema)


def test_schema_downcast(table_schema_simple: Schema) -> None:
# large_string type is compatible with string type
other_schema = pa.schema((
pa.field("foo", pa.large_string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
))

try:
_check_schema_compatible(table_schema_simple, other_schema)
except Exception:
pytest.fail("Unexpected Exception raised when calling `_check_schema`")
Loading