Skip to content

Commit 319a4f2

Browse files
zhxchen17facebook-github-bot
authored andcommitted
Remove unneeded _to_copy in edge dialect.
Summary: In executorch we will dtype-specialize the kernels and also run on a single device with export. Therefore _to_copy is not needed in edge dialect. Reviewed By: tugsbayasgalan Differential Revision: D56579169 fbshipit-source-id: 5a2e3cd453a11bd2ad009b439587b0fc589f7fe4
1 parent 8fcba36 commit 319a4f2

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

exir/passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass
4444
from executorch.exir.passes.normalize_transpose_pass import NormalizeTransposePass
4545
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
46-
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass
46+
from executorch.exir.passes.remove_noop_pass import RemoveNoopPass, RemoveToCopyPass
4747
from executorch.exir.passes.replace_aten_with_edge_pass import OpReplacePass
4848
from executorch.exir.passes.replace_broken_ops_with_function_ops_pass import (
4949
ReplaceBrokenOpsWithFunctionalOpsPass,
@@ -482,6 +482,7 @@ def dead_code_elimination_pass(graph_module: torch.fx.GraphModule) -> PassResult
482482
ScalarToTensorPass(),
483483
SymToTensorPass(),
484484
RemoveNoopPass(),
485+
RemoveToCopyPass(),
485486
]
486487
).passes
487488

exir/passes/remove_noop_pass.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,30 @@ def call(self, graph_module: GraphModule) -> PassResult:
9090
graph_module.graph.eliminate_dead_code()
9191

9292
return PassResult(graph_module, True)
93+
94+
95+
class RemoveToCopyPass(ExportPass):
96+
"""
97+
Removes _to_copy that pass through arguments.
98+
"""
99+
100+
def call(self, graph_module: GraphModule) -> PassResult:
101+
for node in graph_module.graph.nodes:
102+
if node.op != "call_function":
103+
continue
104+
105+
if node.target not in (torch.ops.aten._to_copy.default,):
106+
continue
107+
108+
orig_tensor = node.args[0].meta["val"]
109+
110+
if (
111+
orig_tensor.dtype == node.meta["val"].dtype
112+
and orig_tensor.device == node.meta["val"].device
113+
):
114+
node.replace_all_uses_with(node.args[0])
115+
116+
graph_module.graph.eliminate_dead_code()
117+
graph_module.graph.lint()
118+
119+
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)