Skip to content

Commit f35c021

Browse files
committed
checkpoint: managed to let code run without error. need to handle column coalescing next
1 parent 2813897 commit f35c021

File tree

3 files changed

+51
-12
lines changed

3 files changed

+51
-12
lines changed

bigframes/core/blocks.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2341,8 +2341,6 @@ def merge(
23412341
joined_expr, (get_column_left, get_column_right) = self.expr.relational_join(
23422342
other.expr, type=how, conditions=conditions
23432343
)
2344-
result_columns = []
2345-
matching_join_labels = []
23462344

23472345
left_post_join_ids = tuple(get_column_left[id] for id in left_join_ids)
23482346
right_post_join_ids = tuple(get_column_right[id] for id in right_join_ids)
@@ -2351,12 +2349,18 @@ def merge(
23512349
joined_expr, left_post_join_ids, right_post_join_ids, how=how, drop=False
23522350
)
23532351

2352+
result_columns = []
2353+
matching_join_labels = []
2354+
# Select left value columns
23542355
for col_id in self.value_columns:
23552356
if col_id in left_join_ids:
23562357
key_part = left_join_ids.index(col_id)
23572358
matching_right_id = right_join_ids[key_part]
2358-
if self.col_id_to_label[col_id] == other.col_id_to_label.get(
2359-
matching_right_id, None
2359+
2360+
if (
2361+
matching_right_id in other.index_columns # Coalesce with right index
2362+
or self.col_id_to_label[col_id]
2363+
== other.col_id_to_label[matching_right_id]
23602364
):
23612365
matching_join_labels.append(self.col_id_to_label[col_id])
23622366
result_columns.append(coalesced_ids[key_part])
@@ -2382,7 +2386,13 @@ def merge(
23822386
],
23832387
)
23842388

2385-
joined_expr = joined_expr.select_columns(result_columns)
2389+
left_idx_id_post_join = [get_column_left[id] for id in self.index_columns]
2390+
right_idx_id_post_join = [get_column_right[id] for id in other.index_columns]
2391+
index_cols = _resolve_index_col(
2392+
left_idx_id_post_join, right_idx_id_post_join, left_index, right_index, how
2393+
)
2394+
2395+
joined_expr = joined_expr.select_columns(result_columns + index_cols)
23862396
labels = utils.merge_column_labels(
23872397
self.column_labels,
23882398
other.column_labels,
@@ -2402,10 +2412,8 @@ def merge(
24022412
or self.session._default_index_type == bigframes.enums.DefaultIndexKind.NULL
24032413
):
24042414
return Block(joined_expr, index_columns=[], column_labels=labels)
2405-
elif left_index:
2406-
return Block(joined_expr, index_columns=[left_post_join_ids], column_labels=labels)
2407-
elif right_index:
2408-
return Block(joined_expr, index_columns=[right_post_join_ids], column_labels=labels)
2415+
elif index_cols:
2416+
return Block(joined_expr, index_columns=index_cols, column_labels=labels)
24092417
else:
24102418
expr, offset_index_id = joined_expr.promote_offsets()
24112419
index_columns = [offset_index_id]
@@ -3471,3 +3479,33 @@ def _pd_index_to_array_value(
34713479
rows.append(row)
34723480

34733481
return core.ArrayValue.from_pyarrow(pa.Table.from_pylist(rows), session=session)
3482+
3483+
3484+
def _resolve_index_col(
3485+
left_index_cols: list[str],
3486+
right_index_cols: list[str],
3487+
left_index: bool,
3488+
right_index: bool,
3489+
how: typing.Literal[
3490+
"inner",
3491+
"left",
3492+
"outer",
3493+
"right",
3494+
"cross",
3495+
],
3496+
) -> list[str]:
3497+
if left_index and right_index:
3498+
if how == "inner" or how == "left":
3499+
return left_index_cols
3500+
if how == "right":
3501+
return right_index_cols
3502+
if how == "outer":
3503+
return []
3504+
else:
3505+
return []
3506+
elif left_index and not right_index:
3507+
return right_index_cols
3508+
elif right_index and not left_index:
3509+
return left_index_cols
3510+
else:
3511+
return []

bigframes/core/reshape/merge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def merge(
7979
sort=sort,
8080
suffixes=suffixes,
8181
left_index=left_index,
82-
right_index=right_index
82+
right_index=right_index,
8383
)
8484
return dataframe.DataFrame(block)
8585

tests/system/small/core/test_reshape.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def test_join_with_index(
5757
bf_result, pd_result, check_dtype=False, check_index_type=False
5858
)
5959

60+
6061
@pytest.mark.parametrize(
6162
("left_on", "right_on", "left_index", "right_index"),
6263
[
@@ -68,8 +69,8 @@ def test_join_with_index(
6869
def test_join_with_multiindex(
6970
session: session.Session, left_on, right_on, left_index, right_index
7071
):
71-
multi_idx = pd.MultiIndex.from_tuples([(1,2), (2, 3), (3,4)])
72-
df1 = pd.DataFrame({"col_a": [1, 2, 3], "col_b": [2, 3, 4]}, index=multi_idx)
72+
multi_idx = pd.MultiIndex.from_tuples([(1, 2), (2, 3), (3, 4)])
73+
df1 = pd.DataFrame({"col_a": [1, 2, 3], "col_b": [2, 3, 4]}, index=multi_idx)
7374
bf1 = session.read_pandas(df1)
7475
df2 = pd.DataFrame({"col_c": [1, 2, 3], "col_d": [2, 3, 4]}, index=multi_idx)
7576
bf2 = session.read_pandas(df2)

0 commit comments

Comments
 (0)