From 3dcc344e8e8511b8d0d4087b16c8cc1bb4af04a2 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Wed, 13 Mar 2024 21:25:35 -0700 Subject: [PATCH 1/8] cast to pyarrow schema --- pyiceberg/table/__init__.py | 5 +++++ tests/catalog/test_sql.py | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2b15cdeb08..bfcdae5150 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1147,6 +1147,8 @@ def overwrite( except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e + from pyiceberg.io.pyarrow import schema_to_pyarrow + if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") @@ -1157,6 +1159,9 @@ def overwrite( raise ValueError("Cannot write to partitioned tables") _check_schema(self.schema(), other_schema=df.schema) + # safe to cast + pyarrow_schema = schema_to_pyarrow(self.schema()) + df = df.cast(pyarrow_schema) with self.transaction() as txn: with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot: diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 3a77f8678a..b20f617e32 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -193,6 +193,39 @@ def test_create_table_with_pyarrow_schema( catalog.drop_table(random_identifier) +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + # lazy_fixture('catalog_sqlite'), + ], +) +def test_write_pyarrow_schema(catalog: SqlCatalog, random_identifier: Identifier) -> None: + import pyarrow as pa + + pyarrow_table = pa.Table.from_arrays( + [ + pa.array([None, "A", "B", "C"]), # 'foo' column + pa.array([1, 2, 3, 4]), # 'bar' column + pa.array([True, None, False, True]), # 'baz' column + pa.array([None, "A", "B", "C"]), # 'large' column + ], + schema=pa.schema([ + pa.field('foo', pa.string(), nullable=True), + pa.field('bar', pa.int32(), nullable=False), + pa.field('baz', pa.bool_(), nullable=True), + pa.field('large', pa.large_string(), nullable=True), + ]), + ) + database_name, _table_name = random_identifier + catalog.create_namespace(database_name) + table = catalog.create_table(random_identifier, pyarrow_table.schema) + print(pyarrow_table.schema) + print(table.schema().as_struct()) + print() + table.overwrite(pyarrow_table) + + @pytest.mark.parametrize( 'catalog', [ From 05e74442c6d50b4ea1c508e6a04c674b96cf5f76 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Sun, 24 Mar 2024 12:37:44 -0700 Subject: [PATCH 2/8] use Schema.as_arrow() --- pyiceberg/table/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index bfcdae5150..2d27ec469d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1147,8 +1147,6 @@ def overwrite( except ModuleNotFoundError as e: raise ModuleNotFoundError("For writes PyArrow needs to be installed") from e - from pyiceberg.io.pyarrow import schema_to_pyarrow - if not isinstance(df, pa.Table): raise ValueError(f"Expected PyArrow table, got: {df}") @@ -1160,8 +1158,7 @@ def overwrite( _check_schema(self.schema(), other_schema=df.schema) # safe to cast - pyarrow_schema = schema_to_pyarrow(self.schema()) - df = df.cast(pyarrow_schema) + df = df.cast(self.schema().as_arrow()) with self.transaction() as txn: with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot: From d231dbc6507854ee8e5a8c2445322dcfad803c75 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Sun, 24 Mar 2024 16:20:20 -0700 Subject: [PATCH 3/8] also for append --- pyiceberg/table/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 2d27ec469d..ee89efa93d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1119,6 +1119,8 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) raise ValueError("Cannot write to partitioned tables") _check_schema(self.schema(), other_schema=df.schema) + # safe to cast + df = df.cast(self.schema().as_arrow()) with self.transaction() as txn: with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: From 5b553ab6873cd81cd9975f150d5c5f76cb90796a Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Mon, 25 Mar 2024 09:29:24 -0700 Subject: [PATCH 4/8] _check_schema_compatible --- pyiceberg/table/__init__.py | 14 +++++++++++--- tests/table/test_init.py | 10 +++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index ee89efa93d..e6c8efd98d 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -145,7 +145,15 @@ _JAVA_LONG_MAX = 9223372036854775807 -def _check_schema(table_schema: Schema, other_schema: "pa.Schema") -> None: +def _check_schema_compatible(table_schema: Schema, other_schema: "pa.Schema") -> None: + """ + Check if the `table_schema` is compatible with `other_schema`. + + Two schemas are considered compatible when they are equal in terms of the Iceberg Schema type. + + Raises: + ValueError: If the schemas are not compatible. + """ from pyiceberg.io.pyarrow import _pyarrow_to_schema_without_ids, pyarrow_to_schema name_mapping = table_schema.name_mapping @@ -1118,7 +1126,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) if len(self.spec().fields) > 0: raise ValueError("Cannot write to partitioned tables") - _check_schema(self.schema(), other_schema=df.schema) + _check_schema_compatible(self.schema(), other_schema=df.schema) # safe to cast df = df.cast(self.schema().as_arrow()) @@ -1158,7 +1166,7 @@ def overwrite( if len(self.spec().fields) > 0: raise ValueError("Cannot write to partitioned tables") - _check_schema(self.schema(), other_schema=df.schema) + _check_schema_compatible(self.schema(), other_schema=df.schema) # safe to cast df = df.cast(self.schema().as_arrow()) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index bb212d696e..5459e9c79b 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -63,7 +63,7 @@ TableIdentifier, UpdateSchema, _apply_table_update, - _check_schema, + _check_schema_compatible, _match_deletes_to_data_file, _TableMetadataUpdateContext, update_table_metadata, @@ -1033,7 +1033,7 @@ def test_schema_mismatch_type(table_schema_simple: Schema) -> None: """ with pytest.raises(ValueError, match=expected): - _check_schema(table_schema_simple, other_schema) + _check_schema_compatible(table_schema_simple, other_schema) def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: @@ -1054,7 +1054,7 @@ def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: """ with pytest.raises(ValueError, match=expected): - _check_schema(table_schema_simple, other_schema) + _check_schema_compatible(table_schema_simple, other_schema) def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: @@ -1074,7 +1074,7 @@ def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: """ with pytest.raises(ValueError, match=expected): - _check_schema(table_schema_simple, other_schema) + _check_schema_compatible(table_schema_simple, other_schema) def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: @@ -1088,7 +1088,7 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." with pytest.raises(ValueError, match=expected): - _check_schema(table_schema_simple, other_schema) + _check_schema_compatible(table_schema_simple, other_schema) def test_table_properties(example_table_metadata_v2: Dict[str, Any]) -> None: From 5103d8a8af2b8bfff7d080adda96263ef5695f30 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Mon, 25 Mar 2024 09:30:15 -0700 Subject: [PATCH 5/8] comment --- pyiceberg/table/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index e6c8efd98d..1e81357367 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1127,7 +1127,7 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) raise ValueError("Cannot write to partitioned tables") _check_schema_compatible(self.schema(), other_schema=df.schema) - # safe to cast + # the two schemas are compatible so safe to cast df = df.cast(self.schema().as_arrow()) with self.transaction() as txn: @@ -1167,7 +1167,7 @@ def overwrite( raise ValueError("Cannot write to partitioned tables") _check_schema_compatible(self.schema(), other_schema=df.schema) - # safe to cast + # the two schemas are compatible so safe to cast df = df.cast(self.schema().as_arrow()) with self.transaction() as txn: From f565dc83a78f994c3665370d14d3683133c6e7a6 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Mon, 25 Mar 2024 09:31:36 -0700 Subject: [PATCH 6/8] use .as_arrow() --- pyiceberg/io/pyarrow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 72de14880a..9dbb5e8abc 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1784,7 +1784,7 @@ def write_file(io: FileIO, table_metadata: TableMetadata, tasks: Iterator[WriteT file_path = f'{table_metadata.location}/data/{task.generate_data_file_filename("parquet")}' schema = table_metadata.schema() - arrow_file_schema = schema_to_pyarrow(schema) + arrow_file_schema = schema.as_arrow() fo = io.new_output(file_path) row_group_size = PropertyUtil.property_as_int( From 1d6a08c204ae2a891af679f68449964acf5f046e Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Mon, 25 Mar 2024 09:36:35 -0700 Subject: [PATCH 7/8] add test for downcast schema --- tests/table/test_init.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 5459e9c79b..f1191295f3 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -1091,6 +1091,20 @@ def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: _check_schema_compatible(table_schema_simple, other_schema) +def test_schema_downcast(table_schema_simple: Schema) -> None: + # large_string type is compatible with string type + other_schema = pa.schema(( + pa.field("foo", pa.large_string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + )) + + try: + _check_schema_compatible(table_schema_simple, other_schema) + except Exception: + pytest.fail("Unexpected Exception raised when calling `_check_schema`") + + def test_table_properties(example_table_metadata_v2: Dict[str, Any]) -> None: # metadata properties are all strings for k, v in example_table_metadata_v2["properties"].items(): From 6c7ca99758caef9f07de1ea34b222f5e63d892b3 Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Mon, 25 Mar 2024 20:20:30 -0700 Subject: [PATCH 8/8] cast only when necessary --- pyiceberg/table/__init__.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 1e81357367..2ad1f7fe81 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1127,8 +1127,9 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) raise ValueError("Cannot write to partitioned tables") _check_schema_compatible(self.schema(), other_schema=df.schema) - # the two schemas are compatible so safe to cast - df = df.cast(self.schema().as_arrow()) + # cast if the two schemas are compatible but not equal + if self.schema().as_arrow() != df.schema: + df = df.cast(self.schema().as_arrow()) with self.transaction() as txn: with txn.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: @@ -1167,8 +1168,9 @@ def overwrite( raise ValueError("Cannot write to partitioned tables") _check_schema_compatible(self.schema(), other_schema=df.schema) - # the two schemas are compatible so safe to cast - df = df.cast(self.schema().as_arrow()) + # cast if the two schemas are compatible but not equal + if self.schema().as_arrow() != df.schema: + df = df.cast(self.schema().as_arrow()) with self.transaction() as txn: with txn.update_snapshot(snapshot_properties=snapshot_properties).overwrite() as update_snapshot: