44#
55
66import logging
7- from typing import cast , Any , Dict , List , Union
7+ from typing import Any , cast , Dict , List , Union
88
99import torch
1010from executorch .backends .apple .mps .mps_preprocess import MPSBackend
1111from executorch .backends .apple .mps .operators .node_visitor import get_node_visitors
1212from executorch .backends .apple .mps .utils .mps_utils import is_parameter
13+ from executorch .backends .transforms import get_shape
1314from executorch .exir .backend .backend_details import CompileSpec
1415from executorch .exir .backend .canonical_partitioners .pattern_op_partitioner import (
1516 generate_partitions_from_list_of_nodes ,
2021 PartitionResult ,
2122)
2223from executorch .exir .backend .utils import tag_constant_data
24+ from executorch .exir .dialects ._ops import ops as exir_ops
2325from torch .export .exported_program import ExportedProgram
2426from torch .fx .passes .infra .partitioner import Partition
2527from torch .fx .passes .operator_support import OperatorSupportBase
26- from executorch .exir .dialects ._ops import ops as exir_ops
27- from executorch .backends .transforms import get_shape
2828
2929FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
3030logging .basicConfig (level = logging .DEBUG , format = FORMAT )
3636 exir_ops .edge .aten .index_put .default ,
3737]
3838
39+
3940class MPSOperatorSupport (OperatorSupportBase ):
4041 def __init__ (self , edge_program : torch .export .ExportedProgram , compiler_specs ):
4142 self .node_visitors = get_node_visitors (edge_program )
@@ -90,7 +91,10 @@ def mps_graph_advanced_indexing_support(self, node: torch.fx.Node):
9091
9192 def use_metal_kernel (self , node : torch .fx .Node ):
9293 if node .target in METAL_KERNELS :
93- if node .target == exir_ops .edge .aten .index .Tensor or node .target == exir_ops .edge .aten .index_put .default :
94+ if (
95+ node .target == exir_ops .edge .aten .index .Tensor
96+ or node .target == exir_ops .edge .aten .index_put .default
97+ ):
9498 if not self .mps_graph_advanced_indexing_support (node ):
9599 return True
96100 return False
@@ -104,7 +108,9 @@ def tag_nodes(self, partitions: List[Partition]) -> None:
104108 logging .warning (f"[WARNING] Using Metal kernel for op { node .name } !" )
105109 # Partition the Metal kernel into a separate partition
106110 crt_partition_counter += 1
107- delegation_tag = f"{ delegation_tag } _metal_kernel_{ crt_partition_counter } "
111+ delegation_tag = (
112+ f"{ delegation_tag } _metal_kernel_{ crt_partition_counter } "
113+ )
108114 crt_partition_counter += 1
109115 else :
110116 delegation_tag = f"{ delegation_tag } _{ crt_partition_counter } "
0 commit comments