@@ -652,12 +652,12 @@ def elemwise_scalar_op_has_c_code(
652652 # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
653653 nodes_bitflags = {node : 1 << i for i , node in enumerate (fgraph .toposort ())}
654654 # Root variables have `None` as owner, which we can handle with a bitset of 0
655- ancestors_bitset = {None : 0 }
655+ ancestors_bitsets = {None : 0 }
656656 for node , node_bitflag in nodes_bitflags .items ():
657657 # The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
658- ancestors_bitset [node ] = reduce (
658+ ancestors_bitsets [node ] = reduce (
659659 or_ ,
660- (ancestors_bitset [inp .owner ] for inp in node .inputs ),
660+ (ancestors_bitsets [inp .owner ] for inp in node .inputs ),
661661 node_bitflag ,
662662 )
663663 # Handle root and leaf nodes gracefully
@@ -666,10 +666,12 @@ def elemwise_scalar_op_has_c_code(
666666 nodes_bitflags [None ] = 0
667667 # Nothing ever depends on the special Output nodes, so just use a new bit for all of them
668668 out_bitflag = 1 << len (nodes_bitflags )
669- for out in fg .outputs :
670- for client , _ in fg_clients [out ]:
671- if isinstance (client .op , Output ):
672- nodes_bitflags [client ] = out_bitflag
669+ nodes_bitflags |= (
670+ (client , out_bitflag )
671+ for out in fg .outputs
672+ for client , _ in fg_clients [out ]
673+ if isinstance (client .op , Output )
674+ )
673675
674676 # Start main loop to find collection of fuseable subgraphs
675677 # We store the collection in `sorted_subgraphs`, in reverse topological order
@@ -726,7 +728,7 @@ def elemwise_scalar_op_has_c_code(
726728 if node_bitflag & unfuseable_ancestors_bitset :
727729 # An unfuseable ancestor of the subgraph depends on this node, can't fuse
728730 continue
729- elif ancestors_bitset [node ] & unfuseable_clients_bitset :
731+ elif ancestors_bitsets [node ] & unfuseable_clients_bitset :
730732 # This node depends on an unfuseable client of the subgraph, can't fuse
731733 continue
732734
@@ -742,7 +744,7 @@ def elemwise_scalar_op_has_c_code(
742744 for inp in node .inputs :
743745 ancestor_node = inp .owner
744746 ancestor_bitflag = nodes_bitflags [ancestor_node ]
745- if ancestor_bitflag & subgraph_bitset :
747+ if ( not is_ancestor ) and ( ancestor_bitflag & subgraph_bitset ) :
746748 continue
747749 if node in fuseable_clients .get (ancestor_node , ()):
748750 heappush (
@@ -752,14 +754,14 @@ def elemwise_scalar_op_has_c_code(
752754 else :
753755 # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
754756 # nor with any of the ancestor's ancestors
755- unfuseable_ancestors_bitset |= ancestors_bitset [
757+ unfuseable_ancestors_bitset |= ancestors_bitsets [
756758 ancestor_node
757759 ]
758760
759761 next_fuseable_clients = fuseable_clients .get (node , ())
760762 for client , _ in fg_clients [node .outputs [0 ]]:
761763 client_bitflag = nodes_bitflags [client ]
762- if client_bitflag & subgraph_bitset :
764+ if is_ancestor and ( client_bitflag & subgraph_bitset ) :
763765 continue
764766 if client in next_fuseable_clients :
765767 heappush (fuseables_nodes_queue , (client_bitflag , client ))
0 commit comments