1+ # Licensed to the Apache Software Foundation (ASF) under one
2+ # or more contributor license agreements. See the NOTICE file
3+ # distributed with this work for additional information
4+ # regarding copyright ownership. The ASF licenses this file
5+ # to you under the Apache License, Version 2.0 (the
6+ # "License"); you may not use this file except in compliance
7+ # with the License. You may obtain a copy of the License at
8+ #
9+ # http://www.apache.org/licenses/LICENSE-2.0
10+ #
11+ # Unless required by applicable law or agreed to in writing,
12+ # software distributed under the License is distributed on an
13+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+ # KIND, either express or implied. See the License for the
15+ # specific language governing permissions and limitations
16+ # under the License.
17+ import mxnet as mx
18+ from gluoncv .model_zoo import get_model
19+
20+ import numpy as np
21+ import pytest
22+ import itertools
23+
24+ import tvm
25+ import tvm .relay .testing
26+ from tvm import relay
27+ from tvm .relay .op .contrib import dnnl
28+ import tvm .testing
29+ import argparse
30+
31+ has_dnnl_codegen = pytest .mark .skipif (
32+ not tvm .get_global_func ("relay.ext.dnnl" , True ), reason = "DNNL codegen not available"
33+ )
34+
35+ run_module = tvm .testing .parameter (
36+ pytest .param (False , marks = [has_dnnl_codegen , * tvm .testing .requires_llvm ()]),
37+ pytest .param (
38+ True , marks = [has_dnnl_codegen , * tvm .testing .requires_llvm ()]
39+ ),
40+ ids = ["compile" , "run" ],
41+ )
42+
43+
44+ def vmobj_to_list (o ):
45+ if isinstance (o , tvm .nd .NDArray ):
46+ return [o .numpy ()]
47+ elif isinstance (o , tvm .runtime .container .ADT ) or isinstance (o , list ):
48+ return [vmobj_to_list (f ) for f in o ]
49+ else :
50+ raise RuntimeError ("Unknown object type: %s" % type (o ))
51+
52+
53+ def assert_result_dict_holds (result_dict ):
54+ for k1 , k2 in itertools .combinations (result_dict , 2 ):
55+ res1 = vmobj_to_list (result_dict [k1 ])
56+ res2 = vmobj_to_list (result_dict [k2 ])
57+ for r1 , r2 in zip (res1 , res2 ):
58+ tvm .testing .assert_allclose (r1 , r2 , rtol = 1e-3 , atol = 1e-3 )
59+
60+
61+ def run_and_verify_func (config , target = "llvm" , run_module = True ):
62+ """Test a Relay func by compiling, running, and comparing TVM and DNNL outputs.
63+
64+ Parameters
65+ ----------
66+ config : Tuple[relay.Function, Dict[str, NDArray], List[str]]
67+ A tuple containing 1) The function to test, 2) A dictionary of var names to input shapes and
68+ 3) A list of which vars should be considered params.
69+
70+ run_module: bool
71+ If True, the built module will be run after being compiled.
72+ """
73+ f , input_shapes , is_param = config
74+ params = {x : np .random .uniform (- 1 , 1 , input_shapes [x ]).astype (np .float32 ) for x in is_param }
75+ input_dict = {
76+ k : np .random .uniform (- 1 , 1 , v ).astype (np .float32 )
77+ for k , v in input_shapes .items ()
78+ if k not in is_param
79+ }
80+ dev = tvm .device (target )
81+
82+ result_dict = dict ()
83+ for mode in ["graph" , "vm" ]:
84+ for use_dnnl in [False , True ]:
85+ mod = tvm .IRModule ()
86+ mod ["main" ] = f
87+ result_key = mode + ("_dnnl" if use_dnnl else "" )
88+ if use_dnnl :
89+ mod = dnnl .partition_for_dnnl (mod , params )
90+ with tvm .transform .PassContext (opt_level = 3 ):
91+ func = relay .create_executor (
92+ mode , mod = mod , device = dev , target = target
93+ ).evaluate ()
94+ else :
95+ with tvm .transform .PassContext (opt_level = 3 ):
96+ func = relay .create_executor (
97+ mode , mod = mod , device = dev , target = target
98+ ).evaluate ()
99+ if run_module :
100+ result_dict [result_key ] = func (** input_dict , ** params )
101+
102+ if run_module :
103+ assert_result_dict_holds (result_dict )
104+
105+
106+ def test_dnnl_not_compatible (run_module ):
107+ dtype = "float32"
108+ xshape = (1 , 32 , 14 , 14 )
109+ x_data = np .random .uniform (- 1 , 1 , xshape ).astype (dtype )
110+
111+ x = relay .var ("x" , shape = (xshape ), dtype = dtype )
112+ y = relay .add (x , x )
113+ z = relay .cast (relay .cast (y , "int32" ), "float32" )
114+ out = relay .nn .relu (z )
115+ f = relay .Function ([x ], out )
116+ mod = tvm .IRModule ()
117+ mod ["main" ] = f
118+ mod = dnnl .partition_for_dnnl (mod )
119+ for mode in ["graph" , "vm" ]:
120+ with tvm .transform .PassContext (opt_level = 3 ):
121+ func = relay .create_executor (
122+ mode , mod = mod , device = tvm .cpu (0 ), target = "llvm"
123+ ).evaluate ()
124+ if run_module :
125+ results = func (x_data )
126+
127+
128+ def test_conv2d (run_module ):
129+ def get_graph (
130+ x_shape = (1 , 32 , 8 , 8 ),
131+ k_shape = (16 , 32 , 3 , 3 ),
132+ groups = 1 ,
133+ padding = (0 , 0 ),
134+ strides = (1 , 1 ),
135+ dilation = (1 , 1 ),
136+ channels = None ,
137+ ):
138+ x = relay .var ("x" , shape = (x_shape ), dtype = "float32" )
139+ kernel = relay .var ("kernel" , shape = (k_shape ), dtype = "float32" )
140+ out = relay .nn .conv2d (
141+ x ,
142+ kernel ,
143+ kernel_size = k_shape [2 :4 ],
144+ groups = groups ,
145+ padding = padding ,
146+ strides = strides ,
147+ dilation = dilation ,
148+ channels = channels ,
149+ )
150+ f = relay .Function ([x , kernel ], out )
151+ return f , {"x" : x_shape , "kernel" : k_shape }, ["kernel" ]
152+
153+ for k_shape , groups in [((16 , 32 , 3 , 3 ), 1 ), ((32 , 1 , 3 , 3 ), 32 )]:
154+ for padding in [(0 , 0 ), (1 , 1 )]:
155+ for strides in [(1 , 1 ), (2 , 2 )]:
156+ for dilation in [(1 , 1 )]:
157+ run_and_verify_func (
158+ get_graph (
159+ k_shape = k_shape ,
160+ groups = groups ,
161+ padding = padding ,
162+ strides = strides ,
163+ dilation = dilation ,
164+ ),
165+ run_module = run_module ,
166+ )
167+
168+
169+ def test_conv2d_weights_const (run_module ):
170+ def get_graph (
171+ x_shape = (1 , 32 , 8 , 8 ),
172+ k_shape = (16 , 32 , 3 , 3 ),
173+ groups = 1 ,
174+ padding = (0 , 0 ),
175+ strides = (1 , 1 ),
176+ dilation = (1 , 1 ),
177+ ):
178+ x = relay .var ("x" , shape = (x_shape ), dtype = "float32" )
179+ kernel = relay .const (np .ones (k_shape ).astype ("float32" ))
180+ out = relay .nn .conv2d (
181+ x ,
182+ kernel ,
183+ channels = k_shape [0 ],
184+ kernel_size = k_shape [2 :4 ],
185+ groups = groups ,
186+ padding = padding ,
187+ strides = strides ,
188+ dilation = dilation ,
189+ )
190+ f = relay .Function ([x ], out )
191+ return f , {"x" : x_shape }, []
192+
193+ run_and_verify_func (get_graph (), run_module = run_module )
194+
195+
196+ def test_dense (run_module ):
197+ def get_graph (x_shape = (1 , 16 ), k_shape = (32 , 16 )):
198+ x = relay .var ("x" , shape = (x_shape ), dtype = "float32" )
199+ kernel = relay .var ("kernel" , shape = (k_shape ), dtype = "float32" )
200+ out = relay .nn .dense (x , kernel , units = k_shape [0 ])
201+ f = relay .Function ([x , kernel ], out )
202+ return f , {"x" : x_shape , "kernel" : k_shape }, ["kernel" ]
203+
204+ run_and_verify_func (get_graph (), run_module = run_module )
205+ run_and_verify_func (get_graph (k_shape = (1 , 16 )), run_module = run_module )
206+
207+
208+ def test_multiple_outputs (run_module ):
209+ def get_graph ():
210+ x = relay .var ("x" , shape = (1 , 3 ), dtype = "float32" )
211+ y = relay .var ("y" , shape = (1 , 3 ), dtype = "float32" )
212+ z = relay .add (x , y )
213+ w = relay .add (z , y )
214+ out = relay .Tuple ((z , w ))
215+ f = relay .Function ([x , y ], out )
216+ return f , {"x" : (1 , 3 ), "y" : (1 , 3 )}, []
217+
218+ run_and_verify_func (get_graph (), run_module = run_module )
219+
220+
221+ if __name__ == "__main__" :
222+ import sys
223+ sys .exit (pytest .main ([__file__ ] + sys .argv [1 :]))
0 commit comments