-
-
Notifications
You must be signed in to change notification settings - Fork 69
added sample_filter_outputs utility and accompanying simple tests #526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added sample_filter_outputs utility and accompanying simple tests #526
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great to me, just a small question about the use of modecontext
here
Also you need to rebase :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a sample_filter_outputs
utility method to the StateSpace class that enables users to sample filtered, predicted, smoothed, and observed states and covariances from fitted models. This aligns with PyMC's workflow pattern of sample
→ sample_posterior_predictive
.
- Adds
sample_filter_outputs
method to StateSpace class for sampling various filter outputs - Includes validation logic to ensure requested filter output names are valid
- Adds comprehensive tests covering basic functionality and error handling
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
pymc_extras/statespace/core/statespace.py | Implements the main sample_filter_outputs method with validation and output sampling logic |
tests/statespace/core/test_statespace.py | Adds test cases for the new utility including positive and negative test scenarios |
Comments suppressed due to low confidence (1)
tests/statespace/core/test_statespace.py:1045
- The error message test hardcodes the expected array string representation, which may be fragile across different NumPy versions or platforms. Consider using a more flexible assertion that checks for key parts of the error message instead of the exact format.
msg = "['filter_covariances' 'filter_states'] not a valid filter output name!"
case "filtered_states" | "predicted_states" | "smoothed_states": | ||
dims = [TIME_DIM, "state"] | ||
case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances": | ||
dims = [TIME_DIM, "state", "state_aux"] | ||
case "observed_states": | ||
dims = [TIME_DIM, "observed_state"] | ||
case "observed_covariances": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The match statement uses hardcoded string literals that are repeated multiple times. Consider defining constants for the filter output names to improve maintainability and reduce the risk of typos.
case "filtered_states" | "predicted_states" | "smoothed_states": | |
dims = [TIME_DIM, "state"] | |
case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances": | |
dims = [TIME_DIM, "state", "state_aux"] | |
case "observed_states": | |
dims = [TIME_DIM, "observed_state"] | |
case "observed_covariances": | |
case FILTERED_STATES | PREDICTED_STATES | SMOOTHED_STATES: | |
dims = [TIME_DIM, "state"] | |
case FILTERED_COVARIANCES | PREDICTED_COVARIANCES | SMOOTHED_COVARIANCES: | |
dims = [TIME_DIM, "state", "state_aux"] | |
case OBSERVED_STATES: | |
dims = [TIME_DIM, "observed_state"] | |
case OBSERVED_COVARIANCES: |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A dictionary mapping cases to dims might be better than the casework here. There isn't already one in constants?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a dictionary in constats.py
variable name FILTER_OUTPUT_DIMS
but the key names are singular whereas the returned names from both kalman_(filter|smoother).build_graph()
are plural. I handled this inside the method but I am wondering if you want to consolidate this either in constants.py
or inside the kalman_(filter|smoother).build_graph()
methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Barf. This is the kind of inconsistency that we really need to stomp out. Do you see any reason why there should be a singular version and a plural version? My guess is that I just did it out of sloppiness. If you don't, can you pick one and use it everywhere?
My preference would be plural. The logic is that the variable (the symbolic output) really is many states. On the other hand, dimensions should be singular, because it's just one dimension. Example: the "state" dimension is a label for the 1st dimension of a (100, 5) tensor, vs the "filtered_states" object which is a (100, 5) tensor concatenating the evolution of 5 states over 100 timesteps.
Rebased from upstream
2. Added handle for when filter_output param is passed in as a str 3. removed case statement in favor of dictionary mapping that already exists in conf.py
0d4df37
to
5b064d4
Compare
Hey @jessegrabowski, I updated some of the constant in
|
Dims should be singular, I have strong feelings on that. For the matrix names, I have less of a strong preference. On one hand, If you agree, that would mean all of the things you identified there would stay as-is I think. |
I agree that the singular dim names sound better. No objections from me about keeping the rest as-is! I think the main issue of having the same thing named differently is now resolved 🤞 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few final nitpicks then let's merge this! It's looking really great.
# Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph() | ||
# filter_output_dims_mapping = {} | ||
# for k in FILTER_OUTPUT_DIMS.keys(): | ||
# filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops! Sorry about this. That was a careless oversight. I will clean that up right away!
else: | ||
unknown_filter_output_names = np.setdiff1d( | ||
filter_output_names, [x.name for x in all_filter_outputs] | ||
) | ||
if unknown_filter_output_names.size > 0: | ||
raise ValueError( | ||
f"{unknown_filter_output_names} not a valid filter output name!" | ||
) | ||
filter_output_names = [ | ||
x for x in all_filter_outputs if x.name in filter_output_names | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the input validation up to the top, so we fail quickly without doing any work if the user passes invalid names
|
||
frozen_model = freeze_dims_and_data(m) | ||
with frozen_model: | ||
idata_filter = pm.sample_posterior_predictive( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: no need for an intermediate variable here, just directly return
with frozen_model: | ||
idata_filter = pm.sample_posterior_predictive( | ||
idata if group == "posterior" else idata.prior, | ||
var_names=[x.name for x in frozen_model.deterministics], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just use filter_output_names
here. I'm not sure anything could go wrong with your approach, but it's an unnecessary extra bit of complexity.
…intermediate variables
@@ -1684,6 +1684,21 @@ def sample_filter_outputs( | |||
if isinstance(filter_output_names, str): | |||
filter_output_names = [filter_output_names] | |||
|
|||
drop_keys = {"predicted_observed_states", "predicted_observed_covariances"} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we shouldn't treat these as special (even though I agree it's silly to ask for them). I'd be confused if I tried to ask for them and it said it's not a valid filter output name.
Having everything in one place is convenient, even if it's duplicative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, yeah I agree with you.
I just wasn't sure because in constants.py
FILTER_OUTPUT_DIMS
has predicted_observed_states
and predicted_observed_covariances
but the output from kalman_filter.build_graph()
doesn't have predicted_observed_states
and predicted_observed_covariances
it seems like these are named observed_states
and observed_covariances
.
Should I change the names in constants.py
to match the returned names from kalman_filter.build_graph()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, these should be consistent. But where does the name change currently happen between the filter and the idata? Maybe this is an issue for another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jessegrabowski, I believe this happens in _postprocess_scan_results()
in kalman_filter.py
. It looks like the names of the filter outputs are hardcoded in there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to make this consistent in this PR I have no objection. I don't have a good sense if you should change FILTER_OUTPUT_DIMS to match the output names, or change the output names to match the FILTER_OUTPUT_DIMS. I'll defer to you if you have a sense of which one is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I will do it in this PR because I think it is somewhat related. I think the names should match whatever we put in FILTER_OUTPUT_DIMS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made one last nitpick, but it's not a blocker. Feel free to address or not, then merge :)
…pdated sample_filter_outputs to allow sampling any filter outputs defined in constants.py
Hey @jessegrabowski, I don't believe I have permissions to merge. |
Great work as always :D |
Thank you, Jesse! Always happy to help! |
This addresses #521, allows state space users to sample filtered|predicted|smoothed|observed states|covariances using a utility that is consistent with PyMC workflow of sample -> sample_posterior.