Skip to content

added sample_filter_outputs utility and accompanying simple tests #526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 28, 2025

Conversation

Dekermanjian
Copy link
Contributor

This addresses #521, allows state space users to sample filtered|predicted|smoothed|observed states|covariances using a utility that is consistent with PyMC workflow of sample -> sample_posterior.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great to me, just a small question about the use of modecontext here

@jessegrabowski jessegrabowski requested a review from Copilot July 21, 2025 10:23
@jessegrabowski
Copy link
Member

Also you need to rebase :)

Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a sample_filter_outputs utility method to the StateSpace class that enables users to sample filtered, predicted, smoothed, and observed states and covariances from fitted models. This aligns with PyMC's workflow pattern of samplesample_posterior_predictive.

  • Adds sample_filter_outputs method to StateSpace class for sampling various filter outputs
  • Includes validation logic to ensure requested filter output names are valid
  • Adds comprehensive tests covering basic functionality and error handling

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
pymc_extras/statespace/core/statespace.py Implements the main sample_filter_outputs method with validation and output sampling logic
tests/statespace/core/test_statespace.py Adds test cases for the new utility including positive and negative test scenarios
Comments suppressed due to low confidence (1)

tests/statespace/core/test_statespace.py:1045

  • The error message test hardcodes the expected array string representation, which may be fragile across different NumPy versions or platforms. Consider using a more flexible assertion that checks for key parts of the error message instead of the exact format.
    msg = "['filter_covariances' 'filter_states'] not a valid filter output name!"

Comment on lines 1745 to 1751
case "filtered_states" | "predicted_states" | "smoothed_states":
dims = [TIME_DIM, "state"]
case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances":
dims = [TIME_DIM, "state", "state_aux"]
case "observed_states":
dims = [TIME_DIM, "observed_state"]
case "observed_covariances":
Copy link
Preview

Copilot AI Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The match statement uses hardcoded string literals that are repeated multiple times. Consider defining constants for the filter output names to improve maintainability and reduce the risk of typos.

Suggested change
case "filtered_states" | "predicted_states" | "smoothed_states":
dims = [TIME_DIM, "state"]
case "filtered_covariances" | "predicted_covariances" | "smoothed_covariances":
dims = [TIME_DIM, "state", "state_aux"]
case "observed_states":
dims = [TIME_DIM, "observed_state"]
case "observed_covariances":
case FILTERED_STATES | PREDICTED_STATES | SMOOTHED_STATES:
dims = [TIME_DIM, "state"]
case FILTERED_COVARIANCES | PREDICTED_COVARIANCES | SMOOTHED_COVARIANCES:
dims = [TIME_DIM, "state", "state_aux"]
case OBSERVED_STATES:
dims = [TIME_DIM, "observed_state"]
case OBSERVED_COVARIANCES:

Copilot uses AI. Check for mistakes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A dictionary mapping cases to dims might be better than the casework here. There isn't already one in constants?

Copy link
Contributor Author

@Dekermanjian Dekermanjian Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a dictionary in constats.py variable name FILTER_OUTPUT_DIMS but the key names are singular whereas the returned names from both kalman_(filter|smoother).build_graph() are plural. I handled this inside the method but I am wondering if you want to consolidate this either in constants.py or inside the kalman_(filter|smoother).build_graph() methods?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Barf. This is the kind of inconsistency that we really need to stomp out. Do you see any reason why there should be a singular version and a plural version? My guess is that I just did it out of sloppiness. If you don't, can you pick one and use it everywhere?

My preference would be plural. The logic is that the variable (the symbolic output) really is many states. On the other hand, dimensions should be singular, because it's just one dimension. Example: the "state" dimension is a label for the 1st dimension of a (100, 5) tensor, vs the "filtered_states" object which is a (100, 5) tensor concatenating the evolution of 5 states over 100 timesteps.

2. Added handle for when filter_output param is passed in as a str
3. removed case statement in favor of dictionary mapping that already exists in conf.py
@Dekermanjian Dekermanjian force-pushed the filter_outputs_utility branch from 0d4df37 to 5b064d4 Compare July 21, 2025 21:50
@Dekermanjian
Copy link
Contributor Author

Hey @jessegrabowski, I updated some of the constant in constants.py to be plural and updated the tests for any mismatches. There were a few constants that I was not sure about so I wanted to ask you first before I change them. These are:

ALL_STATE_DIM = "state(?s)"
ALL_STATE_AUX_DIM = "state(?s)_aux"
OBS_STATE_DIM = "observed_state(?s)"
OBS_STATE_AUX_DIM = "observed_state(?s)_aux"

NEVER_TIME_VARYING = ["initial_state(?s)", "initial_state(?s)_cov(?s)"]
VECTOR_VALUED = ["initial_state(?s)", "state(?s)_intercept(?s)", "obs_intercept(?s)"]

LONG_MATRIX_NAMES = [
    "initial_state(?s)",
    "initial_state(?s)_cov(?s)",
    "state(?s)_intercept(?s)",
    "obs_intercept(?s)",
    "obs_cov(?s)",
    "state_cov(?s)",
]

@jessegrabowski
Copy link
Member

Dims should be singular, I have strong feelings on that.

For the matrix names, I have less of a strong preference. On one hand, x0 is a vector of states, but on the other hand, it is the state vector. So it could go either way. I guess I lean to not changing things in that case?

If you agree, that would mean all of the things you identified there would stay as-is I think.

@Dekermanjian
Copy link
Contributor Author

I agree that the singular dim names sound better. No objections from me about keeping the rest as-is! I think the main issue of having the same thing named differently is now resolved 🤞

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few final nitpicks then let's merge this! It's looking really great.

Comment on lines 1729 to 1732
# Filter output names are singular in constants.py but are returned as plural from kalman_.build_graph()
# filter_output_dims_mapping = {}
# for k in FILTER_OUTPUT_DIMS.keys():
# filter_output_dims_mapping[k + "s"] = FILTER_OUTPUT_DIMS[k]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops! Sorry about this. That was a careless oversight. I will clean that up right away!

Comment on lines 1742 to 1752
else:
unknown_filter_output_names = np.setdiff1d(
filter_output_names, [x.name for x in all_filter_outputs]
)
if unknown_filter_output_names.size > 0:
raise ValueError(
f"{unknown_filter_output_names} not a valid filter output name!"
)
filter_output_names = [
x for x in all_filter_outputs if x.name in filter_output_names
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the input validation up to the top, so we fail quickly without doing any work if the user passes invalid names


frozen_model = freeze_dims_and_data(m)
with frozen_model:
idata_filter = pm.sample_posterior_predictive(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need for an intermediate variable here, just directly return

with frozen_model:
idata_filter = pm.sample_posterior_predictive(
idata if group == "posterior" else idata.prior,
var_names=[x.name for x in frozen_model.deterministics],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use filter_output_names here. I'm not sure anything could go wrong with your approach, but it's an unnecessary extra bit of complexity.

@@ -1684,6 +1684,21 @@ def sample_filter_outputs(
if isinstance(filter_output_names, str):
filter_output_names = [filter_output_names]

drop_keys = {"predicted_observed_states", "predicted_observed_covariances"}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't treat these as special (even though I agree it's silly to ask for them). I'd be confused if I tried to ask for them and it said it's not a valid filter output name.

Having everything in one place is convenient, even if it's duplicative.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, yeah I agree with you.

I just wasn't sure because in constants.py FILTER_OUTPUT_DIMS has predicted_observed_states and predicted_observed_covariances but the output from kalman_filter.build_graph() doesn't have predicted_observed_states and predicted_observed_covariances it seems like these are named observed_states and observed_covariances.

Should I change the names in constants.py to match the returned names from kalman_filter.build_graph()?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, these should be consistent. But where does the name change currently happen between the filter and the idata? Maybe this is an issue for another PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski, I believe this happens in _postprocess_scan_results() in kalman_filter.py. It looks like the names of the filter outputs are hardcoded in there.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to make this consistent in this PR I have no objection. I don't have a good sense if you should change FILTER_OUTPUT_DIMS to match the output names, or change the output names to match the FILTER_OUTPUT_DIMS. I'll defer to you if you have a sense of which one is better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I will do it in this PR because I think it is somewhat related. I think the names should match whatever we put in FILTER_OUTPUT_DIMS

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made one last nitpick, but it's not a blocker. Feel free to address or not, then merge :)

…pdated sample_filter_outputs to allow sampling any filter outputs defined in constants.py
@Dekermanjian
Copy link
Contributor Author

I made one last nitpick, but it's not a blocker. Feel free to address or not, then merge :)

Hey @jessegrabowski, I don't believe I have permissions to merge.

@jessegrabowski jessegrabowski merged commit 24930b5 into pymc-devs:main Jul 28, 2025
17 checks passed
@jessegrabowski
Copy link
Member

Great work as always :D

@Dekermanjian
Copy link
Contributor Author

Thank you, Jesse! Always happy to help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants