Skip to content

Commit 243dc9a

Browse files
committed
Annotate internal functions with __device__ instead of __global__
Calling a function annotated with `__global__` can be done from the GPU (see https://stackoverflow.com/a/39448797), but requires a different calling convention.
1 parent 5f1ec7c commit 243dc9a

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/target/source/codegen_cuda.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ void CodeGenCUDA::Init(bool output_ssa) {
4949
ICHECK_EQ(vid_global_barrier_state_, runtime::symbol::tvm_global_barrier_state);
5050
}
5151

52-
void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" __global__ "; }
52+
void CodeGenCUDA::PrintFuncPrefix(std::ostream& os) { os << "extern \"C\" "; }
5353

5454
class ThreadIdxExtractor : public tir::StmtVisitor {
5555
private:
@@ -76,6 +76,12 @@ class ThreadIdxExtractor : public tir::StmtVisitor {
7676
};
7777

7878
void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
79+
if (f->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
80+
os << " __global__ ";
81+
} else {
82+
os << " __device__ ";
83+
}
84+
7985
ThreadIdxExtractor extractor;
8086
extractor(f->body);
8187
arith::Analyzer analyzer;

0 commit comments

Comments
 (0)