diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 644f4e6dfa7a..ed32ea77858f 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -469,7 +469,10 @@ def apply( TS, TR = 2, 64 elif target.kind.name == "rocm": VEC_C = 4 - LOAD_V_SHARED = True + # TODO: set LOAD_V_SHARED = False for now + # rocm might have some issues when load/store of shared do not belong to same data type + # and only works for certain vector lens, our commonly useful vector lens are in 4 + LOAD_V_SHARED = False LOAD_V_VEC = 8 UNROLL = 256 if isinstance(len_S, int):