29
29
MAX_STRING_LENGTH = 20
30
30
DEFAULT_SEED = 123
31
31
NUM_NON_CONTROL_COLS = 2
32
+ NUM_TREATED = 1
32
33
LARGE_N_POST = 5000
33
34
LARGE_N_PRE = 5000
34
35
@@ -109,15 +110,15 @@ def test_indicator(n_pre_treatment: int, n_post_treatment: int):
109
110
n_pre_treatment = st .integers (min_value = 1 , max_value = 50 ),
110
111
n_post_treatment = st .integers (min_value = 1 , max_value = 50 ),
111
112
)
112
- def test_to_df (n_control : int , n_pre_treatment : int , n_post_treatment : int ):
113
+ def test_to_df_no_cov (n_control : int , n_pre_treatment : int , n_post_treatment : int ):
113
114
constants = TestConstants (
114
115
N_POST_TREATMENT = n_post_treatment ,
115
116
N_PRE_TREATMENT = n_pre_treatment ,
116
117
N_CONTROL = n_control ,
117
118
)
118
119
data = simulate_data (0.0 , DEFAULT_SEED , constants = constants )
119
120
120
- df = data .to_df ()
121
+ df , _ = data .to_df ()
121
122
assert isinstance (df , pd .DataFrame )
122
123
assert df .shape == (
123
124
n_pre_treatment + n_post_treatment ,
@@ -133,6 +134,47 @@ def test_to_df(n_control: int, n_pre_treatment: int, n_post_treatment: int):
133
134
assert isinstance (index , DatetimeIndex )
134
135
assert index [0 ].strftime ("%Y-%m-%d" ) == data ._start_date .strftime ("%Y-%m-%d" )
135
136
137
+ @given (
138
+ n_control = st .integers (min_value = 1 , max_value = 50 ),
139
+ n_pre_treatment = st .integers (min_value = 1 , max_value = 50 ),
140
+ n_post_treatment = st .integers (min_value = 1 , max_value = 50 ),
141
+ n_covariates = st .integers (min_value = 1 , max_value = 50 ),
142
+ )
143
+ def test_to_df_with_cov (n_control : int ,
144
+ n_pre_treatment : int ,
145
+ n_post_treatment : int ,
146
+ n_covariates :int ):
147
+ constants = TestConstants (
148
+ N_POST_TREATMENT = n_post_treatment ,
149
+ N_PRE_TREATMENT = n_pre_treatment ,
150
+ N_CONTROL = n_control ,
151
+ N_COVARIATES = n_covariates ,
152
+ )
153
+ data = simulate_data (0.0 , DEFAULT_SEED , constants = constants )
154
+
155
+ df_outs , df_covs = data .to_df ()
156
+ assert isinstance (df_outs , pd .DataFrame )
157
+ assert df_outs .shape == (
158
+ n_pre_treatment + n_post_treatment ,
159
+ n_control + NUM_NON_CONTROL_COLS ,
160
+ )
161
+
162
+ assert isinstance (df_covs , pd .DataFrame )
163
+ assert df_covs .shape == (
164
+ n_pre_treatment + n_post_treatment ,
165
+ n_covariates * (n_control + NUM_TREATED )
166
+ + NUM_NON_CONTROL_COLS - NUM_TREATED ,
167
+ )
168
+
169
+ colnames = data ._get_columns ()
170
+ assert isinstance (colnames , list )
171
+ assert colnames [0 ] == "T"
172
+ assert len (colnames ) == n_control + 1
173
+
174
+ index = data .full_index
175
+ assert isinstance (index , DatetimeIndex )
176
+ assert index [0 ].strftime ("%Y-%m-%d" ) == data ._start_date .strftime ("%Y-%m-%d" )
177
+
136
178
137
179
@given (
138
180
n_control = st .integers (min_value = 2 , max_value = 50 ),
0 commit comments