From 495873fd0998ac338cd180cb8b98f6bcf07e9164 Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Tue, 23 Sep 2025 13:38:47 +0000 Subject: [PATCH 1/4] Covariate support for Dataset class --- pyproject.toml | 2 +- src/causal_validation/data.py | 106 +++++++++- src/causal_validation/validation/placebo.py | 9 +- src/causal_validation/validation/rmspe.py | 6 +- tests/test_causal_validation/test_data.py | 222 +++++++++++++++++--- 5 files changed, 295 insertions(+), 50 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index accaaef..2cf98ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -171,7 +171,7 @@ select = [ "TID", "ISC", ] -ignore = ["F722"] +ignore = ["F722", "PLW1641"] [tool.ruff.format] quote-style = "double" diff --git a/src/causal_validation/data.py b/src/causal_validation/data.py index 057acfa..a81f13a 100644 --- a/src/causal_validation/data.py +++ b/src/causal_validation/data.py @@ -21,15 +21,45 @@ @dataclass class Dataset: + """A causal inference dataset containing pre/post intervention observations + and optional associated covariates. + + Attributes: + Xtr: Pre-intervention control unit observations (N x D) + Xte: Post-intervention control unit observations (M x D) + ytr: Pre-intervention treated unit observations (N x 1) + yte: Post-intervention treated unit observations (M x 1) + _start_date: Start date for time indexing + Ptr: Pre-intervention control unit covariates (N x D x F) + Pte: Post-intervention control unit covariates (M x D x F) + Rtr: Pre-intervention treated unit covariates (N x 1 x F) + Rte: Post-intervention treated unit covariates (M x 1 x F) + counterfactual: Optional counterfactual outcomes (M x 1) + synthetic: Optional synthetic control outcomes (M x 1). + This is weighted combination of control units + minimizing a distance-based error w.r.t. the + treated in pre-intervention period. + _name: Optional name identifier for the dataset + """ Xtr: Float[np.ndarray, "N D"] Xte: Float[np.ndarray, "M D"] ytr: Float[np.ndarray, "N 1"] yte: Float[np.ndarray, "M 1"] _start_date: dt.date + Ptr: tp.Optional[Float[np.ndarray, "N D F"]] = None + Pte: tp.Optional[Float[np.ndarray, "M D F"]] = None + Rtr: tp.Optional[Float[np.ndarray, "N 1 F"]] = None + Rte: tp.Optional[Float[np.ndarray, "M 1 F"]] = None counterfactual: tp.Optional[Float[np.ndarray, "M 1"]] = None synthetic: tp.Optional[Float[np.ndarray, "M 1"]] = None _name: str = None + def __post_init__(self): + covariates = [self.Ptr, self.Pte, self.Rtr, self.Rte] + self.has_covariates = all(cov is not None for cov in covariates) + if not self.has_covariates: + assert all(cov is None for cov in covariates) + def to_df( self, index_start: str = dt.date(year=2023, month=1, day=1) ) -> pd.DataFrame: @@ -59,6 +89,13 @@ def n_units(self) -> int: def n_timepoints(self) -> int: return self.n_post_intervention + self.n_pre_intervention + @property + def n_covariates(self) -> int: + if self.has_covariates: + return self.Ptr.shape[2] + else: + return 0 + @property def control_units(self) -> Float[np.ndarray, "{self.n_timepoints} {self.n_units}"]: return np.vstack([self.Xtr, self.Xte]) @@ -67,6 +104,26 @@ def control_units(self) -> Float[np.ndarray, "{self.n_timepoints} {self.n_units} def treated_units(self) -> Float[np.ndarray, "{self.n_timepoints} 1"]: return np.vstack([self.ytr, self.yte]) + @property + def control_covariates( + self, + ) -> tp.Optional[ + Float[np.ndarray, "{self.n_timepoints} {self.n_units} {self.n_covariates}"] + ]: + if self.has_covariates: + return np.vstack([self.Ptr, self.Pte]) + else: + return None + + @property + def treated_covariates( + self, + ) -> tp.Optional[Float[np.ndarray, "{self.n_timepoints} 1 {self.n_covariates}"]]: + if self.has_covariates: + return np.vstack([self.Rtr, self.Rte]) + else: + return None + @property def pre_intervention_obs( self, @@ -79,6 +136,32 @@ def post_intervention_obs( ) -> tp.Tuple[Float[np.ndarray, "M D"], Float[np.ndarray, "M 1"]]: return self.Xte, self.yte + @property + def pre_intervention_covariates( + self, + ) -> tp.Optional[ + tp.Tuple[ + Float[np.ndarray, "N D F"], Float[np.ndarray, "N 1 F"], + ] + ]: + if self.has_covariates: + return self.Ptr, self.Rtr + else: + return None + + @property + def post_intervention_covariates( + self, + ) -> tp.Optional[ + tp.Tuple[ + Float[np.ndarray, "M D F"], Float[np.ndarray, "M 1 F"], + ] + ]: + if self.has_covariates: + return self.Pte, self.Rte + else: + return None + @property def full_index(self) -> DatetimeIndex: return self._get_index(self._start_date) @@ -97,7 +180,12 @@ def get_index(self, period: InterventionTypes) -> DatetimeIndex: return self.full_index def _get_columns(self) -> tp.List[str]: - colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + if self.has_covariates: + colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + [ + f"F{i}" for i in range(self.n_covariates) + ] + else: + colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] return colnames def _get_index(self, start_date: dt.date) -> DatetimeIndex: @@ -116,7 +204,10 @@ def inflate(self, inflation_vals: Float[np.ndarray, "M 1"]) -> Dataset: Xtr, ytr = [deepcopy(i) for i in self.pre_intervention_obs] Xte, yte = [deepcopy(i) for i in self.post_intervention_obs] inflated_yte = yte * inflation_vals - return Dataset(Xtr, Xte, ytr, inflated_yte, self._start_date, yte) + return Dataset( + Xtr, Xte, ytr, inflated_yte, self._start_date, + self.Ptr, self.Pte, self.Rtr, self.Rte, yte, self.synthetic, self._name + ) def __eq__(self, other: Dataset) -> bool: ytr = np.allclose(self.ytr, other.ytr) @@ -151,14 +242,21 @@ def _slots(self) -> tp.Dict[str, int]: def drop_unit(self, idx: int) -> Dataset: Xtr = np.delete(self.Xtr, [idx], axis=1) Xte = np.delete(self.Xte, [idx], axis=1) + Ptr = np.delete(self.Ptr, [idx], axis=1) if self.Ptr is not None else None + Pte = np.delete(self.Pte, [idx], axis=1) if self.Pte is not None else None return Dataset( Xtr, Xte, self.ytr, self.yte, self._start_date, + Ptr, + Pte, + self.Rtr, + self.Rte, self.counterfactual, self.synthetic, + self._name, ) def to_placebo_data(self, to_treat_idx: int) -> Dataset: @@ -212,5 +310,7 @@ def reassign_treatment( Xtr = data.Xtr Xte = data.Xte return Dataset( - Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic + Xtr, Xte, ytr, yte, data._start_date, + data.Ptr, data.Pte, data.Rtr, data.Rte, + data.counterfactual, data.synthetic, data._name ) diff --git a/src/causal_validation/validation/placebo.py b/src/causal_validation/validation/placebo.py index b8f7c36..b8143fb 100644 --- a/src/causal_validation/validation/placebo.py +++ b/src/causal_validation/validation/placebo.py @@ -1,7 +1,6 @@ from dataclasses import dataclass import typing as tp -from azcausal.core.effect import Effect import numpy as np import pandas as pd from pandera import ( @@ -11,14 +10,8 @@ ) from rich.progress import ( Progress, - ProgressBar, - track, ) from scipy.stats import ttest_1samp -from tqdm import ( - tqdm, - trange, -) from causal_validation.data import ( Dataset, @@ -108,7 +101,7 @@ def execute(self, verbose: bool = True) -> PlaceboTestResult: "[blue]Datasets", total=n_datasets, visible=verbose ) unit_task = progress.add_task( - f"[green]Control Units", + "[green]Control Units", total=n_control, visible=verbose, ) diff --git a/src/causal_validation/validation/rmspe.py b/src/causal_validation/validation/rmspe.py index 6b541ff..b606722 100644 --- a/src/causal_validation/validation/rmspe.py +++ b/src/causal_validation/validation/rmspe.py @@ -2,18 +2,14 @@ import typing as tp from jaxtyping import Float -import numpy as np import pandas as pd from pandera import ( Check, Column, DataFrameSchema, ) -from rich import box from rich.progress import ( Progress, - ProgressBar, - track, ) from causal_validation.validation.placebo import PlaceboTest @@ -87,7 +83,7 @@ def execute(self, verbose: bool = True) -> RMSPETestResult: "[blue]Datasets", total=n_datasets, visible=verbose ) unit_task = progress.add_task( - f"[green]Treatment and Control Units", + "[green]Treatment and Control Units", total=n_control + 1, visible=verbose, ) diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py index 07554a4..7434bf1 100644 --- a/tests/test_causal_validation/test_data.py +++ b/tests/test_causal_validation/test_data.py @@ -23,6 +23,7 @@ simulate_data, ) from causal_validation.types import InterventionTypes +import datetime as dt MIN_STRING_LENGTH = 1 MAX_STRING_LENGTH = 20 @@ -198,6 +199,10 @@ def test_drop_unit(n_pre: int, n_post: int, n_control: int): assert reduced_data.Xte.shape == desired_shape_Xte assert reduced_data.ytr.shape == desired_shape_ytr assert reduced_data.yte.shape == desired_shape_yte + + assert reduced_data.counterfactual == data.counterfactual + assert reduced_data.synthetic == data.synthetic + assert reduced_data._name == data._name @pytest.mark.parametrize("n_pre, n_post, n_control", [(60, 30, 10), (60, 30, 20)]) @@ -288,37 +293,188 @@ def test_naming_setter(name: str, extra_chars: str): @given( - seeds=st.lists( - elements=st.integers(min_value=1, max_value=1000), min_size=1, max_size=10 - ), - to_name=st.booleans(), + n_pre=st.integers(min_value=10, max_value=100), + n_post=st.integers(min_value=10, max_value=100), + n_control=st.integers(min_value=2, max_value=20), +) +@settings(max_examples=5) +def test_counterfactual_synthetic_attributes(n_pre: int, n_post: int, n_control: int): + constants = TestConstants( + N_POST_TREATMENT=n_post, + N_PRE_TREATMENT=n_pre, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + assert data.counterfactual is None + assert data.synthetic is None + + counterfactual_vals = np.random.randn(n_post, 1) + synthetic_vals = np.random.randn(n_post, 1) + + data_with_attrs = Dataset( + data.Xtr, data.Xte, data.ytr, data.yte, data._start_date, + data.Ptr, data.Pte, data.Rtr, data.Rte, + counterfactual_vals, synthetic_vals, "test_dataset" + ) + + np.testing.assert_array_equal(data_with_attrs.counterfactual, counterfactual_vals) + np.testing.assert_array_equal(data_with_attrs.synthetic, synthetic_vals) + assert data_with_attrs.name == "test_dataset" + + +@given( + n_pre=st.integers(min_value=10, max_value=100), + n_post=st.integers(min_value=10, max_value=100), + n_control=st.integers(min_value=2, max_value=20), +) +@settings(max_examples=5) +def test_inflate_method(n_pre: int, n_post: int, n_control: int): + constants = TestConstants( + N_POST_TREATMENT=n_post, + N_PRE_TREATMENT=n_pre, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + inflation_vals = np.ones((n_post, 1)) * 1.1 + inflated_data = data.inflate(inflation_vals) + + np.testing.assert_array_equal(inflated_data.Xtr, data.Xtr) + np.testing.assert_array_equal(inflated_data.ytr, data.ytr) + np.testing.assert_array_equal(inflated_data.Xte, data.Xte) + + expected_yte = data.yte * inflation_vals + np.testing.assert_array_equal(inflated_data.yte, expected_yte) + + np.testing.assert_array_equal(inflated_data.counterfactual, data.yte) + + +@given( + n_pre=st.integers(min_value=10, max_value=100), + n_post=st.integers(min_value=10, max_value=100), + n_control=st.integers(min_value=2, max_value=20), +) +@settings(max_examples=5) +def test_control_treated_properties(n_pre: int, n_post: int, n_control: int): + constants = TestConstants( + N_POST_TREATMENT=n_post, + N_PRE_TREATMENT=n_pre, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + control_units = data.control_units + expected_control = np.vstack([data.Xtr, data.Xte]) + np.testing.assert_array_equal(control_units, expected_control) + assert control_units.shape == (n_pre + n_post, n_control) + + treated_units = data.treated_units + expected_treated = np.vstack([data.ytr, data.yte]) + np.testing.assert_array_equal(treated_units, expected_treated) + assert treated_units.shape == (n_pre + n_post, 1) + + +@given( + n_pre=st.integers(min_value=10, max_value=100), + n_post=st.integers(min_value=10, max_value=100), + n_control=st.integers(min_value=2, max_value=20), ) -def test_dataset_container(seeds: tp.List[int], to_name: bool): - datasets = [simulate_data(0.0, s) for s in seeds] - if to_name: - names = [f"D_{idx}" for idx in range(len(datasets))] - else: - names = None - container = DatasetContainer(datasets, names) - - # Test names were correctly assigned - if to_name: - assert container.names == names - else: - assert container.names == [f"Dataset {idx}" for idx in range(len(datasets))] - - # Assert ordering - for idx, dataset in enumerate(container): - assert dataset == datasets[idx] - - # Assert no data was dropped/added - assert len(container) == len(datasets) - - # Test `as_dict()` method preserves order - container_dict = container.as_dict() - for idx, (k, v) in enumerate(container_dict.items()): - if to_name: - assert k == names[idx] - else: - assert k == f"Dataset {idx}" - assert v == datasets[idx] +@settings(max_examples=5) +def test_covariate_properties_without_covariates(n_pre: int, n_post: int, n_control: int): + constants = TestConstants( + N_POST_TREATMENT=n_post, + N_PRE_TREATMENT=n_pre, + N_CONTROL=n_control, + ) + data = simulate_data(0.0, DEFAULT_SEED, constants=constants) + + assert data.has_covariates is False + assert data.control_covariates is None + assert data.treated_covariates is None + assert data.pre_intervention_covariates is None + assert data.post_intervention_covariates is None + assert data.n_covariates == 0 + + +@given( + n_pre=st.integers(min_value=10, max_value=50), + n_post=st.integers(min_value=10, max_value=50), + n_control=st.integers(min_value=2, max_value=10), + n_covariates=st.integers(min_value=1, max_value=5), + Xtr=st.data(), + Xte=st.data(), + ytr=st.data(), + yte=st.data(), + Ptr=st.data(), + Pte=st.data(), + Rtr=st.data(), + Rte=st.data(), +) +@settings(max_examples=5) +def test_covariate_properties_with_covariates(n_pre: int, + n_post: int, + n_control: int, + n_covariates: int, + Xtr, + Xte, + ytr, + yte, + Ptr, + Pte, + Rtr, + Rte): + + Xtr = Xtr.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_pre*n_control, max_size=n_pre*n_control)) + Xtr = np.array(Xtr).reshape(n_pre, n_control) + + Xte = Xte.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_post*n_control, max_size=n_post*n_control)) + Xte = np.array(Xte).reshape(n_post, n_control) + + ytr = ytr.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_pre, max_size=n_pre)) + ytr = np.array(ytr).reshape(n_pre, 1) + + yte = yte.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_post, max_size=n_post)) + yte = np.array(yte).reshape(n_post, 1) + + Ptr = Ptr.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_pre*n_control*n_covariates, max_size=n_pre*n_control*n_covariates)) + Ptr = np.array(Ptr).reshape(n_pre, n_control, n_covariates) + + Pte = Pte.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_post*n_control*n_covariates, max_size=n_post*n_control*n_covariates)) + Pte = np.array(Pte).reshape(n_post, n_control, n_covariates) + + Rtr = Rtr.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_pre*n_covariates, max_size=n_pre*n_covariates)) + Rtr = np.array(Rtr).reshape(n_pre, 1, n_covariates) + + Rte = Rte.draw(st.lists(st.floats(min_value=-10, max_value=10), + min_size=n_post*n_covariates, max_size=n_post*n_covariates)) + Rte = np.array(Rte).reshape(n_post, 1, n_covariates) + + data = Dataset(Xtr, Xte, ytr, yte, dt.date(2023, 1, 1), Ptr, Pte, Rtr, Rte) + + assert data.n_covariates == n_covariates + assert data.has_covariates is True + + control_covariates = data.control_covariates + expected_control_cov = np.vstack([Ptr, Pte]) + np.testing.assert_array_equal(control_covariates, expected_control_cov) + assert control_covariates.shape == (n_pre + n_post, n_control, n_covariates) + + treated_covariates = data.treated_covariates + expected_treated_cov = np.vstack([Rtr, Rte]) + np.testing.assert_array_equal(treated_covariates, expected_treated_cov) + assert treated_covariates.shape == (n_pre + n_post, 1, n_covariates) + + pre_cov = data.pre_intervention_covariates + assert pre_cov == (Ptr, Rtr) + + post_cov = data.post_intervention_covariates + assert post_cov == (Pte, Rte) + From 6682af4fbe5c7925d23e0aea230cdb16e293b0f3 Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Tue, 23 Sep 2025 13:47:10 +0000 Subject: [PATCH 2/4] Add dataset container test back --- tests/test_causal_validation/test_data.py | 35 +++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py index 7434bf1..70befc4 100644 --- a/tests/test_causal_validation/test_data.py +++ b/tests/test_causal_validation/test_data.py @@ -374,6 +374,41 @@ def test_control_treated_properties(n_pre: int, n_post: int, n_control: int): np.testing.assert_array_equal(treated_units, expected_treated) assert treated_units.shape == (n_pre + n_post, 1) +@given( + seeds=st.lists( + elements=st.integers(min_value=1, max_value=1000), min_size=1, max_size=10 + ), + to_name=st.booleans(), +) +def test_dataset_container(seeds: tp.List[int], to_name: bool): + datasets = [simulate_data(0.0, s) for s in seeds] + if to_name: + names = [f"D_{idx}" for idx in range(len(datasets))] + else: + names = None + container = DatasetContainer(datasets, names) + + # Test names were correctly assigned + if to_name: + assert container.names == names + else: + assert container.names == [f"Dataset {idx}" for idx in range(len(datasets))] + + # Assert ordering + for idx, dataset in enumerate(container): + assert dataset == datasets[idx] + + # Assert no data was dropped/added + assert len(container) == len(datasets) + + # Test `as_dict()` method preserves order + container_dict = container.as_dict() + for idx, (k, v) in enumerate(container_dict.items()): + if to_name: + assert k == names[idx] + else: + assert k == f"Dataset {idx}" + assert v == datasets[idx] @given( n_pre=st.integers(min_value=10, max_value=100), From 0dc4c3d44df393dcb11a751da04561705f5df428 Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Tue, 23 Sep 2025 14:15:35 +0000 Subject: [PATCH 3/4] Fix linting errors in tests. --- tests/test_causal_validation/test_data.py | 112 +++++++----------- .../test_validation/test_placebo.py | 2 +- .../test_validation/test_rmspe.py | 2 +- 3 files changed, 42 insertions(+), 74 deletions(-) diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py index 70befc4..1b7705c 100644 --- a/tests/test_causal_validation/test_data.py +++ b/tests/test_causal_validation/test_data.py @@ -199,7 +199,7 @@ def test_drop_unit(n_pre: int, n_post: int, n_control: int): assert reduced_data.Xte.shape == desired_shape_Xte assert reduced_data.ytr.shape == desired_shape_ytr assert reduced_data.yte.shape == desired_shape_yte - + assert reduced_data.counterfactual == data.counterfactual assert reduced_data.synthetic == data.synthetic assert reduced_data._name == data._name @@ -305,19 +305,19 @@ def test_counterfactual_synthetic_attributes(n_pre: int, n_post: int, n_control: N_CONTROL=n_control, ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - + assert data.counterfactual is None assert data.synthetic is None - + counterfactual_vals = np.random.randn(n_post, 1) synthetic_vals = np.random.randn(n_post, 1) - + data_with_attrs = Dataset( data.Xtr, data.Xte, data.ytr, data.yte, data._start_date, data.Ptr, data.Pte, data.Rtr, data.Rte, counterfactual_vals, synthetic_vals, "test_dataset" ) - + np.testing.assert_array_equal(data_with_attrs.counterfactual, counterfactual_vals) np.testing.assert_array_equal(data_with_attrs.synthetic, synthetic_vals) assert data_with_attrs.name == "test_dataset" @@ -336,17 +336,17 @@ def test_inflate_method(n_pre: int, n_post: int, n_control: int): N_CONTROL=n_control, ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - - inflation_vals = np.ones((n_post, 1)) * 1.1 + + inflation_vals = np.ones((n_post, 1)) * 1.1 inflated_data = data.inflate(inflation_vals) - + np.testing.assert_array_equal(inflated_data.Xtr, data.Xtr) np.testing.assert_array_equal(inflated_data.ytr, data.ytr) np.testing.assert_array_equal(inflated_data.Xte, data.Xte) - + expected_yte = data.yte * inflation_vals np.testing.assert_array_equal(inflated_data.yte, expected_yte) - + np.testing.assert_array_equal(inflated_data.counterfactual, data.yte) @@ -363,12 +363,12 @@ def test_control_treated_properties(n_pre: int, n_post: int, n_control: int): N_CONTROL=n_control, ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - + control_units = data.control_units expected_control = np.vstack([data.Xtr, data.Xte]) np.testing.assert_array_equal(control_units, expected_control) assert control_units.shape == (n_pre + n_post, n_control) - + treated_units = data.treated_units expected_treated = np.vstack([data.ytr, data.yte]) np.testing.assert_array_equal(treated_units, expected_treated) @@ -416,14 +416,16 @@ def test_dataset_container(seeds: tp.List[int], to_name: bool): n_control=st.integers(min_value=2, max_value=20), ) @settings(max_examples=5) -def test_covariate_properties_without_covariates(n_pre: int, n_post: int, n_control: int): +def test_covariate_properties_without_covariates( + n_pre: int, n_post: int, n_control: int +): constants = TestConstants( N_POST_TREATMENT=n_post, N_PRE_TREATMENT=n_pre, N_CONTROL=n_control, ) data = simulate_data(0.0, DEFAULT_SEED, constants=constants) - + assert data.has_covariates is False assert data.control_covariates is None assert data.treated_covariates is None @@ -437,79 +439,45 @@ def test_covariate_properties_without_covariates(n_pre: int, n_post: int, n_cont n_post=st.integers(min_value=10, max_value=50), n_control=st.integers(min_value=2, max_value=10), n_covariates=st.integers(min_value=1, max_value=5), - Xtr=st.data(), - Xte=st.data(), - ytr=st.data(), - yte=st.data(), - Ptr=st.data(), - Pte=st.data(), - Rtr=st.data(), - Rte=st.data(), + seed=st.integers(min_value=1, max_value=10000), ) @settings(max_examples=5) -def test_covariate_properties_with_covariates(n_pre: int, - n_post: int, - n_control: int, - n_covariates: int, - Xtr, - Xte, - ytr, - yte, - Ptr, - Pte, - Rtr, - Rte): - - Xtr = Xtr.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_pre*n_control, max_size=n_pre*n_control)) - Xtr = np.array(Xtr).reshape(n_pre, n_control) - - Xte = Xte.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_post*n_control, max_size=n_post*n_control)) - Xte = np.array(Xte).reshape(n_post, n_control) - - ytr = ytr.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_pre, max_size=n_pre)) - ytr = np.array(ytr).reshape(n_pre, 1) - - yte = yte.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_post, max_size=n_post)) - yte = np.array(yte).reshape(n_post, 1) - - Ptr = Ptr.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_pre*n_control*n_covariates, max_size=n_pre*n_control*n_covariates)) - Ptr = np.array(Ptr).reshape(n_pre, n_control, n_covariates) - - Pte = Pte.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_post*n_control*n_covariates, max_size=n_post*n_control*n_covariates)) - Pte = np.array(Pte).reshape(n_post, n_control, n_covariates) - - Rtr = Rtr.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_pre*n_covariates, max_size=n_pre*n_covariates)) - Rtr = np.array(Rtr).reshape(n_pre, 1, n_covariates) - - Rte = Rte.draw(st.lists(st.floats(min_value=-10, max_value=10), - min_size=n_post*n_covariates, max_size=n_post*n_covariates)) - Rte = np.array(Rte).reshape(n_post, 1, n_covariates) - +def test_covariate_properties_with_covariates( + n_pre: int, + n_post: int, + n_control: int, + n_covariates: int, + seed: int, +): + rng = np.random.RandomState(seed) + + Xtr = rng.uniform(-10, 10, (n_pre, n_control)) + Xte = rng.uniform(-10, 10, (n_post, n_control)) + ytr = rng.uniform(-10, 10, (n_pre, 1)) + yte = rng.uniform(-10, 10, (n_post, 1)) + Ptr = rng.uniform(-10, 10, (n_pre, n_control, n_covariates)) + Pte = rng.uniform(-10, 10, (n_post, n_control, n_covariates)) + Rtr = rng.uniform(-10, 10, (n_pre, 1, n_covariates)) + Rte = rng.uniform(-10, 10, (n_post, 1, n_covariates)) + data = Dataset(Xtr, Xte, ytr, yte, dt.date(2023, 1, 1), Ptr, Pte, Rtr, Rte) - + assert data.n_covariates == n_covariates assert data.has_covariates is True - + control_covariates = data.control_covariates expected_control_cov = np.vstack([Ptr, Pte]) np.testing.assert_array_equal(control_covariates, expected_control_cov) assert control_covariates.shape == (n_pre + n_post, n_control, n_covariates) - + treated_covariates = data.treated_covariates expected_treated_cov = np.vstack([Rtr, Rte]) np.testing.assert_array_equal(treated_covariates, expected_treated_cov) assert treated_covariates.shape == (n_pre + n_post, 1, n_covariates) - + pre_cov = data.pre_intervention_covariates assert pre_cov == (Ptr, Rtr) - + post_cov = data.post_intervention_covariates assert post_cov == (Pte, Rte) diff --git a/tests/test_causal_validation/test_validation/test_placebo.py b/tests/test_causal_validation/test_validation/test_placebo.py index 858c5f3..172de83 100644 --- a/tests/test_causal_validation/test_validation/test_placebo.py +++ b/tests/test_causal_validation/test_validation/test_placebo.py @@ -30,7 +30,7 @@ def test_schema_coerce(): df = PlaceboSchema.example() cols = df.columns for col in cols: - if not col in ["Model", "Dataset"]: + if col not in ["Model", "Dataset"]: df[col] = np.ceil((df[col])) PlaceboSchema.validate(df) diff --git a/tests/test_causal_validation/test_validation/test_rmspe.py b/tests/test_causal_validation/test_validation/test_rmspe.py index 1bc6b37..ead1dfa 100644 --- a/tests/test_causal_validation/test_validation/test_rmspe.py +++ b/tests/test_causal_validation/test_validation/test_rmspe.py @@ -35,7 +35,7 @@ def test_schema_coerce(): df = RMSPESchema.example() cols = df.columns for col in cols: - if not col in ["Model", "Dataset"]: + if col not in ["Model", "Dataset"]: df[col] = np.ceil((df[col])) RMSPESchema.validate(df) From 51def3b8565fe0bc2be51fa0ab37a58d7b9a500b Mon Sep 17 00:00:00 2001 From: Semih Akbayrak Date: Wed, 24 Sep 2025 15:08:48 +0000 Subject: [PATCH 4/4] Covariates can be included in simulations. --- src/causal_validation/config.py | 59 +++++++ src/causal_validation/simulate.py | 41 ++++- src/causal_validation/weights.py | 32 +++- tests/test_causal_validation/test_config.py | 80 ++++++++++ tests/test_causal_validation/test_simulate.py | 145 ++++++++++++++++++ tests/test_causal_validation/test_weights.py | 32 +++- 6 files changed, 376 insertions(+), 13 deletions(-) create mode 100644 tests/test_causal_validation/test_config.py create mode 100644 tests/test_causal_validation/test_simulate.py diff --git a/src/causal_validation/config.py b/src/causal_validation/config.py index 8040be1..ca7928a 100644 --- a/src/causal_validation/config.py +++ b/src/causal_validation/config.py @@ -4,7 +4,11 @@ ) import datetime as dt +from jaxtyping import Float +import typing as tp + import numpy as np +from scipy.stats import halfcauchy from causal_validation.types import ( Number, @@ -20,9 +24,38 @@ class WeightConfig: @dataclass(kw_only=True) class Config: + """Configuration for causal data generation. + + Args: + n_control_units (int): Number of control units in the synthetic dataset. + n_pre_intervention_timepoints (int): Number of time points before intervention. + n_post_intervention_timepoints (int): Number of time points after intervention. + n_covariates (Optional[int]): Number of covariates. Defaults to None. + covariate_means (Optional[Float[np.ndarray, "D K"]]): Mean values for covariates + D is n_control_units and K is n_covariates. Defaults to None. If it is set + to None while n_covariates is provided, covariate_means will be generated + randomly from Normal distribution. + covariate_stds (Optional[Float[np.ndarray, "D K"]]): Standard deviations for + covariates. D is n_control_units and K is n_covariates. Defaults to None. + If it is set to None while n_covariates is provided, covariate_stds + will be generated randomly from Half-Cauchy distribution. + covariate_coeffs (Optional[np.ndarray]): Linear regression + coefficients to map covariates to output observations. K is n_covariates. + Defaults to None. + global_mean (Number): Global mean for data generation. Defaults to 20.0. + global_scale (Number): Global scale for data generation. Defaults to 0.2. + start_date (dt.date): Start date for time series. Defaults to 2023-01-01. + seed (int): Random seed for reproducibility. Defaults to 123. + weights_cfg (WeightConfig): Configuration for unit weights. Defaults to + UniformWeights. + """ n_control_units: int n_pre_intervention_timepoints: int n_post_intervention_timepoints: int + n_covariates: tp.Optional[int] = None + covariate_means: tp.Optional[Float[np.ndarray, "D K"]] = None + covariate_stds: tp.Optional[Float[np.ndarray, "D K"]] = None + covariate_coeffs: tp.Optional[np.ndarray] = None global_mean: Number = 20.0 global_scale: Number = 0.2 start_date: dt.date = dt.date(year=2023, month=1, day=1) @@ -31,3 +64,29 @@ class Config: def __post_init__(self): self.rng = np.random.RandomState(self.seed) + if self.covariate_means is not None: + assert self.covariate_means.shape == (self.n_control_units, + self.n_covariates) + + if self.covariate_stds is not None: + assert self.covariate_stds.shape == (self.n_control_units, + self.n_covariates) + + if (self.n_covariates is not None) & (self.covariate_means is None): + self.covariate_means = self.rng.normal( + loc=0.0, scale=5.0, size=(self.n_control_units, + self.n_covariates) + ) + + if (self.n_covariates is not None) & (self.covariate_stds is None): + self.covariate_stds = ( + halfcauchy.rvs(scale=0.5, + size=(self.n_control_units, + self.n_covariates), + random_state=self.rng) + ) + + if (self.n_covariates is not None) & (self.covariate_coeffs is None): + self.covariate_coeffs = self.rng.normal( + loc=0.0, scale=5.0, size=self.n_covariates + ) diff --git a/src/causal_validation/simulate.py b/src/causal_validation/simulate.py index 2f02c10..e85e80f 100644 --- a/src/causal_validation/simulate.py +++ b/src/causal_validation/simulate.py @@ -29,9 +29,40 @@ def _simulate_base_obs( obs = key.normal( loc=config.global_mean, scale=config.global_scale, size=(n_timepoints, n_units) ) - Xtr = obs[: config.n_pre_intervention_timepoints, :] - Xte = obs[config.n_pre_intervention_timepoints :, :] - ytr = weights.weight_obs(Xtr) - yte = weights.weight_obs(Xte) - data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date) + + if config.n_covariates is not None: + Xtr_ = obs[: config.n_pre_intervention_timepoints, :] + Xte_ = obs[config.n_pre_intervention_timepoints :, :] + + covariates = key.normal( + loc=config.covariate_means, + scale=config.covariate_stds, + size=(n_timepoints, n_units, config.n_covariates) + ) + + Ptr = covariates[:config.n_pre_intervention_timepoints, :, :] + Pte = covariates[config.n_pre_intervention_timepoints:, :, :] + + Xtr = Xtr_ + Ptr @ config.covariate_coeffs + Xte = Xte_ + Pte @ config.covariate_coeffs + + ytr = weights.weight_contr(Xtr) + yte = weights.weight_contr(Xte) + + Rtr = weights.weight_contr(Ptr) + Rte = weights.weight_contr(Pte) + + data = Dataset( + Xtr, Xte, ytr, yte, _start_date=config.start_date, + Ptr=Ptr, Pte=Pte, Rtr=Rtr, Rte=Rte + ) + else: + Xtr = obs[: config.n_pre_intervention_timepoints, :] + Xte = obs[config.n_pre_intervention_timepoints :, :] + + ytr = weights.weight_contr(Xtr) + yte = weights.weight_contr(Xte) + + data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date) + return data diff --git a/src/causal_validation/weights.py b/src/causal_validation/weights.py index 42234f9..8d6c6f2 100644 --- a/src/causal_validation/weights.py +++ b/src/causal_validation/weights.py @@ -11,15 +11,23 @@ if tp.TYPE_CHECKING: from causal_validation.config import WeightConfig +# Constants for array dimensions +_NDIM_2D = 2 +_NDIM_3D = 3 + @dataclass class AbstractWeights(BaseObject): name: str = "Abstract Weights" - def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]: + def _get_weights( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "D 1"]: raise NotImplementedError("Please implement `_get_weights` in all subclasses.") - def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]: + def get_weights( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "D 1"]: weights = self._get_weights(obs) np.testing.assert_almost_equal( @@ -28,13 +36,21 @@ def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"] assert min(weights >= 0), "Weights should be non-negative" return weights - def __call__(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: - return self.weight_obs(obs) + def __call__( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]: + return self.weight_contr(obs) - def weight_obs(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: + def weight_contr( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]: weights = self.get_weights(obs) - weighted_obs = obs @ weights + if obs.ndim == _NDIM_2D: + weighted_obs = obs @ weights + elif obs.ndim == _NDIM_3D: + weighted_obs = np.einsum("n d k, d i -> n i k", obs, weights) + return weighted_obs @@ -42,7 +58,9 @@ def weight_obs(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: class UniformWeights(AbstractWeights): name: str = "Uniform Weights" - def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]: + def _get_weights( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "D 1"]: n_units = obs.shape[1] return np.repeat(1.0 / n_units, repeats=n_units).reshape(-1, 1) diff --git a/tests/test_causal_validation/test_config.py b/tests/test_causal_validation/test_config.py new file mode 100644 index 0000000..afca1b4 --- /dev/null +++ b/tests/test_causal_validation/test_config.py @@ -0,0 +1,80 @@ +import numpy as np +from hypothesis import given, strategies as st + +from causal_validation.config import Config + + +@given( + n_units=st.integers(min_value=1, max_value=10), + n_pre=st.integers(min_value=1, max_value=20), + n_post=st.integers(min_value=1, max_value=20) +) +def test_config_basic_initialization(n_units, n_pre, n_post): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post + ) + assert cfg.n_control_units == n_units + assert cfg.n_pre_intervention_timepoints == n_pre + assert cfg.n_post_intervention_timepoints == n_post + assert cfg.n_covariates is None + assert cfg.covariate_means is None + assert cfg.covariate_stds is None + assert cfg.covariate_coeffs is None + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + n_covariates=st.integers(min_value=1, max_value=3), + seed=st.integers(min_value=1, max_value=1000) +) +def test_config_with_covariates_auto_generation( + n_units, n_pre, n_post, n_covariates, seed +): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=n_covariates, + seed=seed + ) + assert cfg.n_covariates == n_covariates + assert cfg.covariate_means.shape == (n_units, n_covariates) + assert cfg.covariate_stds.shape == (n_units, n_covariates) + assert cfg.covariate_coeffs.shape == (n_covariates,) + assert np.all(cfg.covariate_stds >= 0) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_covariates=st.integers(min_value=1, max_value=3) +) +def test_config_with_explicit_covariate_means(n_units, n_covariates): + means = np.random.random((n_units, n_covariates)) + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=10, + n_post_intervention_timepoints=5, + n_covariates=n_covariates, + covariate_means=means + ) + np.testing.assert_array_equal(cfg.covariate_means, means) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_covariates=st.integers(min_value=1, max_value=3) +) +def test_config_with_explicit_covariate_stds(n_units, n_covariates): + stds = np.random.random((n_units, n_covariates)) + 0.1 + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=10, + n_post_intervention_timepoints=5, + n_covariates=n_covariates, + covariate_stds=stds + ) + np.testing.assert_array_equal(cfg.covariate_stds, stds) diff --git a/tests/test_causal_validation/test_simulate.py b/tests/test_causal_validation/test_simulate.py new file mode 100644 index 0000000..52213dc --- /dev/null +++ b/tests/test_causal_validation/test_simulate.py @@ -0,0 +1,145 @@ +import numpy as np +from hypothesis import given, strategies as st + +from causal_validation.config import Config +from causal_validation.simulate import simulate + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_basic(n_units, n_pre, n_post, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data = simulate(cfg) + + assert data.Xtr.shape == (n_pre, n_units) + assert data.Xte.shape == (n_post, n_units) + assert data.ytr.shape == (n_pre, 1) + assert data.yte.shape == (n_post, 1) + assert not data.has_covariates + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + n_covariates=st.integers(min_value=1, max_value=3), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_with_covariates(n_units, n_pre, n_post, n_covariates, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=n_covariates, + seed=seed + ) + data = simulate(cfg) + + assert data.Xtr.shape == (n_pre, n_units) + assert data.Xte.shape == (n_post, n_units) + assert data.ytr.shape == (n_pre, 1) + assert data.yte.shape == (n_post, 1) + assert data.has_covariates + assert data.Ptr.shape == (n_pre, n_units, n_covariates) + assert data.Pte.shape == (n_post, n_units, n_covariates) + assert data.Rtr.shape == (n_pre, 1, n_covariates) + assert data.Rte.shape == (n_post, 1, n_covariates) + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_reproducible(n_units, n_pre, n_post, seed): + cfg1 = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data1 = simulate(cfg1) + + cfg2 = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data2 = simulate(cfg2) + + np.testing.assert_array_equal(data1.Xtr, data2.Xtr) + np.testing.assert_array_equal(data1.Xte, data2.Xte) + np.testing.assert_array_equal(data1.ytr, data2.ytr) + np.testing.assert_array_equal(data1.yte, data2.yte) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_pre=st.integers(min_value=3, max_value=10), + n_post=st.integers(min_value=1, max_value=5), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_covariate_effects(n_units, n_pre, n_post, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=1, + covariate_coeffs=np.array([10.0]), + seed=seed + ) + data_with_cov = simulate(cfg) + + cfg_no_cov = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data_no_cov = simulate(cfg_no_cov) + + assert not np.allclose(data_with_cov.ytr, data_no_cov.ytr) + assert not np.allclose(data_with_cov.yte, data_no_cov.yte) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_pre=st.integers(min_value=3, max_value=10), + n_post=st.integers(min_value=1, max_value=5), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_exact_covariate_effects(n_units, n_pre, n_post, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=2, + covariate_means=np.ones((n_units,2)), + covariate_stds= 1e-12*np.ones((n_units,2)), + covariate_coeffs=np.array([10.0, 5.0]), + seed=seed + ) + data_with_cov = simulate(cfg) + + cfg_no_cov = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data_no_cov = simulate(cfg_no_cov) + + assert np.allclose(data_with_cov.Xtr-15, data_no_cov.Xtr) + assert np.allclose(data_with_cov.Xte-15, data_no_cov.Xte) + diff --git a/tests/test_causal_validation/test_weights.py b/tests/test_causal_validation/test_weights.py index 1b5ed6d..543d711 100644 --- a/tests/test_causal_validation/test_weights.py +++ b/tests/test_causal_validation/test_weights.py @@ -23,10 +23,40 @@ def test_uniform_weights(n_units: int, n_time: int): n_units=st.integers(min_value=1, max_value=100), n_time=st.integers(min_value=1, max_value=100), ) -def test_weight_obs(n_units: int, n_time: int): +def test_weight_contr(n_units: int, n_time: int): obs = np.ones(shape=(n_time, n_units)) weighted_obs = UniformWeights()(obs) np.testing.assert_almost_equal(np.mean(weighted_obs), weighted_obs, decimal=6) np.testing.assert_almost_equal( obs @ UniformWeights().get_weights(obs), weighted_obs, decimal=6 ) + + +@given( + n_units=st.integers(min_value=1, max_value=10), + n_time=st.integers(min_value=1, max_value=10), + n_covariates=st.integers(min_value=1, max_value=5), +) +def test_weight_contr_3d(n_units: int, n_time: int, n_covariates: int): + covariates = np.ones(shape=(n_time, n_units, n_covariates)) + weights = UniformWeights() + weighted_covs = weights.weight_contr(covariates) + + assert weighted_covs.shape == (n_time, 1, n_covariates) + expected = np.einsum("n d k, d i -> n i k", + covariates, weights.get_weights(covariates)) + np.testing.assert_almost_equal(weighted_covs, expected, decimal=6) + + +def test_weights_sum_to_one(): + obs = np.random.random((10, 5)) + weights = UniformWeights() + weight_vals = weights.get_weights(obs) + np.testing.assert_almost_equal(weight_vals.sum(), 1.0, decimal=6) + + +def test_weights_non_negative(): + obs = np.random.random((10, 5)) + weights = UniformWeights() + weight_vals = weights.get_weights(obs) + assert np.all(weight_vals >= 0)