Skip to content

Commit 4c3e783

Browse files
Format changes. (#42)
1 parent 129ce7e commit 4c3e783

File tree

13 files changed

+164
-87
lines changed

13 files changed

+164
-87
lines changed

docs/examples/azcausal.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@
185185
"panel = data.to_azcausal()\n",
186186
"model = SDID()\n",
187187
"result = model.fit(panel)\n",
188-
"print(f\"Delta: {100*(TRUE_EFFECT - result.effect.percentage().value / 100): .2f}%\")\n",
188+
"print(f\"Delta: {100 * (TRUE_EFFECT - result.effect.percentage().value / 100): .2f}%\")\n",
189189
"print(result.summary(title=\"Synthetic Data Experiment\"))"
190190
]
191191
}

src/causal_validation/config.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
field,
44
)
55
import datetime as dt
6-
7-
from jaxtyping import Float
86
import typing as tp
97

8+
from jaxtyping import Float
109
import numpy as np
1110
from scipy.stats import halfcauchy
1211

@@ -49,6 +48,7 @@ class Config:
4948
weights_cfg (WeightConfig): Configuration for unit weights. Defaults to
5049
UniformWeights.
5150
"""
51+
5252
n_control_units: int
5353
n_pre_intervention_timepoints: int
5454
n_post_intervention_timepoints: int
@@ -65,25 +65,27 @@ class Config:
6565
def __post_init__(self):
6666
self.rng = np.random.RandomState(self.seed)
6767
if self.covariate_means is not None:
68-
assert self.covariate_means.shape == (self.n_control_units,
69-
self.n_covariates)
68+
assert self.covariate_means.shape == (
69+
self.n_control_units,
70+
self.n_covariates,
71+
)
7072

7173
if self.covariate_stds is not None:
72-
assert self.covariate_stds.shape == (self.n_control_units,
73-
self.n_covariates)
74+
assert self.covariate_stds.shape == (
75+
self.n_control_units,
76+
self.n_covariates,
77+
)
7478

7579
if (self.n_covariates is not None) & (self.covariate_means is None):
7680
self.covariate_means = self.rng.normal(
77-
loc=0.0, scale=5.0, size=(self.n_control_units,
78-
self.n_covariates)
81+
loc=0.0, scale=5.0, size=(self.n_control_units, self.n_covariates)
7982
)
8083

8184
if (self.n_covariates is not None) & (self.covariate_stds is None):
82-
self.covariate_stds = (
83-
halfcauchy.rvs(scale=0.5,
84-
size=(self.n_control_units,
85-
self.n_covariates),
86-
random_state=self.rng)
85+
self.covariate_stds = halfcauchy.rvs(
86+
scale=0.5,
87+
size=(self.n_control_units, self.n_covariates),
88+
random_state=self.rng,
8789
)
8890

8991
if (self.n_covariates is not None) & (self.covariate_coeffs is None):

src/causal_validation/data.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class Dataset:
4141
treated in pre-intervention period.
4242
_name: Optional name identifier for the dataset
4343
"""
44+
4445
Xtr: Float[np.ndarray, "N D"]
4546
Xte: Float[np.ndarray, "M D"]
4647
ytr: Float[np.ndarray, "N 1"]
@@ -161,7 +162,8 @@ def pre_intervention_covariates(
161162
self,
162163
) -> tp.Optional[
163164
tp.Tuple[
164-
Float[np.ndarray, "N D F"], Float[np.ndarray, "N 1 F"],
165+
Float[np.ndarray, "N D F"],
166+
Float[np.ndarray, "N 1 F"],
165167
]
166168
]:
167169
if self.has_covariates:
@@ -174,7 +176,8 @@ def post_intervention_covariates(
174176
self,
175177
) -> tp.Optional[
176178
tp.Tuple[
177-
Float[np.ndarray, "M D F"], Float[np.ndarray, "M 1 F"],
179+
Float[np.ndarray, "M D F"],
180+
Float[np.ndarray, "M 1 F"],
178181
]
179182
]:
180183
if self.has_covariates:
@@ -220,8 +223,18 @@ def inflate(self, inflation_vals: Float[np.ndarray, "M 1"]) -> Dataset:
220223
Xte, yte = [deepcopy(i) for i in self.post_intervention_obs]
221224
inflated_yte = yte * inflation_vals
222225
return Dataset(
223-
Xtr, Xte, ytr, inflated_yte, self._start_date,
224-
self.Ptr, self.Pte, self.Rtr, self.Rte, yte, self.synthetic, self._name
226+
Xtr,
227+
Xte,
228+
ytr,
229+
inflated_yte,
230+
self._start_date,
231+
self.Ptr,
232+
self.Pte,
233+
self.Rtr,
234+
self.Rte,
235+
yte,
236+
self.synthetic,
237+
self._name,
225238
)
226239

227240
def __eq__(self, other: Dataset) -> bool:
@@ -326,7 +339,16 @@ def reassign_treatment(
326339
Xtr = data.Xtr
327340
Xte = data.Xte
328341
return Dataset(
329-
Xtr, Xte, ytr, yte, data._start_date,
330-
data.Ptr, data.Pte, data.Rtr, data.Rte,
331-
data.counterfactual, data.synthetic, data._name
342+
Xtr,
343+
Xte,
344+
ytr,
345+
yte,
346+
data._start_date,
347+
data.Ptr,
348+
data.Pte,
349+
data.Rtr,
350+
data.Rte,
351+
data.counterfactual,
352+
data.synthetic,
353+
data._name,
332354
)

src/causal_validation/simulate.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ def _simulate_base_obs(
3737
covariates = key.normal(
3838
loc=config.covariate_means,
3939
scale=config.covariate_stds,
40-
size=(n_timepoints, n_units, config.n_covariates)
40+
size=(n_timepoints, n_units, config.n_covariates),
4141
)
4242

43-
Ptr = covariates[:config.n_pre_intervention_timepoints, :, :]
44-
Pte = covariates[config.n_pre_intervention_timepoints:, :, :]
43+
Ptr = covariates[: config.n_pre_intervention_timepoints, :, :]
44+
Pte = covariates[config.n_pre_intervention_timepoints :, :, :]
4545

4646
Xtr = Xtr_ + Ptr @ config.covariate_coeffs
4747
Xte = Xte_ + Pte @ config.covariate_coeffs
@@ -53,8 +53,15 @@ def _simulate_base_obs(
5353
Rte = weights.weight_contr(Pte)
5454

5555
data = Dataset(
56-
Xtr, Xte, ytr, yte, _start_date=config.start_date,
57-
Ptr=Ptr, Pte=Pte, Rtr=Rtr, Rte=Rte
56+
Xtr,
57+
Xte,
58+
ytr,
59+
yte,
60+
_start_date=config.start_date,
61+
Ptr=Ptr,
62+
Pte=Pte,
63+
Rtr=Rtr,
64+
Rte=Rte,
5865
)
5966
else:
6067
Xtr = obs[: config.n_pre_intervention_timepoints, :]

src/causal_validation/transforms/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from causal_validation.transforms.noise import Noise, CovariateNoise
1+
from causal_validation.transforms.noise import (
2+
CovariateNoise,
3+
Noise,
4+
)
25
from causal_validation.transforms.periodic import Periodic
36
from causal_validation.transforms.trends import Trend
47

src/causal_validation/transforms/base.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,17 @@ def apply_values(
7575
Xte = Xte + post_intervention_vals[:, 1:]
7676
yte = yte + post_intervention_vals[:, :1]
7777
return Dataset(
78-
Xtr, Xte, ytr, yte, data._start_date,
79-
data.Ptr, data.Pte, data.Rtr, data.Rte,
80-
data.counterfactual, data.synthetic
78+
Xtr,
79+
Xte,
80+
ytr,
81+
yte,
82+
data._start_date,
83+
data.Ptr,
84+
data.Pte,
85+
data.Rtr,
86+
data.Rte,
87+
data.counterfactual,
88+
data.synthetic,
8189
)
8290

8391

@@ -96,11 +104,20 @@ def apply_values(
96104
Xte = Xte * post_intervention_vals
97105
yte = yte * post_intervention_vals
98106
return Dataset(
99-
Xtr, Xte, ytr, yte, data._start_date,
100-
data.Ptr, data.Pte, data.Rtr, data.Rte,
101-
data.counterfactual, data.synthetic
107+
Xtr,
108+
Xte,
109+
ytr,
110+
yte,
111+
data._start_date,
112+
data.Ptr,
113+
data.Pte,
114+
data.Rtr,
115+
data.Rte,
116+
data.counterfactual,
117+
data.synthetic,
102118
)
103119

120+
104121
@dataclass(kw_only=True)
105122
class AdditiveCovariateTransform(AbstractTransform):
106123
def apply_values(
@@ -116,7 +133,15 @@ def apply_values(
116133
Pte = Pte + post_intervention_vals[:, 1:, :]
117134
Rte = Rte + post_intervention_vals[:, :1, :]
118135
return Dataset(
119-
data.Xtr, data.Xte, data.ytr, data.yte,
120-
data._start_date, Ptr, Pte, Rtr, Rte,
121-
data.counterfactual, data.synthetic
136+
data.Xtr,
137+
data.Xte,
138+
data.ytr,
139+
data.yte,
140+
data._start_date,
141+
Ptr,
142+
Pte,
143+
Rtr,
144+
Rte,
145+
data.counterfactual,
146+
data.synthetic,
122147
)

src/causal_validation/transforms/noise.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from dataclasses import dataclass, field
1+
from dataclasses import (
2+
dataclass,
3+
field,
4+
)
25
from typing import Tuple
36

47
from jaxtyping import Float
@@ -7,12 +10,12 @@
710

811
from causal_validation.data import Dataset
912
from causal_validation.transforms.base import (
13+
AdditiveCovariateTransform,
1014
AdditiveOutputTransform,
11-
AdditiveCovariateTransform
1215
)
1316
from causal_validation.transforms.parameter import (
17+
CovariateNoiseParameter,
1418
TimeVaryingParameter,
15-
CovariateNoiseParameter
1619
)
1720

1821

@@ -53,8 +56,8 @@ class CovariateNoise(AdditiveCovariateTransform):
5356

5457
def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]:
5558
noise = self.noise_dist.get_value(
56-
n_units=data.n_units+1,
59+
n_units=data.n_units + 1,
5760
n_timepoints=data.n_timepoints,
58-
n_covariates=data.n_covariates
61+
n_covariates=data.n_covariates,
5962
)
6063
return noise

src/causal_validation/transforms/parameter.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@ def get_value(
5757
self, n_units: int, n_timepoints: int, n_covariates: int
5858
) -> Float[np.ndarray, "{n_timepoints} {n_units} {n_covariates}"]:
5959
covariate_noise = self.sampling_dist.rvs(
60-
size=(n_timepoints, n_units, n_covariates),
61-
random_state=self.random_state
60+
size=(n_timepoints, n_units, n_covariates), random_state=self.random_state
6261
)
6362
return covariate_noise
6463

tests/test_causal_validation/test_config.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1+
from hypothesis import (
2+
given,
3+
strategies as st,
4+
)
15
import numpy as np
2-
from hypothesis import given, strategies as st
36

47
from causal_validation.config import Config
58

69

710
@given(
811
n_units=st.integers(min_value=1, max_value=10),
912
n_pre=st.integers(min_value=1, max_value=20),
10-
n_post=st.integers(min_value=1, max_value=20)
13+
n_post=st.integers(min_value=1, max_value=20),
1114
)
1215
def test_config_basic_initialization(n_units, n_pre, n_post):
1316
cfg = Config(
1417
n_control_units=n_units,
1518
n_pre_intervention_timepoints=n_pre,
16-
n_post_intervention_timepoints=n_post
19+
n_post_intervention_timepoints=n_post,
1720
)
1821
assert cfg.n_control_units == n_units
1922
assert cfg.n_pre_intervention_timepoints == n_pre
@@ -29,7 +32,7 @@ def test_config_basic_initialization(n_units, n_pre, n_post):
2932
n_pre=st.integers(min_value=1, max_value=10),
3033
n_post=st.integers(min_value=1, max_value=10),
3134
n_covariates=st.integers(min_value=1, max_value=3),
32-
seed=st.integers(min_value=1, max_value=1000)
35+
seed=st.integers(min_value=1, max_value=1000),
3336
)
3437
def test_config_with_covariates_auto_generation(
3538
n_units, n_pre, n_post, n_covariates, seed
@@ -39,7 +42,7 @@ def test_config_with_covariates_auto_generation(
3942
n_pre_intervention_timepoints=n_pre,
4043
n_post_intervention_timepoints=n_post,
4144
n_covariates=n_covariates,
42-
seed=seed
45+
seed=seed,
4346
)
4447
assert cfg.n_covariates == n_covariates
4548
assert cfg.covariate_means.shape == (n_units, n_covariates)
@@ -50,7 +53,7 @@ def test_config_with_covariates_auto_generation(
5053

5154
@given(
5255
n_units=st.integers(min_value=1, max_value=3),
53-
n_covariates=st.integers(min_value=1, max_value=3)
56+
n_covariates=st.integers(min_value=1, max_value=3),
5457
)
5558
def test_config_with_explicit_covariate_means(n_units, n_covariates):
5659
means = np.random.random((n_units, n_covariates))
@@ -59,14 +62,14 @@ def test_config_with_explicit_covariate_means(n_units, n_covariates):
5962
n_pre_intervention_timepoints=10,
6063
n_post_intervention_timepoints=5,
6164
n_covariates=n_covariates,
62-
covariate_means=means
65+
covariate_means=means,
6366
)
6467
np.testing.assert_array_equal(cfg.covariate_means, means)
6568

6669

6770
@given(
6871
n_units=st.integers(min_value=1, max_value=3),
69-
n_covariates=st.integers(min_value=1, max_value=3)
72+
n_covariates=st.integers(min_value=1, max_value=3),
7073
)
7174
def test_config_with_explicit_covariate_stds(n_units, n_covariates):
7275
stds = np.random.random((n_units, n_covariates)) + 0.1
@@ -75,6 +78,6 @@ def test_config_with_explicit_covariate_stds(n_units, n_covariates):
7578
n_pre_intervention_timepoints=10,
7679
n_post_intervention_timepoints=5,
7780
n_covariates=n_covariates,
78-
covariate_stds=stds
81+
covariate_stds=stds,
7982
)
8083
np.testing.assert_array_equal(cfg.covariate_stds, stds)

0 commit comments

Comments
 (0)