diff --git a/pbjam/distributions.py b/pbjam/distributions.py index 9c86468..0d8a682 100644 --- a/pbjam/distributions.py +++ b/pbjam/distributions.py @@ -116,6 +116,9 @@ def __init__(self, a=1, b=1, loc=0, scale=1): self._set_stdatt() + def __repr__(self): + return f'Distribution: Beta(a={self.a}, b={self.b}, x0={self.loc}, x1={self.loc + self.scale})' + def rv(self): """ Draw random variable from distribution @@ -411,6 +414,8 @@ def __init__(self, loc=0, scale=1): self._set_stdatt() + def __repr__(self): + return f'Distribution: Uniform(x1={self.a}, x2={self.b})' def rv(self): """ Draw random variable from distribution @@ -544,6 +549,8 @@ def __init__(self, loc=0, scale=1): self._set_stdatt() + def __repr__(self): + return f'Distribution: Normal(μ={self.loc}, σ={self.scale})' def rv(self): """ Draw random variable from distribution @@ -656,6 +663,9 @@ def __init__(self,): """ self._set_stdatt() + + def __repr__(self): + return f'Distribution: TruncatedSine' def rv(self): """ Draw random variable from distribution @@ -773,6 +783,9 @@ def __init__(self, low, high): self._set_stdatt() + def __repr__(self): + return f'Distribution: RandInt(min={self.low}, max={self.high})' + def rv(self): """ Draw random variable from distribution diff --git a/pbjam/l1models.py b/pbjam/l1models.py index 5259e12..b5197a3 100644 --- a/pbjam/l1models.py +++ b/pbjam/l1models.py @@ -18,6 +18,7 @@ import pbjam.distributions as dist from dynesty import utils as dyfunc jax.config.update('jax_enable_x64', True) +from functools import partial class commonFuncs(jar.generalModelFuncs): """ @@ -622,17 +623,17 @@ def setPriors(self,): self.DR.logpdf[i], self.DR.cdf[i]) - AddKeys = [k for k in self.variables if k in self.addPriors.keys()] - - self.priors.update({key : self.addPriors[key] for key in AddKeys}) - # Core rotation prior self.priors['nurot_c'] = dist.uniform(loc=-2., scale=2.) - self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.) # The inclination prior is a sine truncated between 0, and pi/2. - self.priors['inc'] = dist.truncsine() + self.priors['inc'] = dist.truncsine() + + # override priors + AddKeys = [k for k in self.variables if k in self.addPriors.keys()] + self.priors.update({key : self.addPriors[key] for key in AddKeys}) + def model(self, thetaU,): """ @@ -1085,19 +1086,18 @@ def setPriors(self,): self.DR.logpdf[i], self.DR.cdf[i]) - AddKeys = [k for k in self.variables if k in self.addPriors.keys()] - - self.priors.update({key : self.addPriors[key] for key in AddKeys}) - self.priors['q'] = dist.uniform(loc=0.01, scale=0.6) - # Core rotation prior self.priors['nurot_c'] = dist.uniform(loc=-2., scale=3.) - self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.) # The inclination prior is a sine truncated between 0, and pi/2. - self.priors['inc'] = dist.truncsine() + self.priors['inc'] = dist.truncsine() + + # override priors + AddKeys = [k for k in self.variables if k in self.addPriors.keys()] + self.priors.update({key : self.addPriors[key] for key in AddKeys}) + def unpackParams(self, theta): """ Cast the parameters in a dictionary @@ -1256,6 +1256,7 @@ def nearest(self, nu, nu_target): return nu_target[jnp.argmin(jnp.abs(nu[:, None] - nu_target[None, :]), axis=1)] + @partial(jax.jit, static_argnums=(0,)) def Theta_p(self, nu, Dnu, nu_p): """ Compute the p-mode phase function Theta_p. @@ -1279,6 +1280,7 @@ def Theta_p(self, nu, Dnu, nu_p): (nu - self.nearest(nu, nu_p)) / Dnu + jnp.round((self.nearest(nu, nu_p) - nu_p[0]) / Dnu) ) + @partial(jax.jit, static_argnums=(0,)) def Theta_g(self, nu, DPi1, nu_g): """ Compute the g-mode phase function Theta_g. @@ -1303,6 +1305,7 @@ def Theta_g(self, nu, DPi1, nu_g): (1 / self.nearest(nu, nu_g) - 1 / nu) / DPi1 ) + @partial(jax.jit, static_argnums=(0,)) def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g): """ Compute the local mixing fraction zeta. @@ -1334,6 +1337,7 @@ def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g): return 1 / (1 + DPi1 / Dnu * nu**2 / q * jnp.sin(Theta_g)**2 / jnp.cos(Theta_p)**2) + @partial(jax.jit, static_argnums=(0,)) def zeta_p(self, nu, q, DPi1, Dnu, nu_p): """ Compute the mixing fraction zeta using only the p-mode phase function. Agrees with zeta only at the @@ -1361,6 +1365,7 @@ def zeta_p(self, nu, q, DPi1, Dnu, nu_p): return 1 / (1 + DPi1 / Dnu * nu**2 / (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q)) + @partial(jax.jit, static_argnums=(0,)) def zeta_g(self, nu, q, DPi1, Dnu, nu_g): """ @@ -1390,6 +1395,7 @@ def zeta_g(self, nu, q, DPi1, Dnu, nu_g): return 1 / (1 + DPi1 / Dnu * nu**2 * (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q)) + @partial(jax.jit, static_argnums=(0,)) def F(self, nu, nu_p, nu_g, Dnu, DPi1, q): """ Compute the characteristic function F such that F(nu) = 0 yields eigenvalues. @@ -1417,6 +1423,7 @@ def F(self, nu, nu_p, nu_g, Dnu, DPi1, q): return jnp.tan(self.Theta_p(nu, Dnu, nu_p)) * jnp.tan(self.Theta_g(nu, DPi1, nu_g)) - q + @partial(jax.jit, static_argnums=(0,)) def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0): """ Compute the first derivative dF/dnu of the characteristic function F. @@ -1446,6 +1453,7 @@ def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0): + jnp.tan(self.Theta_p(nu, Dnu, nu_p)) / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2 - qp) + @partial(jax.jit, static_argnums=(0,)) def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0): """ Compute the second derivative d^2F / dnu^2of the characteristic function F. @@ -1477,6 +1485,7 @@ def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0): + 2 / jnp.cos(self.Theta_p(nu, Dnu, nu_p))**2 * jnp.pi / Dnu / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2 - qpp) + @partial(jax.jit, static_argnums=(0,5)) def halley_iteration(self, x, y, yp, ypp, lmbda=1.): """ Perform Halley's method (2nd order Householder) iteration, with damping @@ -1501,6 +1510,7 @@ def halley_iteration(self, x, y, yp, ypp, lmbda=1.): """ return x - lmbda * 2 * y * yp / (2 * yp * yp - y * ypp) + @partial(jax.jit, static_argnums=(0,6)) def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5): """ Solve the characteristic equation using Halley's method to couple @@ -1528,21 +1538,20 @@ def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5): num : array-like Array of mixed mode frequencies. """ - - num_p = jnp.copy(nu_p) - - num_g = jnp.copy(nu_g) - for _ in range(self.rootiter): - num_p = self.halley_iteration(num_p, + def _body(i, x0): + num_p, num_g = x0 + a = self.halley_iteration(num_p, self.F(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_p), self.Fp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1), self.Fpp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda) - num_g = self.halley_iteration(num_g, + b = self.halley_iteration(num_g, self.F(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_g), self.Fp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1), self.Fpp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda) + return a, b + num_p, num_g = jax.lax.fori_loop(0, self.rootiter, _body, (nu_p, nu_g)) return jnp.append(num_p, num_g) def parseSamples(self, smp, Nmax=5000): diff --git a/pbjam/modeID.py b/pbjam/modeID.py index 1e953f1..48732c4 100644 --- a/pbjam/modeID.py +++ b/pbjam/modeID.py @@ -116,24 +116,14 @@ def runl20model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwa return self.l20result - def runl1model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwargs={}, logl_kwargs={}, model='auto', PCAsamples=500, PCAdims=7, **kwargs): + def makel1model(self, model='auto', PCAsamples=500, PCAdims=7, **kwargs): """ - Runs the l1 model on the selected spectrum. + Construct a model for the l = 1 residual power spectrum. Should follow the l20 model run. Parameters ---------- - progress : bool, optional - Whether to show progress during the model run. Default is True. - dynamic : bool, optional - Whether to use dynamic nested sampling. Default is False (static nested sampling). - minSamples : int, optional - The minimum number of samples to generate. Default is 5000. - sampler_kwargs : dict, optional - Additional keyword arguments for the sampler. Default is an empty dictionary. - logl_kwargs : dict, optional - Additional keyword arguments for the log-likelihood function. Default is an empty dictionary. model : str Choice of which model to use for estimating the l=1 mode locations. Choices are MS, SG, RGB models. PCAsamples : int, optional @@ -191,6 +181,30 @@ def runl1model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwar modelChoice='simple') else: raise ValueError(f'Model {model} is invalid. Please use either MS, SG or RGB.') + + def runl1model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwargs={}, logl_kwargs={}, **kwargs): + ''' + Run an l = 1 model of the residual power spectum. + + Keyword arguments not listed below will be passed to self.makel1model, + in the even that self.l1model is not yet defined. + + Parameters + ---------- + progress : bool, optional + Whether to show progress during the model run. Default is True. + dynamic : bool, optional + Whether to use dynamic nested sampling. Default is False (static nested sampling). + minSamples : int, optional + The minimum number of samples to generate. Default is 5000. + sampler_kwargs : dict, optional + Additional keyword arguments for the sampler. Default is an empty dictionary. + logl_kwargs : dict, optional + Additional keyword arguments for the log-likelihood function. Default is an empty dictionary. + ''' + + if not hasattr(self, 'l1model'): + self.makel1model(**kwargs) self.l1Samples = self.l1model.runSampler(progress=progress, dynamic=dynamic, diff --git a/pbjam/query.py b/pbjam/query.py new file mode 100644 index 0000000..609361e --- /dev/null +++ b/pbjam/query.py @@ -0,0 +1,252 @@ +## Convenience functions for looking up bp_rp (and teff), +## rescued from pbjam1 + +import time +from astroquery.mast import ObservationsClass as AsqMastObsCl +from astroquery.mast import Catalogs +from astroquery.simbad import Simbad +from astroquery.gaia import Gaia + +def _querySimbad(ID): + """ Query simbad for Gaia DR2 source ID. + + Looks up the target ID on Simbad to check if it has a Gaia DR2 ID. + + The input ID can be any commonly used identifier, such as a Bayer + designation, HD number or KIC. + + Notes + ----- + TIC numbers are note currently listed on Simbad. Do a separate MAST quiry + for this. + + Parameters + ---------- + ID : str + Target identifier. + + Returns + ------- + gaiaID : str + Gaia DR2 source ID. Returns None if no Gaia ID is found. + """ + + print('Querying Simbad for Gaia ID') + + try: + job = Simbad.query_objectids(ID) + except: + print(f'Unable to resolve {ID} with Simbad') + return None + + for line in job['id']: # as of astroquery >= 0.4.8, this is lowercase + if 'Gaia DR2' in line: + return line.replace('Gaia DR2 ', '') + return None + +def _queryTIC(ID, radius = 20): + """ Find bp_rp in TIC + + Queries the TIC at MAST to search for a target ID to return bp-rp value. The + TIC is already cross-matched with the Gaia catalog, so it contains a bp-rp + value for many targets (not all though). + + For some reason it does a cone search, which may return more than one + target. In which case the target matching the ID is found in the returned + list. + + Parameters + ---------- + ID : str + The TIC identifier to search for. + radius : float, optional + Radius in arcseconds to use for the sky cone search. Default is 20". + + Returns + ------- + bp_rp : float + Gaia bp-rp value from the TIC. + """ + + print('Querying TIC for Gaia values.') + job = Catalogs.query_object(objectname=ID, catalog='TIC', + radius = radius*units.arcsec) + + if len(job) > 0: + idx = job['ID'] == str(ID.replace('TIC','').replace(' ', '')) + return { + 'bp_rp': float(job['gaiabp'][idx] - job['gaiarp'][idx]), #This should crash if len(result) > 1. + 'teff': float(job['Teff'][idx]) + } + else: + return None + +def _queryMAST(ID): + """ Query ID at MAST + + Sends a query for a target ID to MAST which returns an Astropy Skycoords + object with the target coordinates. + + ID can be any commonly used identifier such as a Bayer designation, HD, KIC, + 2MASS or other name. + + Parameters + ---------- + ID : str + Target identifier + + Returns + ------- + job : astropy.Skycoords + An Astropy Skycoords object with the target coordinates. + + """ + + print(f'Querying MAST for the {ID} coordinates.') + mastobs = AsqMastObsCl() + try: + return mastobs.resolve_object(objectname = ID) + except: + return None + +def _queryGaia(ID=None,coords=None, radius = 20): + """ Query Gaia archive + + Sends an ADQL query to the Gaia archive to look up a requested target ID or + set of coordinates. + + If the query is based on coordinates a cone search will be performed and the + closest target is returned. Provided coordinates must be astropy.Skycoords. + + Parameters + ---------- + ID : str + Gaia source ID to search for. + coord : astropy.Skycoords + An Astropy Skycoords object with the target coordinates. Must only + contain one target. + radius : float, optional + Radius in arcseconds to use for the sky cone search. Default is 20". + + Returns + ------- + bp_rp : float + Gaia bp-rp value of the requested target from the Gaia archive. + """ + + if ID is not None: + adql_query = "select * from gaiadr2.gaia_source where source_id=%s" % (ID) + try: + job = Gaia.launch_job(adql_query).get_results() + except: + return None + + elif coords is not None: + ra = coords.to_value() + dec = coords.to_value() + adql_query = f"SELECT DISTANCE(POINT('ICRS', ra, dec), POINT('ICRS', {ra}, {dec})) AS dist, * FROM gaiadr2.gaia_source WHERE 1=CONTAINS(POINT('ICRS', ra, dec), CIRCLE('ICRS', {ra}, {dec}, {radius})) ORDER BY dist ASC" + + try: + job = Gaia.launch_job(adql_query).get_results() + except: + return None + else: + raise ValueError('No ID or coordinates provided when querying the Gaia archive.') + + return { + 'bp_rp': job['bp_rp'][0], + 'teff': job['teff_val'][0] + } + +def _format_name(name): + """ Format input ID + + Users tend to be inconsistent in naming targets, which is an issue for + looking stuff up on, e.g., Simbad. + + This function formats the name so that Simbad doesn't throw a fit. + + If the name doesn't look like anything in the variant list it will only be + changed to a lower-case string. + + Parameters + ---------- + name : str + Name to be formatted. + + Returns + ------- + name : str + Formatted name + + """ + + name = str(name) + name = name.lower() + + # Add naming exceptions here + variants = {'KIC': ['kic', 'kplr', 'KIC'], + 'Gaia DR2': ['gaia dr2', 'gdr2', 'dr2', 'Gaia DR2'], + 'Gaia DR1': ['gaia dr1', 'gdr1', 'dr1', 'Gaia DR1'], + 'EPIC': ['epic', 'ktwo', 'EPIC'], + 'TIC': ['tic', 'tess', 'TIC'] + } + + fname = None + for key in variants: + for x in variants[key]: + if x in name: + fname = name.replace(x,'') + fname = re.sub(r"\s+", "", fname, flags=re.UNICODE) + fname = key+' '+fname + return fname + + return name + + +def get_spec(ID): + """ Search online for bp_rp and Teff values based on ID. + + First a check is made to see if the target is a TIC number, in which case + the TIC will be queried, since this is already cross-matched with Gaia DR2. + + If it is not a TIC number, Simbad is queries to identify a possible Gaia + source ID. + + As a last resort MAST is queried to provide the target coordinates, after + which a Gaia query is launched to find the closest target. The default + search radius is 20" around the provided coordinates. + + Parameters + ---------- + ID : str + Target identifier to search for. + + Returns + ------- + props : dict + Gaia bp-rp and Teff value for the target. Is nan if no result is found or the + queries failed. + + """ + + time.sleep(1) + + ID = _format_name(ID) + + if 'TIC' in ID: + bp_rp = _queryTIC(ID) + + else: + try: + gaiaID = _querySimbad(ID) + res = _queryGaia(ID=gaiaID) + except: + try: + coords = _queryMAST(ID) + res = _queryGaia(coords=coords) + except: + print(f'Unable to retrieve a bp_rp and Teff value for {ID}.') + res = None + + return res \ No newline at end of file