diff --git a/pyproject.toml b/pyproject.toml index bbc7c6a3..8fc82038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "pyarrow>=14.0.1", "scipy>=1.7.2", # kdtree "universal-pathlib>=0.2.2", + "tabulate>=0.7.0", ] [project.urls] @@ -44,6 +45,7 @@ dev = [ "pytest", "pytest-cov", # Used to report total code coverage "pytest-mock", # Used to mock objects in tests + "types-tabulate", # Type information for tabulate ] full = [ "fsspec[full]", # complete file system specs. diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index 5a3cc002..9fab054b 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -182,6 +182,7 @@ def crossmatch( ) = BuiltInCrossmatchAlgorithm.KD_TREE, output_catalog_name: str | None = None, require_right_margin: bool = False, + suffix_method: str | None = None, **kwargs, ) -> Catalog: """Perform a cross-match between two catalogs @@ -244,6 +245,11 @@ def crossmatch( Default: {left_name}_x_{right_name} require_right_margin (bool): If true, raises an error if the right margin is missing which could lead to incomplete crossmatches. Default: False + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A Catalog with the data from the left and right catalogs merged with one row for each @@ -268,10 +274,14 @@ def crossmatch( if output_catalog_name is None: output_catalog_name = f"{self.name}_x_{other.name}" ddf, ddf_map, alignment = crossmatch_catalog_data( - self, other, suffixes, algorithm=algorithm, **kwargs + self, other, suffixes, algorithm=algorithm, suffix_method=suffix_method, **kwargs ) new_catalog_info = create_merged_catalog_info( - self.hc_structure.catalog_info, other.hc_structure.catalog_info, output_catalog_name, suffixes + self, + other, + output_catalog_name, + suffixes, + suffix_method, ) hc_catalog = self.hc_structure.__class__( new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf), moc=alignment.moc @@ -741,6 +751,7 @@ def merge_asof( direction: str = "backward", suffixes: tuple[str, str] | None = None, output_catalog_name: str | None = None, + suffix_method: str | None = None, ): """Uses the pandas `merge_asof` function to merge two catalogs on their indices by distance of keys @@ -754,6 +765,12 @@ def merge_asof( other (lsdb.Catalog): the right catalog to merge to suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names direction (str): the direction to perform the merge_asof + output_catalog_name (str): The name of the resulting catalog to be stored in metadata + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A new catalog with the columns from each of the input catalogs with their respective suffixes @@ -765,7 +782,9 @@ def merge_asof( if len(suffixes) != 2: raise ValueError("`suffixes` must be a tuple with two strings") - ddf, ddf_map, alignment = merge_asof_catalog_data(self, other, suffixes=suffixes, direction=direction) + ddf, ddf_map, alignment = merge_asof_catalog_data( + self, other, suffixes=suffixes, direction=direction, suffix_method=suffix_method + ) if output_catalog_name is None: output_catalog_name = ( @@ -774,7 +793,11 @@ def merge_asof( ) new_catalog_info = create_merged_catalog_info( - self.hc_structure.catalog_info, other.hc_structure.catalog_info, output_catalog_name, suffixes + self, + other, + output_catalog_name, + suffixes, + suffix_method, ) hc_catalog = hc.catalog.Catalog( new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf), moc=alignment.moc @@ -789,6 +812,7 @@ def join( through: AssociationCatalog | None = None, suffixes: tuple[str, str] | None = None, output_catalog_name: str | None = None, + suffix_method: str | None = None, ) -> Catalog: """Perform a spatial join to another catalog @@ -804,6 +828,11 @@ def join( between pixels and individual rows. suffixes (Tuple[str,str]): suffixes to apply to the columns of each table output_catalog_name (str): The name of the resulting catalog to be stored in metadata + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A new catalog with the columns from each of the input catalogs with their respective suffixes @@ -818,7 +847,9 @@ def join( self._check_unloaded_columns([left_on, right_on]) if through is not None: - ddf, ddf_map, alignment = join_catalog_data_through(self, other, through, suffixes) + ddf, ddf_map, alignment = join_catalog_data_through( + self, other, through, suffixes, suffix_method=suffix_method + ) else: if left_on is None or right_on is None: raise ValueError("Either both of left_on and right_on, or through must be set") @@ -826,13 +857,19 @@ def join( raise ValueError("left_on must be a column in the left catalog") if right_on not in other._ddf.columns: raise ValueError("right_on must be a column in the right catalog") - ddf, ddf_map, alignment = join_catalog_data_on(self, other, left_on, right_on, suffixes) + ddf, ddf_map, alignment = join_catalog_data_on( + self, other, left_on, right_on, suffixes, suffix_method=suffix_method + ) if output_catalog_name is None: output_catalog_name = self.hc_structure.catalog_info.catalog_name new_catalog_info = create_merged_catalog_info( - self.hc_structure.catalog_info, other.hc_structure.catalog_info, output_catalog_name, suffixes + self, + other, + output_catalog_name, + suffixes, + suffix_method, ) hc_catalog = hc.catalog.Catalog( new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf), moc=alignment.moc diff --git a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py index 182ca68e..271f5b60 100644 --- a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py +++ b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py @@ -10,6 +10,8 @@ from hats.catalog import TableProperties from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN +from lsdb.dask.merge_catalog_functions import apply_suffixes + if TYPE_CHECKING: from lsdb.catalog import Catalog @@ -96,14 +98,14 @@ def __init__( self.right_catalog_info = right_catalog_info self.right_margin_catalog_info = right_margin_catalog_info - def crossmatch(self, suffixes, **kwargs) -> npd.NestedFrame: + def crossmatch(self, suffixes, suffix_method="all_columns", **kwargs) -> npd.NestedFrame: """Perform a crossmatch""" l_inds, r_inds, extra_cols = self.perform_crossmatch(**kwargs) if not len(l_inds) == len(r_inds) == len(extra_cols): raise ValueError( "Crossmatch algorithm must return left and right indices and extra columns with same length" ) - return self._create_crossmatch_df(l_inds, r_inds, extra_cols, suffixes) + return self._create_crossmatch_df(l_inds, r_inds, extra_cols, suffixes, suffix_method) def crossmatch_nested(self, nested_column_name, **kwargs) -> npd.NestedFrame: """Perform a crossmatch""" @@ -196,6 +198,7 @@ def _create_crossmatch_df( right_idx: npt.NDArray[np.int64], extra_cols: pd.DataFrame, suffixes: tuple[str, str], + suffix_method="all_columns", ) -> npd.NestedFrame: """Creates a df containing the crossmatch result from matching indices and additional columns @@ -209,8 +212,9 @@ def _create_crossmatch_df( additional columns added """ # rename columns so no same names during merging - self._rename_columns_with_suffix(self.left, suffixes[0]) - self._rename_columns_with_suffix(self.right, suffixes[1]) + self.left, self.right = apply_suffixes( + self.left, self.right, suffixes, suffix_method, log_changes=False + ) # concat dataframes together index_name = self.left.index.name if self.left.index.name is not None else "index" left_join_part = self.left.iloc[left_idx].reset_index() diff --git a/src/lsdb/dask/crossmatch_catalog_data.py b/src/lsdb/dask/crossmatch_catalog_data.py index f9504090..3cd81299 100644 --- a/src/lsdb/dask/crossmatch_catalog_data.py +++ b/src/lsdb/dask/crossmatch_catalog_data.py @@ -40,6 +40,7 @@ def perform_crossmatch( right_margin_catalog_info, algorithm, suffixes, + suffix_method, meta_df, **kwargs, ): @@ -66,7 +67,7 @@ def perform_crossmatch( left_catalog_info, right_catalog_info, right_margin_catalog_info, - ).crossmatch(suffixes, **kwargs) + ).crossmatch(suffixes, suffix_method=suffix_method, **kwargs) # pylint: disable=too-many-arguments, unused-argument @@ -119,6 +120,7 @@ def crossmatch_catalog_data( algorithm: ( Type[AbstractCrossmatchAlgorithm] | BuiltInCrossmatchAlgorithm ) = BuiltInCrossmatchAlgorithm.KD_TREE, + suffix_method: str | None = None, **kwargs, ) -> tuple[nd.NestedFrame, DaskDFPixelMap, PixelAlignment]: """Cross-matches the data from two catalogs @@ -157,7 +159,10 @@ def crossmatch_catalog_data( # generate meta table structure for dask df meta_df = generate_meta_df_for_joined_tables( - [left, right], suffixes, extra_columns=crossmatch_algorithm.extra_columns + (left, right), + suffixes, + suffix_method=suffix_method, + extra_columns=crossmatch_algorithm.extra_columns, ) # perform the crossmatch on each partition pairing using dask delayed for lazy computation @@ -166,6 +171,7 @@ def crossmatch_catalog_data( perform_crossmatch, crossmatch_algorithm, suffixes, + suffix_method, meta_df, **kwargs, ) diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py index 1d5fc79d..7feb36f8 100644 --- a/src/lsdb/dask/join_catalog_data.py +++ b/src/lsdb/dask/join_catalog_data.py @@ -18,6 +18,9 @@ align_and_apply, align_catalogs, align_catalogs_with_association, + apply_left_suffix, + apply_right_suffix, + apply_suffixes, concat_partition_and_margin, construct_catalog_args, filter_by_spatial_index_to_pixel, @@ -35,24 +38,6 @@ NON_JOINING_ASSOCIATION_COLUMNS = ["Norder", "Dir", "Npix", "join_Norder", "join_Dir", "join_Npix"] -def rename_columns_with_suffixes(left: npd.NestedFrame, right: npd.NestedFrame, suffixes: tuple[str, str]): - """Renames two dataframes with the suffixes specified - - Args: - left (npd.NestedFrame): the left dataframe to apply the first suffix to - right (npd.NestedFrame): the right dataframe to apply the second suffix to - suffixes (Tuple[str, str]): the pair of suffixes to apply to the dataframes - - Returns: - A tuple of (left, right) updated dataframes with their columns renamed - """ - left_columns_renamed = {name: name + suffixes[0] for name in left.columns} - left = left.rename(columns=left_columns_renamed) - right_columns_renamed = {name: name + suffixes[1] for name in right.columns} - right = right.rename(columns=right_columns_renamed) - return left, right - - # pylint: disable=too-many-arguments, unused-argument def perform_join_on( left: npd.NestedFrame, @@ -67,6 +52,7 @@ def perform_join_on( left_on: str, right_on: str, suffixes: tuple[str, str], + suffix_method: str | None = None, ): """Performs a join on two catalog partitions @@ -83,6 +69,10 @@ def perform_join_on( left_on (str): the column to join on from the left partition right_on (str): the column to join on from the right partition suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Returns: A dataframe with the result of merging the left and right partitions on the specified columns @@ -92,10 +82,11 @@ def perform_join_on( right_joined_df = concat_partition_and_margin(right, right_margin) - left, right_joined_df = rename_columns_with_suffixes(left, right_joined_df, suffixes) - merged = left.reset_index().merge( - right_joined_df, left_on=left_on + suffixes[0], right_on=right_on + suffixes[1] - ) + left_join_column = apply_left_suffix(left_on, right_joined_df.columns, suffixes, suffix_method) + right_join_column = apply_right_suffix(right_on, left.columns, suffixes, suffix_method) + left, right_joined_df = apply_suffixes(left, right_joined_df, suffixes, suffix_method, log_changes=False) + + merged = left.reset_index().merge(right_joined_df, left_on=left_join_column, right_on=right_join_column) merged.set_index(SPATIAL_INDEX_COLUMN, inplace=True) return merged @@ -147,7 +138,7 @@ def perform_join_nested( return merged -# pylint: disable=too-many-arguments, unused-argument +# pylint: disable=too-many-arguments, unused-argument, too-many-locals def perform_join_through( left: npd.NestedFrame, right: npd.NestedFrame, @@ -162,6 +153,7 @@ def perform_join_through( right_margin_catalog_info: TableProperties, assoc_catalog_info: TableProperties, suffixes: tuple[str, str], + suffix_method: str | None = None, ): """Performs a join on two catalog partitions through an association catalog @@ -180,6 +172,10 @@ def perform_join_through( catalog assoc_catalog_info (hc.TableProperties): the hats structure of the association catalog suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Returns: A dataframe with the result of merging the left and right partitions on the specified columns @@ -191,7 +187,13 @@ def perform_join_through( right_joined_df = concat_partition_and_margin(right, right_margin) - left, right_joined_df = rename_columns_with_suffixes(left, right_joined_df, suffixes) + left_join_column = apply_left_suffix( + assoc_catalog_info.primary_column, right_joined_df.columns, suffixes, suffix_method + ) + right_join_column = apply_right_suffix( + assoc_catalog_info.join_column, left.columns, suffixes, suffix_method + ) + left, right_joined_df = apply_suffixes(left, right_joined_df, suffixes, suffix_method, log_changes=False) # Edge case: if right_column + suffix == join_column_association, columns will be in the wrong order # so rename association column @@ -202,8 +204,9 @@ def perform_join_through( columns={assoc_catalog_info.join_column_association: join_column_association}, inplace=True ) + join_columns = [assoc_catalog_info.primary_column_association, join_column_association] join_columns_to_drop = [] - for c in [assoc_catalog_info.primary_column_association, join_column_association]: + for c in join_columns: if c not in left.columns and c not in right_joined_df.columns and c not in join_columns_to_drop: join_columns_to_drop.append(c) @@ -215,16 +218,21 @@ def perform_join_through( left.reset_index() .merge( through, - left_on=assoc_catalog_info.primary_column + suffixes[0], + left_on=left_join_column, right_on=assoc_catalog_info.primary_column_association, ) .merge( right_joined_df, left_on=join_column_association, - right_on=assoc_catalog_info.join_column + suffixes[1], + right_on=right_join_column, ) ) + extra_join_cols = through.columns.drop(join_columns + cols_to_drop) + other_cols = merged.columns.drop(extra_join_cols) + + merged = merged[other_cols.append(extra_join_cols)] + merged.set_index(SPATIAL_INDEX_COLUMN, inplace=True) if len(join_columns_to_drop) > 0: merged.drop(join_columns_to_drop, axis=1, inplace=True) @@ -241,6 +249,7 @@ def perform_merge_asof( right_catalog_info: TableProperties, suffixes: tuple[str, str], direction: str, + suffix_method: str | None = None, ): """Performs a merge_asof on two catalog partitions @@ -253,6 +262,10 @@ def perform_merge_asof( right_catalog_info (hc.TableProperties): the catalog info of the right catalog suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names direction (str): The direction to perform the merge_asof + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Returns: A dataframe with the result of merging the left and right partitions on the specified columns with @@ -261,7 +274,7 @@ def perform_merge_asof( if right_pixel.order > left_pixel.order: left = filter_by_spatial_index_to_pixel(left, right_pixel.order, right_pixel.pixel) - left, right = rename_columns_with_suffixes(left, right, suffixes) + left, right = apply_suffixes(left, right, suffixes, suffix_method, log_changes=False) left.sort_index(inplace=True) right.sort_index(inplace=True) merged = pd.merge_asof(left, right, left_index=True, right_index=True, direction=direction) @@ -269,7 +282,12 @@ def perform_merge_asof( def join_catalog_data_on( - left: Catalog, right: Catalog, left_on: str, right_on: str, suffixes: tuple[str, str] + left: Catalog, + right: Catalog, + left_on: str, + right_on: str, + suffixes: tuple[str, str], + suffix_method: str | None = None, ) -> tuple[nd.NestedFrame, DaskDFPixelMap, PixelAlignment]: """Joins two catalogs spatially on a specified column @@ -279,6 +297,11 @@ def join_catalog_data_on( left_on (str): the column to join on from the left partition right_on (str): the column to join on from the right partition suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A tuple of the dask dataframe with the result of the join, the pixel map from HEALPix @@ -301,9 +324,10 @@ def join_catalog_data_on( left_on, right_on, suffixes, + suffix_method, ) - meta_df = generate_meta_df_for_joined_tables([left, right], suffixes) + meta_df = generate_meta_df_for_joined_tables((left, right), suffixes, suffix_method=suffix_method) return construct_catalog_args(joined_partitions, meta_df, alignment) @@ -358,7 +382,11 @@ def join_catalog_data_nested( def join_catalog_data_through( - left: Catalog, right: Catalog, association: AssociationCatalog, suffixes: tuple[str, str] + left: Catalog, + right: Catalog, + association: AssociationCatalog, + suffixes: tuple[str, str], + suffix_method: str | None = None, ) -> tuple[nd.NestedFrame, DaskDFPixelMap, PixelAlignment]: """Joins two catalogs with an association table @@ -367,6 +395,11 @@ def join_catalog_data_through( right (lsdb.Catalog): the right catalog to join association (AssociationCatalog): the association catalog to join the catalogs with suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A tuple of the dask dataframe with the result of the join, the pixel map from HEALPix @@ -387,6 +420,7 @@ def join_catalog_data_through( association.hc_structure.catalog_info.primary_column, association.hc_structure.catalog_info.join_column, suffixes, + suffix_method=suffix_method, ) if right.margin is None: @@ -420,6 +454,7 @@ def join_catalog_data_through( ], perform_join_through, suffixes, + suffix_method, ) association_join_columns = [ @@ -430,13 +465,19 @@ def join_catalog_data_through( # pylint: disable=protected-access extra_df = association._ddf._meta.drop(non_joining_columns + association_join_columns, axis=1) - meta_df = generate_meta_df_for_joined_tables([left, extra_df, right], [suffixes[0], "", suffixes[1]]) + meta_df = generate_meta_df_for_joined_tables( + (left, right), suffixes, extra_columns=extra_df, suffix_method=suffix_method + ) return construct_catalog_args(joined_partitions, meta_df, alignment) def merge_asof_catalog_data( - left: Catalog, right: Catalog, suffixes: tuple[str, str], direction: str = "backward" + left: Catalog, + right: Catalog, + suffixes: tuple[str, str], + direction: str = "backward", + suffix_method: str | None = None, ) -> tuple[nd.NestedFrame, DaskDFPixelMap, PixelAlignment]: """Uses the pandas `merge_asof` function to merge two catalogs on their indices by distance of keys @@ -451,6 +492,11 @@ def merge_asof_catalog_data( right (lsdb.Catalog): the right catalog to join suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names direction (str): the direction to perform the merge_asof + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A tuple of the dask dataframe with the result of the join, the pixel map from HEALPix @@ -463,9 +509,9 @@ def merge_asof_catalog_data( left_pixels, right_pixels = get_healpix_pixels_from_alignment(alignment) joined_partitions = align_and_apply( - [(left, left_pixels), (right, right_pixels)], perform_merge_asof, suffixes, direction + [(left, left_pixels), (right, right_pixels)], perform_merge_asof, suffixes, direction, suffix_method ) - meta_df = generate_meta_df_for_joined_tables([left, right], suffixes) + meta_df = generate_meta_df_for_joined_tables((left, right), suffixes, suffix_method=suffix_method) return construct_catalog_args(joined_partitions, meta_df, alignment) diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index 1255bc68..a705aa39 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +import logging +import warnings +from typing import TYPE_CHECKING, Callable, Literal, Sequence import hats.pixel_math.healpix_shim as hp import nested_pandas as npd @@ -18,6 +20,7 @@ from hats.pixel_tree import PixelAlignment, PixelAlignmentType, align_trees from hats.pixel_tree.moc_utils import copy_moc from hats.pixel_tree.pixel_alignment import align_with_mocs +from tabulate import tabulate import lsdb.nested as nd from lsdb.dask.divisions import get_pixels_divisions @@ -32,6 +35,152 @@ ASSOC_NORDER = "assoc_Norder" ASSOC_NPIX = "assoc_Npix" +DEFAULT_SUFFIX_METHOD: Literal["all_columns", "overlapping_columns"] = "all_columns" + + +def apply_suffix_all_columns( + left_df: npd.NestedFrame, right_df: npd.NestedFrame, suffixes: tuple[str, str] +) -> tuple[npd.NestedFrame, npd.NestedFrame]: + """Applies suffixes to all columns in both dataframes + + Args: + left_df (npd.NestedFrame): The left dataframe + right_df (npd.NestedFrame): The right dataframe + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + + Returns: + A tuple of the two dataframes with the suffixes applied + """ + left_suffix, right_suffix = suffixes + left_df = left_df.add_suffix(left_suffix) + right_df = right_df.add_suffix(right_suffix) + return left_df, right_df + + +def apply_suffix_overlapping_columns( + left_df: npd.NestedFrame, right_df: npd.NestedFrame, suffixes: tuple[str, str], log_changes: bool = True +) -> tuple[npd.NestedFrame, npd.NestedFrame]: + """Applies suffixes to overlapping columns in both dataframes + + Logs an info message for each column that is being renamed. + + Args: + left_df (npd.NestedFrame): The left dataframe + right_df (npd.NestedFrame): The right dataframe + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + log_changes (bool): If True, logs an info message for each column that is being renamed. + + Returns: + A tuple of the two dataframes with the suffixes applied + """ + left_suffix, right_suffix = suffixes + overlapping_columns = set(left_df.columns).intersection(set(right_df.columns)) + left_df = left_df.rename(columns={c: c + left_suffix for c in overlapping_columns}) + right_df = right_df.rename(columns={c: c + right_suffix for c in overlapping_columns}) + + table = tabulate( + [(c, c + left_suffix, c + right_suffix) for c in overlapping_columns], + headers=["Column", f"Left (suffix={left_suffix})", f"Right (suffix={right_suffix})"], + tablefmt="pretty", + ) + + if overlapping_columns and log_changes: + logging.info("Renaming overlapping columns:\n%s", table) + + return left_df, right_df + + +def apply_suffixes( + left_df: npd.NestedFrame, + right_df: npd.NestedFrame, + suffixes: tuple[str, str], + suffix_method: str | None = None, + log_changes: bool = True, +) -> tuple[npd.NestedFrame, npd.NestedFrame]: + """Applies suffixes to the columns of two dataframes using the specified suffix method + + Args: + left_df (npd.NestedFrame): The left dataframe + right_df (npd.NestedFrame): The right dataframe + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + suffix_method (str | None): The method to use to generate suffixes. Options are 'all_columns', + 'overlapping_columns'. If None, defaults to 'all_columns' but will change to + 'overlapping_columns' in a future release. + log_changes (bool): If True, logs an info message for each column that is being renamed. + + Returns: + A tuple of the two dataframes with the suffixes applied + """ + if suffix_method is None: + suffix_method = DEFAULT_SUFFIX_METHOD + warnings.warn( + "The default suffix behavior will change from applying suffixes to all columns to only " + "applying suffixes to overlapping columns in a future release." + "To maintain the current behavior, explicitly set `suffix_method='all_columns'`. " + "To change to the new behavior, set `suffix_method='overlapping_columns'`.", + FutureWarning, + ) + + if suffix_method == "all_columns": + return apply_suffix_all_columns(left_df, right_df, suffixes) + if suffix_method == "overlapping_columns": + return apply_suffix_overlapping_columns(left_df, right_df, suffixes, log_changes) + raise ValueError(f"Invalid suffix method: {suffix_method}") + + +def apply_left_suffix( + col_name: str, + right_col_names: list[str], + suffixes: tuple[str, str], + suffix_method: str | None = None, + log_changes: bool = False, +) -> str: + """Applies the left suffix to a column name using the specified suffix function + + Args: + col_name (str): The column name to apply the suffix to + right_col_names (list[str]): The list of column names in the right dataframe + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + suffix_method (str): The method to use to generate suffixes. Options are 'all_columns', + 'overlapping_columns' + log_changes (bool): If True, logs an info message for each column that is being renamed. + Default: False + + Returns: + The column name with the left suffix applied + """ + left_df = npd.NestedFrame(columns=[col_name]) + right_df = npd.NestedFrame(columns=right_col_names) + left_df, _ = apply_suffixes(left_df, right_df, suffixes, suffix_method, log_changes=log_changes) + return left_df.columns[0] + + +def apply_right_suffix( + col_name: str, + left_col_names: list[str], + suffixes: tuple[str, str], + suffix_method: str | None = None, + log_changes: bool = False, +) -> str: + """Applies the right suffix to a column name using the specified suffix function + + Args: + col_name (str): The column name to apply the suffix to + left_col_names (list[str]): The column names in the left dataframe + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + suffix_method (str): The method to use to generate suffixes. Options are 'all_columns', + 'overlapping_columns' + log_changes (bool): If True, logs an info message for each column that is being renamed. + Default: False + + Returns: + The column name with the right suffix applied + """ + left_df = npd.NestedFrame(columns=left_col_names) + right_df = npd.NestedFrame(columns=[col_name]) + _, right_df = apply_suffixes(left_df, right_df, suffixes, suffix_method, log_changes=log_changes) + return right_df.columns[0] + def concat_partition_and_margin( partition: npd.NestedFrame, margin: npd.NestedFrame | None @@ -524,8 +673,9 @@ def get_healpix_pixels_from_association( def generate_meta_df_for_joined_tables( - catalogs: Sequence[Catalog], - suffixes: Sequence[str], + catalogs: tuple[Catalog, Catalog], + suffixes: tuple[str, str], + suffix_method: str | None = None, extra_columns: pd.DataFrame | None = None, index_name: str = SPATIAL_INDEX_COLUMN, index_type: npt.DTypeLike | None = None, @@ -538,6 +688,7 @@ def generate_meta_df_for_joined_tables( Args: catalogs (Sequence[lsdb.Catalog]): The catalogs to merge together suffixes (Sequence[Str]): The column suffixes to apply each catalog + suffix_method (str): The method to use to generate suffixes. extra_columns (pd.Dataframe): Any additional columns to the merged catalogs index_name (str): The name of the index in the resulting DataFrame index_type (npt.DTypeLike): The type of the index in the resulting DataFrame. @@ -547,15 +698,18 @@ def generate_meta_df_for_joined_tables( An empty dataframe with the columns of each catalog with their respective suffix, and any extra columns specified, with the index name set. """ - meta = {} # Construct meta for crossmatched catalog columns - for table, suffix in zip(catalogs, suffixes): - for name, col_type in table.dtypes.items(): - if name not in paths.HIVE_COLUMNS: - meta[name + suffix] = pd.Series(dtype=col_type) + # pylint: disable=protected-access + left_meta, right_meta = apply_suffixes( + catalogs[0]._ddf._meta, + catalogs[1]._ddf._meta, + suffixes, + suffix_method, + ) + meta = pd.concat([left_meta, right_meta], axis=1) # Construct meta for crossmatch result columns if extra_columns is not None: - meta.update(extra_columns) + meta = pd.concat([meta, extra_columns], axis=1) if index_type is None: # pylint: disable=protected-access index_type = catalogs[0]._ddf._meta.index.dtype @@ -691,7 +845,11 @@ def align_catalog_to_partitions( def create_merged_catalog_info( - left_info: TableProperties, right_info: TableProperties, updated_name: str, suffixes: tuple[str, str] + left: Catalog, + right: Catalog, + updated_name: str, + suffixes: tuple[str, str], + suffix_method: str | None = None, ) -> TableProperties: """Creates the catalog info of the resulting catalog from merging two catalogs @@ -699,24 +857,23 @@ def create_merged_catalog_info( catalog name, and sets the total rows to 0 Args: - left_info (TableProperties): The catalog_info of the left catalog - right_info (TableProperties): The catalog_info of the right catalog + left (Catalog): The left catalog being merged + right (Catalog): The right catalog being merged updated_name (str): The updated name of the catalog suffixes (tuple[str, str]): The suffixes of the catalogs in the merged result + suffix_method (str): The method used to generate suffixes. Options are 'all_columns', + 'overlapping_columns' + + Returns: + The catalog info of the resulting merged catalog """ - default_cols = ( - [c + suffixes[0] for c in left_info.default_columns] if left_info.default_columns is not None else [] - ) - default_cols = ( - default_cols + [c + suffixes[1] for c in right_info.default_columns] - if right_info.default_columns is not None - else default_cols - ) - default_cols_to_use = default_cols if len(default_cols) > 0 else None + left_info = left.hc_structure.catalog_info + ra_col = apply_left_suffix(left_info.ra_column, right.columns, suffixes, suffix_method) # type: ignore + dec_col = apply_left_suffix(left_info.dec_column, right.columns, suffixes, suffix_method) # type: ignore return left_info.copy_and_update( catalog_name=updated_name, - ra_column=left_info.ra_column + suffixes[0], - dec_column=left_info.dec_column + suffixes[0], + ra_column=ra_col, + dec_column=dec_col, total_rows=0, - default_columns=default_cols_to_use, + default_columns=None, ) diff --git a/tests/conftest.py b/tests/conftest.py index fc3d6e99..ceb7d1f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -452,13 +452,6 @@ def assert_default_columns_in_columns(cat): for col in cat.hc_structure.catalog_info.default_columns: assert col in cat._ddf.columns - @staticmethod - def assert_columns_in_joined_catalog(joined_cat, cats, suffixes): - for cat, suffix in zip(cats, suffixes): - for col_name, dtype in cat.dtypes.items(): - if col_name not in paths.HIVE_COLUMNS: - assert (col_name + suffix, dtype) in joined_cat.dtypes.items() - @staticmethod def assert_columns_in_nested_joined_catalog( joined_cat, left_cat, right_cat, right_ignore_columns, nested_colname diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index 72d36105..c66a1a1a 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -1,3 +1,5 @@ +import logging + import nested_pandas as npd import numpy as np import pandas as pd @@ -12,7 +14,7 @@ from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm from lsdb.core.crossmatch.bounded_kdtree_match import BoundedKdTreeCrossmatch from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch -from lsdb.dask.merge_catalog_functions import align_catalogs +from lsdb.dask.merge_catalog_functions import align_catalogs, apply_suffixes @pytest.mark.parametrize("algo", [KdTreeCrossmatch]) @@ -38,6 +40,11 @@ def test_kdtree_crossmatch(algo, small_sky_catalog, small_sky_xmatch_catalog, xm assert xmatch_row["id_small_sky_xmatch"].to_numpy() == correct_row["xmatch_id"] assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) + assert xmatched_cat.hc_structure.catalog_info.ra_column in xmatched_cat.columns + assert xmatched_cat.hc_structure.catalog_info.dec_column in xmatched_cat.columns + assert xmatched_cat.hc_structure.catalog_info.ra_column == "ra_small_sky" + assert xmatched_cat.hc_structure.catalog_info.dec_column == "dec_small_sky" + @staticmethod def test_kdtree_crossmatch_nested(algo, small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct): with pytest.warns(RuntimeWarning, match="Results may be incomplete and/or inaccurate"): @@ -215,11 +222,90 @@ def test_crossmatch_negative_margin( assert len(xmatch_row) == 1 assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) + @staticmethod + def test_overlapping_suffix_method(algo, small_sky_catalog, small_sky_xmatch_catalog, caplog): + suffixes = ("_left", "_right") + # Test that renamed columns are logged correctly + with caplog.at_level(logging.INFO): + xmatched = small_sky_catalog.crossmatch( + small_sky_xmatch_catalog, + algorithm=algo, + suffix_method="overlapping_columns", + suffixes=suffixes, + ) + + assert "Renaming overlapping columns" in caplog.text + + computed = xmatched.compute() + for col in small_sky_catalog.columns: + if col in small_sky_xmatch_catalog.columns: + assert f"{col}{suffixes[0]}" in xmatched.columns + assert f"{col}{suffixes[0]}" in computed.columns + assert col in caplog.text + assert f"{col}{suffixes[0]}" in caplog.text + else: + assert col in xmatched.columns + assert col in computed.columns + for col in small_sky_xmatch_catalog.columns: + if col in small_sky_catalog.columns: + assert f"{col}{suffixes[1]}" in xmatched.columns + assert f"{col}{suffixes[1]}" in computed.columns + assert f"{col}{suffixes[1]}" in caplog.text + else: + assert col in xmatched.columns + assert col in computed.columns + + assert xmatched.hc_structure.catalog_info.ra_column in xmatched.columns + assert xmatched.hc_structure.catalog_info.dec_column in xmatched.columns + assert xmatched.hc_structure.catalog_info.ra_column == "ra_left" + assert xmatched.hc_structure.catalog_info.dec_column == "dec_left" + + @staticmethod + def test_overlapping_suffix_method_no_overlaps(algo, small_sky_catalog, small_sky_xmatch_catalog, caplog): + suffixes = ("_left", "_right") + small_sky_catalog = small_sky_catalog.rename( + {col: f"{col}_unique" for col in small_sky_catalog.columns} + ) + small_sky_catalog.hc_structure.catalog_info.ra_column = ( + f"{small_sky_catalog.hc_structure.catalog_info.ra_column}_unique" + ) + small_sky_catalog.hc_structure.catalog_info.dec_column = ( + f"{small_sky_catalog.hc_structure.catalog_info.dec_column}_unique" + ) + # Test that renamed columns are logged correctly + with caplog.at_level(logging.INFO): + xmatched = small_sky_catalog.crossmatch( + small_sky_xmatch_catalog, + algorithm=algo, + suffix_method="overlapping_columns", + suffixes=suffixes, + ) + + assert len(caplog.text) == 0 + + computed = xmatched.compute() + for col in small_sky_catalog.columns: + assert col in xmatched.columns + assert col in computed.columns + for col in small_sky_xmatch_catalog.columns: + assert col in xmatched.columns + assert col in computed.columns + + assert xmatched.hc_structure.catalog_info.ra_column in xmatched.columns + assert xmatched.hc_structure.catalog_info.dec_column in xmatched.columns + assert xmatched.hc_structure.catalog_info.ra_column == "ra_unique" + assert xmatched.hc_structure.catalog_info.dec_column == "dec_unique" + @staticmethod def test_wrong_suffixes(algo, small_sky_catalog, small_sky_xmatch_catalog): with pytest.raises(ValueError): small_sky_catalog.crossmatch(small_sky_xmatch_catalog, suffixes=("wrong",), algorithm=algo) + @staticmethod + def test_wrong_suffix_method(algo, small_sky_catalog, small_sky_xmatch_catalog): + with pytest.raises(ValueError, match="Invalid suffix method"): + small_sky_catalog.crossmatch(small_sky_xmatch_catalog, suffix_method="wrong", algorithm=algo) + @staticmethod def test_right_margin_missing(algo, small_sky_catalog, small_sky_xmatch_catalog): small_sky_xmatch_catalog.margin = None @@ -440,11 +526,10 @@ class MockCrossmatchAlgorithmOverwrite(AbstractCrossmatchAlgorithm): extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.float64)}) - def crossmatch(self, suffixes, mock_results: pd.DataFrame = None): # type: ignore + def crossmatch(self, suffixes, suffix_method="all_columns", mock_results: pd.DataFrame = None, **kwargs): left_reset = self.left.reset_index(drop=True) right_reset = self.right.reset_index(drop=True) - self._rename_columns_with_suffix(self.left, suffixes[0]) - self._rename_columns_with_suffix(self.right, suffixes[1]) + self.left, self.right = apply_suffixes(self.left, self.right, suffixes, suffix_method) mock_results = mock_results[mock_results["ss_id"].isin(left_reset["id"].to_numpy())] left_indexes = mock_results.apply( lambda row: left_reset[left_reset["id"] == row["ss_id"]].index[0], axis=1 diff --git a/tests/lsdb/catalog/test_join.py b/tests/lsdb/catalog/test_join.py index 7f392e71..65b02113 100644 --- a/tests/lsdb/catalog/test_join.py +++ b/tests/lsdb/catalog/test_join.py @@ -1,4 +1,5 @@ import nested_pandas as npd +import numpy as np import pandas as pd import pyarrow as pa import pytest @@ -18,7 +19,21 @@ def test_small_sky_join_small_sky_order1(small_sky_catalog, small_sky_order1_cat small_sky_order1_catalog, left_on="id", right_on="id", suffixes=suffixes ) assert isinstance(joined._ddf, nd.NestedFrame) - helpers.assert_columns_in_joined_catalog(joined, [small_sky_catalog, small_sky_order1_catalog], suffixes) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "id_b", + "ra_b", + "dec_b", + "ra_error_b", + "dec_error_b", + ] + + assert np.all(joined.columns == expected_columns) assert joined._ddf.index.name == SPATIAL_INDEX_COLUMN assert joined._ddf.index.dtype == pd.ArrowDtype(pa.int64()) alignment = align_catalogs(small_sky_catalog, small_sky_order1_catalog) @@ -26,6 +41,7 @@ def test_small_sky_join_small_sky_order1(small_sky_catalog, small_sky_order1_cat assert joined.get_healpix_pixels() == alignment.pixel_tree.get_healpix_pixels() joined_compute = joined.compute() + assert np.all(joined_compute.columns == expected_columns) assert isinstance(joined_compute, npd.NestedFrame) small_sky_compute = small_sky_catalog.compute() small_sky_order1_compute = small_sky_order1_catalog.compute() @@ -40,6 +56,41 @@ def test_small_sky_join_small_sky_order1(small_sky_catalog, small_sky_order1_cat assert not joined.hc_structure.on_disk +def test_small_sky_join_overlapping_suffix(small_sky_catalog, small_sky_order1_catalog, helpers): + suffixes = ("_a", "_b") + with pytest.warns(match="margin"): + joined = small_sky_catalog.join( + small_sky_order1_catalog, + left_on="id", + right_on="id", + suffixes=suffixes, + suffix_method="overlapping_columns", + ) + assert isinstance(joined._ddf, nd.NestedFrame) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "id_b", + "ra_b", + "dec_b", + "ra_error_b", + "dec_error_b", + ] + + assert np.all(joined.columns == expected_columns) + + joined_compute = joined.compute() + + assert np.all(joined_compute.columns == expected_columns) + + helpers.assert_divisions_are_correct(joined) + helpers.assert_schema_correct(joined) + + def test_small_sky_join_small_sky_order1_source( small_sky_catalog, small_sky_order1_source_with_margin, helpers ): @@ -47,15 +98,33 @@ def test_small_sky_join_small_sky_order1_source( joined = small_sky_catalog.join( small_sky_order1_source_with_margin, left_on="id", right_on="object_id", suffixes=suffixes ) - helpers.assert_columns_in_joined_catalog( - joined, [small_sky_catalog, small_sky_order1_source_with_margin], suffixes - ) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "source_id_b", + "source_ra_b", + "source_dec_b", + "mjd_b", + "mag_b", + "band_b", + "object_id_b", + "object_ra_b", + "object_dec_b", + ] + + assert np.all(joined.columns == expected_columns) alignment = align_catalogs(small_sky_catalog, small_sky_order1_source_with_margin) assert joined.hc_structure.moc == alignment.moc assert joined.get_healpix_pixels() == alignment.pixel_tree.get_healpix_pixels() joined_compute = joined.compute() + + assert np.all(joined_compute.columns == expected_columns) small_sky_order1_compute = small_sky_order1_source_with_margin.compute() assert len(joined_compute) == len(small_sky_order1_compute) joined_test = small_sky_order1_compute.merge(joined_compute, left_on="object_id", right_on="object_id_b") @@ -71,15 +140,30 @@ def test_small_sky_join_default_columns( joined = small_sky_order1_default_cols_catalog.join( small_sky_order1_source_with_margin, left_on="id", right_on="object_id", suffixes=suffixes ) - helpers.assert_columns_in_joined_catalog( - joined, [small_sky_order1_default_cols_catalog, small_sky_order1_source_with_margin], suffixes - ) + + expected_columns = [ + "ra_a", + "dec_a", + "id_a", + "source_id_b", + "source_ra_b", + "source_dec_b", + "mjd_b", + "mag_b", + "band_b", + "object_id_b", + "object_ra_b", + "object_dec_b", + ] + + assert np.all(joined.columns == expected_columns) alignment = align_catalogs(small_sky_order1_default_cols_catalog, small_sky_order1_source_with_margin) assert joined.hc_structure.moc == alignment.moc assert joined.get_healpix_pixels() == alignment.pixel_tree.get_healpix_pixels() joined_compute = joined.compute() + assert np.all(joined_compute.columns == expected_columns) small_sky_order1_compute = small_sky_order1_source_with_margin.compute() assert len(joined_compute) == len(small_sky_order1_compute) joined_test = small_sky_order1_compute.merge(joined_compute, left_on="object_id", right_on="object_id_b") @@ -102,7 +186,7 @@ def test_join_wrong_suffixes(small_sky_catalog, small_sky_order1_catalog): def test_join_association( - small_sky_catalog, small_sky_order1_source_collection_catalog, small_sky_to_o1source_catalog, helpers + small_sky_catalog, small_sky_order1_source_collection_catalog, small_sky_to_o1source_catalog ): suffixes = ("_a", "_b") joined = small_sky_catalog.join( @@ -112,13 +196,33 @@ def test_join_association( alignment = align_catalogs(small_sky_catalog, small_sky_order1_source_collection_catalog) assert joined.hc_structure.moc == alignment.moc assert joined.get_healpix_pixels() == alignment.pixel_tree.get_healpix_pixels() - helpers.assert_columns_in_joined_catalog( - joined, [small_sky_catalog, small_sky_order1_source_collection_catalog], suffixes - ) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "source_id_b", + "source_ra_b", + "source_dec_b", + "mjd_b", + "mag_b", + "band_b", + "object_id_b", + "object_ra_b", + "object_dec_b", + "_dist_arcsec", + ] + + assert np.all(joined.columns == expected_columns) + assert joined._ddf.index.name == SPATIAL_INDEX_COLUMN assert joined._ddf.index.dtype == pd.ArrowDtype(pa.int64()) joined_data = joined.compute() + + assert np.all(joined_data.columns == expected_columns) assert isinstance(joined_data, npd.NestedFrame) association_data = small_sky_to_o1source_catalog.compute() assert len(joined_data) == len(association_data) @@ -147,6 +251,44 @@ def test_join_association( assert joined_row.index == left_row.index +def test_join_association_overlapping_suffix( + small_sky_catalog, small_sky_order1_source_collection_catalog, small_sky_to_o1source_catalog, helpers +): + suffixes = ("_a", "_b") + joined = small_sky_catalog.join( + small_sky_order1_source_collection_catalog, + through=small_sky_to_o1source_catalog, + suffixes=suffixes, + suffix_method="overlapping_columns", + ) + expected_columns = [ + "id", + "ra", + "dec", + "ra_error", + "dec_error", + "source_id", + "source_ra", + "source_dec", + "mjd", + "mag", + "band", + "object_id", + "object_ra", + "object_dec", + "_dist_arcsec", + ] + + assert np.all(joined.columns == expected_columns) + + joined_compute = joined.compute() + + assert np.all(joined_compute.columns == expected_columns) + + helpers.assert_divisions_are_correct(joined) + helpers.assert_schema_correct(joined) + + def test_join_association_suffix_edge_case( small_sky_catalog, small_sky_order1_source_collection_catalog, small_sky_to_o1source_catalog ): @@ -200,9 +342,26 @@ def test_join_nested(small_sky_catalog, small_sky_order1_source_with_margin, hel right_on="object_id", nested_column_name="sources", ) - helpers.assert_columns_in_nested_joined_catalog( - joined, small_sky_catalog, small_sky_order1_source_with_margin, ["object_id"], "sources" - ) + expected_columns = [ + "id", + "ra", + "dec", + "ra_error", + "dec_error", + "sources", + ] + expected_nested_columns = [ + "source_id", + "source_ra", + "source_dec", + "mjd", + "mag", + "band", + "object_ra", + "object_dec", + ] + assert np.all(joined.columns == expected_columns) + assert np.all(joined["sources"].nest.fields == expected_nested_columns) helpers.assert_divisions_are_correct(joined) alignment = align_catalogs(small_sky_catalog, small_sky_order1_source_with_margin) assert joined.hc_structure.moc == alignment.moc @@ -263,6 +422,35 @@ def test_merge_asof(small_sky_catalog, small_sky_xmatch_catalog, direction, help pd.testing.assert_frame_equal(joined_compute.drop(columns=drop_cols), correct_result) +def test_merge_asof_overlapping_suffix(small_sky_catalog, small_sky_xmatch_catalog, helpers): + suffixes = ("_a", "_b") + joined = small_sky_catalog.merge_asof( + small_sky_xmatch_catalog, direction="backward", suffixes=suffixes, suffix_method="overlapping_columns" + ) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "id_b", + "ra_b", + "dec_b", + "ra_error_b", + "dec_error_b", + "calculated_dist", + ] + assert np.all(joined.columns == expected_columns) + helpers.assert_divisions_are_correct(joined) + + joined_compute = joined.compute() + + assert np.all(joined_compute.columns == expected_columns) + helpers.assert_divisions_are_correct(joined) + helpers.assert_schema_correct(joined) + + def merging_function(input_frame, map_input, *args, **kwargs): if len(input_frame) == 0: ## this is the empty call to infer meta