diff --git a/CHANGES.rst b/CHANGES.rst index 1879f11851..772fb316cb 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -14,6 +14,12 @@ API changes Service fixes and enhancements ------------------------------ +mast +^^^^ + +- Expand the supported data types for filter values in ``Mast.mast_query``. Previously, users had to input + filter values enclosed in lists, even when specifying a single value or dictionary. [#3422] + Infrastructure, Utility and Other Changes and Additions diff --git a/astroquery/mast/observations.py b/astroquery/mast/observations.py index 0c73a5c124..4a530742f8 100644 --- a/astroquery/mast/observations.py +++ b/astroquery/mast/observations.py @@ -1121,6 +1121,63 @@ def service_request_async(self, service, params, *, pagesize=None, page=None, ** return self._portal_api_connection.service_request_async(service, params, pagesize, page, **kwargs) + def _normalize_filter_value(self, key: str, val) -> list: + """ + Normalize a filter value into a list suitable for MAST filters. + + Parameters + ---------- + key : str + Parameter name (used for error messages). + val : any + Raw filter value. + + Returns + ------- + list + Normalized filter values. + """ + # Range filters must be dicts with 'min' and 'max' + if isinstance(val, dict): + if not {"min", "max"}.issubset(val.keys()): + raise InvalidQueryError( + f'Range filter for "{key}" must be a dictionary with "min" and "max" keys.' + ) + return [val] + + # Convert numpy arrays to lists + if isinstance(val, np.ndarray): + val = val.tolist() + + # Convert numpy arrays, sets, or tuples to lists + if isinstance(val, (set, tuple)): + val = list(val) + + # Wrap scalars into a list + return val if isinstance(val, list) else [val] + + def _build_filters(self, service_params): + """ + Construct filters for filtered services. + + Parameters + ---------- + service_params : dict + Parameters not classified as request/position keys. + + Returns + ------- + list of dict + Filters suitable for a MAST filtered query. + """ + filters = [] + for key, val in service_params.items(): + filters.append({ + "paramName": key, + "values": self._normalize_filter_value(key, val) + }) + return filters + def mast_query(self, service, columns=None, **kwargs): """ Given a Mashup service and parameters as keyword arguments, builds and excecutes a Mashup query. @@ -1129,7 +1186,7 @@ def mast_query(self, service, columns=None, **kwargs): ---------- service : str The Mashup service to query. - columns : str, optional + columns : str or list, optional Specifies the columns to be returned as a comma-separated list, e.g. "ID, ra, dec". **kwargs : Service-specific parameters and MashupRequest properties. See the @@ -1137,45 +1194,49 @@ def mast_query(self, service, columns=None, **kwargs): `MashupRequest Class Reference `__ for valid keyword arguments. + For filtered services (i.e. those with "filtered" in the service name), + parameters that are not related to position or MashupRequest properties + are treated as filters. If the column has discrete values, the parameter value should be a + single value or list of values, and values will be matched exactly. If the column is continuous, + you can filter by a single value, a list of values, or a range of values. If filtering by a range of values, + the parameter value should be a dict in the form ``{'min': minVal, 'max': maxVal}``. + Returns ------- response : `~astropy.table.Table` """ # Specific keywords related to positional and MashupRequest parameters. - position_keys = ['ra', 'dec', 'radius', 'position'] - request_keys = ['format', 'data', 'filename', 'timeout', 'clearcache', - 'removecache', 'removenullcolumns', 'page', 'pagesize'] + position_keys = {'ra', 'dec', 'radius', 'position'} + request_keys = {'format', 'data', 'filename', 'timeout', 'clearcache', + 'removecache', 'removenullcolumns', 'page', 'pagesize'} - # Explicit formatting for Mast's filtered services - if 'filtered' in service.lower(): + # Split params into categories + position_params = {k: v for k, v in kwargs.items() if k.lower() in position_keys} + request_params = {k: v for k, v in kwargs.items() if k.lower() in request_keys} + service_params = {k: v for k, v in kwargs.items() if k.lower() not in position_keys | request_keys} - # Separating the filter params from the positional and service_request method params. - filters = [{'paramName': k, 'values': kwargs[k]} for k in kwargs - if k.lower() not in position_keys+request_keys] - position_params = {k: v for k, v in kwargs.items() if k.lower() in position_keys} - request_params = {k: v for k, v in kwargs.items() if k.lower() in request_keys} + # Handle filtered vs. non-filtered services + if 'filtered' in service.lower(): + filters = self._build_filters(service_params) - # Mast's filtered services require at least one filter - if filters == []: - raise InvalidQueryError("Please provide at least one filter.") + if not filters: + raise InvalidQueryError('Please provide at least one filter.') - # Building 'params' for Mast.service_request - if columns is None: - columns = '*' + if columns is not None and isinstance(columns, list): + columns = ','.join(columns) - params = {'columns': columns, - 'filters': filters, - **position_params - } + params = { + 'columns': columns or '*', + 'filters': filters, + **position_params, + } else: - - # Separating service specific params from service_request method params - params = {k: v for k, v in kwargs.items() if k.lower() not in request_keys} - request_params = {k: v for k, v in kwargs.items() if k.lower() in request_keys} - - # Warning for wrong input if columns is not None: - warnings.warn("'columns' parameter will not mask non-filtered services", InputWarning) + warnings.warn( + "'columns' parameter is ignored for non-filtered services.", + InputWarning + ) + params = {**service_params, **position_params} return self.service_request(service, params, **request_params) diff --git a/astroquery/mast/tests/test_mast.py b/astroquery/mast/tests/test_mast.py index 8c1f07147f..6fc144d4b0 100644 --- a/astroquery/mast/tests/test_mast.py +++ b/astroquery/mast/tests/test_mast.py @@ -8,6 +8,7 @@ from unittest.mock import patch import pytest +import numpy as np from astropy.table import Table, unique from astropy.coordinates import SkyCoord @@ -551,9 +552,11 @@ def test_mast_query(patch_post): # filtered search result = mast.Mast.mast_query('Mast.Caom.Filtered', - dataproduct_type=['image'], - proposal_pi=['Osten, Rachel A.'], - s_dec=[{'min': 43.5, 'max': 45.5}]) + dataproduct_type=['image', 'spectrum'], + proposal_pi={'Osten, Rachel A.'}, + calib_level=np.asarray(3), + s_dec={'min': 43.5, 'max': 45.5}, + columns=['proposal_pi', 's_dec', 'obs_id']) pp_list = result['proposal_pi'] sd_list = result['s_dec'] assert isinstance(result, Table) @@ -561,10 +564,18 @@ def test_mast_query(patch_post): assert max(sd_list) < 45.5 assert min(sd_list) > 43.5 - # error handling - with pytest.raises(InvalidQueryError) as invalid_query: + # warn if columns provided for non-filtered query + with pytest.warns(InputWarning, match="'columns' parameter is ignored"): + mast.Mast.mast_query('Mast.Caom.Cone', ra=23.34086, dec=60.658, radius=0.2, columns=['obs_id', 's_ra']) + + # error if no filters provided for filtered query + with pytest.raises(InvalidQueryError, match="Please provide at least one filter."): mast.Mast.mast_query('Mast.Caom.Filtered') - assert "Please provide at least one filter." in str(invalid_query.value) + + # error if a full range if not provided for range filter + with pytest.raises(InvalidQueryError, + match='Range filter for "s_ra" must be a dictionary with "min" and "max" keys.'): + mast.Mast.mast_query('Mast.Caom.Filtered', s_ra={'min': 10.0}) def test_resolve_object_single(patch_post): diff --git a/astroquery/mast/tests/test_mast_remote.py b/astroquery/mast/tests/test_mast_remote.py index a4db087e9c..0d4be9e961 100644 --- a/astroquery/mast/tests/test_mast_remote.py +++ b/astroquery/mast/tests/test_mast_remote.py @@ -417,17 +417,30 @@ def test_mast_service_request(self): assert len(result) == 10 def test_mast_query(self): - result = Mast.mast_query('Mast.Caom.Cone', ra=184.3, dec=54.5, radius=0.2) - - # Is result in the right format + # Cone search (unfiltered) + result = Mast.mast_query('Mast.Caom.Cone', ra=184.3, dec=54.5, radius=0.005) assert isinstance(result, Table) - - # Are the GALEX observations in the results table assert "GALEX" in result['obs_collection'] - - # Are the two GALEX observations with obs_id 6374399093149532160 in the results table assert len(result[np.where(result["obs_id"] == "6374399093149532160")]) == 2 + # Filtered query + columns = ['target_name', 'obs_collection', 'calib_level', 'sequence_number', 't_min'] + result = Mast.mast_query('Mast.Caom.Filtered', + target_name=375422201, + obs_collection={'TESS'}, + calib_level=np.asarray(3), + sequence_number=[15, 16], + t_min={'min': 58710, 'max': 58720}, + columns=columns) + assert isinstance(result, Table) + assert all(result['target_name'] == '375422201') + assert all(result['obs_collection'] == 'TESS') + assert all(result['calib_level'] == 3) + assert all((result['sequence_number'] == 15) | (result['sequence_number'] == 16)) + assert (result['t_min'] >= 58710).all() and (result['t_min'] <= 58720).all() + assert all(c in list(result.columns.keys()) for c in columns) + assert len(result.columns) == 5 + def test_mast_session_info(self): sessionInfo = Mast.session_info(verbose=False) assert sessionInfo['ezid'] == 'anonymous' diff --git a/docs/mast/mast_mastquery.rst b/docs/mast/mast_mastquery.rst index b2c5cb05ee..3c3deaa3a7 100644 --- a/docs/mast/mast_mastquery.rst +++ b/docs/mast/mast_mastquery.rst @@ -7,7 +7,7 @@ The Mast class provides more direct access to the MAST interface. It requires more knowledge of the inner workings of the MAST API, and should be rarely needed. However in the case of new functionality not yet implemented in astroquery, this class does allow access. See the -`MAST api documentation `__ for more +`MAST API documentation `__ for more information. The basic MAST query function allows users to query through the following @@ -18,12 +18,17 @@ their corresponding parameters and returns query results as an Filtered Mast Queries ===================== -MAST's Filtered services use the parameters 'columns' and 'filters'. The 'columns' -parameter is a required string that specifies the columns to be returned as a -comma-separated list. The 'filters' parameter is a required list of filters to be -applied. The `~astroquery.mast.MastClass.mast_query` method accepts that list of -filters as keyword arguments paired with a list of values, similar to -`~astroquery.mast.ObservationsClass.query_criteria`. +MAST's filtered services (i.e. those with "filtered" in the service name) accept service-specific parameters, MashupRequest properties, and +column filters as keyword arguments and return a table of matching observations. See the +`service documentation `__ and the +`MashupRequest Class Reference `__ for valid keyword arguments. + +Parameters that are not related to position or MashupRequest properties are treated as column filters. +If the column has discrete values, the parameter value should be a single value or list of values, and values will be matched exactly. +If the column is continuous, you can filter by a single value, a list of values, or a range of values. If filtering by a range of values, +the parameter value should be a dict in the form ``{'min': minVal, 'max': maxVal}``. + +The ``columns`` parameter specifies the columns to be returned in the response as a comma-separated string or list of strings. The following example uses a JWST service with column names and filters specific to JWST services. For the full list of valid parameters view the @@ -34,22 +39,25 @@ JWST services. For the full list of valid parameters view the >>> from astroquery.mast import Mast ... >>> observations = Mast.mast_query('Mast.Jwst.Filtered.Nirspec', - ... columns='title, instrume, targname', - ... targoopp=['T']) + ... targoopp='T', + ... productLevel=['2a', '2b'], + ... duration={'min': 810, 'max': 820}, + ... columns=['filename', 'targoopp', 'productLevel', 'duration']) >>> print(observations) # doctest: +IGNORE_OUTPUT - title instrume targname - ------------------------------- -------- ---------------- - ToO Comet NIRSPEC ZTF (C/2022 E3) - ToO Comet NIRSPEC ZTF (C/2022 E3) - ToO Comet NIRSPEC ZTF (C/2022 E3) - ToO Comet NIRSPEC ZTF (C/2022 E3) - De-Mystifying SPRITEs with JWST NIRSPEC SPIRITS18nu - ToO Comet NIRSPEC ZTF (C/2022 E3) - ... ... ... - ToO Comet NIRSPEC ZTF (C/2022 E3) - ToO Comet NIRSPEC ZTF (C/2022 E3) - ToO Comet NIRSPEC ZTF (C/2022 E3) - Length = 319 rows + filename targoopp productLevel duration + -------------------------------------------- -------- ------------ -------- + jw05324004001_03102_00004_nrs2_rate.fits t 2a 816.978 + jw05324004001_03102_00004_nrs2_rateints.fits t 2a 816.978 + jw05324004001_03102_00001_nrs2_rate.fits t 2a 816.978 + jw05324004001_03102_00001_nrs2_rateints.fits t 2a 816.978 + jw05324004001_03102_00005_nrs2_rate.fits t 2a 816.978 + ... ... ... ... + jw05324004001_03102_00003_nrs1_s2d.fits t 2b 816.978 + jw05324004001_03102_00003_nrs1_x1d.fits t 2b 816.978 + jw05324004001_03102_00002_nrs1_cal.fits t 2b 816.978 + jw05324004001_03102_00002_nrs1_s2d.fits t 2b 816.978 + jw05324004001_03102_00002_nrs1_x1d.fits t 2b 816.978 + Length = 25 rows TESS Queries @@ -181,7 +189,7 @@ result in a warning. ... ra=254.287, ... dec=-4.09933, ... radius=0.02) # doctest: +SHOW_WARNINGS - InputWarning: 'columns' parameter will not mask non-filtered services + InputWarning: 'columns' parameter is ignored for non-filtered services. Advanced Service Request ========================