diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index f60c40efa21c..399e418c464b 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -855,7 +855,7 @@ def get_tile_size(x, y, t): cnt = (x * y) // t assert (x * y) % t == 0 tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: tile_y += 1 assert tile_y <= cnt tile_x = cnt // tile_y @@ -1509,7 +1509,7 @@ def get_tile_size(x, y, t): cnt = (x * y) // t assert (x * y) % t == 0 tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: tile_y += 1 assert tile_y <= cnt tile_x = cnt // tile_y @@ -1867,7 +1867,7 @@ def get_tile_size(x, y, t): cnt = (x * y) // t assert (x * y) % t == 0 tile_y = (int)(math.ceil(math.sqrt(cnt))) - while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt: + while (cnt % tile_y != 0 or y % tile_y != 0 or x % (cnt // tile_y) != 0) and tile_y <= cnt: tile_y += 1 assert tile_y <= cnt tile_x = cnt // tile_y