diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 7c63aa79a1..e5572e6e52 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -859,7 +859,11 @@ def upsert( return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt) def add_files( - self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True + self, + file_paths: List[str], + snapshot_properties: Dict[str, str] = EMPTY_DICT, + check_duplicate_files: bool = True, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand API for adding files as data files to the table transaction. @@ -888,12 +892,12 @@ def add_files( self.set_properties( **{TableProperties.DEFAULT_NAME_MAPPING: self.table_metadata.schema().name_mapping.model_dump_json()} ) - with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot: + with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files: data_files = _parquet_files_to_data_files( table_metadata=self.table_metadata, file_paths=file_paths, io=self._table.io ) for data_file in data_files: - update_snapshot.append_data_file(data_file) + append_files.append_data_file(data_file) def update_spec(self) -> UpdateSpec: """Create a new UpdateSpec to update the partitioning of the table. @@ -1431,7 +1435,11 @@ def delete( ) def add_files( - self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True + self, + file_paths: List[str], + snapshot_properties: Dict[str, str] = EMPTY_DICT, + check_duplicate_files: bool = True, + branch: Optional[str] = MAIN_BRANCH, ) -> None: """ Shorthand API for adding files as data files to the table. @@ -1444,7 +1452,10 @@ def add_files( """ with self.transaction() as tx: tx.add_files( - file_paths=file_paths, snapshot_properties=snapshot_properties, check_duplicate_files=check_duplicate_files + file_paths=file_paths, + snapshot_properties=snapshot_properties, + check_duplicate_files=check_duplicate_files, + branch=branch, ) def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: diff --git a/tests/integration/test_add_files.py b/tests/integration/test_add_files.py index 64c8028be7..84a30ab371 100644 --- a/tests/integration/test_add_files.py +++ b/tests/integration/test_add_files.py @@ -926,3 +926,36 @@ def test_add_files_hour_transform(session_catalog: Catalog) -> None: writer.write_table(arrow_table) tbl.add_files(file_paths=[file_path]) + + +@pytest.mark.integration +def test_add_files_to_branch(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: + identifier = f"default.test_add_files_branch_v{format_version}" + branch = "branch1" + + tbl = _create_table(session_catalog, identifier, format_version) + + file_paths = [f"s3://warehouse/default/addfile/v{format_version}/test-{i}.parquet" for i in range(5)] + # write parquet files + for file_path in file_paths: + fo = tbl.io.new_output(file_path) + with fo.create(overwrite=True) as fos: + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: + writer.write_table(ARROW_TABLE) + + # Dummy write to avoid failures on creating branch in empty table + tbl.append(ARROW_TABLE) + assert tbl.metadata.current_snapshot_id is not None + tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit() + + # add the parquet files as data files + tbl.add_files(file_paths=file_paths, branch=branch) + + df = spark.table(identifier) + assert df.count() == 1, "Expected 1 row in Main table" + + branch_df = spark.table(f"{identifier}.branch_{branch}") + assert branch_df.count() == 6, "Expected 5 rows in branch" + + for col in branch_df.columns: + assert branch_df.filter(branch_df[col].isNotNull()).count() == 6, "Expected all 6 rows to be non-null"