Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/target/source/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -802,15 +802,16 @@ void CodeGenC::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*)
}

void CodeGenC::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
// constraint of current logic
ICHECK_EQ(op->base.dtype(), DataType::Int(32));
os << "((int" << op->lanes << ")(";
// NOTE: C have comma expression so cannot use (int2)(v0, v1)
// instead should use int2(v0, v1)
PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != op->lanes - 1) os << ", ";
}
os << "))";
os << ")";
}

void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) {
Expand Down Expand Up @@ -999,9 +1000,11 @@ void CodeGenC::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
}

if (i == 0) {
os << "((";
// NOTE: C have comma expression so cannot use (float2)(v0, v1)
// instead should use float2(v0, v1)
os << "(";
PrintType(t, os);
os << ")(";
os << "(";
}
os << value;
if (i != t.lanes() - 1) {
Expand Down
17 changes: 5 additions & 12 deletions src/target/source/codegen_metal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,17 +299,6 @@ void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // N
os << ')';
}

void CodeGenMetal::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
PrintType(op->dtype, os);
os << "(";
for (int i = 0; i < op->lanes; ++i) {
if (i != 0) os << ", ";
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
}
os << ')';
}

void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
if (op->op.same_as(builtin::reinterpret())) {
// generate as_type<TYPE>(ARG)
Expand Down Expand Up @@ -369,7 +358,11 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
code << fsource;
}

return MetalModuleCreate(code.str(), fmt, ExtractFuncInfo(mod), source.str());
std::string code_str = code.str();
if (const auto* f = Registry::Get("tvm_callback_metal_postproc")) {
code_str = (*f)(code_str).operator std::string();
}
return MetalModuleCreate(code_str, fmt, ExtractFuncInfo(mod), source.str());
}

TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal);
Expand Down
4 changes: 2 additions & 2 deletions src/target/source/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ class CodeGenMetal final : public CodeGenC {
void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
// overload visitor
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)

// reuse parent's function.
using CodeGenC::PrintType;

Expand Down
37 changes: 37 additions & 0 deletions src/target/source/codegen_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,31 @@ void CodeGenOpenCL::PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr
stream << ");\n";
}

void CodeGenOpenCL::PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
std::ostream& os) { // NOLINT(*)
ICHECK_GT(t.lanes(), 1);
if (t.bits() == 8 && (t.is_int() || t.is_uint())) {
if (i != 0) {
os << "|";
}
os << "((0x000000ff << " << i * 8 << ") & (" << value << " << " << i * 8 << "))";
return;
}
if (i == 0) {
// NOTE: opencl print things as (float2)(v0, v1)
os << "((";
PrintType(t, os);
os << ")(";
}
os << value;
if (i != t.lanes() - 1) {
os << ",";
} else {
os << "))";
}
return;
}

void CodeGenOpenCL::PrintStorageSync(const CallNode* op) {
const std::string& sync = op->args[0].as<StringImmNode>()->value;
if (sync == "warp") {
Expand Down Expand Up @@ -490,6 +515,18 @@ void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) { //
os << "))";
}

void CodeGenOpenCL::VisitExpr_(const RampNode* op, std::ostream& os) { // NOLINT(*)
os << "((";
PrintType(op->dtype, os);
os << ")(";
for (int i = 0; i < op->lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
if (i != op->lanes - 1) os << ", ";
}
os << "))";
}

void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // NOLINT(*)
if (std::isinf(op->value)) {
if (op->value < 0) {
Expand Down
3 changes: 3 additions & 0 deletions src/target/source/codegen_opencl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ class CodeGenOpenCL final : public CodeGenC {
std::string GetVecLoad(DataType t, const BufferNode* buffer, PrimExpr base) final;
void PrintVecStore(const BufferNode* buffer, DataType t, PrimExpr base,
const std::string& value) final; // NOLINT(*)
void PrintVecElemLoadExpr(DataType t, int i, const std::string& value,
std::ostream& os) final; // NOLINT(*)
// the address of load/store
void PrintVecAddr(const BufferNode* buffer, DataType t, PrimExpr base,
std::ostream& os); // NOLINT(*)
Expand All @@ -62,6 +64,7 @@ class CodeGenOpenCL final : public CodeGenC {
// overload visitor
void VisitStmt_(const AllocateNode* op) final; // NOLINT(*)
void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*)
void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*)
Expand Down
27 changes: 27 additions & 0 deletions tests/python/unittest/test_target_codegen_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing
import tvm.script
from tvm.script import tir as T

tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
Expand Down Expand Up @@ -54,6 +56,31 @@ def check_inf_nan(dev, n, value, dtype):
check_inf_nan(dev, 1, float("nan"), "float16")


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_unaligned_vectorize():
@tvm.script.ir_module
class IRModule:
@T.prim_func
def main(A: T.Buffer((2, 3), "float32"), B: T.Buffer((6,), "float32")):
T.func_attr({"global_symbol": "main"})
for i0_1 in T.thread_binding(3, thread="threadIdx.x"):
for i0_0 in T.vectorized(2):
with T.block("block"):
vi0 = T.axis.spatial(6, i0_0 * 3 + i0_1)
B[vi0] = A[vi0 // 3, vi0 % 3]

target = "metal"
dev = tvm.metal()

a = (np.arange(6).reshape(2, 3)).astype("float32")
a_nd = tvm.nd.array(a, dev)
b_nd = tvm.nd.empty((6,), "float32", dev)
f = tvm.build(IRModule, target=target)
f(a_nd, b_nd)
np.testing.assert_allclose(b_nd.numpy(), a.reshape(6), atol=1e-5, rtol=1e-5)


@tvm.testing.requires_gpu
@tvm.testing.requires_metal
def test_metal_erf():
Expand Down