Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/causal_validation/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
)
import datetime as dt

from jaxtyping import Float
import typing as tp

import numpy as np
from scipy.stats import halfcauchy

from causal_validation.types import (
Number,
Expand All @@ -20,9 +24,38 @@ class WeightConfig:

@dataclass(kw_only=True)
class Config:
"""Configuration for causal data generation.

Args:
n_control_units (int): Number of control units in the synthetic dataset.
n_pre_intervention_timepoints (int): Number of time points before intervention.
n_post_intervention_timepoints (int): Number of time points after intervention.
n_covariates (Optional[int]): Number of covariates. Defaults to None.
covariate_means (Optional[Float[np.ndarray, "D K"]]): Mean values for covariates
D is n_control_units and K is n_covariates. Defaults to None. If it is set
to None while n_covariates is provided, covariate_means will be generated
randomly from Normal distribution.
covariate_stds (Optional[Float[np.ndarray, "D K"]]): Standard deviations for
covariates. D is n_control_units and K is n_covariates. Defaults to None.
If it is set to None while n_covariates is provided, covariate_stds
will be generated randomly from Half-Cauchy distribution.
covariate_coeffs (Optional[np.ndarray]): Linear regression
coefficients to map covariates to output observations. K is n_covariates.
Defaults to None.
global_mean (Number): Global mean for data generation. Defaults to 20.0.
global_scale (Number): Global scale for data generation. Defaults to 0.2.
start_date (dt.date): Start date for time series. Defaults to 2023-01-01.
seed (int): Random seed for reproducibility. Defaults to 123.
weights_cfg (WeightConfig): Configuration for unit weights. Defaults to
UniformWeights.
"""
n_control_units: int
n_pre_intervention_timepoints: int
n_post_intervention_timepoints: int
n_covariates: tp.Optional[int] = None
covariate_means: tp.Optional[Float[np.ndarray, "D K"]] = None
covariate_stds: tp.Optional[Float[np.ndarray, "D K"]] = None
covariate_coeffs: tp.Optional[np.ndarray] = None
global_mean: Number = 20.0
global_scale: Number = 0.2
start_date: dt.date = dt.date(year=2023, month=1, day=1)
Expand All @@ -31,3 +64,29 @@ class Config:

def __post_init__(self):
self.rng = np.random.RandomState(self.seed)
if self.covariate_means is not None:
assert self.covariate_means.shape == (self.n_control_units,
self.n_covariates)

if self.covariate_stds is not None:
assert self.covariate_stds.shape == (self.n_control_units,
self.n_covariates)

if (self.n_covariates is not None) & (self.covariate_means is None):
self.covariate_means = self.rng.normal(
loc=0.0, scale=5.0, size=(self.n_control_units,
self.n_covariates)
)

if (self.n_covariates is not None) & (self.covariate_stds is None):
self.covariate_stds = (
halfcauchy.rvs(scale=0.5,
size=(self.n_control_units,
self.n_covariates),
random_state=self.rng)
)

if (self.n_covariates is not None) & (self.covariate_coeffs is None):
self.covariate_coeffs = self.rng.normal(
loc=0.0, scale=5.0, size=self.n_covariates
)
41 changes: 36 additions & 5 deletions src/causal_validation/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,40 @@ def _simulate_base_obs(
obs = key.normal(
loc=config.global_mean, scale=config.global_scale, size=(n_timepoints, n_units)
)
Xtr = obs[: config.n_pre_intervention_timepoints, :]
Xte = obs[config.n_pre_intervention_timepoints :, :]
ytr = weights.weight_obs(Xtr)
yte = weights.weight_obs(Xte)
data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date)

if config.n_covariates is not None:
Xtr_ = obs[: config.n_pre_intervention_timepoints, :]
Xte_ = obs[config.n_pre_intervention_timepoints :, :]

covariates = key.normal(
loc=config.covariate_means,
scale=config.covariate_stds,
size=(n_timepoints, n_units, config.n_covariates)
)

Ptr = covariates[:config.n_pre_intervention_timepoints, :, :]
Pte = covariates[config.n_pre_intervention_timepoints:, :, :]

Xtr = Xtr_ + Ptr @ config.covariate_coeffs
Xte = Xte_ + Pte @ config.covariate_coeffs

ytr = weights.weight_contr(Xtr)
yte = weights.weight_contr(Xte)

Rtr = weights.weight_contr(Ptr)
Rte = weights.weight_contr(Pte)

data = Dataset(
Xtr, Xte, ytr, yte, _start_date=config.start_date,
Ptr=Ptr, Pte=Pte, Rtr=Rtr, Rte=Rte
)
else:
Xtr = obs[: config.n_pre_intervention_timepoints, :]
Xte = obs[config.n_pre_intervention_timepoints :, :]

ytr = weights.weight_contr(Xtr)
yte = weights.weight_contr(Xte)

data = Dataset(Xtr, Xte, ytr, yte, _start_date=config.start_date)

return data
32 changes: 25 additions & 7 deletions src/causal_validation/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,23 @@
if tp.TYPE_CHECKING:
from causal_validation.config import WeightConfig

# Constants for array dimensions
_NDIM_2D = 2
_NDIM_3D = 3


@dataclass
class AbstractWeights(BaseObject):
name: str = "Abstract Weights"

def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]:
def _get_weights(
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
) -> Float[np.ndarray, "D 1"]:
raise NotImplementedError("Please implement `_get_weights` in all subclasses.")

def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]:
def get_weights(
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
) -> Float[np.ndarray, "D 1"]:
weights = self._get_weights(obs)

np.testing.assert_almost_equal(
Expand All @@ -28,21 +36,31 @@ def get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]
assert min(weights >= 0), "Weights should be non-negative"
return weights

def __call__(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]:
return self.weight_obs(obs)
def __call__(
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]:
return self.weight_contr(obs)

def weight_obs(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "N 1"]:
def weight_contr(
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
) -> Float[np.ndarray, "N 1"] | Float[np.ndarray, "N 1 K"]:
weights = self.get_weights(obs)

weighted_obs = obs @ weights
if obs.ndim == _NDIM_2D:
weighted_obs = obs @ weights
elif obs.ndim == _NDIM_3D:
weighted_obs = np.einsum("n d k, d i -> n i k", obs, weights)

return weighted_obs


@dataclass
class UniformWeights(AbstractWeights):
name: str = "Uniform Weights"

def _get_weights(self, obs: Float[np.ndarray, "N D"]) -> Float[np.ndarray, "D 1"]:
def _get_weights(
self, obs: Float[np.ndarray, "N D"] | Float[np.ndarray, "N D K"]
) -> Float[np.ndarray, "D 1"]:
n_units = obs.shape[1]
return np.repeat(1.0 / n_units, repeats=n_units).reshape(-1, 1)

Expand Down
80 changes: 80 additions & 0 deletions tests/test_causal_validation/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np
from hypothesis import given, strategies as st

from causal_validation.config import Config


@given(
n_units=st.integers(min_value=1, max_value=10),
n_pre=st.integers(min_value=1, max_value=20),
n_post=st.integers(min_value=1, max_value=20)
)
def test_config_basic_initialization(n_units, n_pre, n_post):
cfg = Config(
n_control_units=n_units,
n_pre_intervention_timepoints=n_pre,
n_post_intervention_timepoints=n_post
)
assert cfg.n_control_units == n_units
assert cfg.n_pre_intervention_timepoints == n_pre
assert cfg.n_post_intervention_timepoints == n_post
assert cfg.n_covariates is None
assert cfg.covariate_means is None
assert cfg.covariate_stds is None
assert cfg.covariate_coeffs is None


@given(
n_units=st.integers(min_value=1, max_value=5),
n_pre=st.integers(min_value=1, max_value=10),
n_post=st.integers(min_value=1, max_value=10),
n_covariates=st.integers(min_value=1, max_value=3),
seed=st.integers(min_value=1, max_value=1000)
)
def test_config_with_covariates_auto_generation(
n_units, n_pre, n_post, n_covariates, seed
):
cfg = Config(
n_control_units=n_units,
n_pre_intervention_timepoints=n_pre,
n_post_intervention_timepoints=n_post,
n_covariates=n_covariates,
seed=seed
)
assert cfg.n_covariates == n_covariates
assert cfg.covariate_means.shape == (n_units, n_covariates)
assert cfg.covariate_stds.shape == (n_units, n_covariates)
assert cfg.covariate_coeffs.shape == (n_covariates,)
assert np.all(cfg.covariate_stds >= 0)


@given(
n_units=st.integers(min_value=1, max_value=3),
n_covariates=st.integers(min_value=1, max_value=3)
)
def test_config_with_explicit_covariate_means(n_units, n_covariates):
means = np.random.random((n_units, n_covariates))
cfg = Config(
n_control_units=n_units,
n_pre_intervention_timepoints=10,
n_post_intervention_timepoints=5,
n_covariates=n_covariates,
covariate_means=means
)
np.testing.assert_array_equal(cfg.covariate_means, means)


@given(
n_units=st.integers(min_value=1, max_value=3),
n_covariates=st.integers(min_value=1, max_value=3)
)
def test_config_with_explicit_covariate_stds(n_units, n_covariates):
stds = np.random.random((n_units, n_covariates)) + 0.1
cfg = Config(
n_control_units=n_units,
n_pre_intervention_timepoints=10,
n_post_intervention_timepoints=5,
n_covariates=n_covariates,
covariate_stds=stds
)
np.testing.assert_array_equal(cfg.covariate_stds, stds)
Loading