Skip to content

Commit a238376

Browse files
committed
[CODEGEN][METAL] Fix unaligned vector load
This PR fixes the implementation of unaligned vector load. Previously vector construction was printed as (float2)(v0, v1). This will cause problem as C have comma expression, and (v0, v1) will be evaluated as v1. The final result will become float2(v1, v1). The bug affects all codegen that uses the default implementation, such as metal. We added a testcase on metal to cover this case.
1 parent a5ed21d commit a238376

File tree

4 files changed

+46
-21
lines changed

4 files changed

+46
-21
lines changed

src/target/source/codegen_c.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI
695695
can_vector_load = true;
696696
}
697697
}
698-
698+
LOG(INFO) << "Something wrong";
699699
if (can_vector_load) {
700700
std::string ref = GetVecLoad(op->dtype, op->buffer.get(), base.Eval());
701701
HandleVolatileLoads(ref, op, os);
@@ -722,8 +722,10 @@ void CodeGenC::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // NOLI
722722
value_temp << '[';
723723
PrintVecElemLoad(sindex, index.dtype(), i, value_temp);
724724
value_temp << ']';
725+
LOG(INFO) << "PrinVecElemLoad" << value_temp.str();
725726
PrintVecElemLoadExpr(op->dtype, i, value_temp.str(), svalue_expr);
726727
}
728+
LOG(INFO) << svalue_expr.str();
727729
os << svalue_expr.str();
728730
}
729731
}
@@ -802,15 +804,16 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
802804
}
803805

804806
void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
805-
// constraint of current logic
806-
ICHECK_EQ(op->base.dtype(), DataType::Int(32));
807-
os << "((int" << op->lanes << ")(";
807+
// NOTE: C have comma expression so cannot use (int2)(v0, v1)
808+
// instead should use int2(v0, v1)
809+
PrintType(op->dtype, os);
810+
os << "(";
808811
for (int i = 0; i < op->lanes; i++) {
809812
os << "(" << PrintExpr(op->base) << ")"
810813
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
811814
if (i != op->lanes - 1) os << ", ";
812815
}
813-
os << "))";
816+
os << ")";
814817
}
815818

816819
void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
@@ -999,9 +1002,11 @@ void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
9991002
}
10001003

10011004
if (i == 0) {
1002-
os << "((";
1005+
// NOTE: C have comma expression so cannot use (float2)(v0, v1)
1006+
// instead should use float2(v0, v1)
1007+
os << "(";
10031008
PrintType(t, os);
1004-
os << ")(";
1009+
os << "(";
10051010
}
10061011
os << value;
10071012
if (i != t.lanes() - 1) {

src/target/source/codegen_metal.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,6 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N
299299
os << ')';
300300
}
301301

302-
void CodeGenMetal::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
303-
PrintType(op->dtype, os);
304-
os << "(";
305-
for (int i = 0; i < op->lanes; ++i) {
306-
if (i != 0) os << ", ";
307-
os << "(" << PrintExpr(op->base) << ")"
308-
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
309-
}
310-
os << ')';
311-
}
312-
313302
void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
314303
if (op->op.same_as(builtin::reinterpret())) {
315304
// generate as_type<TYPE>(ARG)
@@ -369,7 +358,11 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
369358
code << fsource;
370359
}
371360

372-
return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str());
361+
std::string code_str = code.str();
362+
if (const auto* f = Registry::Get("tvm_callback_metal_postproc")) {
363+
code_str = (*f)(code_str).operator std::string();
364+
}
365+
return MetalModuleCreate(code_str, fmt, ExtractFuncInfo(mod), source.str());
373366
}
374367

375368
TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);

src/target/source/codegen_metal.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ class CodeGenMetal final : public CodeGenC {
5151
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
5252
// overload visitor
5353
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
54-
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
5554
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
56-
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
55+
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
56+
5757
// reuse parent's function.
5858
using CodeGenC::PrintType;
5959

tests/python/unittest/test_target_codegen_metal.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
2323
from tvm.contrib import nvcc
2424
import tvm.testing
25+
import tvm.script
26+
from tvm.script import tir as T
2527

2628
tx = te.thread_axis("threadIdx.x")
2729
bx = te.thread_axis("blockIdx.x")
@@ -54,6 +56,31 @@ def check_inf_nan(dev, n, value, dtype):
5456
check_inf_nan(dev, 1, float("nan"), "float16")
5557

5658

59+
@tvm.testing.requires_gpu
60+
@tvm.testing.requires_metal
61+
def test_unaligned_vectorize():
62+
@tvm.script.ir_module
63+
class IRModule:
64+
@T.prim_func
65+
def main(A: T.Buffer((2, 3), "float32"), B: T.Buffer((6,), "float32")):
66+
T.func_attr({"global_symbol": "main"})
67+
for i0_1 in T.thread_binding(3, thread="threadIdx.x"):
68+
for i0_0 in T.vectorized(2):
69+
with T.block("block"):
70+
vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1)
71+
B[vi0] = A[vi0 // 3, vi0 % 3]
72+
73+
target = "metal"
74+
dev = tvm.metal()
75+
76+
a = (np.arange(6).reshape(2, 3)).astype("float32")
77+
a_nd = tvm.nd.array(a, dev)
78+
b_nd = tvm.nd.empty((6,), "float32", dev)
79+
f = tvm.build(IRModule, target=target)
80+
f(a_nd, b_nd)
81+
np.testing.assert_allclose(b_nd.numpy(), a.reshape(6), atol=1e-5, rtol=1e-5)
82+
83+
5784
@tvm.testing.requires_gpu
5885
@tvm.testing.requires_metal
5986
def test_metal_erf():

0 commit comments

Comments
 (0)