Skip to content

Commit df24393

Browse files
tqchenjunrushao
andauthored
[TEST] Move llvm import test away from minimum test (#9171)
* [TEST] Move llvm import test away from minimum test The llvm import relies on the same system clang and llvm version and may be tricky to get right on all platforms. Given this is an advanced feature, and there has been some problems in windows(could relates to clang version update). This PR moves away from minimum tests. * Update test_minimal_target_codegen_llvm.py * Update test_target_codegen_llvm.py Co-authored-by: Junru Shao <[email protected]>
1 parent 659f3b7 commit df24393

File tree

2 files changed

+42
-42
lines changed

2 files changed

+42
-42
lines changed

tests/python/all-platform-minimal-test/test_minimal_target_codegen_llvm.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import tvm.testing
2121
from tvm import te
2222
from tvm import topi
23-
from tvm.contrib import utils, clang
23+
from tvm.contrib import utils
2424
import numpy as np
2525
import ctypes
2626
import math
@@ -65,43 +65,3 @@ def check_llvm():
6565
tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())
6666

6767
check_llvm()
68-
69-
70-
@tvm.testing.requires_llvm
71-
def test_llvm_import():
72-
"""all-platform-minimal-test: check shell dependent clang behavior."""
73-
# extern "C" is necessary to get the correct signature
74-
cc_code = """
75-
extern "C" float my_add(float x, float y) {
76-
return x + y;
77-
}
78-
"""
79-
n = 10
80-
A = te.placeholder((n,), name="A")
81-
B = te.compute(
82-
(n,), lambda *i: tvm.tir.call_pure_extern("float32", "my_add", A(*i), 1.0), name="B"
83-
)
84-
85-
def check_llvm(use_file):
86-
if not clang.find_clang(required=False):
87-
print("skip because clang is not available")
88-
return
89-
temp = utils.tempdir()
90-
ll_path = temp.relpath("temp.ll")
91-
ll_code = clang.create_llvm(cc_code, output=ll_path)
92-
s = te.create_schedule(B.op)
93-
if use_file:
94-
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path)
95-
else:
96-
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code)
97-
# BUILD and invoke the kernel.
98-
f = tvm.build(s, [A, B], "llvm")
99-
dev = tvm.cpu(0)
100-
# launch the kernel.
101-
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
102-
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)
103-
f(a, b)
104-
tvm.testing.assert_allclose(b.numpy(), a.numpy() + 1.0)
105-
106-
check_llvm(use_file=True)
107-
check_llvm(use_file=False)

tests/python/unittest/test_target_codegen_llvm.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import tvm.testing
2424
from tvm import te
2525
from tvm import topi
26-
from tvm.contrib import utils
26+
from tvm.contrib import utils, clang
2727
import numpy as np
2828
import ctypes
2929
import math
@@ -845,5 +845,45 @@ def make_call_extern(caller, callee):
845845
assert matches == sorted(matches)
846846

847847

848+
@tvm.testing.requires_llvm
849+
def test_llvm_import():
850+
"""all-platform-minimal-test: check shell dependent clang behavior."""
851+
# extern "C" is necessary to get the correct signature
852+
cc_code = """
853+
extern "C" float my_add(float x, float y) {
854+
return x + y;
855+
}
856+
"""
857+
n = 10
858+
A = te.placeholder((n,), name="A")
859+
B = te.compute(
860+
(n,), lambda *i: tvm.tir.call_pure_extern("float32", "my_add", A(*i), 1.0), name="B"
861+
)
862+
863+
def check_llvm(use_file):
864+
if not clang.find_clang(required=False):
865+
print("skip because clang is not available")
866+
return
867+
temp = utils.tempdir()
868+
ll_path = temp.relpath("temp.ll")
869+
ll_code = clang.create_llvm(cc_code, output=ll_path)
870+
s = te.create_schedule(B.op)
871+
if use_file:
872+
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_path)
873+
else:
874+
s[B].pragma(s[B].op.axis[0], "import_llvm", ll_code)
875+
# BUILD and invoke the kernel.
876+
f = tvm.build(s, [A, B], "llvm")
877+
dev = tvm.cpu(0)
878+
# launch the kernel.
879+
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
880+
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)
881+
f(a, b)
882+
tvm.testing.assert_allclose(b.numpy(), a.numpy() + 1.0)
883+
884+
check_llvm(use_file=True)
885+
check_llvm(use_file=False)
886+
887+
848888
if __name__ == "__main__":
849889
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 commit comments

Comments
 (0)