Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions pbjam/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def __init__(self, a=1, b=1, loc=0, scale=1):

self._set_stdatt()

def __repr__(self):
return f'Distribution: Beta(a={self.a}, b={self.b}, x0={self.loc}, x1={self.loc + self.scale})'

def rv(self):
""" Draw random variable from distribution

Expand Down Expand Up @@ -411,6 +414,8 @@ def __init__(self, loc=0, scale=1):

self._set_stdatt()

def __repr__(self):
return f'Distribution: Uniform(x1={self.a}, x2={self.b})'

def rv(self):
""" Draw random variable from distribution
Expand Down Expand Up @@ -544,6 +549,8 @@ def __init__(self, loc=0, scale=1):

self._set_stdatt()

def __repr__(self):
return f'Distribution: Normal(μ={self.loc}, σ={self.scale})'

def rv(self):
""" Draw random variable from distribution
Expand Down Expand Up @@ -656,6 +663,9 @@ def __init__(self,):
"""

self._set_stdatt()

def __repr__(self):
return f'Distribution: TruncatedSine'

def rv(self):
""" Draw random variable from distribution
Expand Down Expand Up @@ -773,6 +783,9 @@ def __init__(self, low, high):

self._set_stdatt()

def __repr__(self):
return f'Distribution: RandInt(min={self.low}, max={self.high})'

def rv(self):
""" Draw random variable from distribution

Expand Down
49 changes: 29 additions & 20 deletions pbjam/l1models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pbjam.distributions as dist
from dynesty import utils as dyfunc
jax.config.update('jax_enable_x64', True)
from functools import partial

class commonFuncs(jar.generalModelFuncs):
"""
Expand Down Expand Up @@ -622,17 +623,17 @@ def setPriors(self,):
self.DR.logpdf[i],
self.DR.cdf[i])

AddKeys = [k for k in self.variables if k in self.addPriors.keys()]

self.priors.update({key : self.addPriors[key] for key in AddKeys})

# Core rotation prior
self.priors['nurot_c'] = dist.uniform(loc=-2., scale=2.)

self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.)

# The inclination prior is a sine truncated between 0, and pi/2.
self.priors['inc'] = dist.truncsine()
self.priors['inc'] = dist.truncsine()

# override priors
AddKeys = [k for k in self.variables if k in self.addPriors.keys()]
self.priors.update({key : self.addPriors[key] for key in AddKeys})


def model(self, thetaU,):
"""
Expand Down Expand Up @@ -1085,19 +1086,18 @@ def setPriors(self,):
self.DR.logpdf[i],
self.DR.cdf[i])

AddKeys = [k for k in self.variables if k in self.addPriors.keys()]

self.priors.update({key : self.addPriors[key] for key in AddKeys})

self.priors['q'] = dist.uniform(loc=0.01, scale=0.6)

# Core rotation prior
self.priors['nurot_c'] = dist.uniform(loc=-2., scale=3.)

self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.)

# The inclination prior is a sine truncated between 0, and pi/2.
self.priors['inc'] = dist.truncsine()
self.priors['inc'] = dist.truncsine()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for changing the order of things here you'll need to check that it doesn't influence the sampling.

I remember having problems with that at some point, where I thought using dictionaries should have solved this, but apparently it didn't.

Just so we aren't sampling nurot_e when we think it's something else like eps_g.


# override priors
AddKeys = [k for k in self.variables if k in self.addPriors.keys()]
self.priors.update({key : self.addPriors[key] for key in AddKeys})


def unpackParams(self, theta):
""" Cast the parameters in a dictionary
Expand Down Expand Up @@ -1256,6 +1256,7 @@ def nearest(self, nu, nu_target):

return nu_target[jnp.argmin(jnp.abs(nu[:, None] - nu_target[None, :]), axis=1)]

@partial(jax.jit, static_argnums=(0,))
def Theta_p(self, nu, Dnu, nu_p):
"""
Compute the p-mode phase function Theta_p.
Expand All @@ -1279,6 +1280,7 @@ def Theta_p(self, nu, Dnu, nu_p):
(nu - self.nearest(nu, nu_p)) / Dnu + jnp.round((self.nearest(nu, nu_p) - nu_p[0]) / Dnu)
)

@partial(jax.jit, static_argnums=(0,))
def Theta_g(self, nu, DPi1, nu_g):
"""
Compute the g-mode phase function Theta_g.
Expand All @@ -1303,6 +1305,7 @@ def Theta_g(self, nu, DPi1, nu_g):
(1 / self.nearest(nu, nu_g) - 1 / nu) / DPi1
)

@partial(jax.jit, static_argnums=(0,))
def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g):
"""
Compute the local mixing fraction zeta.
Expand Down Expand Up @@ -1334,6 +1337,7 @@ def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g):

return 1 / (1 + DPi1 / Dnu * nu**2 / q * jnp.sin(Theta_g)**2 / jnp.cos(Theta_p)**2)

@partial(jax.jit, static_argnums=(0,))
def zeta_p(self, nu, q, DPi1, Dnu, nu_p):
"""
Compute the mixing fraction zeta using only the p-mode phase function. Agrees with zeta only at the
Expand Down Expand Up @@ -1361,6 +1365,7 @@ def zeta_p(self, nu, q, DPi1, Dnu, nu_p):

return 1 / (1 + DPi1 / Dnu * nu**2 / (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q))

@partial(jax.jit, static_argnums=(0,))
def zeta_g(self, nu, q, DPi1, Dnu, nu_g):

"""
Expand Down Expand Up @@ -1390,6 +1395,7 @@ def zeta_g(self, nu, q, DPi1, Dnu, nu_g):

return 1 / (1 + DPi1 / Dnu * nu**2 * (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q))

@partial(jax.jit, static_argnums=(0,))
def F(self, nu, nu_p, nu_g, Dnu, DPi1, q):
"""
Compute the characteristic function F such that F(nu) = 0 yields eigenvalues.
Expand Down Expand Up @@ -1417,6 +1423,7 @@ def F(self, nu, nu_p, nu_g, Dnu, DPi1, q):

return jnp.tan(self.Theta_p(nu, Dnu, nu_p)) * jnp.tan(self.Theta_g(nu, DPi1, nu_g)) - q

@partial(jax.jit, static_argnums=(0,))
def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0):
"""
Compute the first derivative dF/dnu of the characteristic function F.
Expand Down Expand Up @@ -1446,6 +1453,7 @@ def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0):
+ jnp.tan(self.Theta_p(nu, Dnu, nu_p)) / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2
- qp)

@partial(jax.jit, static_argnums=(0,))
def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0):
"""
Compute the second derivative d^2F / dnu^2of the characteristic function F.
Expand Down Expand Up @@ -1477,6 +1485,7 @@ def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0):
+ 2 / jnp.cos(self.Theta_p(nu, Dnu, nu_p))**2 * jnp.pi / Dnu / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2
- qpp)

@partial(jax.jit, static_argnums=(0,5))
def halley_iteration(self, x, y, yp, ypp, lmbda=1.):
"""
Perform Halley's method (2nd order Householder) iteration, with damping
Expand All @@ -1501,6 +1510,7 @@ def halley_iteration(self, x, y, yp, ypp, lmbda=1.):
"""
return x - lmbda * 2 * y * yp / (2 * yp * yp - y * ypp)

@partial(jax.jit, static_argnums=(0,6))
def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5):
"""
Solve the characteristic equation using Halley's method to couple
Expand Down Expand Up @@ -1528,21 +1538,20 @@ def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5):
num : array-like
Array of mixed mode frequencies.
"""

num_p = jnp.copy(nu_p)

num_g = jnp.copy(nu_g)

for _ in range(self.rootiter):
num_p = self.halley_iteration(num_p,
def _body(i, x0):
num_p, num_g = x0
a = self.halley_iteration(num_p,
self.F(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_p),
self.Fp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1),
self.Fpp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda)
num_g = self.halley_iteration(num_g,
b = self.halley_iteration(num_g,
self.F(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_g),
self.Fp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1),
self.Fpp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda)
return a, b

num_p, num_g = jax.lax.fori_loop(0, self.rootiter, _body, (nu_p, nu_g))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this improve the compile-time?

We of course also need to verify that results don't change.

return jnp.append(num_p, num_g)

def parseSamples(self, smp, Nmax=5000):
Expand Down
38 changes: 26 additions & 12 deletions pbjam/modeID.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,14 @@ def runl20model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwa

return self.l20result

def runl1model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwargs={}, logl_kwargs={}, model='auto', PCAsamples=500, PCAdims=7, **kwargs):
def makel1model(self, model='auto', PCAsamples=500, PCAdims=7, **kwargs):
"""
Runs the l1 model on the selected spectrum.
Construct a model for the l = 1 residual power spectrum.

Should follow the l20 model run.

Parameters
----------
progress : bool, optional
Whether to show progress during the model run. Default is True.
dynamic : bool, optional
Whether to use dynamic nested sampling. Default is False (static nested sampling).
minSamples : int, optional
The minimum number of samples to generate. Default is 5000.
sampler_kwargs : dict, optional
Additional keyword arguments for the sampler. Default is an empty dictionary.
logl_kwargs : dict, optional
Additional keyword arguments for the log-likelihood function. Default is an empty dictionary.
model : str
Choice of which model to use for estimating the l=1 mode locations. Choices are MS, SG, RGB models.
PCAsamples : int, optional
Expand Down Expand Up @@ -191,6 +181,30 @@ def runl1model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwar
modelChoice='simple')
else:
raise ValueError(f'Model {model} is invalid. Please use either MS, SG or RGB.')

def runl1model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwargs={}, logl_kwargs={}, **kwargs):
'''
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to use " " " instead of ' ' '.

Run an l = 1 model of the residual power spectum.

Keyword arguments not listed below will be passed to self.makel1model,
in the even that self.l1model is not yet defined.

Parameters
----------
progress : bool, optional
Whether to show progress during the model run. Default is True.
dynamic : bool, optional
Whether to use dynamic nested sampling. Default is False (static nested sampling).
minSamples : int, optional
The minimum number of samples to generate. Default is 5000.
sampler_kwargs : dict, optional
Additional keyword arguments for the sampler. Default is an empty dictionary.
logl_kwargs : dict, optional
Additional keyword arguments for the log-likelihood function. Default is an empty dictionary.
'''

if not hasattr(self, 'l1model'):
self.makel1model(**kwargs)

self.l1Samples = self.l1model.runSampler(progress=progress,
dynamic=dynamic,
Expand Down
Loading