Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,14 @@ constexpr const char* tvm_call_packed_lowered = "tvm_call_packed_lowered";
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";
/*!
* \brief See pseudo code
*
* Type tvm_warp_shuffle(Type value, warp_id) {
* return (value passed in by warp indicated by warp_id);
* }
*/
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,15 @@ LoweredFunc ThreadSync(LoweredFunc stmt, std::string storage_scope);
*/
LoweredFunc LowerThreadAllreduce(LoweredFunc f, int warp_size);

/*!
* \brief Lower warp memory in stmt.
* \param f The device function to be lowered.
* \param warp_size the size of warp where no sync is needed.
* this function will only take in effect if warp_size is bigger than one.
* \return Transformed function.
*/
LoweredFunc LowerWarpMemory(LoweredFunc f, int warp_size);

/*!
* \brief Lower packed function call.
* \param f The function to be lowered.
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,10 @@ def build(sch,
else:
raise ValueError("unknown function type %d" % func.func_type)

for i, func in enumerate(fdevice):
warp_size = target.thread_warp_size
fdevice[i] = ir_pass.LowerWarpMemory(func, warp_size)

if "gpu" in target.keys and not fdevice:
warnings.warn(
"Specified target %s, but cannot find device code, did you do bind?" % target)
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LiftAttrScope);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerWarpMemory);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
Expand Down
14 changes: 3 additions & 11 deletions src/codegen/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <iomanip>
#include <cctype>
#include "./codegen_c.h"
#include "../pass/ir_util.h"
#include "../arithmetic/compute_expr.h"

namespace tvm {
Expand Down Expand Up @@ -544,15 +545,6 @@ void CodeGenC::PrintVecBinaryOp(
}
}

inline bool TryGetRamp1Base(Expr index, int lanes, Expr *base) {
const Ramp* r = index.as<Ramp>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}

void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
int lanes = op->type.lanes();
// delcare type.
Expand All @@ -563,7 +555,7 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
CHECK(is_one(op->predicate))
<< "predicated load is not supported";
Expr base;
if (TryGetRamp1Base(op->index, op->type.lanes(), &base)) {
if (GetRamp1Base(op->index, op->type.lanes(), &base)) {
std::string ref = GetVecLoad(op->type, op->buffer_var.get(), base);
os << ref;
} else {
Expand Down Expand Up @@ -617,7 +609,7 @@ void CodeGenC::VisitStmt_(const Store* op) {
CHECK(is_one(op->predicate))
<< "Predicated store is not supported";
Expr base;
if (TryGetRamp1Base(op->index, t.lanes(), &base)) {
if (GetRamp1Base(op->index, t.lanes(), &base)) {
std::string value = this->PrintExpr(op->value);
this->PrintVecStore(op->buffer_var.get(), t, base, value);
} else {
Expand Down
10 changes: 10 additions & 0 deletions src/codegen/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ struct CUDAPopcount {
}
};

struct CUDAShuffle {
std::string operator()(Type t, std::string name) const {
return "__shfl";
}
};

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.exp")
.set_body(DispatchExtern<CUDAFastMath>);

Expand All @@ -67,6 +73,10 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchExtern<CUDAShuffle>);


} // namespace intrin
} // namespace codegen
} // namespace tvm
11 changes: 11 additions & 0 deletions src/codegen/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.pow")
TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.popcount")
.set_body(DispatchExtern<Direct>);

// There is no warp shuffle instruction in standard OpenCL
// When shuffle is used, we assume it is intel's shuffle extension
struct IntelShuffle {
std::string operator()(Type t, std::string name) const {
return "intel_sub_group_shuffle";
}
};

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
.set_body(DispatchExtern<IntelShuffle>);

} // namespace intrin
} // namespace codegen
} // namespace tvm
17 changes: 17 additions & 0 deletions src/pass/ir_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,23 @@ inline int GetTempAllocaAlignment(Type type, int32_t const_size) {
}
return align;
}

/*!
* \brief Pattern match index to Ramp with stride=1
* This is a common pattern in continuous memory load.
* \param index The index formula
* \param lanes number of lanes in the ramp
* \param base The result base.
* \return true if pattern match success and store the base to base.
*/
inline bool GetRamp1Base(Expr index, int lanes, Expr *base) {
const Ramp* r = index.as<Ramp>();
if (!r) return false;
if (!is_one(r->stride)) return false;
CHECK_EQ(r->lanes, lanes);
*base = r->base;
return true;
}
} // namespace ir
} // namespace tvm
#endif // TVM_PASS_IR_UTIL_H_
Loading