@@ -2341,8 +2341,6 @@ def merge(
23412341 joined_expr , (get_column_left , get_column_right ) = self .expr .relational_join (
23422342 other .expr , type = how , conditions = conditions
23432343 )
2344- result_columns = []
2345- matching_join_labels = []
23462344
23472345 left_post_join_ids = tuple (get_column_left [id ] for id in left_join_ids )
23482346 right_post_join_ids = tuple (get_column_right [id ] for id in right_join_ids )
@@ -2351,12 +2349,18 @@ def merge(
23512349 joined_expr , left_post_join_ids , right_post_join_ids , how = how , drop = False
23522350 )
23532351
2352+ result_columns = []
2353+ matching_join_labels = []
2354+ # Select left value columns
23542355 for col_id in self .value_columns :
23552356 if col_id in left_join_ids :
23562357 key_part = left_join_ids .index (col_id )
23572358 matching_right_id = right_join_ids [key_part ]
2358- if self .col_id_to_label [col_id ] == other .col_id_to_label .get (
2359- matching_right_id , None
2359+
2360+ if (
2361+ matching_right_id in other .index_columns # Coalesce with right index
2362+ or self .col_id_to_label [col_id ]
2363+ == other .col_id_to_label [matching_right_id ]
23602364 ):
23612365 matching_join_labels .append (self .col_id_to_label [col_id ])
23622366 result_columns .append (coalesced_ids [key_part ])
@@ -2382,7 +2386,13 @@ def merge(
23822386 ],
23832387 )
23842388
2385- joined_expr = joined_expr .select_columns (result_columns )
2389+ left_idx_id_post_join = [get_column_left [id ] for id in self .index_columns ]
2390+ right_idx_id_post_join = [get_column_right [id ] for id in other .index_columns ]
2391+ index_cols = _resolve_index_col (
2392+ left_idx_id_post_join , right_idx_id_post_join , left_index , right_index , how
2393+ )
2394+
2395+ joined_expr = joined_expr .select_columns (result_columns + index_cols )
23862396 labels = utils .merge_column_labels (
23872397 self .column_labels ,
23882398 other .column_labels ,
@@ -2402,10 +2412,8 @@ def merge(
24022412 or self .session ._default_index_type == bigframes .enums .DefaultIndexKind .NULL
24032413 ):
24042414 return Block (joined_expr , index_columns = [], column_labels = labels )
2405- elif left_index :
2406- return Block (joined_expr , index_columns = [left_post_join_ids ], column_labels = labels )
2407- elif right_index :
2408- return Block (joined_expr , index_columns = [right_post_join_ids ], column_labels = labels )
2415+ elif index_cols :
2416+ return Block (joined_expr , index_columns = index_cols , column_labels = labels )
24092417 else :
24102418 expr , offset_index_id = joined_expr .promote_offsets ()
24112419 index_columns = [offset_index_id ]
@@ -3471,3 +3479,33 @@ def _pd_index_to_array_value(
34713479 rows .append (row )
34723480
34733481 return core .ArrayValue .from_pyarrow (pa .Table .from_pylist (rows ), session = session )
3482+
3483+
3484+ def _resolve_index_col (
3485+ left_index_cols : list [str ],
3486+ right_index_cols : list [str ],
3487+ left_index : bool ,
3488+ right_index : bool ,
3489+ how : typing .Literal [
3490+ "inner" ,
3491+ "left" ,
3492+ "outer" ,
3493+ "right" ,
3494+ "cross" ,
3495+ ],
3496+ ) -> list [str ]:
3497+ if left_index and right_index :
3498+ if how == "inner" or how == "left" :
3499+ return left_index_cols
3500+ if how == "right" :
3501+ return right_index_cols
3502+ if how == "outer" :
3503+ return []
3504+ else :
3505+ return []
3506+ elif left_index and not right_index :
3507+ return right_index_cols
3508+ elif right_index and not left_index :
3509+ return left_index_cols
3510+ else :
3511+ return []
0 commit comments