@@ -48,7 +48,7 @@ def boolean_indexing_set_or_inc(fgraph, node):
4848
4949@node_rewriter ([Sum ])
5050def boolean_indexing_sum (fgraph , node ):
51- """Replace the sum of `AdvancedSubtensor` with boolean indexing.
51+ """Replace the sum of `AdvancedSubtensor` with exclusively boolean indexing.
5252
5353 JAX cannot JIT-compile functions that use boolean indexing, but can compile
5454 those expressions that can be re-expressed using `jax.numpy.where`. This
@@ -61,21 +61,30 @@ def boolean_indexing_sum(fgraph, node):
6161 if not isinstance (operand , TensorVariable ):
6262 return
6363
64+ # If it's not a scalar reduction, it couldn't have been a pure boolean mask
65+ if node .outputs [0 ].ndim != 0 :
66+ return
67+
6468 if operand .owner is None :
6569 return
6670
6771 if not isinstance (operand .owner .op , AdvancedSubtensor ):
6872 return
6973
70- x = operand .owner .inputs [0 ]
71- cond = operand .owner .inputs [1 ]
74+ # Get out if AdvancedSubtensor has more than a single indexing operation
75+ if len (operand .owner .inputs ) > 2 :
76+ return
77+
78+ [x , cond ] = operand .owner .inputs
7279
7380 if not isinstance (cond , TensorVariable ):
7481 return
7582
7683 if not cond .type .dtype == "bool" :
7784 return
7885
86+ # Output must be a scalar, since pure boolean indexing returns a vector
87+ # No need to worry about axis
7988 out = at .sum (at .where (cond , x , 0 ))
8089 return out .owner .outputs
8190
0 commit comments