@@ -2341,8 +2341,6 @@ def merge(
23412341 joined_expr , (get_column_left , get_column_right ) = self .expr .relational_join (
23422342 other .expr , type = how , conditions = conditions
23432343 )
2344- result_columns = []
2345- matching_join_labels = []
23462344
23472345 left_post_join_ids = tuple (get_column_left [id ] for id in left_join_ids )
23482346 right_post_join_ids = tuple (get_column_right [id ] for id in right_join_ids )
@@ -2351,6 +2349,9 @@ def merge(
23512349 joined_expr , left_post_join_ids , right_post_join_ids , how = how , drop = False
23522350 )
23532351
2352+ result_columns = []
2353+ matching_join_labels = []
2354+ # Select left value columns
23542355 for col_id in self .value_columns :
23552356 if col_id in left_join_ids :
23562357 key_part = left_join_ids .index (col_id )
@@ -2382,7 +2383,13 @@ def merge(
23822383 ],
23832384 )
23842385
2385- joined_expr = joined_expr .select_columns (result_columns )
2386+ left_idx_id_post_join = [get_column_left [id ] for id in self .index_columns ]
2387+ right_idx_id_post_join = [get_column_right [id ] for id in other .index_columns ]
2388+ index_cols = _resolve_index_col (
2389+ left_idx_id_post_join , right_idx_id_post_join , left_index , right_index , how
2390+ )
2391+
2392+ joined_expr = joined_expr .select_columns (result_columns + index_cols )
23862393 labels = utils .merge_column_labels (
23872394 self .column_labels ,
23882395 other .column_labels ,
@@ -2402,10 +2409,8 @@ def merge(
24022409 or self .session ._default_index_type == bigframes .enums .DefaultIndexKind .NULL
24032410 ):
24042411 return Block (joined_expr , index_columns = [], column_labels = labels )
2405- elif left_index :
2406- return Block (joined_expr , index_columns = [left_post_join_ids ], column_labels = labels )
2407- elif right_index :
2408- return Block (joined_expr , index_columns = [right_post_join_ids ], column_labels = labels )
2412+ elif index_cols :
2413+ return Block (joined_expr , index_columns = index_cols , column_labels = labels )
24092414 else :
24102415 expr , offset_index_id = joined_expr .promote_offsets ()
24112416 index_columns = [offset_index_id ]
@@ -3471,3 +3476,33 @@ def _pd_index_to_array_value(
34713476 rows .append (row )
34723477
34733478 return core .ArrayValue .from_pyarrow (pa .Table .from_pylist (rows ), session = session )
3479+
3480+
3481+ def _resolve_index_col (
3482+ left_index_cols : list [str ],
3483+ right_index_cols : list [str ],
3484+ left_index : bool ,
3485+ right_index : bool ,
3486+ how : typing .Literal [
3487+ "inner" ,
3488+ "left" ,
3489+ "outer" ,
3490+ "right" ,
3491+ "cross" ,
3492+ ],
3493+ ) -> list [str ]:
3494+ if left_index and right_index :
3495+ if how == "inner" or how == "left" :
3496+ return left_index_cols
3497+ if how == "right" :
3498+ return right_index_cols
3499+ if how == "outer" :
3500+ return []
3501+ else :
3502+ return []
3503+ elif left_index and not right_index :
3504+ return right_index_cols
3505+ elif right_index and not left_index :
3506+ return left_index_cols
3507+ else :
3508+ return []
0 commit comments