@@ -1295,12 +1295,28 @@ def local_inplace_setsubtensor(fgraph, node):
12951295
12961296@node_rewriter ([AdvancedIncSubtensor1 ], inplace = True )
12971297def local_inplace_AdvancedIncSubtensor1 (fgraph , node ):
1298- if isinstance (node .op , AdvancedIncSubtensor1 ) and not node .op .inplace :
1299- new_op = node .op .clone_inplace ()
1300- new_node = new_op (* node .inputs )
1301- copy_stack_trace (node .outputs , new_node )
1302- return [new_node ]
1303- return False
1298+ if node .op .inplace :
1299+ return
1300+
1301+ x , y , idx = node .inputs
1302+ if fgraph .has_destroyers ([x ]):
1303+ # In this case we can't operate inplace, but if x is just an alloc of zeros
1304+ # We're better off duplicating it and then acting on it inplace.
1305+ if (
1306+ x .owner is not None
1307+ and isinstance (x .owner .op , Alloc )
1308+ and all (x .owner .inputs [0 ].type .broadcastable )
1309+ and isinstance (x .owner .inputs [0 ], Constant )
1310+ and x .owner .inputs [0 ].unique_value == 0
1311+ ):
1312+ x = x .owner .clone ().outputs [0 ]
1313+ else :
1314+ return None # Inplace isn't valid
1315+
1316+ new_op = node .op .clone_inplace ()
1317+ new_node = new_op (x , y , idx )
1318+ copy_stack_trace (node .outputs , new_node )
1319+ return [new_node ]
13041320
13051321
13061322compile .optdb .register (
0 commit comments