diff --git a/docs/examples/azcausal.ipynb b/docs/examples/azcausal.ipynb index b7f8a02..0cb7672 100644 --- a/docs/examples/azcausal.ipynb +++ b/docs/examples/azcausal.ipynb @@ -185,7 +185,7 @@ "panel = data.to_azcausal()\n", "model = SDID()\n", "result = model.fit(panel)\n", - "print(f\"Delta: {100*(TRUE_EFFECT - result.effect.percentage().value / 100): .2f}%\")\n", + "print(f\"Delta: {100 * (TRUE_EFFECT - result.effect.percentage().value / 100): .2f}%\")\n", "print(result.summary(title=\"Synthetic Data Experiment\"))" ] } diff --git a/src/causal_validation/config.py b/src/causal_validation/config.py index ca7928a..5134d17 100644 --- a/src/causal_validation/config.py +++ b/src/causal_validation/config.py @@ -3,10 +3,9 @@ field, ) import datetime as dt - -from jaxtyping import Float import typing as tp +from jaxtyping import Float import numpy as np from scipy.stats import halfcauchy @@ -49,6 +48,7 @@ class Config: weights_cfg (WeightConfig): Configuration for unit weights. Defaults to UniformWeights. """ + n_control_units: int n_pre_intervention_timepoints: int n_post_intervention_timepoints: int @@ -65,25 +65,27 @@ class Config: def __post_init__(self): self.rng = np.random.RandomState(self.seed) if self.covariate_means is not None: - assert self.covariate_means.shape == (self.n_control_units, - self.n_covariates) + assert self.covariate_means.shape == ( + self.n_control_units, + self.n_covariates, + ) if self.covariate_stds is not None: - assert self.covariate_stds.shape == (self.n_control_units, - self.n_covariates) + assert self.covariate_stds.shape == ( + self.n_control_units, + self.n_covariates, + ) if (self.n_covariates is not None) & (self.covariate_means is None): self.covariate_means = self.rng.normal( - loc=0.0, scale=5.0, size=(self.n_control_units, - self.n_covariates) + loc=0.0, scale=5.0, size=(self.n_control_units, self.n_covariates) ) if (self.n_covariates is not None) & (self.covariate_stds is None): - self.covariate_stds = ( - halfcauchy.rvs(scale=0.5, - size=(self.n_control_units, - self.n_covariates), - random_state=self.rng) + self.covariate_stds = halfcauchy.rvs( + scale=0.5, + size=(self.n_control_units, self.n_covariates), + random_state=self.rng, ) if (self.n_covariates is not None) & (self.covariate_coeffs is None): diff --git a/src/causal_validation/data.py b/src/causal_validation/data.py index e7922a8..73f0c6f 100644 --- a/src/causal_validation/data.py +++ b/src/causal_validation/data.py @@ -41,6 +41,7 @@ class Dataset: treated in pre-intervention period. _name: Optional name identifier for the dataset """ + Xtr: Float[np.ndarray, "N D"] Xte: Float[np.ndarray, "M D"] ytr: Float[np.ndarray, "N 1"] @@ -161,7 +162,8 @@ def pre_intervention_covariates( self, ) -> tp.Optional[ tp.Tuple[ - Float[np.ndarray, "N D F"], Float[np.ndarray, "N 1 F"], + Float[np.ndarray, "N D F"], + Float[np.ndarray, "N 1 F"], ] ]: if self.has_covariates: @@ -174,7 +176,8 @@ def post_intervention_covariates( self, ) -> tp.Optional[ tp.Tuple[ - Float[np.ndarray, "M D F"], Float[np.ndarray, "M 1 F"], + Float[np.ndarray, "M D F"], + Float[np.ndarray, "M 1 F"], ] ]: if self.has_covariates: @@ -220,8 +223,18 @@ def inflate(self, inflation_vals: Float[np.ndarray, "M 1"]) -> Dataset: Xte, yte = [deepcopy(i) for i in self.post_intervention_obs] inflated_yte = yte * inflation_vals return Dataset( - Xtr, Xte, ytr, inflated_yte, self._start_date, - self.Ptr, self.Pte, self.Rtr, self.Rte, yte, self.synthetic, self._name + Xtr, + Xte, + ytr, + inflated_yte, + self._start_date, + self.Ptr, + self.Pte, + self.Rtr, + self.Rte, + yte, + self.synthetic, + self._name, ) def __eq__(self, other: Dataset) -> bool: @@ -326,7 +339,16 @@ def reassign_treatment( Xtr = data.Xtr Xte = data.Xte return Dataset( - Xtr, Xte, ytr, yte, data._start_date, - data.Ptr, data.Pte, data.Rtr, data.Rte, - data.counterfactual, data.synthetic, data._name + Xtr, + Xte, + ytr, + yte, + data._start_date, + data.Ptr, + data.Pte, + data.Rtr, + data.Rte, + data.counterfactual, + data.synthetic, + data._name, ) diff --git a/src/causal_validation/simulate.py b/src/causal_validation/simulate.py index e85e80f..b1f85f6 100644 --- a/src/causal_validation/simulate.py +++ b/src/causal_validation/simulate.py @@ -37,11 +37,11 @@ def _simulate_base_obs( covariates = key.normal( loc=config.covariate_means, scale=config.covariate_stds, - size=(n_timepoints, n_units, config.n_covariates) + size=(n_timepoints, n_units, config.n_covariates), ) - Ptr = covariates[:config.n_pre_intervention_timepoints, :, :] - Pte = covariates[config.n_pre_intervention_timepoints:, :, :] + Ptr = covariates[: config.n_pre_intervention_timepoints, :, :] + Pte = covariates[config.n_pre_intervention_timepoints :, :, :] Xtr = Xtr_ + Ptr @ config.covariate_coeffs Xte = Xte_ + Pte @ config.covariate_coeffs @@ -53,8 +53,15 @@ def _simulate_base_obs( Rte = weights.weight_contr(Pte) data = Dataset( - Xtr, Xte, ytr, yte, _start_date=config.start_date, - Ptr=Ptr, Pte=Pte, Rtr=Rtr, Rte=Rte + Xtr, + Xte, + ytr, + yte, + _start_date=config.start_date, + Ptr=Ptr, + Pte=Pte, + Rtr=Rtr, + Rte=Rte, ) else: Xtr = obs[: config.n_pre_intervention_timepoints, :] diff --git a/src/causal_validation/transforms/__init__.py b/src/causal_validation/transforms/__init__.py index 44ac584..eb3895a 100644 --- a/src/causal_validation/transforms/__init__.py +++ b/src/causal_validation/transforms/__init__.py @@ -1,4 +1,7 @@ -from causal_validation.transforms.noise import Noise, CovariateNoise +from causal_validation.transforms.noise import ( + CovariateNoise, + Noise, +) from causal_validation.transforms.periodic import Periodic from causal_validation.transforms.trends import Trend diff --git a/src/causal_validation/transforms/base.py b/src/causal_validation/transforms/base.py index c4258fe..0c98762 100644 --- a/src/causal_validation/transforms/base.py +++ b/src/causal_validation/transforms/base.py @@ -75,9 +75,17 @@ def apply_values( Xte = Xte + post_intervention_vals[:, 1:] yte = yte + post_intervention_vals[:, :1] return Dataset( - Xtr, Xte, ytr, yte, data._start_date, - data.Ptr, data.Pte, data.Rtr, data.Rte, - data.counterfactual, data.synthetic + Xtr, + Xte, + ytr, + yte, + data._start_date, + data.Ptr, + data.Pte, + data.Rtr, + data.Rte, + data.counterfactual, + data.synthetic, ) @@ -96,11 +104,20 @@ def apply_values( Xte = Xte * post_intervention_vals yte = yte * post_intervention_vals return Dataset( - Xtr, Xte, ytr, yte, data._start_date, - data.Ptr, data.Pte, data.Rtr, data.Rte, - data.counterfactual, data.synthetic + Xtr, + Xte, + ytr, + yte, + data._start_date, + data.Ptr, + data.Pte, + data.Rtr, + data.Rte, + data.counterfactual, + data.synthetic, ) + @dataclass(kw_only=True) class AdditiveCovariateTransform(AbstractTransform): def apply_values( @@ -116,7 +133,15 @@ def apply_values( Pte = Pte + post_intervention_vals[:, 1:, :] Rte = Rte + post_intervention_vals[:, :1, :] return Dataset( - data.Xtr, data.Xte, data.ytr, data.yte, - data._start_date, Ptr, Pte, Rtr, Rte, - data.counterfactual, data.synthetic + data.Xtr, + data.Xte, + data.ytr, + data.yte, + data._start_date, + Ptr, + Pte, + Rtr, + Rte, + data.counterfactual, + data.synthetic, ) diff --git a/src/causal_validation/transforms/noise.py b/src/causal_validation/transforms/noise.py index e251225..70c436c 100644 --- a/src/causal_validation/transforms/noise.py +++ b/src/causal_validation/transforms/noise.py @@ -1,4 +1,7 @@ -from dataclasses import dataclass, field +from dataclasses import ( + dataclass, + field, +) from typing import Tuple from jaxtyping import Float @@ -7,12 +10,12 @@ from causal_validation.data import Dataset from causal_validation.transforms.base import ( + AdditiveCovariateTransform, AdditiveOutputTransform, - AdditiveCovariateTransform ) from causal_validation.transforms.parameter import ( + CovariateNoiseParameter, TimeVaryingParameter, - CovariateNoiseParameter ) @@ -53,8 +56,8 @@ class CovariateNoise(AdditiveCovariateTransform): def get_values(self, data: Dataset) -> Float[np.ndarray, "N D"]: noise = self.noise_dist.get_value( - n_units=data.n_units+1, + n_units=data.n_units + 1, n_timepoints=data.n_timepoints, - n_covariates=data.n_covariates + n_covariates=data.n_covariates, ) return noise diff --git a/src/causal_validation/transforms/parameter.py b/src/causal_validation/transforms/parameter.py index 5806b80..e0fac1e 100644 --- a/src/causal_validation/transforms/parameter.py +++ b/src/causal_validation/transforms/parameter.py @@ -57,8 +57,7 @@ def get_value( self, n_units: int, n_timepoints: int, n_covariates: int ) -> Float[np.ndarray, "{n_timepoints} {n_units} {n_covariates}"]: covariate_noise = self.sampling_dist.rvs( - size=(n_timepoints, n_units, n_covariates), - random_state=self.random_state + size=(n_timepoints, n_units, n_covariates), random_state=self.random_state ) return covariate_noise diff --git a/tests/test_causal_validation/test_config.py b/tests/test_causal_validation/test_config.py index afca1b4..77175f9 100644 --- a/tests/test_causal_validation/test_config.py +++ b/tests/test_causal_validation/test_config.py @@ -1,5 +1,8 @@ +from hypothesis import ( + given, + strategies as st, +) import numpy as np -from hypothesis import given, strategies as st from causal_validation.config import Config @@ -7,13 +10,13 @@ @given( n_units=st.integers(min_value=1, max_value=10), n_pre=st.integers(min_value=1, max_value=20), - n_post=st.integers(min_value=1, max_value=20) + n_post=st.integers(min_value=1, max_value=20), ) def test_config_basic_initialization(n_units, n_pre, n_post): cfg = Config( n_control_units=n_units, n_pre_intervention_timepoints=n_pre, - n_post_intervention_timepoints=n_post + n_post_intervention_timepoints=n_post, ) assert cfg.n_control_units == n_units assert cfg.n_pre_intervention_timepoints == n_pre @@ -29,7 +32,7 @@ def test_config_basic_initialization(n_units, n_pre, n_post): n_pre=st.integers(min_value=1, max_value=10), n_post=st.integers(min_value=1, max_value=10), n_covariates=st.integers(min_value=1, max_value=3), - seed=st.integers(min_value=1, max_value=1000) + seed=st.integers(min_value=1, max_value=1000), ) def test_config_with_covariates_auto_generation( n_units, n_pre, n_post, n_covariates, seed @@ -39,7 +42,7 @@ def test_config_with_covariates_auto_generation( n_pre_intervention_timepoints=n_pre, n_post_intervention_timepoints=n_post, n_covariates=n_covariates, - seed=seed + seed=seed, ) assert cfg.n_covariates == n_covariates assert cfg.covariate_means.shape == (n_units, n_covariates) @@ -50,7 +53,7 @@ def test_config_with_covariates_auto_generation( @given( n_units=st.integers(min_value=1, max_value=3), - n_covariates=st.integers(min_value=1, max_value=3) + n_covariates=st.integers(min_value=1, max_value=3), ) def test_config_with_explicit_covariate_means(n_units, n_covariates): means = np.random.random((n_units, n_covariates)) @@ -59,14 +62,14 @@ def test_config_with_explicit_covariate_means(n_units, n_covariates): n_pre_intervention_timepoints=10, n_post_intervention_timepoints=5, n_covariates=n_covariates, - covariate_means=means + covariate_means=means, ) np.testing.assert_array_equal(cfg.covariate_means, means) @given( n_units=st.integers(min_value=1, max_value=3), - n_covariates=st.integers(min_value=1, max_value=3) + n_covariates=st.integers(min_value=1, max_value=3), ) def test_config_with_explicit_covariate_stds(n_units, n_covariates): stds = np.random.random((n_units, n_covariates)) + 0.1 @@ -75,6 +78,6 @@ def test_config_with_explicit_covariate_stds(n_units, n_covariates): n_pre_intervention_timepoints=10, n_post_intervention_timepoints=5, n_covariates=n_covariates, - covariate_stds=stds + covariate_stds=stds, ) np.testing.assert_array_equal(cfg.covariate_stds, stds) diff --git a/tests/test_causal_validation/test_data.py b/tests/test_causal_validation/test_data.py index 28b4638..51796fc 100644 --- a/tests/test_causal_validation/test_data.py +++ b/tests/test_causal_validation/test_data.py @@ -1,4 +1,5 @@ from copy import deepcopy +import datetime as dt import string import typing as tp @@ -23,7 +24,6 @@ simulate_data, ) from causal_validation.types import InterventionTypes -import datetime as dt MIN_STRING_LENGTH = 1 MAX_STRING_LENGTH = 20 @@ -134,16 +134,16 @@ def test_to_df_no_cov(n_control: int, n_pre_treatment: int, n_post_treatment: in assert isinstance(index, DatetimeIndex) assert index[0].strftime("%Y-%m-%d") == data._start_date.strftime("%Y-%m-%d") + @given( n_control=st.integers(min_value=1, max_value=50), n_pre_treatment=st.integers(min_value=1, max_value=50), n_post_treatment=st.integers(min_value=1, max_value=50), n_covariates=st.integers(min_value=1, max_value=50), ) -def test_to_df_with_cov(n_control: int, - n_pre_treatment: int, - n_post_treatment: int, - n_covariates:int): +def test_to_df_with_cov( + n_control: int, n_pre_treatment: int, n_post_treatment: int, n_covariates: int +): constants = TestConstants( N_POST_TREATMENT=n_post_treatment, N_PRE_TREATMENT=n_pre_treatment, @@ -162,8 +162,7 @@ def test_to_df_with_cov(n_control: int, assert isinstance(df_covs, pd.DataFrame) assert df_covs.shape == ( n_pre_treatment + n_post_treatment, - n_covariates * (n_control + NUM_TREATED) - + NUM_NON_CONTROL_COLS - NUM_TREATED, + n_covariates * (n_control + NUM_TREATED) + NUM_NON_CONTROL_COLS - NUM_TREATED, ) colnames = data._get_columns() @@ -355,9 +354,18 @@ def test_counterfactual_synthetic_attributes(n_pre: int, n_post: int, n_control: synthetic_vals = np.random.randn(n_post, 1) data_with_attrs = Dataset( - data.Xtr, data.Xte, data.ytr, data.yte, data._start_date, - data.Ptr, data.Pte, data.Rtr, data.Rte, - counterfactual_vals, synthetic_vals, "test_dataset" + data.Xtr, + data.Xte, + data.ytr, + data.yte, + data._start_date, + data.Ptr, + data.Pte, + data.Rtr, + data.Rte, + counterfactual_vals, + synthetic_vals, + "test_dataset", ) np.testing.assert_array_equal(data_with_attrs.counterfactual, counterfactual_vals) @@ -416,6 +424,7 @@ def test_control_treated_properties(n_pre: int, n_post: int, n_control: int): np.testing.assert_array_equal(treated_units, expected_treated) assert treated_units.shape == (n_pre + n_post, 1) + @given( seeds=st.lists( elements=st.integers(min_value=1, max_value=1000), min_size=1, max_size=10 @@ -452,6 +461,7 @@ def test_dataset_container(seeds: tp.List[int], to_name: bool): assert k == f"Dataset {idx}" assert v == datasets[idx] + @given( n_pre=st.integers(min_value=10, max_value=100), n_post=st.integers(min_value=10, max_value=100), @@ -522,4 +532,3 @@ def test_covariate_properties_with_covariates( post_cov = data.post_intervention_covariates assert post_cov == (Pte, Rte) - diff --git a/tests/test_causal_validation/test_simulate.py b/tests/test_causal_validation/test_simulate.py index 52213dc..2ef23c0 100644 --- a/tests/test_causal_validation/test_simulate.py +++ b/tests/test_causal_validation/test_simulate.py @@ -1,5 +1,8 @@ +from hypothesis import ( + given, + strategies as st, +) import numpy as np -from hypothesis import given, strategies as st from causal_validation.config import Config from causal_validation.simulate import simulate @@ -9,14 +12,14 @@ n_units=st.integers(min_value=1, max_value=5), n_pre=st.integers(min_value=1, max_value=10), n_post=st.integers(min_value=1, max_value=10), - seed=st.integers(min_value=1, max_value=1000) + seed=st.integers(min_value=1, max_value=1000), ) def test_simulate_basic(n_units, n_pre, n_post, seed): cfg = Config( n_control_units=n_units, n_pre_intervention_timepoints=n_pre, n_post_intervention_timepoints=n_post, - seed=seed + seed=seed, ) data = simulate(cfg) @@ -32,7 +35,7 @@ def test_simulate_basic(n_units, n_pre, n_post, seed): n_pre=st.integers(min_value=1, max_value=10), n_post=st.integers(min_value=1, max_value=10), n_covariates=st.integers(min_value=1, max_value=3), - seed=st.integers(min_value=1, max_value=1000) + seed=st.integers(min_value=1, max_value=1000), ) def test_simulate_with_covariates(n_units, n_pre, n_post, n_covariates, seed): cfg = Config( @@ -40,7 +43,7 @@ def test_simulate_with_covariates(n_units, n_pre, n_post, n_covariates, seed): n_pre_intervention_timepoints=n_pre, n_post_intervention_timepoints=n_post, n_covariates=n_covariates, - seed=seed + seed=seed, ) data = simulate(cfg) @@ -59,14 +62,14 @@ def test_simulate_with_covariates(n_units, n_pre, n_post, n_covariates, seed): n_units=st.integers(min_value=1, max_value=5), n_pre=st.integers(min_value=1, max_value=10), n_post=st.integers(min_value=1, max_value=10), - seed=st.integers(min_value=1, max_value=1000) + seed=st.integers(min_value=1, max_value=1000), ) def test_simulate_reproducible(n_units, n_pre, n_post, seed): cfg1 = Config( n_control_units=n_units, n_pre_intervention_timepoints=n_pre, n_post_intervention_timepoints=n_post, - seed=seed + seed=seed, ) data1 = simulate(cfg1) @@ -74,7 +77,7 @@ def test_simulate_reproducible(n_units, n_pre, n_post, seed): n_control_units=n_units, n_pre_intervention_timepoints=n_pre, n_post_intervention_timepoints=n_post, - seed=seed + seed=seed, ) data2 = simulate(cfg2) @@ -88,7 +91,7 @@ def test_simulate_reproducible(n_units, n_pre, n_post, seed): n_units=st.integers(min_value=1, max_value=3), n_pre=st.integers(min_value=3, max_value=10), n_post=st.integers(min_value=1, max_value=5), - seed=st.integers(min_value=1, max_value=1000) + seed=st.integers(min_value=1, max_value=1000), ) def test_simulate_covariate_effects(n_units, n_pre, n_post, seed): cfg = Config( @@ -97,7 +100,7 @@ def test_simulate_covariate_effects(n_units, n_pre, n_post, seed): n_post_intervention_timepoints=n_post, n_covariates=1, covariate_coeffs=np.array([10.0]), - seed=seed + seed=seed, ) data_with_cov = simulate(cfg) @@ -105,7 +108,7 @@ def test_simulate_covariate_effects(n_units, n_pre, n_post, seed): n_control_units=n_units, n_pre_intervention_timepoints=n_pre, n_post_intervention_timepoints=n_post, - seed=seed + seed=seed, ) data_no_cov = simulate(cfg_no_cov) @@ -117,7 +120,7 @@ def test_simulate_covariate_effects(n_units, n_pre, n_post, seed): n_units=st.integers(min_value=1, max_value=3), n_pre=st.integers(min_value=3, max_value=10), n_post=st.integers(min_value=1, max_value=5), - seed=st.integers(min_value=1, max_value=1000) + seed=st.integers(min_value=1, max_value=1000), ) def test_simulate_exact_covariate_effects(n_units, n_pre, n_post, seed): cfg = Config( @@ -125,10 +128,10 @@ def test_simulate_exact_covariate_effects(n_units, n_pre, n_post, seed): n_pre_intervention_timepoints=n_pre, n_post_intervention_timepoints=n_post, n_covariates=2, - covariate_means=np.ones((n_units,2)), - covariate_stds= 1e-12*np.ones((n_units,2)), + covariate_means=np.ones((n_units, 2)), + covariate_stds=1e-12 * np.ones((n_units, 2)), covariate_coeffs=np.array([10.0, 5.0]), - seed=seed + seed=seed, ) data_with_cov = simulate(cfg) @@ -136,10 +139,9 @@ def test_simulate_exact_covariate_effects(n_units, n_pre, n_post, seed): n_control_units=n_units, n_pre_intervention_timepoints=n_pre, n_post_intervention_timepoints=n_post, - seed=seed + seed=seed, ) data_no_cov = simulate(cfg_no_cov) - assert np.allclose(data_with_cov.Xtr-15, data_no_cov.Xtr) - assert np.allclose(data_with_cov.Xte-15, data_no_cov.Xte) - + assert np.allclose(data_with_cov.Xtr - 15, data_no_cov.Xtr) + assert np.allclose(data_with_cov.Xte - 15, data_no_cov.Xte) diff --git a/tests/test_causal_validation/test_transforms/test_noise.py b/tests/test_causal_validation/test_transforms/test_noise.py index 8ffef8c..eb05731 100644 --- a/tests/test_causal_validation/test_transforms/test_noise.py +++ b/tests/test_causal_validation/test_transforms/test_noise.py @@ -130,6 +130,7 @@ def test_perturbation_impact( assert np.max(diff_te_list[0]) < np.max(diff_te_list[2]) assert np.min(diff_te_list[0]) < np.min(diff_te_list[2]) + # Covariate Noise Test def test_cov_slot_type(): noise_transform = CovariateNoise() @@ -138,7 +139,7 @@ def test_cov_slot_type(): @given(n_covariates=st.integers(min_value=1, max_value=50)) @settings(max_examples=5) -def test_output_covariate_transform(n_covariates:int): +def test_output_covariate_transform(n_covariates: int): CONSTANTS2 = TestConstants(N_COVARIATES=n_covariates) base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, CONSTANTS2) @@ -173,7 +174,7 @@ def test_output_covariate_transform(n_covariates:int): @given(n_covariates=st.integers(min_value=1, max_value=50)) @settings(max_examples=5) -def test_cov_composite_transform(n_covariates:int): +def test_cov_composite_transform(n_covariates: int): CONSTANTS2 = TestConstants(N_COVARIATES=n_covariates) base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, CONSTANTS2) @@ -206,7 +207,7 @@ def test_cov_perturbation_impact( loc_small: float, scale_large: float, scale_small: float, - n_covariates:int + n_covariates: int, ): CONSTANTS2 = TestConstants(N_COVARIATES=n_covariates) base_data = simulate_data(GLOBAL_MEAN, DEFAULT_SEED, CONSTANTS2) diff --git a/tests/test_causal_validation/test_weights.py b/tests/test_causal_validation/test_weights.py index 543d711..b409ba3 100644 --- a/tests/test_causal_validation/test_weights.py +++ b/tests/test_causal_validation/test_weights.py @@ -43,8 +43,9 @@ def test_weight_contr_3d(n_units: int, n_time: int, n_covariates: int): weighted_covs = weights.weight_contr(covariates) assert weighted_covs.shape == (n_time, 1, n_covariates) - expected = np.einsum("n d k, d i -> n i k", - covariates, weights.get_weights(covariates)) + expected = np.einsum( + "n d k, d i -> n i k", covariates, weights.get_weights(covariates) + ) np.testing.assert_almost_equal(weighted_covs, expected, decimal=6)