@@ -2323,5 +2323,116 @@ def expected():
23232323 assert tvm .ir .structural_equal (a , b ), "Actual = \n " + str (a ) + "\n \n Expected = \n " + str (b )
23242324
23252325
2326+ @pytest .mark .parametrize (
2327+ "data_layout, kernel_layout" ,
2328+ [
2329+ ("NCHW1c" , "OIHW1i1o" ),
2330+ ("NCHW4c" , "OIHW4i4o" ),
2331+ ("NCHW8c" , "OIHW8i8o" ),
2332+ ("NCHW16c" , "OIHW16i16o" ),
2333+ ],
2334+ )
2335+ def test_resnet_convert_layout_nchwc (data_layout , kernel_layout ):
2336+ x = relay .var ("x" , shape = (1 , 3 , 224 , 224 ))
2337+ weight1 = relay .var ("weight1" , shape = (64 , 3 , 7 , 7 ))
2338+ weight2 = relay .var ("weight2" , shape = (64 , 64 , 3 , 3 ))
2339+ weight3 = relay .var ("weight3" , shape = (64 , 64 , 1 , 1 ))
2340+
2341+ def before ():
2342+ y = relay .nn .conv2d (
2343+ x ,
2344+ weight1 ,
2345+ strides = (2 , 2 ),
2346+ padding = (3 , 3 ),
2347+ channels = 64 ,
2348+ kernel_size = (7 , 7 ),
2349+ data_layout = "NCHW" ,
2350+ kernel_layout = "OIHW" ,
2351+ )
2352+ y = relay .nn .relu (y )
2353+ y = relay .nn .max_pool2d (y , pool_size = (3 , 3 ), strides = (2 , 2 ), padding = (1 , 1 ))
2354+ y1 = relay .nn .conv2d (
2355+ y ,
2356+ weight2 ,
2357+ channels = 64 ,
2358+ kernel_size = (3 , 3 ),
2359+ padding = (1 , 1 ),
2360+ data_layout = "NCHW" ,
2361+ kernel_layout = "OIHW" ,
2362+ )
2363+ y1 = relay .nn .relu (y1 )
2364+ y2 = relay .nn .conv2d (
2365+ y ,
2366+ weight3 ,
2367+ channels = 64 ,
2368+ kernel_size = (1 , 1 ),
2369+ data_layout = "NCHW" ,
2370+ kernel_layout = "OIHW" ,
2371+ )
2372+ y2 = relay .nn .relu (y2 )
2373+ y = y1 + y2
2374+ y = relay .nn .global_max_pool2d (y , layout = "NCHW" )
2375+ return y
2376+
2377+ def expected ():
2378+ if data_layout == "NCHW1c" :
2379+ y = relay .nn .contrib_conv2d_nchwc (
2380+ relay .layout_transform (x , "NCHW" , data_layout ),
2381+ relay .layout_transform (weight1 , "OIHW" , kernel_layout ),
2382+ strides = (2 , 2 ),
2383+ padding = (3 , 3 ),
2384+ channels = 64 ,
2385+ kernel_size = (7 , 7 ),
2386+ data_layout = data_layout ,
2387+ kernel_layout = kernel_layout ,
2388+ )
2389+ y = relay .nn .relu (y )
2390+ y = relay .nn .max_pool2d (
2391+ y , pool_size = (3 , 3 ), strides = (2 , 2 ), padding = (1 , 1 ), layout = data_layout
2392+ )
2393+ else :
2394+ y = relay .nn .conv2d (
2395+ x ,
2396+ weight1 ,
2397+ strides = (2 , 2 ),
2398+ padding = (3 , 3 ),
2399+ channels = 64 ,
2400+ kernel_size = (7 , 7 ),
2401+ data_layout = "NCHW" ,
2402+ kernel_layout = "OIHW" ,
2403+ )
2404+ y = relay .nn .relu (y )
2405+ y = relay .nn .max_pool2d (y , pool_size = (3 , 3 ), strides = (2 , 2 ), padding = (1 , 1 ))
2406+ y = relay .layout_transform (y , "NCHW" , data_layout )
2407+ y1 = relay .nn .contrib_conv2d_nchwc (
2408+ y ,
2409+ relay .layout_transform (weight2 , "OIHW" , kernel_layout ),
2410+ channels = 64 ,
2411+ kernel_size = (3 , 3 ),
2412+ padding = (1 , 1 ),
2413+ data_layout = data_layout ,
2414+ kernel_layout = kernel_layout ,
2415+ )
2416+ y1 = relay .nn .relu (y1 )
2417+ y2 = relay .nn .contrib_conv2d_nchwc (
2418+ y ,
2419+ relay .layout_transform (weight3 , "OIHW" , kernel_layout ),
2420+ channels = 64 ,
2421+ kernel_size = (1 , 1 ),
2422+ data_layout = data_layout ,
2423+ kernel_layout = kernel_layout ,
2424+ )
2425+ y2 = relay .nn .relu (y2 )
2426+ y = y1 + y2
2427+ y = relay .nn .global_max_pool2d (y , layout = data_layout )
2428+ y = relay .layout_transform (y , data_layout , "NCHW" )
2429+ return y
2430+
2431+ a = before ()
2432+ a = run_opt_pass (a , transform .ConvertLayout ({"nn.conv2d" : [data_layout , kernel_layout ]}))
2433+ b = run_opt_pass (expected (), transform .InferType ())
2434+ assert tvm .ir .structural_equal (a , b ), "Actual = \n " + str (a ) + "\n Expect = \n " + str (b )
2435+
2436+
23262437if __name__ == "__main__" :
23272438 pytest .main ([__file__ ])
0 commit comments