1
+ from hypothesis import (
2
+ given ,
3
+ strategies as st ,
4
+ )
1
5
import numpy as np
2
- from hypothesis import given , strategies as st
3
6
4
7
from causal_validation .config import Config
5
8
6
9
7
10
@given (
8
11
n_units = st .integers (min_value = 1 , max_value = 10 ),
9
12
n_pre = st .integers (min_value = 1 , max_value = 20 ),
10
- n_post = st .integers (min_value = 1 , max_value = 20 )
13
+ n_post = st .integers (min_value = 1 , max_value = 20 ),
11
14
)
12
15
def test_config_basic_initialization (n_units , n_pre , n_post ):
13
16
cfg = Config (
14
17
n_control_units = n_units ,
15
18
n_pre_intervention_timepoints = n_pre ,
16
- n_post_intervention_timepoints = n_post
19
+ n_post_intervention_timepoints = n_post ,
17
20
)
18
21
assert cfg .n_control_units == n_units
19
22
assert cfg .n_pre_intervention_timepoints == n_pre
@@ -29,7 +32,7 @@ def test_config_basic_initialization(n_units, n_pre, n_post):
29
32
n_pre = st .integers (min_value = 1 , max_value = 10 ),
30
33
n_post = st .integers (min_value = 1 , max_value = 10 ),
31
34
n_covariates = st .integers (min_value = 1 , max_value = 3 ),
32
- seed = st .integers (min_value = 1 , max_value = 1000 )
35
+ seed = st .integers (min_value = 1 , max_value = 1000 ),
33
36
)
34
37
def test_config_with_covariates_auto_generation (
35
38
n_units , n_pre , n_post , n_covariates , seed
@@ -39,7 +42,7 @@ def test_config_with_covariates_auto_generation(
39
42
n_pre_intervention_timepoints = n_pre ,
40
43
n_post_intervention_timepoints = n_post ,
41
44
n_covariates = n_covariates ,
42
- seed = seed
45
+ seed = seed ,
43
46
)
44
47
assert cfg .n_covariates == n_covariates
45
48
assert cfg .covariate_means .shape == (n_units , n_covariates )
@@ -50,7 +53,7 @@ def test_config_with_covariates_auto_generation(
50
53
51
54
@given (
52
55
n_units = st .integers (min_value = 1 , max_value = 3 ),
53
- n_covariates = st .integers (min_value = 1 , max_value = 3 )
56
+ n_covariates = st .integers (min_value = 1 , max_value = 3 ),
54
57
)
55
58
def test_config_with_explicit_covariate_means (n_units , n_covariates ):
56
59
means = np .random .random ((n_units , n_covariates ))
@@ -59,14 +62,14 @@ def test_config_with_explicit_covariate_means(n_units, n_covariates):
59
62
n_pre_intervention_timepoints = 10 ,
60
63
n_post_intervention_timepoints = 5 ,
61
64
n_covariates = n_covariates ,
62
- covariate_means = means
65
+ covariate_means = means ,
63
66
)
64
67
np .testing .assert_array_equal (cfg .covariate_means , means )
65
68
66
69
67
70
@given (
68
71
n_units = st .integers (min_value = 1 , max_value = 3 ),
69
- n_covariates = st .integers (min_value = 1 , max_value = 3 )
72
+ n_covariates = st .integers (min_value = 1 , max_value = 3 ),
70
73
)
71
74
def test_config_with_explicit_covariate_stds (n_units , n_covariates ):
72
75
stds = np .random .random ((n_units , n_covariates )) + 0.1
@@ -75,6 +78,6 @@ def test_config_with_explicit_covariate_stds(n_units, n_covariates):
75
78
n_pre_intervention_timepoints = 10 ,
76
79
n_post_intervention_timepoints = 5 ,
77
80
n_covariates = n_covariates ,
78
- covariate_stds = stds
81
+ covariate_stds = stds ,
79
82
)
80
83
np .testing .assert_array_equal (cfg .covariate_stds , stds )
0 commit comments