@@ -220,6 +220,42 @@ def test_multi_input_ablation_with_mask(self) -> None:
220220            perturbations_per_eval = (1 , 2 , 3 ),
221221        )
222222
223+     def  test_multi_input_ablation_with_int_input_tensor_and_float_baseline (
224+         self ,
225+     ) ->  None :
226+         def  sum_forward (* inps : torch .Tensor ) ->  torch .Tensor :
227+             flattened  =  [torch .flatten (inp , start_dim = 1 ) for  inp  in  inps ]
228+             return  torch .cat (flattened , dim = 1 ).sum (1 )
229+ 
230+         ablation_algo  =  FeatureAblation (sum_forward )
231+         inp1  =  torch .tensor ([[0 , 1 ], [3 , 4 ]])
232+         inp2  =  torch .tensor (
233+             [
234+                 [[0.1 , 0.2 ], [0.3 , 0.2 ]],
235+                 [[0.4 , 0.5 ], [0.3 , 0.2 ]],
236+             ]
237+         )
238+         inp3  =  torch .tensor ([[0 ], [1 ]])
239+ 
240+         expected  =  (
241+             torch .tensor ([[- 0.2 , 0.8 ], [2.8 , 3.8 ]]),
242+             torch .tensor (
243+                 [
244+                     [[- 3.0 , - 2.9 ], [- 2.8 , - 2.9 ]],
245+                     [[- 2.7 , - 2.6 ], [- 2.8 , - 2.9 ]],
246+                 ]
247+             ),
248+             torch .tensor ([[- 0.4 ], [0.6 ]]),
249+         )
250+         self ._ablation_test_assert (
251+             ablation_algo ,
252+             (inp1 , inp2 , inp3 ),
253+             expected ,
254+             target = None ,
255+             baselines = (0.2 , 3.1 , 0.4 ),
256+             test_enable_cross_tensor_attribution = [False , True ],
257+         )
258+ 
223259    def  test_multi_input_ablation_with_mask_weighted (self ) ->  None :
224260        ablation_algo  =  FeatureAblation (BasicModel_MultiLayer_MultiInput ())
225261        ablation_algo .use_weights  =  True 
0 commit comments