-
Notifications
You must be signed in to change notification settings - Fork 7
several quality-of-life improvements #293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
darthoctopus
wants to merge
5
commits into
dev
Choose a base branch
from
qol
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
4de9a39
old astroquery routines, updated because of API changes
darthoctopus a366e0b
allow priors on model-specific quantities to be overriden rather than…
darthoctopus 451b635
separate out function call for constructing vs running l=1 model
darthoctopus e5142bc
i think this is how you jax.jit a for loop without precompiling an un…
darthoctopus e0ed651
repr strings for distribution objects
darthoctopus File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
import pbjam.distributions as dist | ||
from dynesty import utils as dyfunc | ||
jax.config.update('jax_enable_x64', True) | ||
from functools import partial | ||
|
||
class commonFuncs(jar.generalModelFuncs): | ||
""" | ||
|
@@ -622,17 +623,17 @@ def setPriors(self,): | |
self.DR.logpdf[i], | ||
self.DR.cdf[i]) | ||
|
||
AddKeys = [k for k in self.variables if k in self.addPriors.keys()] | ||
|
||
self.priors.update({key : self.addPriors[key] for key in AddKeys}) | ||
|
||
# Core rotation prior | ||
self.priors['nurot_c'] = dist.uniform(loc=-2., scale=2.) | ||
|
||
self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.) | ||
|
||
# The inclination prior is a sine truncated between 0, and pi/2. | ||
self.priors['inc'] = dist.truncsine() | ||
self.priors['inc'] = dist.truncsine() | ||
|
||
# override priors | ||
AddKeys = [k for k in self.variables if k in self.addPriors.keys()] | ||
self.priors.update({key : self.addPriors[key] for key in AddKeys}) | ||
|
||
|
||
def model(self, thetaU,): | ||
""" | ||
|
@@ -1085,19 +1086,18 @@ def setPriors(self,): | |
self.DR.logpdf[i], | ||
self.DR.cdf[i]) | ||
|
||
AddKeys = [k for k in self.variables if k in self.addPriors.keys()] | ||
|
||
self.priors.update({key : self.addPriors[key] for key in AddKeys}) | ||
|
||
self.priors['q'] = dist.uniform(loc=0.01, scale=0.6) | ||
|
||
# Core rotation prior | ||
self.priors['nurot_c'] = dist.uniform(loc=-2., scale=3.) | ||
|
||
self.priors['nurot_e'] = dist.uniform(loc=-2., scale=2.) | ||
|
||
# The inclination prior is a sine truncated between 0, and pi/2. | ||
self.priors['inc'] = dist.truncsine() | ||
self.priors['inc'] = dist.truncsine() | ||
|
||
# override priors | ||
AddKeys = [k for k in self.variables if k in self.addPriors.keys()] | ||
self.priors.update({key : self.addPriors[key] for key in AddKeys}) | ||
|
||
|
||
def unpackParams(self, theta): | ||
""" Cast the parameters in a dictionary | ||
|
@@ -1256,6 +1256,7 @@ def nearest(self, nu, nu_target): | |
|
||
return nu_target[jnp.argmin(jnp.abs(nu[:, None] - nu_target[None, :]), axis=1)] | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def Theta_p(self, nu, Dnu, nu_p): | ||
""" | ||
Compute the p-mode phase function Theta_p. | ||
|
@@ -1279,6 +1280,7 @@ def Theta_p(self, nu, Dnu, nu_p): | |
(nu - self.nearest(nu, nu_p)) / Dnu + jnp.round((self.nearest(nu, nu_p) - nu_p[0]) / Dnu) | ||
) | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def Theta_g(self, nu, DPi1, nu_g): | ||
""" | ||
Compute the g-mode phase function Theta_g. | ||
|
@@ -1303,6 +1305,7 @@ def Theta_g(self, nu, DPi1, nu_g): | |
(1 / self.nearest(nu, nu_g) - 1 / nu) / DPi1 | ||
) | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g): | ||
""" | ||
Compute the local mixing fraction zeta. | ||
|
@@ -1334,6 +1337,7 @@ def zeta(self, nu, q, DPi1, Dnu, nu_p, nu_g): | |
|
||
return 1 / (1 + DPi1 / Dnu * nu**2 / q * jnp.sin(Theta_g)**2 / jnp.cos(Theta_p)**2) | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def zeta_p(self, nu, q, DPi1, Dnu, nu_p): | ||
""" | ||
Compute the mixing fraction zeta using only the p-mode phase function. Agrees with zeta only at the | ||
|
@@ -1361,6 +1365,7 @@ def zeta_p(self, nu, q, DPi1, Dnu, nu_p): | |
|
||
return 1 / (1 + DPi1 / Dnu * nu**2 / (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q)) | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def zeta_g(self, nu, q, DPi1, Dnu, nu_g): | ||
|
||
""" | ||
|
@@ -1390,6 +1395,7 @@ def zeta_g(self, nu, q, DPi1, Dnu, nu_g): | |
|
||
return 1 / (1 + DPi1 / Dnu * nu**2 * (q * jnp.cos(Theta)**2 + jnp.sin(Theta)**2/q)) | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def F(self, nu, nu_p, nu_g, Dnu, DPi1, q): | ||
""" | ||
Compute the characteristic function F such that F(nu) = 0 yields eigenvalues. | ||
|
@@ -1417,6 +1423,7 @@ def F(self, nu, nu_p, nu_g, Dnu, DPi1, q): | |
|
||
return jnp.tan(self.Theta_p(nu, Dnu, nu_p)) * jnp.tan(self.Theta_g(nu, DPi1, nu_g)) - q | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0): | ||
""" | ||
Compute the first derivative dF/dnu of the characteristic function F. | ||
|
@@ -1446,6 +1453,7 @@ def Fp(self, nu, nu_p, nu_g, Dnu, DPi1, qp=0): | |
+ jnp.tan(self.Theta_p(nu, Dnu, nu_p)) / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2 | ||
- qp) | ||
|
||
@partial(jax.jit, static_argnums=(0,)) | ||
def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0): | ||
""" | ||
Compute the second derivative d^2F / dnu^2of the characteristic function F. | ||
|
@@ -1477,6 +1485,7 @@ def Fpp(self, nu, nu_p, nu_g, Dnu, DPi1, qpp=0): | |
+ 2 / jnp.cos(self.Theta_p(nu, Dnu, nu_p))**2 * jnp.pi / Dnu / jnp.cos(self.Theta_g(nu, DPi1, nu_g))**2 * jnp.pi / DPi1 / nu**2 | ||
- qpp) | ||
|
||
@partial(jax.jit, static_argnums=(0,5)) | ||
def halley_iteration(self, x, y, yp, ypp, lmbda=1.): | ||
""" | ||
Perform Halley's method (2nd order Householder) iteration, with damping | ||
|
@@ -1501,6 +1510,7 @@ def halley_iteration(self, x, y, yp, ypp, lmbda=1.): | |
""" | ||
return x - lmbda * 2 * y * yp / (2 * yp * yp - y * ypp) | ||
|
||
@partial(jax.jit, static_argnums=(0,6)) | ||
def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5): | ||
""" | ||
Solve the characteristic equation using Halley's method to couple | ||
|
@@ -1528,21 +1538,20 @@ def couple(self, nu_p, nu_g, q_p, q_g, DPi1, lmbda=.5): | |
num : array-like | ||
Array of mixed mode frequencies. | ||
""" | ||
|
||
num_p = jnp.copy(nu_p) | ||
|
||
num_g = jnp.copy(nu_g) | ||
|
||
for _ in range(self.rootiter): | ||
num_p = self.halley_iteration(num_p, | ||
def _body(i, x0): | ||
num_p, num_g = x0 | ||
a = self.halley_iteration(num_p, | ||
self.F(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_p), | ||
self.Fp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1), | ||
self.Fpp(num_p, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda) | ||
num_g = self.halley_iteration(num_g, | ||
b = self.halley_iteration(num_g, | ||
self.F(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1, q_g), | ||
self.Fp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1), | ||
self.Fpp(num_g, nu_p, nu_g, self.obs['dnu'][0], DPi1), lmbda=lmbda) | ||
return a, b | ||
|
||
num_p, num_g = jax.lax.fori_loop(0, self.rootiter, _body, (nu_p, nu_g)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this improve the compile-time? We of course also need to verify that results don't change. |
||
return jnp.append(num_p, num_g) | ||
|
||
def parseSamples(self, smp, Nmax=5000): | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -116,24 +116,14 @@ def runl20model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwa | |
|
||
return self.l20result | ||
|
||
def runl1model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwargs={}, logl_kwargs={}, model='auto', PCAsamples=500, PCAdims=7, **kwargs): | ||
def makel1model(self, model='auto', PCAsamples=500, PCAdims=7, **kwargs): | ||
""" | ||
Runs the l1 model on the selected spectrum. | ||
Construct a model for the l = 1 residual power spectrum. | ||
|
||
Should follow the l20 model run. | ||
|
||
Parameters | ||
---------- | ||
progress : bool, optional | ||
Whether to show progress during the model run. Default is True. | ||
dynamic : bool, optional | ||
Whether to use dynamic nested sampling. Default is False (static nested sampling). | ||
minSamples : int, optional | ||
The minimum number of samples to generate. Default is 5000. | ||
sampler_kwargs : dict, optional | ||
Additional keyword arguments for the sampler. Default is an empty dictionary. | ||
logl_kwargs : dict, optional | ||
Additional keyword arguments for the log-likelihood function. Default is an empty dictionary. | ||
model : str | ||
Choice of which model to use for estimating the l=1 mode locations. Choices are MS, SG, RGB models. | ||
PCAsamples : int, optional | ||
|
@@ -191,6 +181,30 @@ def runl1model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwar | |
modelChoice='simple') | ||
else: | ||
raise ValueError(f'Model {model} is invalid. Please use either MS, SG or RGB.') | ||
|
||
def runl1model(self, progress=True, dynamic=False, minSamples=5000, sampler_kwargs={}, logl_kwargs={}, **kwargs): | ||
''' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to use " " " instead of ' ' '. |
||
Run an l = 1 model of the residual power spectum. | ||
|
||
Keyword arguments not listed below will be passed to self.makel1model, | ||
in the even that self.l1model is not yet defined. | ||
|
||
Parameters | ||
---------- | ||
progress : bool, optional | ||
Whether to show progress during the model run. Default is True. | ||
dynamic : bool, optional | ||
Whether to use dynamic nested sampling. Default is False (static nested sampling). | ||
minSamples : int, optional | ||
The minimum number of samples to generate. Default is 5000. | ||
sampler_kwargs : dict, optional | ||
Additional keyword arguments for the sampler. Default is an empty dictionary. | ||
logl_kwargs : dict, optional | ||
Additional keyword arguments for the log-likelihood function. Default is an empty dictionary. | ||
''' | ||
|
||
if not hasattr(self, 'l1model'): | ||
self.makel1model(**kwargs) | ||
|
||
self.l1Samples = self.l1model.runSampler(progress=progress, | ||
dynamic=dynamic, | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think for changing the order of things here you'll need to check that it doesn't influence the sampling.
I remember having problems with that at some point, where I thought using dictionaries should have solved this, but apparently it didn't.
Just so we aren't sampling nurot_e when we think it's something else like eps_g.