diff --git a/topi/python/topi/arm_cpu/conv2d.py b/topi/python/topi/arm_cpu/conv2d.py index 77b37ed5a1e2..55d6bba2c43f 100644 --- a/topi/python/topi/arm_cpu/conv2d.py +++ b/topi/python/topi/arm_cpu/conv2d.py @@ -280,13 +280,15 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec, s[conv].compute_at(s[last], ow) # mark parallel - s[last].parallel(co) + p = s[last].fuse(n, co) + s[last].parallel(p) if data_vec.op.name == 'data_vec_undilated': - _, h, _, _, _, _, _, _ = s[data_vec].op.axis + n, h, _, _, _, _, _, _ = s[data_vec].op.axis else: - _, h, _, _, _, _ = s[data_vec].op.axis - s[data_vec].parallel(h) + n, h, _, _, _, _ = s[data_vec].op.axis + p = s[data_vec].fuse(n, h) + s[data_vec].parallel(p) if kernel_vec.op.name == 'kernel_vec': co, _, _, _, _ = s[kernel_vec].op.axis @@ -470,8 +472,9 @@ def _schedule_winograd(cfg, s, output, last): # output n, co, h, w = s[last].op.axis co, coi = cfg['tile_k'].apply(s, last, co) - s[M].compute_at(s[last], co) - s[last].parallel(co) + p = s[last].fuse(n, co) + s[M].compute_at(s[last], p) + s[last].parallel(p) MM = s.cache_read(M, 'global', [Y]) m = get_const_int(V.shape[0]) + 1 - 3