Skip to content

Commit 6ea3f63

Browse files
committed
checkpoint: managed to let code run without error. need to handle column coalescing next
1 parent 2813897 commit 6ea3f63

File tree

3 files changed

+46
-10
lines changed

3 files changed

+46
-10
lines changed

bigframes/core/blocks.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,8 +2341,6 @@ def merge(
23412341
joined_expr, (get_column_left, get_column_right) = self.expr.relational_join(
23422342
other.expr, type=how, conditions=conditions
23432343
)
2344-
result_columns = []
2345-
matching_join_labels = []
23462344

23472345
left_post_join_ids = tuple(get_column_left[id] for id in left_join_ids)
23482346
right_post_join_ids = tuple(get_column_right[id] for id in right_join_ids)
@@ -2351,6 +2349,9 @@ def merge(
23512349
joined_expr, left_post_join_ids, right_post_join_ids, how=how, drop=False
23522350
)
23532351

2352+
result_columns = []
2353+
matching_join_labels = []
2354+
# Select left value columns
23542355
for col_id in self.value_columns:
23552356
if col_id in left_join_ids:
23562357
key_part = left_join_ids.index(col_id)
@@ -2382,7 +2383,13 @@ def merge(
23822383
],
23832384
)
23842385

2385-
joined_expr = joined_expr.select_columns(result_columns)
2386+
left_idx_id_post_join = [get_column_left[id] for id in self.index_columns]
2387+
right_idx_id_post_join = [get_column_right[id] for id in other.index_columns]
2388+
index_cols = _resolve_index_col(
2389+
left_idx_id_post_join, right_idx_id_post_join, left_index, right_index, how
2390+
)
2391+
2392+
joined_expr = joined_expr.select_columns(result_columns + index_cols)
23862393
labels = utils.merge_column_labels(
23872394
self.column_labels,
23882395
other.column_labels,
@@ -2402,10 +2409,8 @@ def merge(
24022409
or self.session._default_index_type == bigframes.enums.DefaultIndexKind.NULL
24032410
):
24042411
return Block(joined_expr, index_columns=[], column_labels=labels)
2405-
elif left_index:
2406-
return Block(joined_expr, index_columns=[left_post_join_ids], column_labels=labels)
2407-
elif right_index:
2408-
return Block(joined_expr, index_columns=[right_post_join_ids], column_labels=labels)
2412+
elif index_cols:
2413+
return Block(joined_expr, index_columns=index_cols, column_labels=labels)
24092414
else:
24102415
expr, offset_index_id = joined_expr.promote_offsets()
24112416
index_columns = [offset_index_id]
@@ -3471,3 +3476,33 @@ def _pd_index_to_array_value(
34713476
rows.append(row)
34723477

34733478
return core.ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=session)
3479+
3480+
3481+
def _resolve_index_col(
3482+
left_index_cols: list[str],
3483+
right_index_cols: list[str],
3484+
left_index: bool,
3485+
right_index: bool,
3486+
how: typing.Literal[
3487+
"inner",
3488+
"left",
3489+
"outer",
3490+
"right",
3491+
"cross",
3492+
],
3493+
) -> list[str]:
3494+
if left_index and right_index:
3495+
if how == "inner" or how == "left":
3496+
return left_index_cols
3497+
if how == "right":
3498+
return right_index_cols
3499+
if how == "outer":
3500+
return []
3501+
else:
3502+
return []
3503+
elif left_index and not right_index:
3504+
return right_index_cols
3505+
elif right_index and not left_index:
3506+
return left_index_cols
3507+
else:
3508+
return []

bigframes/core/reshape/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def merge(
7979
sort=sort,
8080
suffixes=suffixes,
8181
left_index=left_index,
82-
right_index=right_index
82+
right_index=right_index,
8383
)
8484
return dataframe.DataFrame(block)
8585

tests/system/small/core/test_reshape.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_join_with_index(
5757
bf_result, pd_result, check_dtype=False, check_index_type=False
5858
)
5959

60+
6061
@pytest.mark.parametrize(
6162
("left_on", "right_on", "left_index", "right_index"),
6263
[
@@ -68,8 +69,8 @@ def test_join_with_index(
6869
def test_join_with_multiindex(
6970
session: session.Session, left_on, right_on, left_index, right_index
7071
):
71-
multi_idx = pd.MultiIndex.from_tuples([(1,2), (2, 3), (3,4)])
72-
df1 = pd.DataFrame({"col_a": [1, 2, 3], "col_b": [2, 3, 4]}, index=multi_idx)
72+
multi_idx = pd.MultiIndex.from_tuples([(1, 2), (2, 3), (3, 4)])
73+
df1 = pd.DataFrame({"col_a": [1, 2, 3], "col_b": [2, 3, 4]}, index=multi_idx)
7374
bf1 = session.read_pandas(df1)
7475
df2 = pd.DataFrame({"col_c": [1, 2, 3], "col_d": [2, 3, 4]}, index=multi_idx)
7576
bf2 = session.read_pandas(df2)

0 commit comments

Comments
 (0)