|
42 | 42 | Shape, |
43 | 43 | Shape_i, |
44 | 44 | SpecifyShape, |
45 | | - Unbroadcast, |
46 | 45 | specify_shape, |
47 | | - unbroadcast, |
48 | 46 | ) |
49 | 47 | from pytensor.tensor.subtensor import Subtensor, get_idx_list |
50 | 48 | from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes |
@@ -1296,78 +1294,3 @@ def local_track_shape_i(fgraph, node): |
1296 | 1294 | # structure. |
1297 | 1295 | replacement = shape_feature.scheduled[node] |
1298 | 1296 | return [shape_feature.shape_of[replacement][node.op.i]] |
1299 | | - |
1300 | | - |
1301 | | -@register_useless |
1302 | | -@register_canonicalize |
1303 | | -@register_specialize |
1304 | | -@node_rewriter([Unbroadcast]) |
1305 | | -def local_useless_unbroadcast(fgraph, node): |
1306 | | - """Remove `Unbroadcast` if it does not actually change the broadcasting pattern.""" |
1307 | | - if isinstance(node.op, Unbroadcast): |
1308 | | - x = node.inputs[0] |
1309 | | - if x.type.ndim == node.outputs[0].type.ndim and all( |
1310 | | - s1 == s2 |
1311 | | - for s1, s2 in zip(x.type.shape, node.outputs[0].type.shape, strict=True) |
1312 | | - if s1 == 1 or s2 == 1 |
1313 | | - ): |
1314 | | - # No broadcastable flag was modified |
1315 | | - # No need to copy over stack trace, |
1316 | | - # because x should already have a stack trace. |
1317 | | - return [x] |
1318 | | - else: |
1319 | | - # Keep the flags that modify something |
1320 | | - new_axes = tuple(ax for ax in node.op.axes if x.type.shape[ax] == 1) |
1321 | | - if new_axes == node.op.axes: |
1322 | | - # All flags are useful |
1323 | | - return None |
1324 | | - else: |
1325 | | - r = unbroadcast(x, *new_axes) |
1326 | | - # Copy over stacktrace from previous output |
1327 | | - copy_stack_trace(node.outputs, r) |
1328 | | - return [r] |
1329 | | - |
1330 | | - |
1331 | | -@register_canonicalize |
1332 | | -@register_specialize |
1333 | | -@node_rewriter([Unbroadcast]) |
1334 | | -def local_unbroadcast_lift(fgraph, node): |
1335 | | - """ |
1336 | | - Lifts `Unbroadcast` through unary Elemwise operations, |
1337 | | - and merges consecutive `Unbroadcast`s. |
1338 | | -
|
1339 | | - Unbroadcast(Elemwise(x)) => Elemwise(Unbroadcast(x)) |
1340 | | - Unbroadcast(Unbroadcast(x)) => Unbroadcast(x) |
1341 | | -
|
1342 | | - TODO: Implement equivalent Elemwise lift for SpecifyShape |
1343 | | - """ |
1344 | | - op = node.op |
1345 | | - if not isinstance(op, Unbroadcast): |
1346 | | - return False |
1347 | | - |
1348 | | - inp = node.inputs[0] |
1349 | | - inode = inp.owner |
1350 | | - if inode and isinstance(inode.op, Elemwise) and len(inode.inputs) == 1: |
1351 | | - if len(fgraph.clients.get(inp, ())) == 1: |
1352 | | - unbroadcasted = unbroadcast(inode.inputs[0], *op.axes) |
1353 | | - copy_stack_trace(node.outputs, unbroadcasted) |
1354 | | - |
1355 | | - rval = inode.op.make_node(unbroadcasted).outputs |
1356 | | - |
1357 | | - # Copy over stacktrace from previous output (after unbroadcasting) |
1358 | | - # and input (after elemwise operation) to new output, because an |
1359 | | - # error in the new graph could have been caused by either of the |
1360 | | - # two ops. |
1361 | | - copy_stack_trace(node.outputs + node.inputs, rval) |
1362 | | - return rval |
1363 | | - |
1364 | | - if inode and isinstance(inode.op, Unbroadcast): |
1365 | | - # Merge axis of each unbroadcast |
1366 | | - axis = tuple(set(inode.op.axes).union(set(op.axes))) |
1367 | | - iinput = inode.inputs[0] |
1368 | | - rval = [unbroadcast(iinput, *axis)] |
1369 | | - # Copy over stacktrace from previous output (after second unbroadcasting) |
1370 | | - # and from previous input (after first unbroadcasting) because an error in |
1371 | | - # the new graph could have been caused by either of the two Unbroadcast ops. |
1372 | | - copy_stack_trace(node.outputs + node.inputs, rval) |
1373 | | - return rval |
0 commit comments