@@ -652,7 +652,7 @@ def elemwise_scalar_op_has_c_code(
652652 # `ancestors_bitset[C] & (node_bitset[A] | node_bitset[B]) != 0`
653653 nodes_bitflags = {node : 1 << i for i , node in enumerate (fgraph .toposort ())}
654654 # Root variables have `None` as owner, which we can handle with a bitset of 0
655- ancestors_bitsets = {None : 0 }
655+ ancestors_bitsets : dict [ Apply | None , int ] = {None : 0 }
656656 for node , node_bitflag in nodes_bitflags .items ():
657657 # The bitset of each node is the union of the bitsets of its inputs, plus its own bit flag
658658 ancestors_bitsets [node ] = reduce (
@@ -694,9 +694,13 @@ def elemwise_scalar_op_has_c_code(
694694 # For simplicity, we always want to visit ancestors before clients
695695 # For ancestors, we want to visit the later nodes first (those that have more dependencies)
696696 # whereas for clients we want to visit earlier nodes first (those that have fewer dependencies)
697- # To achieve this we use the bitflag as the sorting key (which encodes the topological order)
698- # and negate it for ancestors.
699- fuseables_nodes_queue = [(- starting_bitflag , starting_node )]
697+ # To achieve this we use the ancestors_bitset as the sorting key (which encodes the topological order)
698+ # and negate it for ancestors. We use the ancestors_bitset instead of the node bitflag because we
699+ # update the former when we find a fuseable subgraph, emulating the effect of recomputing the
700+ # topological order on the remaining nodes.
701+ fuseables_nodes_queue = [
702+ (- ancestors_bitsets [starting_node ], starting_bitflag , starting_node )
703+ ]
700704 heapify (fuseables_nodes_queue )
701705
702706 # We keep 3 bitsets during the exploration of a new subgraph:
@@ -715,10 +719,12 @@ def elemwise_scalar_op_has_c_code(
715719 unfuseable_clients_bitset = 0
716720
717721 while fuseables_nodes_queue :
718- node_bitflag , node = heappop (fuseables_nodes_queue )
719- is_ancestor = node_bitflag < 0
722+ node_ancestors_bitset , node_bitflag , node = heappop (
723+ fuseables_nodes_queue
724+ )
725+ is_ancestor = node_ancestors_bitset < 0
720726 if is_ancestor :
721- node_bitflag = - node_bitflag
727+ node_ancestors_bitset = - node_ancestors_bitset
722728
723729 if node_bitflag & subgraph_bitset :
724730 # Already part of the subgraph
@@ -728,7 +734,7 @@ def elemwise_scalar_op_has_c_code(
728734 if node_bitflag & unfuseable_ancestors_bitset :
729735 # An unfuseable ancestor of the subgraph depends on this node, can't fuse
730736 continue
731- elif ancestors_bitsets [ node ] & unfuseable_clients_bitset :
737+ elif node_ancestors_bitset & unfuseable_clients_bitset :
732738 # This node depends on an unfuseable client of the subgraph, can't fuse
733739 continue
734740
@@ -749,7 +755,11 @@ def elemwise_scalar_op_has_c_code(
749755 if node in fuseable_clients .get (ancestor_node , ()):
750756 heappush (
751757 fuseables_nodes_queue ,
752- (- ancestor_bitflag , ancestor_node ),
758+ (
759+ - ancestors_bitsets [ancestor_node ],
760+ ancestor_bitflag ,
761+ ancestor_node ,
762+ ),
753763 )
754764 else :
755765 # If the node is not in the ancestor's fuseable clients set, it's not fuseable with it,
@@ -764,16 +774,17 @@ def elemwise_scalar_op_has_c_code(
764774 if is_ancestor and (client_bitflag & subgraph_bitset ):
765775 continue
766776 if client in next_fuseable_clients :
767- heappush (fuseables_nodes_queue , (client_bitflag , client ))
777+ heappush (
778+ fuseables_nodes_queue ,
779+ (ancestors_bitsets [client ], client_bitflag , client ),
780+ )
768781 else :
769782 # If a client is not in the node's fuseable clients set, it's nto fuseable with it,
770783 # nor any of its clients. But we don't need to keep track of those as any downstream
771784 # client we may consider later will also depend on this unfuseable client and be rejected
772785 unfuseable_clients_bitset |= client_bitflag
773786
774- # Finished exploring this subgraph
775- all_subgraphs_bitset |= subgraph_bitset
776-
787+ # Finished expansion of subgraph
777788 if subgraph_bitset == starting_bitflag :
778789 # We ended were we started, no fusion possible
779790 continue
@@ -816,6 +827,18 @@ def elemwise_scalar_op_has_c_code(
816827 for out in subgraph_outputs :
817828 fuseable_clients .pop (out .owner , None )
818829
830+ # When we fuse multi-output subgraphs, we also need to fuse the dependencies of successor nodes.
831+ # Nodes that previously depended on a subset of the fused outputs, now depend on all of them.
832+ if len (subgraph_outputs ) > 1 :
833+ subgraph_and_ancestors = (
834+ subgraph_bitset | unfuseable_ancestors_bitset
835+ )
836+ ancestors_bitsets |= (
837+ (node , node_ancestors_bitset | subgraph_and_ancestors )
838+ for node , node_ancestors_bitset in ancestors_bitsets .items ()
839+ if node_ancestors_bitset & subgraph_bitset
840+ )
841+
819842 # Add new subgraph to sorted_subgraphs
820843 # Because we start from sink nodes in reverse topological order, most times new subgraphs
821844 # don't depend on previous subgraphs, so we can just append them at the end.
@@ -828,8 +851,7 @@ def elemwise_scalar_op_has_c_code(
828851 else :
829852 # But not here, so we need to find the right position for insertion.
830853 # We iterate through the previous subgraphs in topological order (reverse of the stored order).
831- # We exclude cumulatively exclude each subgraph_bitset and perform the same dependency check again.
832- # The (index + 1) of the firs iteration where the check passes is the correct insertion position.
854+ # We cumulatively exclude each subgraph_bitset and perform the same dependency check again, until it passes.
833855 remaining_subgraphs_bitset = all_subgraphs_bitset
834856 for index , (other_subgraph_bitset , _ ) in enumerate (
835857 reversed (sorted_subgraphs )
@@ -840,12 +862,20 @@ def elemwise_scalar_op_has_c_code(
840862 unfuseable_ancestors_bitset & remaining_subgraphs_bitset
841863 ):
842864 break # bingo
865+ else : # no-break
866+ raise RuntimeError (
867+ "Failed to find insertion point for fused subgraph"
868+ )
843869 sorted_subgraphs .insert (
844870 - (index + 1 ),
845871 (subgraph_bitset , (subgraph_inputs , subgraph_outputs )),
846872 )
847873
848- # yield from sorted_subgraphs, discarding the subgraph_bitset
874+ # Add subgraph to all_subgraphs_bitset
875+ all_subgraphs_bitset |= subgraph_bitset
876+
877+ # Finished exploring the whole graph
878+ # Yield from sorted_subgraphs, discarding the subgraph_bitset
849879 yield from (io for _ , io in sorted_subgraphs )
850880
851881 max_operands = elemwise_max_operands_fct (None )
0 commit comments