44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import logging
8+ import operator
9+ from typing import Dict
10+
711import torch
12+ from executorch .exir import memory
13+ from executorch .exir .dialects ._ops import ops as exir_ops
14+ from executorch .exir .dialects .edge ._ops import EdgeOpOverload , EdgeOpOverloadPacket
15+ from tabulate import tabulate
816
917
1018# Get the output size of a 1D convolution given the input size and parameters
@@ -23,3 +31,120 @@ def get_conv1d_output_size(
2331 lout = (L + 2 * padding - dilation * (kernel_size - 1 ) - 1 ) // stride + 1
2432
2533 return torch .Size ((in_size [0 ], out_channels , lout ))
34+
35+
36+ # Return the overload packet for the edge op
37+ def get_edge_overload_packet (edge_op : EdgeOpOverload ) -> EdgeOpOverloadPacket :
38+ edge_op_namespace , edge_op_name = (
39+ edge_op .namespace ,
40+ edge_op ._schema .name .split ("::" )[1 ],
41+ )
42+ edge_op_overload_packet = getattr (
43+ getattr (exir_ops .edge , edge_op_namespace ), edge_op_name
44+ )
45+ return edge_op_overload_packet
46+
47+
48+ # Get the frequency list of ops in a graph module
49+ def get_ops_count (graph_module : torch .fx .GraphModule ) -> Dict [str , int ]:
50+ freq = {}
51+ # Loop over nodes to count the number of times each op occurs
52+ for node in graph_module .graph .nodes :
53+ if node .op == "call_function" :
54+ # Ignore getitem, alloc and view cases, we only want actual operations
55+ if (
56+ node .target == operator .getitem
57+ or node .target .__name__ == "alloc"
58+ or node .target == memory .view
59+ ):
60+ continue
61+ # If the op is already present, increment the count
62+ if get_edge_overload_packet (node .target ).__name__ in freq :
63+ freq [get_edge_overload_packet (node .target ).__name__ ] += 1
64+ # else, add a new entry
65+ else :
66+ freq [get_edge_overload_packet (node .target ).__name__ ] = 1
67+ return freq
68+
69+
70+ # Print the ops and how many times they occur multiple graph modules:
71+ # from export, from to_edge, and from Jarvis. Print the available
72+ # implementations for each op, and error out if the op is not supported.
73+ def print_ops_info (
74+ export_gm : torch .fx .GraphModule ,
75+ to_edge_gm : torch .fx .GraphModule ,
76+ jarvis_gm : torch .fx .GraphModule ,
77+ ):
78+ export_ops_count = get_ops_count (export_gm )
79+ to_edge_ops_count = get_ops_count (to_edge_gm )
80+ jarvis_ops_count = get_ops_count (jarvis_gm )
81+
82+ # De-duplicate the "<op>" and "<op>_copy" ops
83+ keys_to_delete_and_add = []
84+ for k1 in export_ops_count :
85+ for k2 in {** to_edge_ops_count , ** jarvis_ops_count }:
86+ if k2 .startswith (k1 ):
87+ keys_to_delete_and_add .append ((k1 , k2 ))
88+ break
89+
90+ for k in keys_to_delete_and_add :
91+ export_ops_count [k [1 ]] = export_ops_count [k [0 ]]
92+ del export_ops_count [k [0 ]]
93+
94+ removed_ops = []
95+ # Get the counts of the ops that are removed from the final graph
96+ for k in {** export_ops_count , ** to_edge_ops_count }:
97+ if k not in jarvis_ops_count :
98+ removed_ops .append (k )
99+
100+ # Create a dict of ops and their counts to pass to tabulate
101+ ops_count = [
102+ [
103+ op ,
104+ jarvis_ops_count [op ],
105+ to_edge_ops_count [op ] if op in to_edge_ops_count else 0 ,
106+ export_ops_count [op ] if op in export_ops_count else 0 ,
107+ ]
108+ for op in jarvis_ops_count
109+ ]
110+ sorted_ops_count = sorted (ops_count , key = lambda x : x [1 ], reverse = True )
111+
112+ # Create a dict of deleted ops and their counts to pass to tabulate
113+ removed_ops_count = [
114+ [
115+ op ,
116+ 0 ,
117+ to_edge_ops_count [op ] if op in to_edge_ops_count else 0 ,
118+ export_ops_count [op ] if op in export_ops_count else 0 ,
119+ ]
120+ for op in removed_ops
121+ ]
122+
123+ # Print the final ops and their counts in a tabular format
124+ logging .info (
125+ tabulate (
126+ sorted_ops_count ,
127+ headers = [
128+ "Final Operators " , # one character longer than the longest op name
129+ "Jarvis (Final) Graph" ,
130+ "To_edge Graph" ,
131+ "Export Graph" ,
132+ ],
133+ tablefmt = "outline" ,
134+ )
135+ )
136+
137+ # Print the removed ops and their counts in a tabular format (if any)
138+ if removed_ops != []:
139+ logging .info (
140+ tabulate (
141+ removed_ops_count ,
142+ headers = [
143+ "Deleted Operators " , # one character longer than the longest op name
144+ "Jarvis (Final) Graph" ,
145+ "To_edge Graph" ,
146+ "Export Graph" ,
147+ ],
148+ tablefmt = "outline" ,
149+ )
150+ )
0 commit comments