diff --git a/src/causal_validation/config.py b/src/causal_validation/config.py index 8040be1..ca7928a 100644 --- a/src/causal_validation/config.py +++ b/src/causal_validation/config.py @@ -4,7 +4,11 @@ ) import datetime as dt +from jaxtyping import Float +import typing as tp + import numpy as np +from scipy.stats import halfcauchy from causal_validation.types import ( Number, @@ -20,9 +24,38 @@ class WeightConfig: @dataclass(kw_only=True) class Config: + """Configuration for causal data generation. + + Args: + n_control_units (int): Number of control units in the synthetic dataset. + n_pre_intervention_timepoints (int): Number of time points before intervention. + n_post_intervention_timepoints (int): Number of time points after intervention. + n_covariates (Optional[int]): Number of covariates. Defaults to None. + covariate_means (Optional[Float[np.ndarray, "D K"]]): Mean values for covariates + D is n_control_units and K is n_covariates. Defaults to None. If it is set + to None while n_covariates is provided, covariate_means will be generated + randomly from Normal distribution. + covariate_stds (Optional[Float[np.ndarray, "D K"]]): Standard deviations for + covariates. D is n_control_units and K is n_covariates. Defaults to None. + If it is set to None while n_covariates is provided, covariate_stds + will be generated randomly from Half-Cauchy distribution. + covariate_coeffs (Optional[np.ndarray]): Linear regression + coefficients to map covariates to output observations. K is n_covariates. + Defaults to None. + global_mean (Number): Global mean for data generation. Defaults to 20.0. + global_scale (Number): Global scale for data generation. Defaults to 0.2. + start_date (dt.date): Start date for time series. Defaults to 2023-01-01. + seed (int): Random seed for reproducibility. Defaults to 123. + weights_cfg (WeightConfig): Configuration for unit weights. Defaults to + UniformWeights. + """ n_control_units: int n_pre_intervention_timepoints: int n_post_intervention_timepoints: int + n_covariates: tp.Optional[int] = None + covariate_means: tp.Optional[Float[np.ndarray, "D K"]] = None + covariate_stds: tp.Optional[Float[np.ndarray, "D K"]] = None + covariate_coeffs: tp.Optional[np.ndarray] = None global_mean: Number = 20.0 global_scale: Number = 0.2 start_date: dt.date = dt.date(year=2023, month=1, day=1) @@ -31,3 +64,29 @@ class Config: def __post_init__(self): self.rng = np.random.RandomState(self.seed) + if self.covariate_means is not None: + assert self.covariate_means.shape == (self.n_control_units, + self.n_covariates) + + if self.covariate_stds is not None: + assert self.covariate_stds.shape == (self.n_control_units, + self.n_covariates) + + if (self.n_covariates is not None) & (self.covariate_means is None): + self.covariate_means = self.rng.normal( + loc=0.0, scale=5.0, size=(self.n_control_units, + self.n_covariates) + ) + + if (self.n_covariates is not None) & (self.covariate_stds is None): + self.covariate_stds = ( + halfcauchy.rvs(scale=0.5, + size=(self.n_control_units, + self.n_covariates), + random_state=self.rng) + ) + + if (self.n_covariates is not None) & (self.covariate_coeffs is None): + self.covariate_coeffs = self.rng.normal( + loc=0.0, scale=5.0, size=self.n_covariates + ) diff --git a/src/causal_validation/simulate.py b/src/causal_validation/simulate.py index 2f02c10..e85e80f 100644 --- a/src/causal_validation/simulate.py +++ b/src/causal_validation/simulate.py @@ -29,9 +29,40 @@ def _simulate_base_obs( obs = key.normal( loc=config.global_mean, scale=config.global_scale, size=(n_timepoints, n_units) ) - Xtr = obs[: config.n_pre_intervention_timepoints, :] - Xte = obs[config.n_pre_intervention_timepoints :, :] - ytr = weights.weight_obs(Xtr) - yte = weights.weight_obs(Xte) - data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date) + + if config.n_covariates is not None: + Xtr_ = obs[: config.n_pre_intervention_timepoints, :] + Xte_ = obs[config.n_pre_intervention_timepoints :, :] + + covariates = key.normal( + loc=config.covariate_means, + scale=config.covariate_stds, + size=(n_timepoints, n_units, config.n_covariates) + ) + + Ptr = covariates[:config.n_pre_intervention_timepoints, :, :] + Pte = covariates[config.n_pre_intervention_timepoints:, :, :] + + Xtr = Xtr_ + Ptr @ config.covariate_coeffs + Xte = Xte_ + Pte @ config.covariate_coeffs + + ytr = weights.weight_contr(Xtr) + yte = weights.weight_contr(Xte) + + Rtr = weights.weight_contr(Ptr) + Rte = weights.weight_contr(Pte) + + data = Dataset( + Xtr, Xte, ytr, yte, _start_date=config.start_date, + Ptr=Ptr, Pte=Pte, Rtr=Rtr, Rte=Rte + ) + else: + Xtr = obs[: config.n_pre_intervention_timepoints, :] + Xte = obs[config.n_pre_intervention_timepoints :, :] + + ytr = weights.weight_contr(Xtr) + yte = weights.weight_contr(Xte) + + data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date) + return data diff --git a/src/causal_validation/weights.py b/src/causal_validation/weights.py index 42234f9..8d6c6f2 100644 --- a/src/causal_validation/weights.py +++ b/src/causal_validation/weights.py @@ -11,15 +11,23 @@ if tp.TYPE_CHECKING: from causal_validation.config import WeightConfig +# Constants for array dimensions +_NDIM_2D = 2 +_NDIM_3D = 3 + @dataclass class AbstractWeights(BaseObject): name: str = "Abstract Weights" - def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]: + def _get_weights( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "D 1"]: raise NotImplementedError("Please implement `_get_weights` in all subclasses.") - def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]: + def get_weights( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "D 1"]: weights = self._get_weights(obs) np.testing.assert_almost_equal( @@ -28,13 +36,21 @@ def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"] assert min(weights >= 0), "Weights should be non-negative" return weights - def __call__(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: - return self.weight_obs(obs) + def __call__( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]: + return self.weight_contr(obs) - def weight_obs(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: + def weight_contr( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]: weights = self.get_weights(obs) - weighted_obs = obs @ weights + if obs.ndim == _NDIM_2D: + weighted_obs = obs @ weights + elif obs.ndim == _NDIM_3D: + weighted_obs = np.einsum("n d k, d i -> n i k", obs, weights) + return weighted_obs @@ -42,7 +58,9 @@ def weight_obs(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]: class UniformWeights(AbstractWeights): name: str = "Uniform Weights" - def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]: + def _get_weights( + self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"] + ) -> Float[np.ndarray, "D 1"]: n_units = obs.shape[1] return np.repeat(1.0 / n_units, repeats=n_units).reshape(-1, 1) diff --git a/tests/test_causal_validation/test_config.py b/tests/test_causal_validation/test_config.py new file mode 100644 index 0000000..afca1b4 --- /dev/null +++ b/tests/test_causal_validation/test_config.py @@ -0,0 +1,80 @@ +import numpy as np +from hypothesis import given, strategies as st + +from causal_validation.config import Config + + +@given( + n_units=st.integers(min_value=1, max_value=10), + n_pre=st.integers(min_value=1, max_value=20), + n_post=st.integers(min_value=1, max_value=20) +) +def test_config_basic_initialization(n_units, n_pre, n_post): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post + ) + assert cfg.n_control_units == n_units + assert cfg.n_pre_intervention_timepoints == n_pre + assert cfg.n_post_intervention_timepoints == n_post + assert cfg.n_covariates is None + assert cfg.covariate_means is None + assert cfg.covariate_stds is None + assert cfg.covariate_coeffs is None + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + n_covariates=st.integers(min_value=1, max_value=3), + seed=st.integers(min_value=1, max_value=1000) +) +def test_config_with_covariates_auto_generation( + n_units, n_pre, n_post, n_covariates, seed +): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=n_covariates, + seed=seed + ) + assert cfg.n_covariates == n_covariates + assert cfg.covariate_means.shape == (n_units, n_covariates) + assert cfg.covariate_stds.shape == (n_units, n_covariates) + assert cfg.covariate_coeffs.shape == (n_covariates,) + assert np.all(cfg.covariate_stds >= 0) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_covariates=st.integers(min_value=1, max_value=3) +) +def test_config_with_explicit_covariate_means(n_units, n_covariates): + means = np.random.random((n_units, n_covariates)) + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=10, + n_post_intervention_timepoints=5, + n_covariates=n_covariates, + covariate_means=means + ) + np.testing.assert_array_equal(cfg.covariate_means, means) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_covariates=st.integers(min_value=1, max_value=3) +) +def test_config_with_explicit_covariate_stds(n_units, n_covariates): + stds = np.random.random((n_units, n_covariates)) + 0.1 + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=10, + n_post_intervention_timepoints=5, + n_covariates=n_covariates, + covariate_stds=stds + ) + np.testing.assert_array_equal(cfg.covariate_stds, stds) diff --git a/tests/test_causal_validation/test_simulate.py b/tests/test_causal_validation/test_simulate.py new file mode 100644 index 0000000..52213dc --- /dev/null +++ b/tests/test_causal_validation/test_simulate.py @@ -0,0 +1,145 @@ +import numpy as np +from hypothesis import given, strategies as st + +from causal_validation.config import Config +from causal_validation.simulate import simulate + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_basic(n_units, n_pre, n_post, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data = simulate(cfg) + + assert data.Xtr.shape == (n_pre, n_units) + assert data.Xte.shape == (n_post, n_units) + assert data.ytr.shape == (n_pre, 1) + assert data.yte.shape == (n_post, 1) + assert not data.has_covariates + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + n_covariates=st.integers(min_value=1, max_value=3), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_with_covariates(n_units, n_pre, n_post, n_covariates, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=n_covariates, + seed=seed + ) + data = simulate(cfg) + + assert data.Xtr.shape == (n_pre, n_units) + assert data.Xte.shape == (n_post, n_units) + assert data.ytr.shape == (n_pre, 1) + assert data.yte.shape == (n_post, 1) + assert data.has_covariates + assert data.Ptr.shape == (n_pre, n_units, n_covariates) + assert data.Pte.shape == (n_post, n_units, n_covariates) + assert data.Rtr.shape == (n_pre, 1, n_covariates) + assert data.Rte.shape == (n_post, 1, n_covariates) + + +@given( + n_units=st.integers(min_value=1, max_value=5), + n_pre=st.integers(min_value=1, max_value=10), + n_post=st.integers(min_value=1, max_value=10), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_reproducible(n_units, n_pre, n_post, seed): + cfg1 = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data1 = simulate(cfg1) + + cfg2 = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data2 = simulate(cfg2) + + np.testing.assert_array_equal(data1.Xtr, data2.Xtr) + np.testing.assert_array_equal(data1.Xte, data2.Xte) + np.testing.assert_array_equal(data1.ytr, data2.ytr) + np.testing.assert_array_equal(data1.yte, data2.yte) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_pre=st.integers(min_value=3, max_value=10), + n_post=st.integers(min_value=1, max_value=5), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_covariate_effects(n_units, n_pre, n_post, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=1, + covariate_coeffs=np.array([10.0]), + seed=seed + ) + data_with_cov = simulate(cfg) + + cfg_no_cov = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data_no_cov = simulate(cfg_no_cov) + + assert not np.allclose(data_with_cov.ytr, data_no_cov.ytr) + assert not np.allclose(data_with_cov.yte, data_no_cov.yte) + + +@given( + n_units=st.integers(min_value=1, max_value=3), + n_pre=st.integers(min_value=3, max_value=10), + n_post=st.integers(min_value=1, max_value=5), + seed=st.integers(min_value=1, max_value=1000) +) +def test_simulate_exact_covariate_effects(n_units, n_pre, n_post, seed): + cfg = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + n_covariates=2, + covariate_means=np.ones((n_units,2)), + covariate_stds= 1e-12*np.ones((n_units,2)), + covariate_coeffs=np.array([10.0, 5.0]), + seed=seed + ) + data_with_cov = simulate(cfg) + + cfg_no_cov = Config( + n_control_units=n_units, + n_pre_intervention_timepoints=n_pre, + n_post_intervention_timepoints=n_post, + seed=seed + ) + data_no_cov = simulate(cfg_no_cov) + + assert np.allclose(data_with_cov.Xtr-15, data_no_cov.Xtr) + assert np.allclose(data_with_cov.Xte-15, data_no_cov.Xte) + diff --git a/tests/test_causal_validation/test_weights.py b/tests/test_causal_validation/test_weights.py index 1b5ed6d..543d711 100644 --- a/tests/test_causal_validation/test_weights.py +++ b/tests/test_causal_validation/test_weights.py @@ -23,10 +23,40 @@ def test_uniform_weights(n_units: int, n_time: int): n_units=st.integers(min_value=1, max_value=100), n_time=st.integers(min_value=1, max_value=100), ) -def test_weight_obs(n_units: int, n_time: int): +def test_weight_contr(n_units: int, n_time: int): obs = np.ones(shape=(n_time, n_units)) weighted_obs = UniformWeights()(obs) np.testing.assert_almost_equal(np.mean(weighted_obs), weighted_obs, decimal=6) np.testing.assert_almost_equal( obs @ UniformWeights().get_weights(obs), weighted_obs, decimal=6 ) + + +@given( + n_units=st.integers(min_value=1, max_value=10), + n_time=st.integers(min_value=1, max_value=10), + n_covariates=st.integers(min_value=1, max_value=5), +) +def test_weight_contr_3d(n_units: int, n_time: int, n_covariates: int): + covariates = np.ones(shape=(n_time, n_units, n_covariates)) + weights = UniformWeights() + weighted_covs = weights.weight_contr(covariates) + + assert weighted_covs.shape == (n_time, 1, n_covariates) + expected = np.einsum("n d k, d i -> n i k", + covariates, weights.get_weights(covariates)) + np.testing.assert_almost_equal(weighted_covs, expected, decimal=6) + + +def test_weights_sum_to_one(): + obs = np.random.random((10, 5)) + weights = UniformWeights() + weight_vals = weights.get_weights(obs) + np.testing.assert_almost_equal(weight_vals.sum(), 1.0, decimal=6) + + +def test_weights_non_negative(): + obs = np.random.random((10, 5)) + weights = UniformWeights() + weight_vals = weights.get_weights(obs) + assert np.all(weight_vals >= 0)