Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,7 +859,11 @@ def upsert(
return UpsertResult(rows_updated=update_row_cnt, rows_inserted=insert_row_cnt)

def add_files(
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
self,
file_paths: List[str],
snapshot_properties: Dict[str, str] = EMPTY_DICT,
check_duplicate_files: bool = True,
branch: Optional[str] = MAIN_BRANCH,
) -> None:
"""
Shorthand API for adding files as data files to the table transaction.
Expand Down Expand Up @@ -888,12 +892,12 @@ def add_files(
self.set_properties(
**{TableProperties.DEFAULT_NAME_MAPPING: self.table_metadata.schema().name_mapping.model_dump_json()}
)
with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
with self._append_snapshot_producer(snapshot_properties, branch=branch) as append_files:
data_files = _parquet_files_to_data_files(
table_metadata=self.table_metadata, file_paths=file_paths, io=self._table.io
)
for data_file in data_files:
update_snapshot.append_data_file(data_file)
append_files.append_data_file(data_file)

def update_spec(self) -> UpdateSpec:
"""Create a new UpdateSpec to update the partitioning of the table.
Expand Down Expand Up @@ -1431,7 +1435,11 @@ def delete(
)

def add_files(
self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT, check_duplicate_files: bool = True
self,
file_paths: List[str],
snapshot_properties: Dict[str, str] = EMPTY_DICT,
check_duplicate_files: bool = True,
branch: Optional[str] = MAIN_BRANCH,
) -> None:
"""
Shorthand API for adding files as data files to the table.
Expand All @@ -1444,7 +1452,10 @@ def add_files(
"""
with self.transaction() as tx:
tx.add_files(
file_paths=file_paths, snapshot_properties=snapshot_properties, check_duplicate_files=check_duplicate_files
file_paths=file_paths,
snapshot_properties=snapshot_properties,
check_duplicate_files=check_duplicate_files,
branch=branch,
)

def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
Expand Down
33 changes: 33 additions & 0 deletions tests/integration/test_add_files.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nice to add a negative test that attempts to add files to a non-existent branch, just to make sure that exceptions are handled gracefully and that meaningful errors are surfaced.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be same as the test for appending to non-existing branch:

def test_append_to_non_existing_branch(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:

If seen from the POV of a snapshot, the flow for add_files and append operation is same as only a new snapshot with new files is being appended.
Since, we are not adding any different code, it just introduces another test which goes through the same flow thus increasing test time.

My suggestion would be to not bloat the test time with similar tests
Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me!

Original file line number Diff line number Diff line change
Expand Up @@ -926,3 +926,36 @@ def test_add_files_hour_transform(session_catalog: Catalog) -> None:
writer.write_table(arrow_table)

tbl.add_files(file_paths=[file_path])


@pytest.mark.integration
def test_add_files_to_branch(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
identifier = f"default.test_add_files_branch_v{format_version}"
branch = "branch1"

tbl = _create_table(session_catalog, identifier, format_version)

file_paths = [f"s3://warehouse/default/addfile/v{format_version}/test-{i}.parquet" for i in range(5)]
# write parquet files
for file_path in file_paths:
fo = tbl.io.new_output(file_path)
with fo.create(overwrite=True) as fos:
with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
writer.write_table(ARROW_TABLE)

# Dummy write to avoid failures on creating branch in empty table
tbl.append(ARROW_TABLE)
assert tbl.metadata.current_snapshot_id is not None
tbl.manage_snapshots().create_branch(snapshot_id=tbl.metadata.current_snapshot_id, branch_name=branch).commit()

# add the parquet files as data files
tbl.add_files(file_paths=file_paths, branch=branch)

df = spark.table(identifier)
assert df.count() == 1, "Expected 1 row in Main table"

branch_df = spark.table(f"{identifier}.branch_{branch}")
assert branch_df.count() == 6, "Expected 5 rows in branch"

for col in branch_df.columns:
assert branch_df.filter(branch_df[col].isNotNull()).count() == 6, "Expected all 6 rows to be non-null"