Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/azcausal.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@
"panel = data.to_azcausal()\n",
"model = SDID()\n",
"result = model.fit(panel)\n",
"print(f\"Delta: {100*(TRUE_EFFECT - result.effect.percentage().value / 100): .2f}%\")\n",
"print(f\"Delta: {100 * (TRUE_EFFECT - result.effect.percentage().value / 100): .2f}%\")\n",
"print(result.summary(title=\"Synthetic Data Experiment\"))"
]
}
Expand Down
28 changes: 15 additions & 13 deletions src/causal_validation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
field,
)
import datetime as dt

from jaxtyping import Float
import typing as tp

from jaxtyping import Float
import numpy as np
from scipy.stats import halfcauchy

Expand Down Expand Up @@ -49,6 +48,7 @@ class Config:
weights_cfg (WeightConfig): Configuration for unit weights. Defaults to
UniformWeights.
"""

n_control_units: int
n_pre_intervention_timepoints: int
n_post_intervention_timepoints: int
Expand All @@ -65,25 +65,27 @@ class Config:
def __post_init__(self):
self.rng = np.random.RandomState(self.seed)
if self.covariate_means is not None:
assert self.covariate_means.shape == (self.n_control_units,
self.n_covariates)
assert self.covariate_means.shape == (
self.n_control_units,
self.n_covariates,
)

if self.covariate_stds is not None:
assert self.covariate_stds.shape == (self.n_control_units,
self.n_covariates)
assert self.covariate_stds.shape == (
self.n_control_units,
self.n_covariates,
)

if (self.n_covariates is not None) & (self.covariate_means is None):
self.covariate_means = self.rng.normal(
loc=0.0, scale=5.0, size=(self.n_control_units,
self.n_covariates)
loc=0.0, scale=5.0, size=(self.n_control_units, self.n_covariates)
)

if (self.n_covariates is not None) & (self.covariate_stds is None):
self.covariate_stds = (
halfcauchy.rvs(scale=0.5,
size=(self.n_control_units,
self.n_covariates),
random_state=self.rng)
self.covariate_stds = halfcauchy.rvs(
scale=0.5,
size=(self.n_control_units, self.n_covariates),
random_state=self.rng,
)

if (self.n_covariates is not None) & (self.covariate_coeffs is None):
Expand Down
36 changes: 29 additions & 7 deletions src/causal_validation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Dataset:
treated in pre-intervention period.
_name: Optional name identifier for the dataset
"""

Xtr: Float[np.ndarray, "N D"]
Xte: Float[np.ndarray, "M D"]
ytr: Float[np.ndarray, "N 1"]
Expand Down Expand Up @@ -161,7 +162,8 @@ def pre_intervention_covariates(
self,
) -> tp.Optional[
tp.Tuple[
Float[np.ndarray, "N D F"], Float[np.ndarray, "N 1 F"],
Float[np.ndarray, "N D F"],
Float[np.ndarray, "N 1 F"],
]
]:
if self.has_covariates:
Expand All @@ -174,7 +176,8 @@ def post_intervention_covariates(
self,
) -> tp.Optional[
tp.Tuple[
Float[np.ndarray, "M D F"], Float[np.ndarray, "M 1 F"],
Float[np.ndarray, "M D F"],
Float[np.ndarray, "M 1 F"],
]
]:
if self.has_covariates:
Expand Down Expand Up @@ -220,8 +223,18 @@ def inflate(self, inflation_vals: Float[np.ndarray, "M 1"]) -> Dataset:
Xte, yte = [deepcopy(i) for i in self.post_intervention_obs]
inflated_yte = yte * inflation_vals
return Dataset(
Xtr, Xte, ytr, inflated_yte, self._start_date,
self.Ptr, self.Pte, self.Rtr, self.Rte, yte, self.synthetic, self._name
Xtr,
Xte,
ytr,
inflated_yte,
self._start_date,
self.Ptr,
self.Pte,
self.Rtr,
self.Rte,
yte,
self.synthetic,
self._name,
)

def __eq__(self, other: Dataset) -> bool:
Expand Down Expand Up @@ -326,7 +339,16 @@ def reassign_treatment(
Xtr = data.Xtr
Xte = data.Xte
return Dataset(
Xtr, Xte, ytr, yte, data._start_date,
data.Ptr, data.Pte, data.Rtr, data.Rte,
data.counterfactual, data.synthetic, data._name
Xtr,
Xte,
ytr,
yte,
data._start_date,
data.Ptr,
data.Pte,
data.Rtr,
data.Rte,
data.counterfactual,
data.synthetic,
data._name,
)
17 changes: 12 additions & 5 deletions src/causal_validation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def _simulate_base_obs(
covariates = key.normal(
loc=config.covariate_means,
scale=config.covariate_stds,
size=(n_timepoints, n_units, config.n_covariates)
size=(n_timepoints, n_units, config.n_covariates),
)

Ptr = covariates[:config.n_pre_intervention_timepoints, :, :]
Pte = covariates[config.n_pre_intervention_timepoints:, :, :]
Ptr = covariates[: config.n_pre_intervention_timepoints, :, :]
Pte = covariates[config.n_pre_intervention_timepoints :, :, :]

Xtr = Xtr_ + Ptr @ config.covariate_coeffs
Xte = Xte_ + Pte @ config.covariate_coeffs
Expand All @@ -53,8 +53,15 @@ def _simulate_base_obs(
Rte = weights.weight_contr(Pte)

data = Dataset(
Xtr, Xte, ytr, yte, _start_date=config.start_date,
Ptr=Ptr, Pte=Pte, Rtr=Rtr, Rte=Rte
Xtr,
Xte,
ytr,
yte,
_start_date=config.start_date,
Ptr=Ptr,
Pte=Pte,
Rtr=Rtr,
Rte=Rte,
)
else:
Xtr = obs[: config.n_pre_intervention_timepoints, :]
Expand Down
5 changes: 4 additions & 1 deletion src/causal_validation/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from causal_validation.transforms.noise import Noise, CovariateNoise
from causal_validation.transforms.noise import (
CovariateNoise,
Noise,
)
from causal_validation.transforms.periodic import Periodic
from causal_validation.transforms.trends import Trend

Expand Down
43 changes: 34 additions & 9 deletions src/causal_validation/transforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,17 @@ def apply_values(
Xte = Xte + post_intervention_vals[:, 1:]
yte = yte + post_intervention_vals[:, :1]
return Dataset(
Xtr, Xte, ytr, yte, data._start_date,
data.Ptr, data.Pte, data.Rtr, data.Rte,
data.counterfactual, data.synthetic
Xtr,
Xte,
ytr,
yte,
data._start_date,
data.Ptr,
data.Pte,
data.Rtr,
data.Rte,
data.counterfactual,
data.synthetic,
)


Expand All @@ -96,11 +104,20 @@ def apply_values(
Xte = Xte * post_intervention_vals
yte = yte * post_intervention_vals
return Dataset(
Xtr, Xte, ytr, yte, data._start_date,
data.Ptr, data.Pte, data.Rtr, data.Rte,
data.counterfactual, data.synthetic
Xtr,
Xte,
ytr,
yte,
data._start_date,
data.Ptr,
data.Pte,
data.Rtr,
data.Rte,
data.counterfactual,
data.synthetic,
)


@dataclass(kw_only=True)
class AdditiveCovariateTransform(AbstractTransform):
def apply_values(
Expand All @@ -116,7 +133,15 @@ def apply_values(
Pte = Pte + post_intervention_vals[:, 1:, :]
Rte = Rte + post_intervention_vals[:, :1, :]
return Dataset(
data.Xtr, data.Xte, data.ytr, data.yte,
data._start_date, Ptr, Pte, Rtr, Rte,
data.counterfactual, data.synthetic
data.Xtr,
data.Xte,
data.ytr,
data.yte,
data._start_date,
Ptr,
Pte,
Rtr,
Rte,
data.counterfactual,
data.synthetic,
)
13 changes: 8 additions & 5 deletions src/causal_validation/transforms/noise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from dataclasses import dataclass, field
from dataclasses import (
dataclass,
field,
)
from typing import Tuple

from jaxtyping import Float
Expand All @@ -7,12 +10,12 @@

from causal_validation.data import Dataset
from causal_validation.transforms.base import (
AdditiveCovariateTransform,
AdditiveOutputTransform,
AdditiveCovariateTransform
)
from causal_validation.transforms.parameter import (
CovariateNoiseParameter,
TimeVaryingParameter,
CovariateNoiseParameter
)


Expand Down Expand Up @@ -53,8 +56,8 @@ class CovariateNoise(AdditiveCovariateTransform):

def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]:
noise = self.noise_dist.get_value(
n_units=data.n_units+1,
n_units=data.n_units + 1,
n_timepoints=data.n_timepoints,
n_covariates=data.n_covariates
n_covariates=data.n_covariates,
)
return noise
3 changes: 1 addition & 2 deletions src/causal_validation/transforms/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def get_value(
self, n_units: int, n_timepoints: int, n_covariates: int
) -> Float[np.ndarray, "{n_timepoints} {n_units} {n_covariates}"]:
covariate_noise = self.sampling_dist.rvs(
size=(n_timepoints, n_units, n_covariates),
random_state=self.random_state
size=(n_timepoints, n_units, n_covariates), random_state=self.random_state
)
return covariate_noise

Expand Down
21 changes: 12 additions & 9 deletions tests/test_causal_validation/test_config.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
from hypothesis import (
given,
strategies as st,
)
import numpy as np
from hypothesis import given, strategies as st

from causal_validation.config import Config


@given(
n_units=st.integers(min_value=1, max_value=10),
n_pre=st.integers(min_value=1, max_value=20),
n_post=st.integers(min_value=1, max_value=20)
n_post=st.integers(min_value=1, max_value=20),
)
def test_config_basic_initialization(n_units, n_pre, n_post):
cfg = Config(
n_control_units=n_units,
n_pre_intervention_timepoints=n_pre,
n_post_intervention_timepoints=n_post
n_post_intervention_timepoints=n_post,
)
assert cfg.n_control_units == n_units
assert cfg.n_pre_intervention_timepoints == n_pre
Expand All @@ -29,7 +32,7 @@ def test_config_basic_initialization(n_units, n_pre, n_post):
n_pre=st.integers(min_value=1, max_value=10),
n_post=st.integers(min_value=1, max_value=10),
n_covariates=st.integers(min_value=1, max_value=3),
seed=st.integers(min_value=1, max_value=1000)
seed=st.integers(min_value=1, max_value=1000),
)
def test_config_with_covariates_auto_generation(
n_units, n_pre, n_post, n_covariates, seed
Expand All @@ -39,7 +42,7 @@ def test_config_with_covariates_auto_generation(
n_pre_intervention_timepoints=n_pre,
n_post_intervention_timepoints=n_post,
n_covariates=n_covariates,
seed=seed
seed=seed,
)
assert cfg.n_covariates == n_covariates
assert cfg.covariate_means.shape == (n_units, n_covariates)
Expand All @@ -50,7 +53,7 @@ def test_config_with_covariates_auto_generation(

@given(
n_units=st.integers(min_value=1, max_value=3),
n_covariates=st.integers(min_value=1, max_value=3)
n_covariates=st.integers(min_value=1, max_value=3),
)
def test_config_with_explicit_covariate_means(n_units, n_covariates):
means = np.random.random((n_units, n_covariates))
Expand All @@ -59,14 +62,14 @@ def test_config_with_explicit_covariate_means(n_units, n_covariates):
n_pre_intervention_timepoints=10,
n_post_intervention_timepoints=5,
n_covariates=n_covariates,
covariate_means=means
covariate_means=means,
)
np.testing.assert_array_equal(cfg.covariate_means, means)


@given(
n_units=st.integers(min_value=1, max_value=3),
n_covariates=st.integers(min_value=1, max_value=3)
n_covariates=st.integers(min_value=1, max_value=3),
)
def test_config_with_explicit_covariate_stds(n_units, n_covariates):
stds = np.random.random((n_units, n_covariates)) + 0.1
Expand All @@ -75,6 +78,6 @@ def test_config_with_explicit_covariate_stds(n_units, n_covariates):
n_pre_intervention_timepoints=10,
n_post_intervention_timepoints=5,
n_covariates=n_covariates,
covariate_stds=stds
covariate_stds=stds,
)
np.testing.assert_array_equal(cfg.covariate_stds, stds)
Loading