1717"""Definition of ROCm operator strategy."""
1818# pylint: disable=invalid-name,unused-argument,unused-wildcard-import,wildcard-import
1919from tvm import topi
20- from tvm .auto_scheduler import is_auto_scheduler_enabled
2120from tvm .te import SpecializedCondition
2221from tvm .contrib .thrust import can_use_rocthrust
2322from tvm .contrib import miopen
2423
2524from .generic import *
2625from .. import op as _op
27- from .cuda import judge_winograd , naive_schedule
26+ from .cuda import batch_matmul_strategy_cuda , conv2d_strategy_cuda , dense_strategy_cuda
2827
2928
3029@conv2d_strategy .register ("rocm" )
3130def conv2d_strategy_rocm (attrs , inputs , out_type , target ):
3231 """conv2d rocm strategy"""
33- strategy = _op .OpStrategy ()
34- data , kernel = inputs
35- dilation_h , dilation_w = attrs .get_int_tuple ("dilation" )
3632 groups = attrs .groups
3733 layout = attrs .data_layout
38- stride_h , stride_w = attrs .get_int_tuple ("strides" )
39- kernel_layout = attrs .kernel_layout
4034 padding = attrs .get_int_tuple ("padding" )
41- if dilation_h < 1 or dilation_w < 1 :
42- raise ValueError ("dilation should be positive value" )
43-
44- if groups == 1 :
45- if layout == "NCHW" :
46- # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
47- assert kernel_layout == "OIHW"
48- strategy .add_implementation (
49- wrap_compute_conv2d (topi .cuda .conv2d_nchw ),
50- wrap_topi_schedule (topi .cuda .schedule_conv2d_nchw ),
51- name = "conv2d_nchw.cuda" ,
52- )
53- _ , _ , kh , kw = get_const_tuple (kernel .shape )
54- if (
55- 2 < kh < 8
56- and 2 < kw < 8
57- and kh == kw
58- and stride_h == 1
59- and stride_w == 1
60- and dilation_h == 1
61- and dilation_w == 1
62- ):
63- strategy .add_implementation (
64- wrap_compute_conv2d (topi .cuda .conv2d_nchw_winograd ),
65- wrap_topi_schedule (topi .cuda .schedule_conv2d_nchw_winograd ),
66- name = "conv2d_nchw_winograd.cuda" ,
67- plevel = 5 ,
68- )
69- elif layout == "NHWC" :
70- assert kernel_layout == "HWIO"
71- strategy .add_implementation (
72- wrap_compute_conv2d (topi .gpu .conv2d_nhwc ),
73- wrap_topi_schedule (topi .gpu .schedule_conv2d_nhwc ),
74- name = "conv2d_nhwc.gpu" ,
75- )
76- N , H , W , _ = get_const_tuple (data .shape )
77- KH , KW , CI , CO = get_const_tuple (kernel .shape )
7835
79- (_ , judge_winograd_autotvm , judge_winograd_auto_scheduler ,) = judge_winograd (
80- N ,
81- H ,
82- W ,
83- KH ,
84- KW ,
85- CI ,
86- CO ,
87- padding ,
88- stride_h ,
89- stride_w ,
90- dilation_h ,
91- dilation_w ,
92- data .dtype ,
93- kernel .dtype ,
94- pre_flag = False ,
95- )
36+ strategy = conv2d_strategy_cuda (attrs , inputs , out_type , target )
9637
97- if judge_winograd_autotvm :
98- strategy .add_implementation (
99- wrap_compute_conv2d (topi .cuda .conv2d_nhwc_winograd_direct ),
100- wrap_topi_schedule (topi .cuda .schedule_conv2d_nhwc_winograd_direct ),
101- name = "conv2d_nhwc_winograd_direct.cuda" ,
102- plevel = 5 ,
103- )
38+ # add miopen implementation
39+ if (
40+ "miopen" in target .libs
41+ and groups == 1
42+ and layout == "NCHW"
43+ and padding [0 ] == padding [2 ]
44+ and padding [1 ] == padding [3 ]
45+ ):
46+ strategy .add_implementation (
47+ wrap_compute_conv2d (topi .rocm .conv2d_nchw_miopen , True ),
48+ wrap_topi_schedule (topi .rocm .schedule_conv2d_nchw_miopen ),
49+ name = "conv2d_nchw_miopen.rocm" ,
50+ plevel = 50 ,
51+ )
10452
105- if is_auto_scheduler_enabled () and judge_winograd_auto_scheduler :
106- strategy .add_implementation (
107- wrap_compute_conv2d (topi .nn .conv2d_winograd_nhwc ),
108- naive_schedule , # this implementation should never be picked by autotvm
109- name = "conv2d_nhwc.winograd" ,
110- plevel = 15 ,
111- )
112- elif layout == "HWCN" :
113- assert kernel_layout == "HWIO"
114- strategy .add_implementation (
115- wrap_compute_conv2d (topi .cuda .conv2d_hwcn ),
116- wrap_topi_schedule (topi .cuda .schedule_conv2d_hwcn ),
117- name = "conv2d_hwcn.cuda" ,
118- )
119- elif layout == "NCHW4c" and data .dtype in ["int8" , "uint8" ]:
120- assert kernel_layout == "OIHW4o4i"
121- strategy .add_implementation (
122- wrap_compute_conv2d (topi .cuda .conv2d_NCHWc_int8 , True ),
123- wrap_topi_schedule (topi .cuda .schedule_conv2d_NCHWc_int8 ),
124- name = "conv2d_NCHWc_int8.cuda" ,
125- )
126- else :
127- raise RuntimeError ("Unsupported conv2d layout {} for CUDA" .format (layout ))
128- # add miopen implementation
129- if (
130- "miopen" in target .libs
131- and layout == "NCHW"
132- and padding [0 ] == padding [2 ]
133- and padding [1 ] == padding [3 ]
134- ):
135- strategy .add_implementation (
136- wrap_compute_conv2d (topi .rocm .conv2d_nchw_miopen , True ),
137- wrap_topi_schedule (topi .rocm .schedule_conv2d_nchw_miopen ),
138- name = "conv2d_nchw_miopen.rocm" ,
139- plevel = 15 ,
140- )
141- elif is_depthwise_conv2d (data .shape , layout , kernel .shape , kernel_layout , groups ):
142- if layout == "NCHW" :
143- assert kernel_layout == "OIHW"
144- strategy .add_implementation (
145- wrap_compute_conv2d (topi .cuda .depthwise_conv2d_nchw ),
146- wrap_topi_schedule (topi .cuda .schedule_depthwise_conv2d_nchw ),
147- name = "depthwise_conv2d_nchw.cuda" ,
148- )
149- elif layout == "NHWC" :
150- assert kernel_layout == "HWOI"
151- strategy .add_implementation (
152- wrap_compute_conv2d (topi .nn .depthwise_conv2d_nhwc ),
153- wrap_topi_schedule (topi .cuda .schedule_depthwise_conv2d_nhwc ),
154- name = "depthwise_conv2d_nhwc.cuda" ,
155- )
156- else :
157- raise RuntimeError ("Unsupported depthwise_conv2d layout {}" .format (layout ))
158- else : # group_conv2d
159- if layout == "NCHW" :
160- # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
161- assert kernel_layout == "OIHW"
162- strategy .add_implementation (
163- wrap_compute_conv2d (topi .cuda .group_conv2d_nchw , has_groups = True ),
164- wrap_topi_schedule (topi .cuda .schedule_group_conv2d_nchw ),
165- name = "group_conv2d_nchw.cuda" ,
166- )
167- elif layout == "NCHW4c" and data .dtype in ["int8" , "uint8" ]:
168- assert kernel_layout == "OIHW4o4i"
169- strategy .add_implementation (
170- wrap_compute_conv2d (topi .cuda .group_conv2d_NCHWc_int8 , True ),
171- wrap_topi_schedule (topi .cuda .schedule_group_conv2d_NCHWc_int8 ),
172- name = "group_conv2d_NCHWc_int8.cuda" ,
173- )
174- else :
175- raise RuntimeError ("Unsupported group_conv2d layout {}" .format (layout ))
17653 return strategy
17754
17855
17956@dense_strategy .register ("rocm" )
18057def dense_strategy_rocm (attrs , inputs , out_type , target ):
18158 """Dense strategy for ROCM"""
18259 assert len (inputs [0 ].shape ) == 2 and len (inputs [1 ].shape ) == 2 , "Only support 2-dim dense"
183- strategy = _op .OpStrategy ()
184- strategy .add_implementation (
185- wrap_compute_dense (topi .rocm .dense ),
186- wrap_topi_schedule (topi .rocm .schedule_dense ),
187- name = "dense.rocm" ,
188- )
60+ strategy = dense_strategy_cuda (attrs , inputs , out_type , target )
61+
18962 if target .kind .name == "rocm" and "rocblas" in target .libs :
19063 assert out_type .dtype == inputs [0 ].dtype , "Mixed precision not supported."
19164 strategy .add_implementation (
@@ -200,13 +73,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
20073@batch_matmul_strategy .register ("rocm" )
20174def batch_matmul_strategy_rocm (attrs , inputs , out_type , target ):
20275 """Batch matmul strategy for ROCM"""
203- strategy = _op .OpStrategy ()
204- strategy .add_implementation (
205- wrap_compute_batch_matmul (topi .cuda .batch_matmul ),
206- wrap_topi_schedule (topi .cuda .schedule_batch_matmul ),
207- name = "batch_matmul.cuda" ,
208- plevel = 10 ,
209- )
76+ strategy = batch_matmul_strategy_cuda (attrs , inputs , out_type , target )
77+
21078 if target .kind .name == "rocm" and "rocblas" in target .libs :
21179 assert out_type .dtype == inputs [0 ].dtype , "Mixed precision not supported."
21280 strategy .add_implementation (
0 commit comments