Skip to content

Commit 129ce7e

Browse files
Noise transformation for covariates (#41)
* Dataset to_df revision to support covariates properly. * Preserve covariates after transformation * Add noise transformation for covariates
1 parent 93bf0b4 commit 129ce7e

File tree

10 files changed

+285
-30
lines changed

10 files changed

+285
-30
lines changed

src/causal_validation/data.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,36 @@ def __post_init__(self):
6262

6363
def to_df(
6464
self, index_start: str = dt.date(year=2023, month=1, day=1)
65-
) -> pd.DataFrame:
66-
inputs = np.vstack([self.Xtr, self.Xte])
67-
outputs = np.vstack([self.ytr, self.yte])
68-
data = np.hstack([outputs, inputs])
65+
) -> tp.Tuple[pd.DataFrame, tp.Optional[pd.DataFrame]]:
66+
control_outputs = np.vstack([self.Xtr, self.Xte])
67+
treated_outputs = np.vstack([self.ytr, self.yte])
68+
data = np.hstack([treated_outputs, control_outputs])
6969
index = self._get_index(index_start)
7070
colnames = self._get_columns()
7171
indicator = self._get_indicator()
72-
df = pd.DataFrame(data, index=index, columns=colnames)
73-
df = df.assign(treated=indicator)
74-
return df
72+
df_outputs = pd.DataFrame(data, index=index, columns=colnames)
73+
df_outputs = df_outputs.assign(treated=indicator)
74+
75+
if not self.has_covariates:
76+
cov_df = None
77+
else:
78+
control_covs = np.concatenate([self.Ptr, self.Pte], axis=0)
79+
treated_covs = np.concatenate([self.Rtr, self.Rte], axis=0)
80+
81+
all_covs = np.concatenate([treated_covs, control_covs], axis=1)
82+
83+
unit_cols = self._get_columns()
84+
covariate_cols = [f"F{i}" for i in range(self.n_covariates)]
85+
86+
cov_data = all_covs.reshape(self.n_timepoints, -1)
87+
88+
col_tuples = [(unit, cov) for unit in unit_cols for cov in covariate_cols]
89+
multi_cols = pd.MultiIndex.from_tuples(col_tuples)
90+
91+
cov_df = pd.DataFrame(cov_data, index=index, columns=multi_cols)
92+
cov_df = cov_df.assign(treated=indicator)
93+
94+
return df_outputs, cov_df
7595

7696
@property
7797
def n_post_intervention(self) -> int:
@@ -180,12 +200,7 @@ def get_index(self, period: InterventionTypes) -> DatetimeIndex:
180200
return self.full_index
181201

182202
def _get_columns(self) -> tp.List[str]:
183-
if self.has_covariates:
184-
colnames = ["T"] + [f"C{i}" for i in range(self.n_units)] + [
185-
f"F{i}" for i in range(self.n_covariates)
186-
]
187-
else:
188-
colnames = ["T"] + [f"C{i}" for i in range(self.n_units)]
203+
colnames = ["T"] + [f"C{i}" for i in range(self.n_units)]
189204
return colnames
190205

191206
def _get_index(self, start_date: dt.date) -> DatetimeIndex:
@@ -224,7 +239,8 @@ def __eq__(self, other: Dataset) -> bool:
224239

225240
def to_azcausal(self):
226241
time_index = np.arange(self.n_timepoints)
227-
data = self.to_df().assign(time=time_index).melt(id_vars=["time", "treated"])
242+
data_df, _ = self.to_df()
243+
data = data_df.assign(time=time_index).melt(id_vars=["time", "treated"])
228244
data.loc[:, "treated"] = np.where(
229245
(data["variable"] == "T") & (data["treated"] == 1.0), 1, 0
230246
)

src/causal_validation/testing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class TestConstants:
1111
N_CONTROL: int = 10
1212
N_PRE_TREATMENT: int = 500
1313
N_POST_TREATMENT: int = 500
14+
N_COVARIATES: tp.Optional[int] = None
1415
DATA_SLOTS: tp.Tuple[str, str, str, str] = ("Xtr", "Xte", "ytr", "yte")
1516
ZERO_DIVISION_ERROR: float = 1e-6
1617
GLOBAL_SCALE: float = 1.0
@@ -26,6 +27,7 @@ def simulate_data(
2627
n_control_units=constants.N_CONTROL,
2728
n_pre_intervention_timepoints=constants.N_PRE_TREATMENT,
2829
n_post_intervention_timepoints=constants.N_POST_TREATMENT,
30+
n_covariates=constants.N_COVARIATES,
2931
global_mean=global_mean,
3032
global_scale=constants.GLOBAL_SCALE,
3133
seed=seed,
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from causal_validation.transforms.noise import Noise
1+
from causal_validation.transforms.noise import Noise, CovariateNoise
22
from causal_validation.transforms.periodic import Periodic
33
from causal_validation.transforms.trends import Trend
44

5-
__all__ = ["Trend", "Periodic", "Noise"]
5+
__all__ = ["Trend", "Periodic", "Noise", "CovariateNoise"]

src/causal_validation/transforms/base.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _get_parameter_values(self, data: Dataset) -> tp.Dict[str, np.ndarray]:
6161

6262

6363
@dataclass(kw_only=True)
64-
class AdditiveTransform(AbstractTransform):
64+
class AdditiveOutputTransform(AbstractTransform):
6565
def apply_values(
6666
self,
6767
pre_intervention_vals: np.ndarray,
@@ -75,12 +75,14 @@ def apply_values(
7575
Xte = Xte + post_intervention_vals[:, 1:]
7676
yte = yte + post_intervention_vals[:, :1]
7777
return Dataset(
78-
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
78+
Xtr, Xte, ytr, yte, data._start_date,
79+
data.Ptr, data.Pte, data.Rtr, data.Rte,
80+
data.counterfactual, data.synthetic
7981
)
8082

8183

8284
@dataclass(kw_only=True)
83-
class MultiplicativeTransform(AbstractTransform):
85+
class MultiplicativeOutputTransform(AbstractTransform):
8486
def apply_values(
8587
self,
8688
pre_intervention_vals: np.ndarray,
@@ -94,5 +96,27 @@ def apply_values(
9496
Xte = Xte * post_intervention_vals
9597
yte = yte * post_intervention_vals
9698
return Dataset(
97-
Xtr, Xte, ytr, yte, data._start_date, data.counterfactual, data.synthetic
99+
Xtr, Xte, ytr, yte, data._start_date,
100+
data.Ptr, data.Pte, data.Rtr, data.Rte,
101+
data.counterfactual, data.synthetic
102+
)
103+
104+
@dataclass(kw_only=True)
105+
class AdditiveCovariateTransform(AbstractTransform):
106+
def apply_values(
107+
self,
108+
pre_intervention_vals: np.ndarray,
109+
post_intervention_vals: np.ndarray,
110+
data: Dataset,
111+
) -> Dataset:
112+
Ptr, Rtr = [deepcopy(i) for i in data.pre_intervention_covariates]
113+
Pte, Rte = [deepcopy(i) for i in data.post_intervention_covariates]
114+
Ptr = Ptr + pre_intervention_vals[:, 1:, :]
115+
Rtr = Rtr + pre_intervention_vals[:, :1, :]
116+
Pte = Pte + post_intervention_vals[:, 1:, :]
117+
Rte = Rte + post_intervention_vals[:, :1, :]
118+
return Dataset(
119+
data.Xtr, data.Xte, data.ytr, data.yte,
120+
data._start_date, Ptr, Pte, Rtr, Rte,
121+
data.counterfactual, data.synthetic
98122
)

src/causal_validation/transforms/noise.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,18 @@
66
from scipy.stats import norm
77

88
from causal_validation.data import Dataset
9-
from causal_validation.transforms.base import AdditiveTransform
10-
from causal_validation.transforms.parameter import TimeVaryingParameter
9+
from causal_validation.transforms.base import (
10+
AdditiveOutputTransform,
11+
AdditiveCovariateTransform
12+
)
13+
from causal_validation.transforms.parameter import (
14+
TimeVaryingParameter,
15+
CovariateNoiseParameter
16+
)
1117

1218

1319
@dataclass(kw_only=True)
14-
class Noise(AdditiveTransform):
20+
class Noise(AdditiveOutputTransform):
1521
"""
1622
Transform the treatment by adding TimeVaryingParameter noise terms sampled from
1723
a specified sampling distribution. By default, the sampling distribution is
@@ -30,3 +36,25 @@ def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]:
3036
).reshape(-1)
3137
noise[:, 0] = noise_treatment
3238
return noise
39+
40+
41+
@dataclass(kw_only=True)
42+
class CovariateNoise(AdditiveCovariateTransform):
43+
"""
44+
Transform the covariates by adding CovariateNoiseParameter noise terms sampled from
45+
a specified sampling distribution. By default, the sampling distribution is
46+
Normal with 0 loc and 0.1 scale.
47+
"""
48+
49+
noise_dist: CovariateNoiseParameter = field(
50+
default_factory=lambda: CovariateNoiseParameter(sampling_dist=norm(0, 0.1))
51+
)
52+
_slots: Tuple[str] = ("noise_dist",)
53+
54+
def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]:
55+
noise = self.noise_dist.get_value(
56+
n_units=data.n_units+1,
57+
n_timepoints=data.n_timepoints,
58+
n_covariates=data.n_covariates
59+
)
60+
return noise

src/causal_validation/transforms/parameter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,18 @@ def get_value(
5151
return np.tile(time_param, reps=n_units)
5252

5353

54+
@dataclass
55+
class CovariateNoiseParameter(RandomParameter):
56+
def get_value(
57+
self, n_units: int, n_timepoints: int, n_covariates: int
58+
) -> Float[np.ndarray, "{n_timepoints} {n_units} {n_covariates}"]:
59+
covariate_noise = self.sampling_dist.rvs(
60+
size=(n_timepoints, n_units, n_covariates),
61+
random_state=self.random_state
62+
)
63+
return covariate_noise
64+
65+
5466
ParameterOrFloat = tp.Union[Parameter, float]
5567

5668

src/causal_validation/transforms/periodic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import numpy as np
66

77
from causal_validation.data import Dataset
8-
from causal_validation.transforms.base import AdditiveTransform
8+
from causal_validation.transforms.base import AdditiveOutputTransform
99
from causal_validation.transforms.parameter import ParameterOrFloat
1010

1111

1212
@dataclass(kw_only=True)
13-
class Periodic(AdditiveTransform):
13+
class Periodic(AdditiveOutputTransform):
1414
amplitude: ParameterOrFloat = 1.0
1515
frequency: ParameterOrFloat = 1.0
1616
shift: ParameterOrFloat = 0.0

src/causal_validation/transforms/trends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
import numpy as np
66

77
from causal_validation.data import Dataset
8-
from causal_validation.transforms.base import AdditiveTransform
8+
from causal_validation.transforms.base import AdditiveOutputTransform
99
from causal_validation.transforms.parameter import ParameterOrFloat
1010

1111

1212
@dataclass(kw_only=True)
13-
class Trend(AdditiveTransform):
13+
class Trend(AdditiveOutputTransform):
1414
degree: int = 1
1515
coefficient: ParameterOrFloat = 1.0
1616
intercept: ParameterOrFloat = 0.0

tests/test_causal_validation/test_data.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
MAX_STRING_LENGTH = 20
3030
DEFAULT_SEED = 123
3131
NUM_NON_CONTROL_COLS = 2
32+
NUM_TREATED = 1
3233
LARGE_N_POST = 5000
3334
LARGE_N_PRE = 5000
3435

@@ -109,15 +110,15 @@ def test_indicator(n_pre_treatment: int, n_post_treatment: int):
109110
n_pre_treatment=st.integers(min_value=1, max_value=50),
110111
n_post_treatment=st.integers(min_value=1, max_value=50),
111112
)
112-
def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int):
113+
def test_to_df_no_cov(n_control: int, n_pre_treatment: int, n_post_treatment: int):
113114
constants = TestConstants(
114115
N_POST_TREATMENT=n_post_treatment,
115116
N_PRE_TREATMENT=n_pre_treatment,
116117
N_CONTROL=n_control,
117118
)
118119
data = simulate_data(0.0, DEFAULT_SEED, constants=constants)
119120

120-
df = data.to_df()
121+
df, _ = data.to_df()
121122
assert isinstance(df, pd.DataFrame)
122123
assert df.shape == (
123124
n_pre_treatment + n_post_treatment,
@@ -133,6 +134,47 @@ def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int):
133134
assert isinstance(index, DatetimeIndex)
134135
assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d")
135136

137+
@given(
138+
n_control=st.integers(min_value=1, max_value=50),
139+
n_pre_treatment=st.integers(min_value=1, max_value=50),
140+
n_post_treatment=st.integers(min_value=1, max_value=50),
141+
n_covariates=st.integers(min_value=1, max_value=50),
142+
)
143+
def test_to_df_with_cov(n_control: int,
144+
n_pre_treatment: int,
145+
n_post_treatment: int,
146+
n_covariates:int):
147+
constants = TestConstants(
148+
N_POST_TREATMENT=n_post_treatment,
149+
N_PRE_TREATMENT=n_pre_treatment,
150+
N_CONTROL=n_control,
151+
N_COVARIATES=n_covariates,
152+
)
153+
data = simulate_data(0.0, DEFAULT_SEED, constants=constants)
154+
155+
df_outs, df_covs = data.to_df()
156+
assert isinstance(df_outs, pd.DataFrame)
157+
assert df_outs.shape == (
158+
n_pre_treatment + n_post_treatment,
159+
n_control + NUM_NON_CONTROL_COLS,
160+
)
161+
162+
assert isinstance(df_covs, pd.DataFrame)
163+
assert df_covs.shape == (
164+
n_pre_treatment + n_post_treatment,
165+
n_covariates * (n_control + NUM_TREATED)
166+
+ NUM_NON_CONTROL_COLS - NUM_TREATED,
167+
)
168+
169+
colnames = data._get_columns()
170+
assert isinstance(colnames, list)
171+
assert colnames[0] == "T"
172+
assert len(colnames) == n_control + 1
173+
174+
index = data.full_index
175+
assert isinstance(index, DatetimeIndex)
176+
assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d")
177+
136178

137179
@given(
138180
n_control=st.integers(min_value=2, max_value=50),

0 commit comments

Comments
 (0)