From ab0b3431bc4037b5ccabe8696d9a42eb58c03bb4 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Tue, 16 Sep 2025 18:33:29 -0400 Subject: [PATCH 01/12] update crossmatch with suffix kwarg --- pyproject.toml | 1 + src/lsdb/catalog/catalog.py | 8 +- .../abstract_crossmatch_algorithm.py | 10 +- src/lsdb/dask/crossmatch_catalog_data.py | 13 ++- src/lsdb/dask/merge_catalog_functions.py | 100 ++++++++++++++++-- tests/lsdb/catalog/test_crossmatch.py | 9 +- 6 files changed, 121 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bbc7c6a39..0f12629b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "pyarrow>=14.0.1", "scipy>=1.7.2", # kdtree "universal-pathlib>=0.2.2", + "tabulate>=0.7.0", ] [project.urls] diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index 5a3cc0029..38ef5f83d 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -182,6 +182,7 @@ def crossmatch( ) = BuiltInCrossmatchAlgorithm.KD_TREE, output_catalog_name: str | None = None, require_right_margin: bool = False, + suffix_method: str | None = None, **kwargs, ) -> Catalog: """Perform a cross-match between two catalogs @@ -244,6 +245,11 @@ def crossmatch( Default: {left_name}_x_{right_name} require_right_margin (bool): If true, raises an error if the right margin is missing which could lead to incomplete crossmatches. Default: False + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A Catalog with the data from the left and right catalogs merged with one row for each @@ -268,7 +274,7 @@ def crossmatch( if output_catalog_name is None: output_catalog_name = f"{self.name}_x_{other.name}" ddf, ddf_map, alignment = crossmatch_catalog_data( - self, other, suffixes, algorithm=algorithm, **kwargs + self, other, suffixes, algorithm=algorithm, suffix_method=suffix_method, **kwargs ) new_catalog_info = create_merged_catalog_info( self.hc_structure.catalog_info, other.hc_structure.catalog_info, output_catalog_name, suffixes diff --git a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py index 182ca68e7..151f45fff 100644 --- a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py +++ b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py @@ -10,6 +10,8 @@ from hats.catalog import TableProperties from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN +from lsdb.dask.merge_catalog_functions import apply_suffix_all_columns + if TYPE_CHECKING: from lsdb.catalog import Catalog @@ -96,14 +98,14 @@ def __init__( self.right_catalog_info = right_catalog_info self.right_margin_catalog_info = right_margin_catalog_info - def crossmatch(self, suffixes, **kwargs) -> npd.NestedFrame: + def crossmatch(self, suffixes, suffix_function=apply_suffix_all_columns, **kwargs) -> npd.NestedFrame: """Perform a crossmatch""" l_inds, r_inds, extra_cols = self.perform_crossmatch(**kwargs) if not len(l_inds) == len(r_inds) == len(extra_cols): raise ValueError( "Crossmatch algorithm must return left and right indices and extra columns with same length" ) - return self._create_crossmatch_df(l_inds, r_inds, extra_cols, suffixes) + return self._create_crossmatch_df(l_inds, r_inds, extra_cols, suffixes, suffix_function) def crossmatch_nested(self, nested_column_name, **kwargs) -> npd.NestedFrame: """Perform a crossmatch""" @@ -196,6 +198,7 @@ def _create_crossmatch_df( right_idx: npt.NDArray[np.int64], extra_cols: pd.DataFrame, suffixes: tuple[str, str], + suffix_function=apply_suffix_all_columns, ) -> npd.NestedFrame: """Creates a df containing the crossmatch result from matching indices and additional columns @@ -209,8 +212,7 @@ def _create_crossmatch_df( additional columns added """ # rename columns so no same names during merging - self._rename_columns_with_suffix(self.left, suffixes[0]) - self._rename_columns_with_suffix(self.right, suffixes[1]) + self.left, self.right = suffix_function(self.left, self.right, suffixes) # concat dataframes together index_name = self.left.index.name if self.left.index.name is not None else "index" left_join_part = self.left.iloc[left_idx].reset_index() diff --git a/src/lsdb/dask/crossmatch_catalog_data.py b/src/lsdb/dask/crossmatch_catalog_data.py index f95040900..0544d7f46 100644 --- a/src/lsdb/dask/crossmatch_catalog_data.py +++ b/src/lsdb/dask/crossmatch_catalog_data.py @@ -20,6 +20,7 @@ generate_meta_df_for_joined_tables, generate_meta_df_for_nested_tables, get_healpix_pixels_from_alignment, + get_suffix_function, ) from lsdb.types import DaskDFPixelMap @@ -40,6 +41,7 @@ def perform_crossmatch( right_margin_catalog_info, algorithm, suffixes, + suffix_function, meta_df, **kwargs, ): @@ -66,7 +68,7 @@ def perform_crossmatch( left_catalog_info, right_catalog_info, right_margin_catalog_info, - ).crossmatch(suffixes, **kwargs) + ).crossmatch(suffixes, suffix_function=suffix_function, **kwargs) # pylint: disable=too-many-arguments, unused-argument @@ -119,6 +121,7 @@ def crossmatch_catalog_data( algorithm: ( Type[AbstractCrossmatchAlgorithm] | BuiltInCrossmatchAlgorithm ) = BuiltInCrossmatchAlgorithm.KD_TREE, + suffix_method: str | None = None, **kwargs, ) -> tuple[nd.NestedFrame, DaskDFPixelMap, PixelAlignment]: """Cross-matches the data from two catalogs @@ -155,9 +158,14 @@ def crossmatch_catalog_data( # get lists of HEALPix pixels from alignment to pass to cross-match left_pixels, right_pixels = get_healpix_pixels_from_alignment(alignment) + suffix_function = get_suffix_function(suffix_method) + # generate meta table structure for dask df meta_df = generate_meta_df_for_joined_tables( - [left, right], suffixes, extra_columns=crossmatch_algorithm.extra_columns + (left, right), + suffixes, + suffix_function=suffix_function, + extra_columns=crossmatch_algorithm.extra_columns, ) # perform the crossmatch on each partition pairing using dask delayed for lazy computation @@ -166,6 +174,7 @@ def crossmatch_catalog_data( perform_crossmatch, crossmatch_algorithm, suffixes, + suffix_function, meta_df, **kwargs, ) diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index 1255bc68f..2d2d19b6a 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -1,6 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +import logging +import warnings +from typing import TYPE_CHECKING, Callable, Sequence, Literal import hats.pixel_math.healpix_shim as hp import nested_pandas as npd @@ -18,6 +20,7 @@ from hats.pixel_tree import PixelAlignment, PixelAlignmentType, align_trees from hats.pixel_tree.moc_utils import copy_moc from hats.pixel_tree.pixel_alignment import align_with_mocs +from tabulate import tabulate import lsdb.nested as nd from lsdb.dask.divisions import get_pixels_divisions @@ -33,6 +36,85 @@ ASSOC_NPIX = "assoc_Npix" +def apply_suffix_all_columns( + left_df: npd.NestedFrame, right_df: npd.NestedFrame, suffixes: tuple[str, str] +) -> tuple[npd.NestedFrame, npd.NestedFrame]: + """Applies suffixes to all columns in both dataframes + + Args: + left_df (npd.NestedFrame): The left dataframe + right_df (npd.NestedFrame): The right dataframe + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + + Returns: + A tuple of the two dataframes with the suffixes applied + """ + left_suffix, right_suffix = suffixes + left_df = left_df.add_suffix(left_suffix) + right_df = right_df.add_suffix(right_suffix) + return left_df, right_df + + +def apply_suffix_overlapping_columns( + left_df: npd.NestedFrame, right_df: npd.NestedFrame, suffixes: tuple[str, str] +) -> tuple[npd.NestedFrame, npd.NestedFrame]: + """Applies suffixes to overlapping columns in both dataframes + + Logs an info message for each column that is being renamed. + + Args: + left_df (npd.NestedFrame): The left dataframe + right_df (npd.NestedFrame): The right dataframe + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + + Returns: + A tuple of the two dataframes with the suffixes applied + """ + left_suffix, right_suffix = suffixes + overlapping_columns = set(left_df.columns).intersection(set(right_df.columns)) + overlapping_columns = [c for c in overlapping_columns] + left_df = left_df.rename(columns={c: c + left_suffix for c in overlapping_columns}) + right_df = right_df.rename(columns={c: c + right_suffix for c in overlapping_columns}) + + table = tabulate( + [(c, c + left_suffix, c + right_suffix) for c in overlapping_columns], + headers=["Column", f"Left (suffix={left_suffix})", f"Right (suffix={right_suffix})"], + tablefmt="pretty", + ) + + logging.info(f"Renaming overlapping columns:\n{table}") + + return left_df, right_df + + +def get_suffix_function( + suffix_method: str | None = None, +) -> Callable[[npd.NestedFrame, npd.NestedFrame, tuple[str, str]], tuple[npd.NestedFrame, npd.NestedFrame]]: + """Gets a function that can be used to generate suffixes for columns based on a specified method + + Args: + suffix_method (str): The method to use to generate suffixes. Options are 'all_columns', 'overlapping_columns', + + Returns: + A function that takes in two dataframes and returns a tuple of the two dataframes with the suffixes applied + """ + if suffix_method is None: + suffix_method = "all_columns" + warnings.warn( + "The default suffix behavior will change from applying suffixes to all columns to only applying suffixes to overlapping columns in a future release." + "To maintain the current behavior, explicitly set `suffix_method='all_columns'`. To change to the new behavior, set `suffix_method='overlapping_columns'`.", + FutureWarning, + ) + + suffix_functions = { + "all_columns": apply_suffix_all_columns, + "overlapping_columns": apply_suffix_overlapping_columns, + } + if suffix_method not in suffix_functions: + raise ValueError(f"Invalid suffix method: {suffix_method}") + return suffix_functions[suffix_method] + + def concat_partition_and_margin( partition: npd.NestedFrame, margin: npd.NestedFrame | None ) -> npd.NestedFrame: @@ -524,8 +606,9 @@ def get_healpix_pixels_from_association( def generate_meta_df_for_joined_tables( - catalogs: Sequence[Catalog], - suffixes: Sequence[str], + catalogs: tuple[Catalog, Catalog], + suffixes: tuple[str, str], + suffix_function: Callable, extra_columns: pd.DataFrame | None = None, index_name: str = SPATIAL_INDEX_COLUMN, index_type: npt.DTypeLike | None = None, @@ -547,15 +630,14 @@ def generate_meta_df_for_joined_tables( An empty dataframe with the columns of each catalog with their respective suffix, and any extra columns specified, with the index name set. """ - meta = {} # Construct meta for crossmatched catalog columns - for table, suffix in zip(catalogs, suffixes): - for name, col_type in table.dtypes.items(): - if name not in paths.HIVE_COLUMNS: - meta[name + suffix] = pd.Series(dtype=col_type) + left_meta, right_meta = suffix_function( + catalogs[0]._ddf._meta, catalogs[1]._ddf._meta, suffixes # pylint: disable=protected-access + ) + meta = pd.concat([left_meta, right_meta], axis=1) # Construct meta for crossmatch result columns if extra_columns is not None: - meta.update(extra_columns) + meta = pd.concat([meta, extra_columns], axis=1) if index_type is None: # pylint: disable=protected-access index_type = catalogs[0]._ddf._meta.index.dtype diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index 72d36105a..6da5cc21e 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -12,7 +12,7 @@ from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm from lsdb.core.crossmatch.bounded_kdtree_match import BoundedKdTreeCrossmatch from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch -from lsdb.dask.merge_catalog_functions import align_catalogs +from lsdb.dask.merge_catalog_functions import align_catalogs, apply_suffix_all_columns @pytest.mark.parametrize("algo", [KdTreeCrossmatch]) @@ -440,11 +440,12 @@ class MockCrossmatchAlgorithmOverwrite(AbstractCrossmatchAlgorithm): extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.float64)}) - def crossmatch(self, suffixes, mock_results: pd.DataFrame = None): # type: ignore + def crossmatch( + self, suffixes, suffix_function=apply_suffix_all_columns, mock_results: pd.DataFrame = None, **kwargs + ): left_reset = self.left.reset_index(drop=True) right_reset = self.right.reset_index(drop=True) - self._rename_columns_with_suffix(self.left, suffixes[0]) - self._rename_columns_with_suffix(self.right, suffixes[1]) + self.left, self.right = suffix_function(self.left, self.right, suffixes) mock_results = mock_results[mock_results["ss_id"].isin(left_reset["id"].to_numpy())] left_indexes = mock_results.apply( lambda row: left_reset[left_reset["id"] == row["ss_id"]].index[0], axis=1 From f8cc93dd4821189d7509330b544e5202d28d13a5 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Tue, 16 Sep 2025 18:59:46 -0400 Subject: [PATCH 02/12] use suffix_method in join methods --- src/lsdb/catalog/catalog.py | 19 +++- .../abstract_crossmatch_algorithm.py | 9 +- src/lsdb/dask/crossmatch_catalog_data.py | 6 +- src/lsdb/dask/join_catalog_data.py | 93 +++++++++++++------ tests/lsdb/catalog/test_crossmatch.py | 7 +- 5 files changed, 92 insertions(+), 42 deletions(-) diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index 38ef5f83d..656742638 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -747,6 +747,7 @@ def merge_asof( direction: str = "backward", suffixes: tuple[str, str] | None = None, output_catalog_name: str | None = None, + suffix_method: str | None = None, ): """Uses the pandas `merge_asof` function to merge two catalogs on their indices by distance of keys @@ -760,6 +761,12 @@ def merge_asof( other (lsdb.Catalog): the right catalog to merge to suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names direction (str): the direction to perform the merge_asof + output_catalog_name (str): The name of the resulting catalog to be stored in metadata + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A new catalog with the columns from each of the input catalogs with their respective suffixes @@ -771,7 +778,7 @@ def merge_asof( if len(suffixes) != 2: raise ValueError("`suffixes` must be a tuple with two strings") - ddf, ddf_map, alignment = merge_asof_catalog_data(self, other, suffixes=suffixes, direction=direction) + ddf, ddf_map, alignment = merge_asof_catalog_data(self, other, suffixes=suffixes, direction=direction, suffix_method=suffix_method) if output_catalog_name is None: output_catalog_name = ( @@ -795,6 +802,7 @@ def join( through: AssociationCatalog | None = None, suffixes: tuple[str, str] | None = None, output_catalog_name: str | None = None, + suffix_method: str | None = None, ) -> Catalog: """Perform a spatial join to another catalog @@ -810,6 +818,11 @@ def join( between pixels and individual rows. suffixes (Tuple[str,str]): suffixes to apply to the columns of each table output_catalog_name (str): The name of the resulting catalog to be stored in metadata + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A new catalog with the columns from each of the input catalogs with their respective suffixes @@ -824,7 +837,7 @@ def join( self._check_unloaded_columns([left_on, right_on]) if through is not None: - ddf, ddf_map, alignment = join_catalog_data_through(self, other, through, suffixes) + ddf, ddf_map, alignment = join_catalog_data_through(self, other, through, suffixes, suffix_method=suffix_method) else: if left_on is None or right_on is None: raise ValueError("Either both of left_on and right_on, or through must be set") @@ -832,7 +845,7 @@ def join( raise ValueError("left_on must be a column in the left catalog") if right_on not in other._ddf.columns: raise ValueError("right_on must be a column in the right catalog") - ddf, ddf_map, alignment = join_catalog_data_on(self, other, left_on, right_on, suffixes) + ddf, ddf_map, alignment = join_catalog_data_on(self, other, left_on, right_on, suffixes, suffix_method=suffix_method) if output_catalog_name is None: output_catalog_name = self.hc_structure.catalog_info.catalog_name diff --git a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py index 151f45fff..94c5bde62 100644 --- a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py +++ b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py @@ -10,7 +10,7 @@ from hats.catalog import TableProperties from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN -from lsdb.dask.merge_catalog_functions import apply_suffix_all_columns +from lsdb.dask.merge_catalog_functions import apply_suffix_all_columns, get_suffix_function if TYPE_CHECKING: from lsdb.catalog import Catalog @@ -98,14 +98,14 @@ def __init__( self.right_catalog_info = right_catalog_info self.right_margin_catalog_info = right_margin_catalog_info - def crossmatch(self, suffixes, suffix_function=apply_suffix_all_columns, **kwargs) -> npd.NestedFrame: + def crossmatch(self, suffixes, suffix_method="all_columns", **kwargs) -> npd.NestedFrame: """Perform a crossmatch""" l_inds, r_inds, extra_cols = self.perform_crossmatch(**kwargs) if not len(l_inds) == len(r_inds) == len(extra_cols): raise ValueError( "Crossmatch algorithm must return left and right indices and extra columns with same length" ) - return self._create_crossmatch_df(l_inds, r_inds, extra_cols, suffixes, suffix_function) + return self._create_crossmatch_df(l_inds, r_inds, extra_cols, suffixes, suffix_method) def crossmatch_nested(self, nested_column_name, **kwargs) -> npd.NestedFrame: """Perform a crossmatch""" @@ -198,7 +198,7 @@ def _create_crossmatch_df( right_idx: npt.NDArray[np.int64], extra_cols: pd.DataFrame, suffixes: tuple[str, str], - suffix_function=apply_suffix_all_columns, + suffix_method="all_columns", ) -> npd.NestedFrame: """Creates a df containing the crossmatch result from matching indices and additional columns @@ -212,6 +212,7 @@ def _create_crossmatch_df( additional columns added """ # rename columns so no same names during merging + suffix_function = get_suffix_function(suffix_method) self.left, self.right = suffix_function(self.left, self.right, suffixes) # concat dataframes together index_name = self.left.index.name if self.left.index.name is not None else "index" diff --git a/src/lsdb/dask/crossmatch_catalog_data.py b/src/lsdb/dask/crossmatch_catalog_data.py index 0544d7f46..919bea5e3 100644 --- a/src/lsdb/dask/crossmatch_catalog_data.py +++ b/src/lsdb/dask/crossmatch_catalog_data.py @@ -41,7 +41,7 @@ def perform_crossmatch( right_margin_catalog_info, algorithm, suffixes, - suffix_function, + suffix_method, meta_df, **kwargs, ): @@ -68,7 +68,7 @@ def perform_crossmatch( left_catalog_info, right_catalog_info, right_margin_catalog_info, - ).crossmatch(suffixes, suffix_function=suffix_function, **kwargs) + ).crossmatch(suffixes, suffix_method=suffix_method, **kwargs) # pylint: disable=too-many-arguments, unused-argument @@ -174,7 +174,7 @@ def crossmatch_catalog_data( perform_crossmatch, crossmatch_algorithm, suffixes, - suffix_function, + suffix_method, meta_df, **kwargs, ) diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py index 1d5fc79dd..63ed1b607 100644 --- a/src/lsdb/dask/join_catalog_data.py +++ b/src/lsdb/dask/join_catalog_data.py @@ -25,6 +25,7 @@ generate_meta_df_for_nested_tables, get_healpix_pixels_from_alignment, get_healpix_pixels_from_association, + get_suffix_function, ) from lsdb.types import DaskDFPixelMap @@ -35,24 +36,6 @@ NON_JOINING_ASSOCIATION_COLUMNS = ["Norder", "Dir", "Npix", "join_Norder", "join_Dir", "join_Npix"] -def rename_columns_with_suffixes(left: npd.NestedFrame, right: npd.NestedFrame, suffixes: tuple[str, str]): - """Renames two dataframes with the suffixes specified - - Args: - left (npd.NestedFrame): the left dataframe to apply the first suffix to - right (npd.NestedFrame): the right dataframe to apply the second suffix to - suffixes (Tuple[str, str]): the pair of suffixes to apply to the dataframes - - Returns: - A tuple of (left, right) updated dataframes with their columns renamed - """ - left_columns_renamed = {name: name + suffixes[0] for name in left.columns} - left = left.rename(columns=left_columns_renamed) - right_columns_renamed = {name: name + suffixes[1] for name in right.columns} - right = right.rename(columns=right_columns_renamed) - return left, right - - # pylint: disable=too-many-arguments, unused-argument def perform_join_on( left: npd.NestedFrame, @@ -67,6 +50,7 @@ def perform_join_on( left_on: str, right_on: str, suffixes: tuple[str, str], + suffix_method: str | None = None, ): """Performs a join on two catalog partitions @@ -83,6 +67,10 @@ def perform_join_on( left_on (str): the column to join on from the left partition right_on (str): the column to join on from the right partition suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Returns: A dataframe with the result of merging the left and right partitions on the specified columns @@ -92,7 +80,8 @@ def perform_join_on( right_joined_df = concat_partition_and_margin(right, right_margin) - left, right_joined_df = rename_columns_with_suffixes(left, right_joined_df, suffixes) + suffix_function = get_suffix_function(suffix_method) + left, right_joined_df = suffix_function(left, right_joined_df, suffixes) merged = left.reset_index().merge( right_joined_df, left_on=left_on + suffixes[0], right_on=right_on + suffixes[1] ) @@ -162,6 +151,7 @@ def perform_join_through( right_margin_catalog_info: TableProperties, assoc_catalog_info: TableProperties, suffixes: tuple[str, str], + suffix_method: str | None = None, ): """Performs a join on two catalog partitions through an association catalog @@ -180,6 +170,10 @@ def perform_join_through( catalog assoc_catalog_info (hc.TableProperties): the hats structure of the association catalog suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Returns: A dataframe with the result of merging the left and right partitions on the specified columns @@ -191,7 +185,8 @@ def perform_join_through( right_joined_df = concat_partition_and_margin(right, right_margin) - left, right_joined_df = rename_columns_with_suffixes(left, right_joined_df, suffixes) + suffix_function = get_suffix_function(suffix_method) + left, right_joined_df = suffix_function(left, right_joined_df, suffixes) # Edge case: if right_column + suffix == join_column_association, columns will be in the wrong order # so rename association column @@ -241,6 +236,7 @@ def perform_merge_asof( right_catalog_info: TableProperties, suffixes: tuple[str, str], direction: str, + suffix_method: str | None = None, ): """Performs a merge_asof on two catalog partitions @@ -253,6 +249,10 @@ def perform_merge_asof( right_catalog_info (hc.TableProperties): the catalog info of the right catalog suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names direction (str): The direction to perform the merge_asof + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Returns: A dataframe with the result of merging the left and right partitions on the specified columns with @@ -261,7 +261,8 @@ def perform_merge_asof( if right_pixel.order > left_pixel.order: left = filter_by_spatial_index_to_pixel(left, right_pixel.order, right_pixel.pixel) - left, right = rename_columns_with_suffixes(left, right, suffixes) + suffix_function = get_suffix_function(suffix_method) + left, right = suffix_function(left, right, suffixes) left.sort_index(inplace=True) right.sort_index(inplace=True) merged = pd.merge_asof(left, right, left_index=True, right_index=True, direction=direction) @@ -269,7 +270,12 @@ def perform_merge_asof( def join_catalog_data_on( - left: Catalog, right: Catalog, left_on: str, right_on: str, suffixes: tuple[str, str] + left: Catalog, + right: Catalog, + left_on: str, + right_on: str, + suffixes: tuple[str, str], + suffix_method: str | None = None, ) -> tuple[nd.NestedFrame, DaskDFPixelMap, PixelAlignment]: """Joins two catalogs spatially on a specified column @@ -279,6 +285,11 @@ def join_catalog_data_on( left_on (str): the column to join on from the left partition right_on (str): the column to join on from the right partition suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A tuple of the dask dataframe with the result of the join, the pixel map from HEALPix @@ -301,9 +312,11 @@ def join_catalog_data_on( left_on, right_on, suffixes, + suffix_method, ) - meta_df = generate_meta_df_for_joined_tables([left, right], suffixes) + suffix_function = get_suffix_function(suffix_method) + meta_df = generate_meta_df_for_joined_tables([left, right], suffixes, suffix_function=suffix_function) return construct_catalog_args(joined_partitions, meta_df, alignment) @@ -358,7 +371,11 @@ def join_catalog_data_nested( def join_catalog_data_through( - left: Catalog, right: Catalog, association: AssociationCatalog, suffixes: tuple[str, str] + left: Catalog, + right: Catalog, + association: AssociationCatalog, + suffixes: tuple[str, str], + suffix_method: str | None = None, ) -> tuple[nd.NestedFrame, DaskDFPixelMap, PixelAlignment]: """Joins two catalogs with an association table @@ -367,6 +384,11 @@ def join_catalog_data_through( right (lsdb.Catalog): the right catalog to join association (AssociationCatalog): the association catalog to join the catalogs with suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A tuple of the dask dataframe with the result of the join, the pixel map from HEALPix @@ -387,6 +409,7 @@ def join_catalog_data_through( association.hc_structure.catalog_info.primary_column, association.hc_structure.catalog_info.join_column, suffixes, + suffix_method=suffix_method, ) if right.margin is None: @@ -420,6 +443,7 @@ def join_catalog_data_through( ], perform_join_through, suffixes, + suffix_method, ) association_join_columns = [ @@ -430,13 +454,20 @@ def join_catalog_data_through( # pylint: disable=protected-access extra_df = association._ddf._meta.drop(non_joining_columns + association_join_columns, axis=1) - meta_df = generate_meta_df_for_joined_tables([left, extra_df, right], [suffixes[0], "", suffixes[1]]) + suffix_function = get_suffix_function(suffix_method) + meta_df = generate_meta_df_for_joined_tables( + [left, right], [suffixes[0], suffixes[1]], extra_columns=extra_df, suffix_function=suffix_function + ) return construct_catalog_args(joined_partitions, meta_df, alignment) def merge_asof_catalog_data( - left: Catalog, right: Catalog, suffixes: tuple[str, str], direction: str = "backward" + left: Catalog, + right: Catalog, + suffixes: tuple[str, str], + direction: str = "backward", + suffix_method: str | None = None, ) -> tuple[nd.NestedFrame, DaskDFPixelMap, PixelAlignment]: """Uses the pandas `merge_asof` function to merge two catalogs on their indices by distance of keys @@ -451,6 +482,11 @@ def merge_asof_catalog_data( right (lsdb.Catalog): the right catalog to join suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names direction (str): the direction to perform the merge_asof + suffix_method (str): Method to use to add suffixes to columns. Options are: + - "overlapping_columns": only add suffixes to columns that are present in both catalogs + - "all_columns": add suffixes to all columns from both catalogs + Default: "all_columns" Warning: This default will change to "overlapping_columns" in a future + release. Returns: A tuple of the dask dataframe with the result of the join, the pixel map from HEALPix @@ -463,9 +499,10 @@ def merge_asof_catalog_data( left_pixels, right_pixels = get_healpix_pixels_from_alignment(alignment) joined_partitions = align_and_apply( - [(left, left_pixels), (right, right_pixels)], perform_merge_asof, suffixes, direction + [(left, left_pixels), (right, right_pixels)], perform_merge_asof, suffixes, direction, suffix_method ) - meta_df = generate_meta_df_for_joined_tables([left, right], suffixes) + suffix_function = get_suffix_function(suffix_method) + meta_df = generate_meta_df_for_joined_tables([left, right], suffixes, suffix_function=suffix_function) return construct_catalog_args(joined_partitions, meta_df, alignment) diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index 6da5cc21e..ad41d71a7 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -12,7 +12,7 @@ from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm from lsdb.core.crossmatch.bounded_kdtree_match import BoundedKdTreeCrossmatch from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch -from lsdb.dask.merge_catalog_functions import align_catalogs, apply_suffix_all_columns +from lsdb.dask.merge_catalog_functions import align_catalogs, get_suffix_function @pytest.mark.parametrize("algo", [KdTreeCrossmatch]) @@ -440,11 +440,10 @@ class MockCrossmatchAlgorithmOverwrite(AbstractCrossmatchAlgorithm): extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.float64)}) - def crossmatch( - self, suffixes, suffix_function=apply_suffix_all_columns, mock_results: pd.DataFrame = None, **kwargs - ): + def crossmatch(self, suffixes, suffix_method="all_columns", mock_results: pd.DataFrame = None, **kwargs): left_reset = self.left.reset_index(drop=True) right_reset = self.right.reset_index(drop=True) + suffix_function = get_suffix_function(suffix_method) self.left, self.right = suffix_function(self.left, self.right, suffixes) mock_results = mock_results[mock_results["ss_id"].isin(left_reset["id"].to_numpy())] left_indexes = mock_results.apply( From 844a23cc519298396eef78c08078e4807079232c Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Wed, 17 Sep 2025 17:39:00 -0400 Subject: [PATCH 03/12] fix behavior and unit test --- src/lsdb/dask/join_catalog_data.py | 37 +++++++++-- src/lsdb/dask/merge_catalog_functions.py | 38 ++++++++++- tests/conftest.py | 28 +++++++- tests/lsdb/catalog/test_crossmatch.py | 25 ++++++++ tests/lsdb/catalog/test_join.py | 81 ++++++++++++++++++++++++ 5 files changed, 202 insertions(+), 7 deletions(-) diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py index 63ed1b607..3f6baed4c 100644 --- a/src/lsdb/dask/join_catalog_data.py +++ b/src/lsdb/dask/join_catalog_data.py @@ -26,6 +26,8 @@ get_healpix_pixels_from_alignment, get_healpix_pixels_from_association, get_suffix_function, + apply_left_suffix, + apply_right_suffix, ) from lsdb.types import DaskDFPixelMap @@ -82,9 +84,17 @@ def perform_join_on( suffix_function = get_suffix_function(suffix_method) left, right_joined_df = suffix_function(left, right_joined_df, suffixes) - merged = left.reset_index().merge( - right_joined_df, left_on=left_on + suffixes[0], right_on=right_on + suffixes[1] + + left_join_column = ( + left_on if left_on in left.columns else apply_left_suffix(left_on, suffix_function, suffixes) + ) + right_join_column = ( + right_on + if right_on in right_joined_df.columns + else apply_right_suffix(right_on, suffix_function, suffixes) ) + + merged = left.reset_index().merge(right_joined_df, left_on=left_join_column, right_on=right_join_column) merged.set_index(SPATIAL_INDEX_COLUMN, inplace=True) return merged @@ -188,6 +198,17 @@ def perform_join_through( suffix_function = get_suffix_function(suffix_method) left, right_joined_df = suffix_function(left, right_joined_df, suffixes) + left_join_column = ( + assoc_catalog_info.primary_column + if assoc_catalog_info.primary_column in left.columns + else apply_left_suffix(assoc_catalog_info.primary_column, suffix_function, suffixes) + ) + right_join_column = ( + assoc_catalog_info.join_column + if assoc_catalog_info.join_column in right_joined_df.columns + else apply_right_suffix(assoc_catalog_info.join_column, suffix_function, suffixes) + ) + # Edge case: if right_column + suffix == join_column_association, columns will be in the wrong order # so rename association column join_column_association = assoc_catalog_info.join_column_association @@ -197,8 +218,9 @@ def perform_join_through( columns={assoc_catalog_info.join_column_association: join_column_association}, inplace=True ) + join_columns = [assoc_catalog_info.primary_column_association, join_column_association] join_columns_to_drop = [] - for c in [assoc_catalog_info.primary_column_association, join_column_association]: + for c in join_columns: if c not in left.columns and c not in right_joined_df.columns and c not in join_columns_to_drop: join_columns_to_drop.append(c) @@ -210,16 +232,21 @@ def perform_join_through( left.reset_index() .merge( through, - left_on=assoc_catalog_info.primary_column + suffixes[0], + left_on=left_join_column, right_on=assoc_catalog_info.primary_column_association, ) .merge( right_joined_df, left_on=join_column_association, - right_on=assoc_catalog_info.join_column + suffixes[1], + right_on=right_join_column, ) ) + extra_join_cols = through.columns.drop(join_columns + cols_to_drop) + other_cols = merged.columns.drop(extra_join_cols) + + merged = merged[other_cols.append(extra_join_cols)] + merged.set_index(SPATIAL_INDEX_COLUMN, inplace=True) if len(join_columns_to_drop) > 0: merged.drop(join_columns_to_drop, axis=1, inplace=True) diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index 2d2d19b6a..f0ec77800 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -35,6 +35,8 @@ ASSOC_NORDER = "assoc_Norder" ASSOC_NPIX = "assoc_Npix" +DEFAULT_SUFFIX_METHOD: Literal["all_columns", "overlapping_columns"] = "all_columns" + def apply_suffix_all_columns( left_df: npd.NestedFrame, right_df: npd.NestedFrame, suffixes: tuple[str, str] @@ -99,7 +101,7 @@ def get_suffix_function( A function that takes in two dataframes and returns a tuple of the two dataframes with the suffixes applied """ if suffix_method is None: - suffix_method = "all_columns" + suffix_method = DEFAULT_SUFFIX_METHOD warnings.warn( "The default suffix behavior will change from applying suffixes to all columns to only applying suffixes to overlapping columns in a future release." "To maintain the current behavior, explicitly set `suffix_method='all_columns'`. To change to the new behavior, set `suffix_method='overlapping_columns'`.", @@ -115,6 +117,40 @@ def get_suffix_function( return suffix_functions[suffix_method] +def apply_left_suffix(col_name: str, suffix_function: Callable, suffixes: tuple[str, str]) -> str: + """Applies the left suffix to a column name using the specified suffix function + + Args: + col_name (str): The column name to apply the suffix to + suffix_function (Callable): The function to use to apply the suffix + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + + Returns: + The column name with the left suffix applied + """ + left_df = npd.NestedFrame(columns=[col_name]) + right_df = npd.NestedFrame(columns=[col_name]) + left_df, _ = suffix_function(left_df, right_df, suffixes) + return left_df.columns[0] + + +def apply_right_suffix(col_name: str, suffix_function: Callable, suffixes: tuple[str, str]) -> str: + """Applies the right suffix to a column name using the specified suffix function + + Args: + col_name (str): The column name to apply the suffix to + suffix_function (Callable): The function to use to apply the suffix + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + + Returns: + The column name with the right suffix applied + """ + left_df = npd.NestedFrame(columns=[col_name]) + right_df = npd.NestedFrame(columns=[col_name]) + _, right_df = suffix_function(left_df, right_df, suffixes) + return right_df.columns[0] + + def concat_partition_and_margin( partition: npd.NestedFrame, margin: npd.NestedFrame | None ) -> npd.NestedFrame: diff --git a/tests/conftest.py b/tests/conftest.py index fc3d6e994..7c52f14e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ import lsdb import lsdb.nested as nd +from lsdb.dask.merge_catalog_functions import DEFAULT_SUFFIX_METHOD DATA_DIR_NAME = "data" SMALL_SKY_DIR_NAME = "small_sky" @@ -452,13 +453,38 @@ def assert_default_columns_in_columns(cat): for col in cat.hc_structure.catalog_info.default_columns: assert col in cat._ddf.columns + @classmethod + def assert_columns_in_joined_catalog( + cls, joined_cat, cats, suffixes, suffix_method=DEFAULT_SUFFIX_METHOD + ): + assert_methods = { + "all_columns": cls.assert_all_suffix_columns_in_joined_catalog, + "overlapping_columns": cls.assert_overlapping_suffix_columns_in_joined_catalog, + } + assert_method = assert_methods.get(suffix_method) + if assert_method is None: + raise ValueError(f"Unknown suffix_strategy: {suffix_method}") + assert_method(joined_cat, cats, suffixes) + @staticmethod - def assert_columns_in_joined_catalog(joined_cat, cats, suffixes): + def assert_all_suffix_columns_in_joined_catalog(joined_cat, cats, suffixes): for cat, suffix in zip(cats, suffixes): for col_name, dtype in cat.dtypes.items(): if col_name not in paths.HIVE_COLUMNS: assert (col_name + suffix, dtype) in joined_cat.dtypes.items() + @staticmethod + def assert_overlapping_suffix_columns_in_joined_catalog(joined_cat, cats, suffixes): + cat_columns = [set(cat.columns) for cat in cats] + overlapping_columns = cat_columns[0].intersection(*cat_columns[1:]) + for cat, suffix in zip(cats, suffixes): + for col_name, dtype in cat.dtypes.items(): + if col_name not in paths.HIVE_COLUMNS: + if col_name in overlapping_columns: + assert (col_name + suffix, dtype) in joined_cat.dtypes.items() + else: + assert (col_name, dtype) in joined_cat.dtypes.items() + @staticmethod def assert_columns_in_nested_joined_catalog( joined_cat, left_cat, right_cat, right_ignore_columns, nested_colname diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index ad41d71a7..e8a505204 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -215,6 +215,31 @@ def test_crossmatch_negative_margin( assert len(xmatch_row) == 1 assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) + @staticmethod + def test_overlapping_suffix_method(algo, small_sky_catalog, small_sky_xmatch_catalog): + suffixes = ("_left", "_right") + xmatched = small_sky_catalog.crossmatch( + small_sky_xmatch_catalog, + algorithm=algo, + suffix_method="overlapping_columns", + suffixes=suffixes, + ) + computed = xmatched.compute() + for col in small_sky_catalog.columns: + if col in small_sky_xmatch_catalog.columns: + assert f"{col}{suffixes[0]}" in xmatched.columns + assert f"{col}{suffixes[0]}" in computed.columns + else: + assert col in xmatched.columns + assert col in computed.columns + for col in small_sky_xmatch_catalog.columns: + if col in small_sky_catalog.columns: + assert f"{col}{suffixes[1]}" in xmatched.columns + assert f"{col}{suffixes[1]}" in computed.columns + else: + assert col in xmatched.columns + assert col in computed.columns + @staticmethod def test_wrong_suffixes(algo, small_sky_catalog, small_sky_xmatch_catalog): with pytest.raises(ValueError): diff --git a/tests/lsdb/catalog/test_join.py b/tests/lsdb/catalog/test_join.py index 7f392e716..43bbfd80a 100644 --- a/tests/lsdb/catalog/test_join.py +++ b/tests/lsdb/catalog/test_join.py @@ -40,6 +40,33 @@ def test_small_sky_join_small_sky_order1(small_sky_catalog, small_sky_order1_cat assert not joined.hc_structure.on_disk +def test_small_sky_join_overlapping_suffix(small_sky_catalog, small_sky_order1_catalog, helpers): + suffixes = ("_a", "_b") + with pytest.warns(match="margin"): + joined = small_sky_catalog.join( + small_sky_order1_catalog, + left_on="id", + right_on="id", + suffixes=suffixes, + suffix_method="overlapping_columns", + ) + assert isinstance(joined._ddf, nd.NestedFrame) + helpers.assert_columns_in_joined_catalog( + joined, [small_sky_catalog, small_sky_order1_catalog], suffixes, suffix_method="overlapping_columns" + ) + + joined_compute = joined.compute() + + helpers.assert_columns_in_joined_catalog( + joined_compute, + [small_sky_catalog, small_sky_order1_catalog], + suffixes, + suffix_method="overlapping_columns", + ) + helpers.assert_divisions_are_correct(joined) + helpers.assert_schema_correct(joined) + + def test_small_sky_join_small_sky_order1_source( small_sky_catalog, small_sky_order1_source_with_margin, helpers ): @@ -147,6 +174,35 @@ def test_join_association( assert joined_row.index == left_row.index +def test_join_association_overlapping_suffix( + small_sky_catalog, small_sky_order1_source_collection_catalog, small_sky_to_o1source_catalog, helpers +): + suffixes = ("_a", "_b") + joined = small_sky_catalog.join( + small_sky_order1_source_collection_catalog, + through=small_sky_to_o1source_catalog, + suffixes=suffixes, + suffix_method="overlapping_columns", + ) + helpers.assert_columns_in_joined_catalog( + joined, + [small_sky_catalog, small_sky_order1_source_collection_catalog], + suffixes, + suffix_method="overlapping_columns", + ) + + joined_compute = joined.compute() + + helpers.assert_columns_in_joined_catalog( + joined_compute, + [small_sky_catalog, small_sky_order1_source_collection_catalog], + suffixes, + suffix_method="overlapping_columns", + ) + helpers.assert_divisions_are_correct(joined) + helpers.assert_schema_correct(joined) + + def test_join_association_suffix_edge_case( small_sky_catalog, small_sky_order1_source_collection_catalog, small_sky_to_o1source_catalog ): @@ -263,6 +319,31 @@ def test_merge_asof(small_sky_catalog, small_sky_xmatch_catalog, direction, help pd.testing.assert_frame_equal(joined_compute.drop(columns=drop_cols), correct_result) +def test_merge_asof_overlapping_suffix(small_sky_catalog, small_sky_xmatch_catalog, helpers): + suffixes = ("_a", "_b") + joined = small_sky_catalog.merge_asof( + small_sky_xmatch_catalog, direction="backward", suffixes=suffixes, suffix_method="overlapping_columns" + ) + helpers.assert_columns_in_joined_catalog( + joined, + [small_sky_catalog, small_sky_xmatch_catalog], + suffixes, + suffix_method="overlapping_columns", + ) + helpers.assert_divisions_are_correct(joined) + + joined_compute = joined.compute() + + helpers.assert_columns_in_joined_catalog( + joined_compute, + [small_sky_catalog, small_sky_xmatch_catalog], + suffixes, + suffix_method="overlapping_columns", + ) + helpers.assert_divisions_are_correct(joined) + helpers.assert_schema_correct(joined) + + def merging_function(input_frame, map_input, *args, **kwargs): if len(input_frame) == 0: ## this is the empty call to infer meta From 23df7a7428c7b9076999b6a9e88573a5367989b0 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Wed, 17 Sep 2025 17:48:05 -0400 Subject: [PATCH 04/12] lint --- pyproject.toml | 1 + src/lsdb/catalog/catalog.py | 12 +++++++++--- .../crossmatch/abstract_crossmatch_algorithm.py | 2 +- src/lsdb/dask/join_catalog_data.py | 12 ++++++------ src/lsdb/dask/merge_catalog_functions.py | 17 ++++++++++------- 5 files changed, 27 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0f12629b3..8fc82038a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dev = [ "pytest", "pytest-cov", # Used to report total code coverage "pytest-mock", # Used to mock objects in tests + "types-tabulate", # Type information for tabulate ] full = [ "fsspec[full]", # complete file system specs. diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index 656742638..e687cd5bb 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -778,7 +778,9 @@ def merge_asof( if len(suffixes) != 2: raise ValueError("`suffixes` must be a tuple with two strings") - ddf, ddf_map, alignment = merge_asof_catalog_data(self, other, suffixes=suffixes, direction=direction, suffix_method=suffix_method) + ddf, ddf_map, alignment = merge_asof_catalog_data( + self, other, suffixes=suffixes, direction=direction, suffix_method=suffix_method + ) if output_catalog_name is None: output_catalog_name = ( @@ -837,7 +839,9 @@ def join( self._check_unloaded_columns([left_on, right_on]) if through is not None: - ddf, ddf_map, alignment = join_catalog_data_through(self, other, through, suffixes, suffix_method=suffix_method) + ddf, ddf_map, alignment = join_catalog_data_through( + self, other, through, suffixes, suffix_method=suffix_method + ) else: if left_on is None or right_on is None: raise ValueError("Either both of left_on and right_on, or through must be set") @@ -845,7 +849,9 @@ def join( raise ValueError("left_on must be a column in the left catalog") if right_on not in other._ddf.columns: raise ValueError("right_on must be a column in the right catalog") - ddf, ddf_map, alignment = join_catalog_data_on(self, other, left_on, right_on, suffixes, suffix_method=suffix_method) + ddf, ddf_map, alignment = join_catalog_data_on( + self, other, left_on, right_on, suffixes, suffix_method=suffix_method + ) if output_catalog_name is None: output_catalog_name = self.hc_structure.catalog_info.catalog_name diff --git a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py index 94c5bde62..70c82bb5e 100644 --- a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py +++ b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py @@ -10,7 +10,7 @@ from hats.catalog import TableProperties from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN -from lsdb.dask.merge_catalog_functions import apply_suffix_all_columns, get_suffix_function +from lsdb.dask.merge_catalog_functions import get_suffix_function if TYPE_CHECKING: from lsdb.catalog import Catalog diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py index 3f6baed4c..ddc9b1c60 100644 --- a/src/lsdb/dask/join_catalog_data.py +++ b/src/lsdb/dask/join_catalog_data.py @@ -18,6 +18,8 @@ align_and_apply, align_catalogs, align_catalogs_with_association, + apply_left_suffix, + apply_right_suffix, concat_partition_and_margin, construct_catalog_args, filter_by_spatial_index_to_pixel, @@ -26,8 +28,6 @@ get_healpix_pixels_from_alignment, get_healpix_pixels_from_association, get_suffix_function, - apply_left_suffix, - apply_right_suffix, ) from lsdb.types import DaskDFPixelMap @@ -146,7 +146,7 @@ def perform_join_nested( return merged -# pylint: disable=too-many-arguments, unused-argument +# pylint: disable=too-many-arguments, unused-argument, too-many-locals def perform_join_through( left: npd.NestedFrame, right: npd.NestedFrame, @@ -343,7 +343,7 @@ def join_catalog_data_on( ) suffix_function = get_suffix_function(suffix_method) - meta_df = generate_meta_df_for_joined_tables([left, right], suffixes, suffix_function=suffix_function) + meta_df = generate_meta_df_for_joined_tables((left, right), suffixes, suffix_function=suffix_function) return construct_catalog_args(joined_partitions, meta_df, alignment) @@ -483,7 +483,7 @@ def join_catalog_data_through( extra_df = association._ddf._meta.drop(non_joining_columns + association_join_columns, axis=1) suffix_function = get_suffix_function(suffix_method) meta_df = generate_meta_df_for_joined_tables( - [left, right], [suffixes[0], suffixes[1]], extra_columns=extra_df, suffix_function=suffix_function + (left, right), suffixes, extra_columns=extra_df, suffix_function=suffix_function ) return construct_catalog_args(joined_partitions, meta_df, alignment) @@ -530,6 +530,6 @@ def merge_asof_catalog_data( ) suffix_function = get_suffix_function(suffix_method) - meta_df = generate_meta_df_for_joined_tables([left, right], suffixes, suffix_function=suffix_function) + meta_df = generate_meta_df_for_joined_tables((left, right), suffixes, suffix_function=suffix_function) return construct_catalog_args(joined_partitions, meta_df, alignment) diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index f0ec77800..a925f079d 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -2,7 +2,7 @@ import logging import warnings -from typing import TYPE_CHECKING, Callable, Sequence, Literal +from typing import TYPE_CHECKING, Callable, Literal, Sequence import hats.pixel_math.healpix_shim as hp import nested_pandas as npd @@ -74,7 +74,6 @@ def apply_suffix_overlapping_columns( """ left_suffix, right_suffix = suffixes overlapping_columns = set(left_df.columns).intersection(set(right_df.columns)) - overlapping_columns = [c for c in overlapping_columns] left_df = left_df.rename(columns={c: c + left_suffix for c in overlapping_columns}) right_df = right_df.rename(columns={c: c + right_suffix for c in overlapping_columns}) @@ -84,7 +83,7 @@ def apply_suffix_overlapping_columns( tablefmt="pretty", ) - logging.info(f"Renaming overlapping columns:\n{table}") + logging.info("Renaming overlapping columns:\n%s", table) return left_df, right_df @@ -95,16 +94,20 @@ def get_suffix_function( """Gets a function that can be used to generate suffixes for columns based on a specified method Args: - suffix_method (str): The method to use to generate suffixes. Options are 'all_columns', 'overlapping_columns', + suffix_method (str): The method to use to generate suffixes. Options are 'all_columns', + 'overlapping_columns', Returns: - A function that takes in two dataframes and returns a tuple of the two dataframes with the suffixes applied + A function that takes in two dataframes and returns a tuple of the two dataframes with the suffixes + applied """ if suffix_method is None: suffix_method = DEFAULT_SUFFIX_METHOD warnings.warn( - "The default suffix behavior will change from applying suffixes to all columns to only applying suffixes to overlapping columns in a future release." - "To maintain the current behavior, explicitly set `suffix_method='all_columns'`. To change to the new behavior, set `suffix_method='overlapping_columns'`.", + "The default suffix behavior will change from applying suffixes to all columns to only " + "applying suffixes to overlapping columns in a future release." + "To maintain the current behavior, explicitly set `suffix_method='all_columns'`. " + "To change to the new behavior, set `suffix_method='overlapping_columns'`.", FutureWarning, ) From 5c2ad321aa33be32fc2a1079c44cf02926a2d5c8 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Thu, 18 Sep 2025 16:29:24 -0400 Subject: [PATCH 05/12] add test of logging --- tests/lsdb/catalog/test_crossmatch.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index e8a505204..65d36aa27 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -1,3 +1,5 @@ +import logging + import nested_pandas as npd import numpy as np import pandas as pd @@ -216,19 +218,26 @@ def test_crossmatch_negative_margin( assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) @staticmethod - def test_overlapping_suffix_method(algo, small_sky_catalog, small_sky_xmatch_catalog): + def test_overlapping_suffix_method(algo, small_sky_catalog, small_sky_xmatch_catalog, caplog): suffixes = ("_left", "_right") - xmatched = small_sky_catalog.crossmatch( - small_sky_xmatch_catalog, - algorithm=algo, - suffix_method="overlapping_columns", - suffixes=suffixes, - ) + # Test that remaned columns are logged correctly + with caplog.at_level(logging.INFO): + xmatched = small_sky_catalog.crossmatch( + small_sky_xmatch_catalog, + algorithm=algo, + suffix_method="overlapping_columns", + suffixes=suffixes, + ) + + assert "Renaming overlapping columns" in caplog.text + computed = xmatched.compute() for col in small_sky_catalog.columns: if col in small_sky_xmatch_catalog.columns: assert f"{col}{suffixes[0]}" in xmatched.columns assert f"{col}{suffixes[0]}" in computed.columns + assert col in caplog.text + assert f"{col}{suffixes[0]}" in caplog.text else: assert col in xmatched.columns assert col in computed.columns @@ -236,6 +245,7 @@ def test_overlapping_suffix_method(algo, small_sky_catalog, small_sky_xmatch_cat if col in small_sky_catalog.columns: assert f"{col}{suffixes[1]}" in xmatched.columns assert f"{col}{suffixes[1]}" in computed.columns + assert f"{col}{suffixes[1]}" in caplog.text else: assert col in xmatched.columns assert col in computed.columns From 17d06146014a6bb136c950ac5c4db909674e0f8a Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Thu, 18 Sep 2025 16:54:30 -0400 Subject: [PATCH 06/12] add error test case --- tests/lsdb/catalog/test_crossmatch.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index 65d36aa27..df3dff3fc 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -255,6 +255,11 @@ def test_wrong_suffixes(algo, small_sky_catalog, small_sky_xmatch_catalog): with pytest.raises(ValueError): small_sky_catalog.crossmatch(small_sky_xmatch_catalog, suffixes=("wrong",), algorithm=algo) + @staticmethod + def test_wrong_suffix_method(algo, small_sky_catalog, small_sky_xmatch_catalog): + with pytest.raises(ValueError, match="Invalid suffix method"): + small_sky_catalog.crossmatch(small_sky_xmatch_catalog, suffix_method="wrong", algorithm=algo) + @staticmethod def test_right_margin_missing(algo, small_sky_catalog, small_sky_xmatch_catalog): small_sky_xmatch_catalog.margin = None From 1d7ea01dbd1a31477da0f980058bae83217a2027 Mon Sep 17 00:00:00 2001 From: Sean McGuire <123987820+smcguire-cmu@users.noreply.github.com> Date: Fri, 19 Sep 2025 16:03:17 -0400 Subject: [PATCH 07/12] Update tests/lsdb/catalog/test_crossmatch.py Co-authored-by: Melissa DeLucchi <113376043+delucchi-cmu@users.noreply.github.com> --- tests/lsdb/catalog/test_crossmatch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index df3dff3fc..3234c945e 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -220,7 +220,7 @@ def test_crossmatch_negative_margin( @staticmethod def test_overlapping_suffix_method(algo, small_sky_catalog, small_sky_xmatch_catalog, caplog): suffixes = ("_left", "_right") - # Test that remaned columns are logged correctly + # Test that renamed columns are logged correctly with caplog.at_level(logging.INFO): xmatched = small_sky_catalog.crossmatch( small_sky_xmatch_catalog, From 3656a82a7aaefa5e184702c811123594f531d126 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Fri, 19 Sep 2025 16:40:39 -0400 Subject: [PATCH 08/12] update metadata with correct suffix behavior --- src/lsdb/catalog/catalog.py | 18 ++++++-- src/lsdb/dask/join_catalog_data.py | 26 +++-------- src/lsdb/dask/merge_catalog_functions.py | 59 +++++++++++++++--------- tests/lsdb/catalog/test_crossmatch.py | 46 ++++++++++++++++++ 4 files changed, 105 insertions(+), 44 deletions(-) diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index e687cd5bb..9fab054b4 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -277,7 +277,11 @@ def crossmatch( self, other, suffixes, algorithm=algorithm, suffix_method=suffix_method, **kwargs ) new_catalog_info = create_merged_catalog_info( - self.hc_structure.catalog_info, other.hc_structure.catalog_info, output_catalog_name, suffixes + self, + other, + output_catalog_name, + suffixes, + suffix_method, ) hc_catalog = self.hc_structure.__class__( new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf), moc=alignment.moc @@ -789,7 +793,11 @@ def merge_asof( ) new_catalog_info = create_merged_catalog_info( - self.hc_structure.catalog_info, other.hc_structure.catalog_info, output_catalog_name, suffixes + self, + other, + output_catalog_name, + suffixes, + suffix_method, ) hc_catalog = hc.catalog.Catalog( new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf), moc=alignment.moc @@ -857,7 +865,11 @@ def join( output_catalog_name = self.hc_structure.catalog_info.catalog_name new_catalog_info = create_merged_catalog_info( - self.hc_structure.catalog_info, other.hc_structure.catalog_info, output_catalog_name, suffixes + self, + other, + output_catalog_name, + suffixes, + suffix_method, ) hc_catalog = hc.catalog.Catalog( new_catalog_info, alignment.pixel_tree, schema=get_arrow_schema(ddf), moc=alignment.moc diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py index ddc9b1c60..c6c155087 100644 --- a/src/lsdb/dask/join_catalog_data.py +++ b/src/lsdb/dask/join_catalog_data.py @@ -83,17 +83,10 @@ def perform_join_on( right_joined_df = concat_partition_and_margin(right, right_margin) suffix_function = get_suffix_function(suffix_method) + left_join_column = apply_left_suffix(left_on, right_joined_df.columns, suffixes, suffix_function) + right_join_column = apply_right_suffix(right_on, left.columns, suffixes, suffix_function) left, right_joined_df = suffix_function(left, right_joined_df, suffixes) - left_join_column = ( - left_on if left_on in left.columns else apply_left_suffix(left_on, suffix_function, suffixes) - ) - right_join_column = ( - right_on - if right_on in right_joined_df.columns - else apply_right_suffix(right_on, suffix_function, suffixes) - ) - merged = left.reset_index().merge(right_joined_df, left_on=left_join_column, right_on=right_join_column) merged.set_index(SPATIAL_INDEX_COLUMN, inplace=True) return merged @@ -196,18 +189,13 @@ def perform_join_through( right_joined_df = concat_partition_and_margin(right, right_margin) suffix_function = get_suffix_function(suffix_method) - left, right_joined_df = suffix_function(left, right_joined_df, suffixes) - - left_join_column = ( - assoc_catalog_info.primary_column - if assoc_catalog_info.primary_column in left.columns - else apply_left_suffix(assoc_catalog_info.primary_column, suffix_function, suffixes) + left_join_column = apply_left_suffix( + assoc_catalog_info.primary_column, right_joined_df.columns, suffixes, suffix_function ) - right_join_column = ( - assoc_catalog_info.join_column - if assoc_catalog_info.join_column in right_joined_df.columns - else apply_right_suffix(assoc_catalog_info.join_column, suffix_function, suffixes) + right_join_column = apply_right_suffix( + assoc_catalog_info.join_column, left.columns, suffixes, suffix_function ) + left, right_joined_df = suffix_function(left, right_joined_df, suffixes) # Edge case: if right_column + suffix == join_column_association, columns will be in the wrong order # so rename association column diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index a925f079d..08cf28e0b 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -83,7 +83,8 @@ def apply_suffix_overlapping_columns( tablefmt="pretty", ) - logging.info("Renaming overlapping columns:\n%s", table) + if overlapping_columns: + logging.info("Renaming overlapping columns:\n%s", table) return left_df, right_df @@ -120,35 +121,47 @@ def get_suffix_function( return suffix_functions[suffix_method] -def apply_left_suffix(col_name: str, suffix_function: Callable, suffixes: tuple[str, str]) -> str: +def apply_left_suffix( + col_name: str, + right_col_names: list[str], + suffixes: tuple[str, str], + suffix_function: Callable, +) -> str: """Applies the left suffix to a column name using the specified suffix function Args: col_name (str): The column name to apply the suffix to - suffix_function (Callable): The function to use to apply the suffix + right_col_names (list[str]): The list of column names in the right dataframe suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + suffix_function (Callable): The function to use to apply the suffix Returns: The column name with the left suffix applied """ left_df = npd.NestedFrame(columns=[col_name]) - right_df = npd.NestedFrame(columns=[col_name]) + right_df = npd.NestedFrame(columns=right_col_names) left_df, _ = suffix_function(left_df, right_df, suffixes) return left_df.columns[0] -def apply_right_suffix(col_name: str, suffix_function: Callable, suffixes: tuple[str, str]) -> str: +def apply_right_suffix( + col_name: str, + left_col_names: list[str], + suffixes: tuple[str, str], + suffix_function: Callable, +) -> str: """Applies the right suffix to a column name using the specified suffix function Args: col_name (str): The column name to apply the suffix to - suffix_function (Callable): The function to use to apply the suffix + left_col_names (list[str]): The column names in the left dataframe suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + suffix_function (Callable): The function to use to apply the suffix Returns: The column name with the right suffix applied """ - left_df = npd.NestedFrame(columns=[col_name]) + left_df = npd.NestedFrame(columns=left_col_names) right_df = npd.NestedFrame(columns=[col_name]) _, right_df = suffix_function(left_df, right_df, suffixes) return right_df.columns[0] @@ -812,7 +825,11 @@ def align_catalog_to_partitions( def create_merged_catalog_info( - left_info: TableProperties, right_info: TableProperties, updated_name: str, suffixes: tuple[str, str] + left: Catalog, + right: Catalog, + updated_name: str, + suffixes: tuple[str, str], + suffix_method: str | None = None, ) -> TableProperties: """Creates the catalog info of the resulting catalog from merging two catalogs @@ -820,24 +837,22 @@ def create_merged_catalog_info( catalog name, and sets the total rows to 0 Args: - left_info (TableProperties): The catalog_info of the left catalog - right_info (TableProperties): The catalog_info of the right catalog + left (Catalog): The left catalog being merged + right (Catalog): The right catalog being merged updated_name (str): The updated name of the catalog suffixes (tuple[str, str]): The suffixes of the catalogs in the merged result + suffix_method (str): The method used to generate suffixes. Options are 'all_columns', + 'overlapping_columns' + + Returns: + The catalog info of the resulting merged catalog """ - default_cols = ( - [c + suffixes[0] for c in left_info.default_columns] if left_info.default_columns is not None else [] - ) - default_cols = ( - default_cols + [c + suffixes[1] for c in right_info.default_columns] - if right_info.default_columns is not None - else default_cols - ) - default_cols_to_use = default_cols if len(default_cols) > 0 else None + suffix_function = get_suffix_function(suffix_method) + left_info = left.hc_structure.catalog_info return left_info.copy_and_update( catalog_name=updated_name, - ra_column=left_info.ra_column + suffixes[0], - dec_column=left_info.dec_column + suffixes[0], + ra_column=apply_left_suffix(left_info.ra_column, right.columns, suffixes, suffix_function), + dec_column=apply_left_suffix(left_info.dec_column, right.columns, suffixes, suffix_function), total_rows=0, - default_columns=default_cols_to_use, + default_columns=None, ) diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index 3234c945e..f181d1d6e 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -40,6 +40,11 @@ def test_kdtree_crossmatch(algo, small_sky_catalog, small_sky_xmatch_catalog, xm assert xmatch_row["id_small_sky_xmatch"].to_numpy() == correct_row["xmatch_id"] assert xmatch_row["_dist_arcsec"].to_numpy() == pytest.approx(correct_row["dist"] * 3600) + assert xmatched_cat.hc_structure.catalog_info.ra_column in xmatched_cat.columns + assert xmatched_cat.hc_structure.catalog_info.dec_column in xmatched_cat.columns + assert xmatched_cat.hc_structure.catalog_info.ra_column == "ra_small_sky" + assert xmatched_cat.hc_structure.catalog_info.dec_column == "dec_small_sky" + @staticmethod def test_kdtree_crossmatch_nested(algo, small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct): with pytest.warns(RuntimeWarning, match="Results may be incomplete and/or inaccurate"): @@ -250,6 +255,47 @@ def test_overlapping_suffix_method(algo, small_sky_catalog, small_sky_xmatch_cat assert col in xmatched.columns assert col in computed.columns + assert xmatched.hc_structure.catalog_info.ra_column in xmatched.columns + assert xmatched.hc_structure.catalog_info.dec_column in xmatched.columns + assert xmatched.hc_structure.catalog_info.ra_column == "ra_left" + assert xmatched.hc_structure.catalog_info.dec_column == "dec_left" + + @staticmethod + def test_overlapping_suffix_method_no_overlaps(algo, small_sky_catalog, small_sky_xmatch_catalog, caplog): + suffixes = ("_left", "_right") + small_sky_catalog = small_sky_catalog.rename( + {col: f"{col}_unique" for col in small_sky_catalog.columns} + ) + small_sky_catalog.hc_structure.catalog_info.ra_column = ( + f"{small_sky_catalog.hc_structure.catalog_info.ra_column}_unique" + ) + small_sky_catalog.hc_structure.catalog_info.dec_column = ( + f"{small_sky_catalog.hc_structure.catalog_info.dec_column}_unique" + ) + # Test that renamed columns are logged correctly + with caplog.at_level(logging.INFO): + xmatched = small_sky_catalog.crossmatch( + small_sky_xmatch_catalog, + algorithm=algo, + suffix_method="overlapping_columns", + suffixes=suffixes, + ) + + assert len(caplog.text) == 0 + + computed = xmatched.compute() + for col in small_sky_catalog.columns: + assert col in xmatched.columns + assert col in computed.columns + for col in small_sky_xmatch_catalog.columns: + assert col in xmatched.columns + assert col in computed.columns + + assert xmatched.hc_structure.catalog_info.ra_column in xmatched.columns + assert xmatched.hc_structure.catalog_info.dec_column in xmatched.columns + assert xmatched.hc_structure.catalog_info.ra_column == "ra_unique" + assert xmatched.hc_structure.catalog_info.dec_column == "dec_unique" + @staticmethod def test_wrong_suffixes(algo, small_sky_catalog, small_sky_xmatch_catalog): with pytest.raises(ValueError): From d3a1d8f5539f9aaf5a393035c03940ea5a4511db Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Fri, 19 Sep 2025 17:13:39 -0400 Subject: [PATCH 09/12] remove indirection --- src/lsdb/dask/merge_catalog_functions.py | 12 +- tests/conftest.py | 32 ---- tests/lsdb/catalog/test_join.py | 199 +++++++++++++++++------ 3 files changed, 158 insertions(+), 85 deletions(-) diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index 08cf28e0b..087e185fd 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -112,13 +112,11 @@ def get_suffix_function( FutureWarning, ) - suffix_functions = { - "all_columns": apply_suffix_all_columns, - "overlapping_columns": apply_suffix_overlapping_columns, - } - if suffix_method not in suffix_functions: - raise ValueError(f"Invalid suffix method: {suffix_method}") - return suffix_functions[suffix_method] + if suffix_method == "all_columns": + return apply_suffix_all_columns + elif suffix_method == "overlapping_columns": + return apply_suffix_overlapping_columns + raise ValueError(f"Invalid suffix method: {suffix_method}") def apply_left_suffix( diff --git a/tests/conftest.py b/tests/conftest.py index 7c52f14e1..9d8d33674 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -453,38 +453,6 @@ def assert_default_columns_in_columns(cat): for col in cat.hc_structure.catalog_info.default_columns: assert col in cat._ddf.columns - @classmethod - def assert_columns_in_joined_catalog( - cls, joined_cat, cats, suffixes, suffix_method=DEFAULT_SUFFIX_METHOD - ): - assert_methods = { - "all_columns": cls.assert_all_suffix_columns_in_joined_catalog, - "overlapping_columns": cls.assert_overlapping_suffix_columns_in_joined_catalog, - } - assert_method = assert_methods.get(suffix_method) - if assert_method is None: - raise ValueError(f"Unknown suffix_strategy: {suffix_method}") - assert_method(joined_cat, cats, suffixes) - - @staticmethod - def assert_all_suffix_columns_in_joined_catalog(joined_cat, cats, suffixes): - for cat, suffix in zip(cats, suffixes): - for col_name, dtype in cat.dtypes.items(): - if col_name not in paths.HIVE_COLUMNS: - assert (col_name + suffix, dtype) in joined_cat.dtypes.items() - - @staticmethod - def assert_overlapping_suffix_columns_in_joined_catalog(joined_cat, cats, suffixes): - cat_columns = [set(cat.columns) for cat in cats] - overlapping_columns = cat_columns[0].intersection(*cat_columns[1:]) - for cat, suffix in zip(cats, suffixes): - for col_name, dtype in cat.dtypes.items(): - if col_name not in paths.HIVE_COLUMNS: - if col_name in overlapping_columns: - assert (col_name + suffix, dtype) in joined_cat.dtypes.items() - else: - assert (col_name, dtype) in joined_cat.dtypes.items() - @staticmethod def assert_columns_in_nested_joined_catalog( joined_cat, left_cat, right_cat, right_ignore_columns, nested_colname diff --git a/tests/lsdb/catalog/test_join.py b/tests/lsdb/catalog/test_join.py index 43bbfd80a..efe65ef02 100644 --- a/tests/lsdb/catalog/test_join.py +++ b/tests/lsdb/catalog/test_join.py @@ -1,4 +1,5 @@ import nested_pandas as npd +import numpy as np import pandas as pd import pyarrow as pa import pytest @@ -18,7 +19,21 @@ def test_small_sky_join_small_sky_order1(small_sky_catalog, small_sky_order1_cat small_sky_order1_catalog, left_on="id", right_on="id", suffixes=suffixes ) assert isinstance(joined._ddf, nd.NestedFrame) - helpers.assert_columns_in_joined_catalog(joined, [small_sky_catalog, small_sky_order1_catalog], suffixes) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "id_b", + "ra_b", + "dec_b", + "ra_error_b", + "dec_error_b", + ] + + assert np.all(joined.columns == expected_columns) assert joined._ddf.index.name == SPATIAL_INDEX_COLUMN assert joined._ddf.index.dtype == pd.ArrowDtype(pa.int64()) alignment = align_catalogs(small_sky_catalog, small_sky_order1_catalog) @@ -26,6 +41,7 @@ def test_small_sky_join_small_sky_order1(small_sky_catalog, small_sky_order1_cat assert joined.get_healpix_pixels() == alignment.pixel_tree.get_healpix_pixels() joined_compute = joined.compute() + assert np.all(joined_compute.columns == expected_columns) assert isinstance(joined_compute, npd.NestedFrame) small_sky_compute = small_sky_catalog.compute() small_sky_order1_compute = small_sky_order1_catalog.compute() @@ -51,18 +67,26 @@ def test_small_sky_join_overlapping_suffix(small_sky_catalog, small_sky_order1_c suffix_method="overlapping_columns", ) assert isinstance(joined._ddf, nd.NestedFrame) - helpers.assert_columns_in_joined_catalog( - joined, [small_sky_catalog, small_sky_order1_catalog], suffixes, suffix_method="overlapping_columns" - ) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "id_b", + "ra_b", + "dec_b", + "ra_error_b", + "dec_error_b", + ] + + assert np.all(joined.columns == expected_columns) joined_compute = joined.compute() - helpers.assert_columns_in_joined_catalog( - joined_compute, - [small_sky_catalog, small_sky_order1_catalog], - suffixes, - suffix_method="overlapping_columns", - ) + assert np.all(joined_compute.columns == expected_columns) + helpers.assert_divisions_are_correct(joined) helpers.assert_schema_correct(joined) @@ -74,15 +98,33 @@ def test_small_sky_join_small_sky_order1_source( joined = small_sky_catalog.join( small_sky_order1_source_with_margin, left_on="id", right_on="object_id", suffixes=suffixes ) - helpers.assert_columns_in_joined_catalog( - joined, [small_sky_catalog, small_sky_order1_source_with_margin], suffixes - ) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "source_id_b", + "source_ra_b", + "source_dec_b", + "mjd_b", + "mag_b", + "band_b", + "object_id_b", + "object_ra_b", + "object_dec_b", + ] + + assert np.all(joined.columns == expected_columns) alignment = align_catalogs(small_sky_catalog, small_sky_order1_source_with_margin) assert joined.hc_structure.moc == alignment.moc assert joined.get_healpix_pixels() == alignment.pixel_tree.get_healpix_pixels() joined_compute = joined.compute() + + assert np.all(joined_compute.columns == expected_columns) small_sky_order1_compute = small_sky_order1_source_with_margin.compute() assert len(joined_compute) == len(small_sky_order1_compute) joined_test = small_sky_order1_compute.merge(joined_compute, left_on="object_id", right_on="object_id_b") @@ -98,15 +140,30 @@ def test_small_sky_join_default_columns( joined = small_sky_order1_default_cols_catalog.join( small_sky_order1_source_with_margin, left_on="id", right_on="object_id", suffixes=suffixes ) - helpers.assert_columns_in_joined_catalog( - joined, [small_sky_order1_default_cols_catalog, small_sky_order1_source_with_margin], suffixes - ) + + expected_columns = [ + "ra_a", + "dec_a", + "id_a", + "source_id_b", + "source_ra_b", + "source_dec_b", + "mjd_b", + "mag_b", + "band_b", + "object_id_b", + "object_ra_b", + "object_dec_b", + ] + + assert np.all(joined.columns == expected_columns) alignment = align_catalogs(small_sky_order1_default_cols_catalog, small_sky_order1_source_with_margin) assert joined.hc_structure.moc == alignment.moc assert joined.get_healpix_pixels() == alignment.pixel_tree.get_healpix_pixels() joined_compute = joined.compute() + assert np.all(joined_compute.columns == expected_columns) small_sky_order1_compute = small_sky_order1_source_with_margin.compute() assert len(joined_compute) == len(small_sky_order1_compute) joined_test = small_sky_order1_compute.merge(joined_compute, left_on="object_id", right_on="object_id_b") @@ -139,13 +196,33 @@ def test_join_association( alignment = align_catalogs(small_sky_catalog, small_sky_order1_source_collection_catalog) assert joined.hc_structure.moc == alignment.moc assert joined.get_healpix_pixels() == alignment.pixel_tree.get_healpix_pixels() - helpers.assert_columns_in_joined_catalog( - joined, [small_sky_catalog, small_sky_order1_source_collection_catalog], suffixes - ) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "source_id_b", + "source_ra_b", + "source_dec_b", + "mjd_b", + "mag_b", + "band_b", + "object_id_b", + "object_ra_b", + "object_dec_b", + "_dist_arcsec", + ] + + assert np.all(joined.columns == expected_columns) + assert joined._ddf.index.name == SPATIAL_INDEX_COLUMN assert joined._ddf.index.dtype == pd.ArrowDtype(pa.int64()) joined_data = joined.compute() + + assert np.all(joined_data.columns == expected_columns) assert isinstance(joined_data, npd.NestedFrame) association_data = small_sky_to_o1source_catalog.compute() assert len(joined_data) == len(association_data) @@ -184,21 +261,30 @@ def test_join_association_overlapping_suffix( suffixes=suffixes, suffix_method="overlapping_columns", ) - helpers.assert_columns_in_joined_catalog( - joined, - [small_sky_catalog, small_sky_order1_source_collection_catalog], - suffixes, - suffix_method="overlapping_columns", - ) + expected_columns = [ + "id", + "ra", + "dec", + "ra_error", + "dec_error", + "source_id", + "source_ra", + "source_dec", + "mjd", + "mag", + "band", + "object_id", + "object_ra", + "object_dec", + "_dist_arcsec", + ] + + assert np.all(joined.columns == expected_columns) joined_compute = joined.compute() - helpers.assert_columns_in_joined_catalog( - joined_compute, - [small_sky_catalog, small_sky_order1_source_collection_catalog], - suffixes, - suffix_method="overlapping_columns", - ) + assert np.all(joined_compute.columns == expected_columns) + helpers.assert_divisions_are_correct(joined) helpers.assert_schema_correct(joined) @@ -256,9 +342,26 @@ def test_join_nested(small_sky_catalog, small_sky_order1_source_with_margin, hel right_on="object_id", nested_column_name="sources", ) - helpers.assert_columns_in_nested_joined_catalog( - joined, small_sky_catalog, small_sky_order1_source_with_margin, ["object_id"], "sources" - ) + expected_columns = [ + "id", + "ra", + "dec", + "ra_error", + "dec_error", + "sources", + ] + expected_nested_columns = [ + "source_id", + "source_ra", + "source_dec", + "mjd", + "mag", + "band", + "object_ra", + "object_dec", + ] + assert np.all(joined.columns == expected_columns) + assert np.all(joined["sources"].nest.fields == expected_nested_columns) helpers.assert_divisions_are_correct(joined) alignment = align_catalogs(small_sky_catalog, small_sky_order1_source_with_margin) assert joined.hc_structure.moc == alignment.moc @@ -324,22 +427,26 @@ def test_merge_asof_overlapping_suffix(small_sky_catalog, small_sky_xmatch_catal joined = small_sky_catalog.merge_asof( small_sky_xmatch_catalog, direction="backward", suffixes=suffixes, suffix_method="overlapping_columns" ) - helpers.assert_columns_in_joined_catalog( - joined, - [small_sky_catalog, small_sky_xmatch_catalog], - suffixes, - suffix_method="overlapping_columns", - ) + + expected_columns = [ + "id_a", + "ra_a", + "dec_a", + "ra_error_a", + "dec_error_a", + "id_b", + "ra_b", + "dec_b", + "ra_error_b", + "dec_error_b", + "calculated_dist", + ] + assert np.all(joined.columns == expected_columns) helpers.assert_divisions_are_correct(joined) joined_compute = joined.compute() - helpers.assert_columns_in_joined_catalog( - joined_compute, - [small_sky_catalog, small_sky_xmatch_catalog], - suffixes, - suffix_method="overlapping_columns", - ) + assert np.all(joined_compute.columns == expected_columns) helpers.assert_divisions_are_correct(joined) helpers.assert_schema_correct(joined) From d6815cf13a9d173a79e2114a39ad5b6d50e4ad5a Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Fri, 19 Sep 2025 17:22:32 -0400 Subject: [PATCH 10/12] lint --- src/lsdb/dask/merge_catalog_functions.py | 10 +++++++--- tests/conftest.py | 1 - tests/lsdb/catalog/test_join.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index 087e185fd..b5874e9b1 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -114,7 +114,7 @@ def get_suffix_function( if suffix_method == "all_columns": return apply_suffix_all_columns - elif suffix_method == "overlapping_columns": + if suffix_method == "overlapping_columns": return apply_suffix_overlapping_columns raise ValueError(f"Invalid suffix method: {suffix_method}") @@ -847,10 +847,14 @@ def create_merged_catalog_info( """ suffix_function = get_suffix_function(suffix_method) left_info = left.hc_structure.catalog_info + # type: ignore + ra_col = apply_left_suffix(left_info.ra_column, right.columns, suffixes, suffix_function) + # type: ignore + dec_col = apply_left_suffix(left_info.dec_column, right.columns, suffixes, suffix_function) return left_info.copy_and_update( catalog_name=updated_name, - ra_column=apply_left_suffix(left_info.ra_column, right.columns, suffixes, suffix_function), - dec_column=apply_left_suffix(left_info.dec_column, right.columns, suffixes, suffix_function), + ra_column=ra_col, + dec_column=dec_col, total_rows=0, default_columns=None, ) diff --git a/tests/conftest.py b/tests/conftest.py index 9d8d33674..ceb7d1f38 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,6 @@ import lsdb import lsdb.nested as nd -from lsdb.dask.merge_catalog_functions import DEFAULT_SUFFIX_METHOD DATA_DIR_NAME = "data" SMALL_SKY_DIR_NAME = "small_sky" diff --git a/tests/lsdb/catalog/test_join.py b/tests/lsdb/catalog/test_join.py index efe65ef02..65b02113b 100644 --- a/tests/lsdb/catalog/test_join.py +++ b/tests/lsdb/catalog/test_join.py @@ -186,7 +186,7 @@ def test_join_wrong_suffixes(small_sky_catalog, small_sky_order1_catalog): def test_join_association( - small_sky_catalog, small_sky_order1_source_collection_catalog, small_sky_to_o1source_catalog, helpers + small_sky_catalog, small_sky_order1_source_collection_catalog, small_sky_to_o1source_catalog ): suffixes = ("_a", "_b") joined = small_sky_catalog.join( From 55a0e944169a6a9a733c92e422154962785ee1f6 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Fri, 19 Sep 2025 17:27:21 -0400 Subject: [PATCH 11/12] lint --- src/lsdb/dask/merge_catalog_functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index b5874e9b1..615ae9abe 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -847,10 +847,10 @@ def create_merged_catalog_info( """ suffix_function = get_suffix_function(suffix_method) left_info = left.hc_structure.catalog_info - # type: ignore - ra_col = apply_left_suffix(left_info.ra_column, right.columns, suffixes, suffix_function) - # type: ignore - dec_col = apply_left_suffix(left_info.dec_column, right.columns, suffixes, suffix_function) + ra_col = apply_left_suffix(left_info.ra_column, right.columns, suffixes, suffix_function) # type: ignore + dec_col = apply_left_suffix( + left_info.dec_column, right.columns, suffixes, suffix_function # type: ignore + ) return left_info.copy_and_update( catalog_name=updated_name, ra_column=ra_col, From 622300aff193e56f44045b5e501c45763dd8726d Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Wed, 24 Sep 2025 16:11:12 -0400 Subject: [PATCH 12/12] use apply_suffixes method instead of suffix function pointers --- .../abstract_crossmatch_algorithm.py | 7 +- src/lsdb/dask/crossmatch_catalog_data.py | 5 +- src/lsdb/dask/join_catalog_data.py | 28 +++----- src/lsdb/dask/merge_catalog_functions.py | 69 ++++++++++++------- tests/lsdb/catalog/test_crossmatch.py | 5 +- 5 files changed, 62 insertions(+), 52 deletions(-) diff --git a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py index 70c82bb5e..271f5b60e 100644 --- a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py +++ b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py @@ -10,7 +10,7 @@ from hats.catalog import TableProperties from hats.pixel_math.spatial_index import SPATIAL_INDEX_COLUMN -from lsdb.dask.merge_catalog_functions import get_suffix_function +from lsdb.dask.merge_catalog_functions import apply_suffixes if TYPE_CHECKING: from lsdb.catalog import Catalog @@ -212,8 +212,9 @@ def _create_crossmatch_df( additional columns added """ # rename columns so no same names during merging - suffix_function = get_suffix_function(suffix_method) - self.left, self.right = suffix_function(self.left, self.right, suffixes) + self.left, self.right = apply_suffixes( + self.left, self.right, suffixes, suffix_method, log_changes=False + ) # concat dataframes together index_name = self.left.index.name if self.left.index.name is not None else "index" left_join_part = self.left.iloc[left_idx].reset_index() diff --git a/src/lsdb/dask/crossmatch_catalog_data.py b/src/lsdb/dask/crossmatch_catalog_data.py index 919bea5e3..3cd81299a 100644 --- a/src/lsdb/dask/crossmatch_catalog_data.py +++ b/src/lsdb/dask/crossmatch_catalog_data.py @@ -20,7 +20,6 @@ generate_meta_df_for_joined_tables, generate_meta_df_for_nested_tables, get_healpix_pixels_from_alignment, - get_suffix_function, ) from lsdb.types import DaskDFPixelMap @@ -158,13 +157,11 @@ def crossmatch_catalog_data( # get lists of HEALPix pixels from alignment to pass to cross-match left_pixels, right_pixels = get_healpix_pixels_from_alignment(alignment) - suffix_function = get_suffix_function(suffix_method) - # generate meta table structure for dask df meta_df = generate_meta_df_for_joined_tables( (left, right), suffixes, - suffix_function=suffix_function, + suffix_method=suffix_method, extra_columns=crossmatch_algorithm.extra_columns, ) diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py index c6c155087..7feb36f8d 100644 --- a/src/lsdb/dask/join_catalog_data.py +++ b/src/lsdb/dask/join_catalog_data.py @@ -20,6 +20,7 @@ align_catalogs_with_association, apply_left_suffix, apply_right_suffix, + apply_suffixes, concat_partition_and_margin, construct_catalog_args, filter_by_spatial_index_to_pixel, @@ -27,7 +28,6 @@ generate_meta_df_for_nested_tables, get_healpix_pixels_from_alignment, get_healpix_pixels_from_association, - get_suffix_function, ) from lsdb.types import DaskDFPixelMap @@ -82,10 +82,9 @@ def perform_join_on( right_joined_df = concat_partition_and_margin(right, right_margin) - suffix_function = get_suffix_function(suffix_method) - left_join_column = apply_left_suffix(left_on, right_joined_df.columns, suffixes, suffix_function) - right_join_column = apply_right_suffix(right_on, left.columns, suffixes, suffix_function) - left, right_joined_df = suffix_function(left, right_joined_df, suffixes) + left_join_column = apply_left_suffix(left_on, right_joined_df.columns, suffixes, suffix_method) + right_join_column = apply_right_suffix(right_on, left.columns, suffixes, suffix_method) + left, right_joined_df = apply_suffixes(left, right_joined_df, suffixes, suffix_method, log_changes=False) merged = left.reset_index().merge(right_joined_df, left_on=left_join_column, right_on=right_join_column) merged.set_index(SPATIAL_INDEX_COLUMN, inplace=True) @@ -188,14 +187,13 @@ def perform_join_through( right_joined_df = concat_partition_and_margin(right, right_margin) - suffix_function = get_suffix_function(suffix_method) left_join_column = apply_left_suffix( - assoc_catalog_info.primary_column, right_joined_df.columns, suffixes, suffix_function + assoc_catalog_info.primary_column, right_joined_df.columns, suffixes, suffix_method ) right_join_column = apply_right_suffix( - assoc_catalog_info.join_column, left.columns, suffixes, suffix_function + assoc_catalog_info.join_column, left.columns, suffixes, suffix_method ) - left, right_joined_df = suffix_function(left, right_joined_df, suffixes) + left, right_joined_df = apply_suffixes(left, right_joined_df, suffixes, suffix_method, log_changes=False) # Edge case: if right_column + suffix == join_column_association, columns will be in the wrong order # so rename association column @@ -276,8 +274,7 @@ def perform_merge_asof( if right_pixel.order > left_pixel.order: left = filter_by_spatial_index_to_pixel(left, right_pixel.order, right_pixel.pixel) - suffix_function = get_suffix_function(suffix_method) - left, right = suffix_function(left, right, suffixes) + left, right = apply_suffixes(left, right, suffixes, suffix_method, log_changes=False) left.sort_index(inplace=True) right.sort_index(inplace=True) merged = pd.merge_asof(left, right, left_index=True, right_index=True, direction=direction) @@ -330,8 +327,7 @@ def join_catalog_data_on( suffix_method, ) - suffix_function = get_suffix_function(suffix_method) - meta_df = generate_meta_df_for_joined_tables((left, right), suffixes, suffix_function=suffix_function) + meta_df = generate_meta_df_for_joined_tables((left, right), suffixes, suffix_method=suffix_method) return construct_catalog_args(joined_partitions, meta_df, alignment) @@ -469,9 +465,8 @@ def join_catalog_data_through( # pylint: disable=protected-access extra_df = association._ddf._meta.drop(non_joining_columns + association_join_columns, axis=1) - suffix_function = get_suffix_function(suffix_method) meta_df = generate_meta_df_for_joined_tables( - (left, right), suffixes, extra_columns=extra_df, suffix_function=suffix_function + (left, right), suffixes, extra_columns=extra_df, suffix_method=suffix_method ) return construct_catalog_args(joined_partitions, meta_df, alignment) @@ -517,7 +512,6 @@ def merge_asof_catalog_data( [(left, left_pixels), (right, right_pixels)], perform_merge_asof, suffixes, direction, suffix_method ) - suffix_function = get_suffix_function(suffix_method) - meta_df = generate_meta_df_for_joined_tables((left, right), suffixes, suffix_function=suffix_function) + meta_df = generate_meta_df_for_joined_tables((left, right), suffixes, suffix_method=suffix_method) return construct_catalog_args(joined_partitions, meta_df, alignment) diff --git a/src/lsdb/dask/merge_catalog_functions.py b/src/lsdb/dask/merge_catalog_functions.py index 615ae9abe..a705aa392 100644 --- a/src/lsdb/dask/merge_catalog_functions.py +++ b/src/lsdb/dask/merge_catalog_functions.py @@ -58,7 +58,7 @@ def apply_suffix_all_columns( def apply_suffix_overlapping_columns( - left_df: npd.NestedFrame, right_df: npd.NestedFrame, suffixes: tuple[str, str] + left_df: npd.NestedFrame, right_df: npd.NestedFrame, suffixes: tuple[str, str], log_changes: bool = True ) -> tuple[npd.NestedFrame, npd.NestedFrame]: """Applies suffixes to overlapping columns in both dataframes @@ -68,6 +68,7 @@ def apply_suffix_overlapping_columns( left_df (npd.NestedFrame): The left dataframe right_df (npd.NestedFrame): The right dataframe suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + log_changes (bool): If True, logs an info message for each column that is being renamed. Returns: A tuple of the two dataframes with the suffixes applied @@ -83,24 +84,32 @@ def apply_suffix_overlapping_columns( tablefmt="pretty", ) - if overlapping_columns: + if overlapping_columns and log_changes: logging.info("Renaming overlapping columns:\n%s", table) return left_df, right_df -def get_suffix_function( +def apply_suffixes( + left_df: npd.NestedFrame, + right_df: npd.NestedFrame, + suffixes: tuple[str, str], suffix_method: str | None = None, -) -> Callable[[npd.NestedFrame, npd.NestedFrame, tuple[str, str]], tuple[npd.NestedFrame, npd.NestedFrame]]: - """Gets a function that can be used to generate suffixes for columns based on a specified method + log_changes: bool = True, +) -> tuple[npd.NestedFrame, npd.NestedFrame]: + """Applies suffixes to the columns of two dataframes using the specified suffix method Args: - suffix_method (str): The method to use to generate suffixes. Options are 'all_columns', - 'overlapping_columns', + left_df (npd.NestedFrame): The left dataframe + right_df (npd.NestedFrame): The right dataframe + suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes + suffix_method (str | None): The method to use to generate suffixes. Options are 'all_columns', + 'overlapping_columns'. If None, defaults to 'all_columns' but will change to + 'overlapping_columns' in a future release. + log_changes (bool): If True, logs an info message for each column that is being renamed. Returns: - A function that takes in two dataframes and returns a tuple of the two dataframes with the suffixes - applied + A tuple of the two dataframes with the suffixes applied """ if suffix_method is None: suffix_method = DEFAULT_SUFFIX_METHOD @@ -113,9 +122,9 @@ def get_suffix_function( ) if suffix_method == "all_columns": - return apply_suffix_all_columns + return apply_suffix_all_columns(left_df, right_df, suffixes) if suffix_method == "overlapping_columns": - return apply_suffix_overlapping_columns + return apply_suffix_overlapping_columns(left_df, right_df, suffixes, log_changes) raise ValueError(f"Invalid suffix method: {suffix_method}") @@ -123,7 +132,8 @@ def apply_left_suffix( col_name: str, right_col_names: list[str], suffixes: tuple[str, str], - suffix_function: Callable, + suffix_method: str | None = None, + log_changes: bool = False, ) -> str: """Applies the left suffix to a column name using the specified suffix function @@ -131,14 +141,17 @@ def apply_left_suffix( col_name (str): The column name to apply the suffix to right_col_names (list[str]): The list of column names in the right dataframe suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes - suffix_function (Callable): The function to use to apply the suffix + suffix_method (str): The method to use to generate suffixes. Options are 'all_columns', + 'overlapping_columns' + log_changes (bool): If True, logs an info message for each column that is being renamed. + Default: False Returns: The column name with the left suffix applied """ left_df = npd.NestedFrame(columns=[col_name]) right_df = npd.NestedFrame(columns=right_col_names) - left_df, _ = suffix_function(left_df, right_df, suffixes) + left_df, _ = apply_suffixes(left_df, right_df, suffixes, suffix_method, log_changes=log_changes) return left_df.columns[0] @@ -146,7 +159,8 @@ def apply_right_suffix( col_name: str, left_col_names: list[str], suffixes: tuple[str, str], - suffix_function: Callable, + suffix_method: str | None = None, + log_changes: bool = False, ) -> str: """Applies the right suffix to a column name using the specified suffix function @@ -154,14 +168,17 @@ def apply_right_suffix( col_name (str): The column name to apply the suffix to left_col_names (list[str]): The column names in the left dataframe suffixes (tuple[str, str]): The suffixes to apply to the left and right dataframes - suffix_function (Callable): The function to use to apply the suffix + suffix_method (str): The method to use to generate suffixes. Options are 'all_columns', + 'overlapping_columns' + log_changes (bool): If True, logs an info message for each column that is being renamed. + Default: False Returns: The column name with the right suffix applied """ left_df = npd.NestedFrame(columns=left_col_names) right_df = npd.NestedFrame(columns=[col_name]) - _, right_df = suffix_function(left_df, right_df, suffixes) + _, right_df = apply_suffixes(left_df, right_df, suffixes, suffix_method, log_changes=log_changes) return right_df.columns[0] @@ -658,7 +675,7 @@ def get_healpix_pixels_from_association( def generate_meta_df_for_joined_tables( catalogs: tuple[Catalog, Catalog], suffixes: tuple[str, str], - suffix_function: Callable, + suffix_method: str | None = None, extra_columns: pd.DataFrame | None = None, index_name: str = SPATIAL_INDEX_COLUMN, index_type: npt.DTypeLike | None = None, @@ -671,6 +688,7 @@ def generate_meta_df_for_joined_tables( Args: catalogs (Sequence[lsdb.Catalog]): The catalogs to merge together suffixes (Sequence[Str]): The column suffixes to apply each catalog + suffix_method (str): The method to use to generate suffixes. extra_columns (pd.Dataframe): Any additional columns to the merged catalogs index_name (str): The name of the index in the resulting DataFrame index_type (npt.DTypeLike): The type of the index in the resulting DataFrame. @@ -681,8 +699,12 @@ def generate_meta_df_for_joined_tables( columns specified, with the index name set. """ # Construct meta for crossmatched catalog columns - left_meta, right_meta = suffix_function( - catalogs[0]._ddf._meta, catalogs[1]._ddf._meta, suffixes # pylint: disable=protected-access + # pylint: disable=protected-access + left_meta, right_meta = apply_suffixes( + catalogs[0]._ddf._meta, + catalogs[1]._ddf._meta, + suffixes, + suffix_method, ) meta = pd.concat([left_meta, right_meta], axis=1) # Construct meta for crossmatch result columns @@ -845,12 +867,9 @@ def create_merged_catalog_info( Returns: The catalog info of the resulting merged catalog """ - suffix_function = get_suffix_function(suffix_method) left_info = left.hc_structure.catalog_info - ra_col = apply_left_suffix(left_info.ra_column, right.columns, suffixes, suffix_function) # type: ignore - dec_col = apply_left_suffix( - left_info.dec_column, right.columns, suffixes, suffix_function # type: ignore - ) + ra_col = apply_left_suffix(left_info.ra_column, right.columns, suffixes, suffix_method) # type: ignore + dec_col = apply_left_suffix(left_info.dec_column, right.columns, suffixes, suffix_method) # type: ignore return left_info.copy_and_update( catalog_name=updated_name, ra_column=ra_col, diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index f181d1d6e..c66a1a1a0 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -14,7 +14,7 @@ from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm from lsdb.core.crossmatch.bounded_kdtree_match import BoundedKdTreeCrossmatch from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch -from lsdb.dask.merge_catalog_functions import align_catalogs, get_suffix_function +from lsdb.dask.merge_catalog_functions import align_catalogs, apply_suffixes @pytest.mark.parametrize("algo", [KdTreeCrossmatch]) @@ -529,8 +529,7 @@ class MockCrossmatchAlgorithmOverwrite(AbstractCrossmatchAlgorithm): def crossmatch(self, suffixes, suffix_method="all_columns", mock_results: pd.DataFrame = None, **kwargs): left_reset = self.left.reset_index(drop=True) right_reset = self.right.reset_index(drop=True) - suffix_function = get_suffix_function(suffix_method) - self.left, self.right = suffix_function(self.left, self.right, suffixes) + self.left, self.right = apply_suffixes(self.left, self.right, suffixes, suffix_method) mock_results = mock_results[mock_results["ss_id"].isin(left_reset["id"].to_numpy())] left_indexes = mock_results.apply( lambda row: left_reset[left_reset["id"] == row["ss_id"]].index[0], axis=1