Skip to content

Commit f4dce24

Browse files
wyc-ruikerwangyucheng
andauthored
[Codegen][CUDA] Fix make_int4x cuda codegen vectorize (#8137)
Co-authored-by: wangyucheng <[email protected]>
1 parent 4344540 commit f4dce24

File tree

3 files changed

+55
-13
lines changed

3 files changed

+55
-13
lines changed

src/target/source/codegen_c.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,18 @@ std::string CodeGenC::GetBufferRef(DataType t, const VarNode* buffer, PrimExpr i
212212
PrintType(t.element_of(), os);
213213
os << "*)";
214214
}
215-
os << vid << " + (";
216-
PrintExpr(index, os);
217-
os << ")";
218215
if (t.bits() == 4 || (t.bits() == 1 && t.is_int())) {
219-
os << " / " << (32 / t.bits());
216+
os << vid << ") + (";
217+
PrintExpr(index, os);
218+
os << ")";
219+
os << " / " << t.lanes();
220+
os << ")[0]";
221+
} else {
222+
os << vid << " + (";
223+
PrintExpr(index, os);
224+
os << ")";
225+
os << "))[0]";
220226
}
221-
os << "))[0]";
222227
}
223228
return os.str();
224229
}

src/target/source/codegen_cuda.cc

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -809,18 +809,48 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
809809
return;
810810
}
811811

812-
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4 && op->lanes == 8) {
813-
// make_int4x8
812+
if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
813+
bool fail = false;
814814
const int64_t* p = as_const_int(op->value);
815815
ICHECK(p);
816816
int64_t v = *p & 0xF;
817-
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
818-
if (op->dtype.is_uint()) {
819-
os << "(uint)" << v;
817+
818+
if (op->lanes == 4) {
819+
v = (v << 12) | (v << 8) | (v << 4) | v;
820+
if (op->dtype.is_uint()) {
821+
os << "(uint16_t)" << v;
822+
} else {
823+
os << "(int16_t)" << v;
824+
}
820825
} else {
821-
os << "(int)" << v;
826+
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8) | (v << 4) | v;
827+
if (op->lanes == 8) {
828+
if (op->dtype.is_uint()) {
829+
os << "(uint)" << v;
830+
} else {
831+
os << "(int)" << v;
832+
}
833+
} else if (op->lanes == 16 || op->lanes == 32) {
834+
os << "make_";
835+
PrintType(op->dtype, os);
836+
os << '(';
837+
for (int i = 0; i < op->lanes / 8; ++i) {
838+
if (i != 0) os << ", ";
839+
if (op->dtype.is_uint()) {
840+
os << "(uint)" << v;
841+
} else {
842+
os << "(int)" << v;
843+
}
844+
}
845+
os << ')';
846+
} else {
847+
fail = true;
848+
}
849+
}
850+
851+
if (!fail) {
852+
return;
822853
}
823-
return;
824854
}
825855

826856
std::string v = PrintExpr(op->value);

tests/python/unittest/test_target_codegen_cuda.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,14 +215,21 @@ def check_cuda(n, value, lanes):
215215
y, x = s[A].op.axis
216216
s[A].vectorize(x)
217217
s[A].bind(y, bx)
218-
fun = tvm.build(s, [A], "cuda", name="make_int4x8")
218+
kernel_name = "make_int4x" + str(lanes)
219+
fun = tvm.build(s, [A], "cuda", name=kernel_name)
219220
np_a = np.full((n, lanes), value, dtype="int8")
220221
a = tvm.nd.empty((n, lanes), dtype, dev)
221222
fun(a)
222223
np.testing.assert_equal(a.numpy(), np_a)
223224

225+
check_cuda(64, 1, 4)
226+
check_cuda(64, 7, 4)
224227
check_cuda(64, 1, 8)
225228
check_cuda(64, 7, 8)
229+
check_cuda(64, 1, 16)
230+
check_cuda(64, 7, 16)
231+
check_cuda(64, 1, 32)
232+
check_cuda(64, 7, 32)
226233

227234

228235
@tvm.testing.requires_gpu

0 commit comments

Comments
 (0)