@@ -145,7 +145,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
145145 if layout == "NCHW" :
146146 assert kernel_layout == "OIHW"
147147 if (
148- (target .kind .name in ["cuda" , "vulkan" ])
148+ (target .kind .name in ["cuda" , "vulkan" , "rocm" ])
149149 and data .dtype in ("int8" , "uint8" )
150150 and kernel .dtype in ("int8" , "uint8" )
151151 ):
@@ -297,7 +297,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
297297 Need to satisfy tensor core schedule."
298298 )
299299 elif (
300- (target .kind .name in ["cuda" , "vulkan" ])
300+ (target .kind .name in ["cuda" , "vulkan" , "rocm" ])
301301 and layout == "NCHW4c"
302302 and data .dtype in ["int8" , "uint8" ]
303303 ):
@@ -376,7 +376,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
376376 ic_chunk = in_channels // 4
377377
378378 if (
379- (target .kind .name in ["cuda" , "vulkan" ])
379+ (target .kind .name in ["cuda" , "vulkan" , "rocm" ])
380380 and data .dtype in ["int8" , "uint8" ]
381381 and kernel .dtype in ["int8" , "uint8" ]
382382 and channels % groups == 0
@@ -836,7 +836,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
836836 b , i = get_const_tuple (data .shape )
837837 o , _ = get_const_tuple (weights .shape )
838838 if (
839- target .kind .name in ["cuda" , "vulkan" ]
839+ target .kind .name in ["cuda" , "vulkan" , "rocm" ]
840840 and data .dtype == "int8"
841841 and weights .dtype == "int8"
842842 and out_type .dtype == "int32"
0 commit comments