Skip to content

Commit 9e1a5ec

Browse files
authored
[RUNTIME] Enable OpenCL (#17)
1 parent e9ff9a8 commit 9e1a5ec

File tree

13 files changed

+415
-18
lines changed

13 files changed

+415
-18
lines changed

Makefile

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ endif
2626
export LDFLAGS = -pthread -lm
2727
export CFLAGS = -std=c++11 -Wall -O2\
2828
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
29+
export FRAMEWORKS=
2930

3031
ifneq ($(ADD_CFLAGS), NONE)
3132
CFLAGS += $(ADD_CFLAGS)
@@ -43,6 +44,20 @@ else
4344
CFLAGS += -DTVM_CUDA_RUNTIME=0
4445
endif
4546

47+
48+
ifeq ($(USE_OPENCL), 1)
49+
CFLAGS += -DTVM_OPENCL_RUNTIME=1
50+
UNAME_S := $(shell uname -s)
51+
ifeq ($(UNAME_S), Darwin)
52+
FRAMEWORKS += -framework OpenCL
53+
else
54+
LDFLAGS += -lOpenCL
55+
endif
56+
else
57+
CFLAGS += -DTVM_OPENCL_RUNTIME=0
58+
endif
59+
60+
4661
include tests/cpp/unittest.mk
4762

4863
test: $(TEST)
@@ -59,7 +74,7 @@ lib/libtvm.a: $(ALL_DEP)
5974

6075
lib/libtvm.so: $(ALL_DEP)
6176
@mkdir -p $(@D)
62-
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
77+
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
6378

6479
$(LIB_HALIDE_IR): LIBHALIDEIR
6580

include/tvm/c_runtime_api.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,23 @@ typedef TVMArray* TVMArrayHandle;
150150
*/
151151
TVM_DLL const char *TVMGetLastError(void);
152152

153+
/*!
154+
* \brief Initialize certain type of devices, this may
155+
* not be necessary for all device types. But is needed for OpenCL.
156+
*
157+
* \param dev_mask The device mask of device type to be initialized
158+
* \param option_keys Additional option keys to pass.
159+
* \param option_vals Additional option values to pass
160+
* \param num_options Number of options to be passed into it.
161+
* \param out_code 1: success, 0: already initialized
162+
* \return Whether the function is successful.
163+
*/
164+
TVM_DLL int TVMDeviceInit(int dev_mask,
165+
const char** option_keys,
166+
const char** option_vals,
167+
int num_options,
168+
int *out_code);
169+
153170
/*!
154171
* \brief Whether the specified context is enabled.
155172
*

make/config.mk

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ ADD_CFLAGS =
3737
# whether use CUDA during compile
3838
USE_CUDA = 1
3939

40+
# whether use OpenCL during compile
41+
USE_OPENCL = 0
42+
4043
# add the path to CUDA library to link and compile flag
4144
# if you have already add them to environment variable, leave it as NONE
4245
# USE_CUDA_PATH = /usr/local/cuda

python/tvm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import schedule
1313

1414
from . import ndarray as nd
15-
from .ndarray import cpu, gpu, opencl
15+
from .ndarray import cpu, gpu, opencl, init_opencl
1616

1717
from ._base import TVMError
1818
from .function import *

python/tvm/_ctypes/_runtime_api.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88

99
from .._base import _LIB
10-
from .._base import c_array
10+
from .._base import c_array, c_str
1111
from .._base import check_call
1212

1313

@@ -182,6 +182,30 @@ def sync(ctx):
182182
check_call(_LIB.TVMSynchronize(ctx, None))
183183

184184

185+
def init_opencl(**kwargs):
186+
"""Initialize the opencl with the options.
187+
188+
Parameters
189+
----------
190+
kwargs : dict
191+
The options
192+
"""
193+
keys = []
194+
vals = []
195+
for k, v in kwargs.items():
196+
keys.append(c_str(k))
197+
vals.append(c_str(v))
198+
dev_mask = ctypes.c_int(4)
199+
out_code = ctypes.c_int()
200+
check_call(_LIB.TVMDeviceInit(
201+
dev_mask,
202+
c_array(ctypes.c_char_p, keys),
203+
c_array(ctypes.c_char_p, vals),
204+
ctypes.c_int(len(keys)),
205+
ctypes.byref(out_code)))
206+
return out_code.value != 0
207+
208+
185209
class NDArrayBase(object):
186210
"""A simple Device/CPU Array object in runtime."""
187211
__slots__ = ["handle"]

python/tvm/ndarray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ._ctypes._runtime_api import TVMContext, TVMDataType, NDArrayBase
1010
from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync
1111
from ._ctypes._runtime_api import _init_runtime_module
12+
from ._ctypes._runtime_api import init_opencl
1213

1314

1415
class NDArray(NDArrayBase):

python/tvm/schedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __getitem__(self, k):
2424
k = k.op
2525
if not isinstance(k, _tensor.Operation):
2626
raise ValueError("Expect schedule key to be Tensor or Operation")
27-
if not k in self.stage_map:
27+
if k not in self.stage_map:
2828
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
2929
return self.stage_map[k]
3030

src/runtime/c_runtime_api.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,23 @@ inline size_t GetDataAlignment(TVMArray* arr) {
6464

6565
using namespace tvm::runtime;
6666

67+
int TVMDeviceInit(int dev_mask,
68+
const char** option_keys,
69+
const char** option_vals,
70+
int num_options,
71+
int* out_code) {
72+
API_BEGIN();
73+
*out_code = 1;
74+
switch (dev_mask) {
75+
case kOpenCL: {
76+
*out_code = DeviceInit<kOpenCL>(option_keys, option_vals, num_options);
77+
break;
78+
}
79+
default: break;
80+
}
81+
API_END();
82+
}
83+
6784
int TVMContextEnabled(TVMContext ctx,
6885
int* out_enabled) {
6986
API_BEGIN();

src/runtime/device_api.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*!
22
* Copyright (c) 2016 by Contributors
3-
* \file device_api.hx
3+
* \file device_api.h
44
* \brief Device specific API
55
*/
66
#ifndef TVM_RUNTIME_DEVICE_API_H_
@@ -11,6 +11,21 @@
1111

1212
namespace tvm {
1313
namespace runtime {
14+
/*!
15+
* \brief Initialize the device.
16+
* \param option_keys Additional option keys to pass.
17+
* \param option_vals Additional option values to pass
18+
* \param num_options Number of options to be passed into it.
19+
* \return 0 if success, 1: if already initialized
20+
* \tparam xpu The device mask.
21+
*/
22+
template<TVMDeviceMask xpu>
23+
inline bool DeviceInit(const char** option_keys,
24+
const char** option_vals,
25+
int num_options) {
26+
return true;
27+
}
28+
1429
/*!
1530
* \brief Whether ctx is enabled.
1631
* \param ctx The device context to perform operation.
@@ -93,7 +108,8 @@ inline void StreamSync(TVMContext ctx, TVMStreamHandle stream);
93108
} // namespace runtime
94109
} // namespace tvm
95110

96-
#include "./device_api_gpu.h"
97111
#include "./device_api_cpu.h"
112+
#include "./device_api_gpu.h"
113+
#include "./device_api_opencl.h"
98114

99115
#endif // TVM_RUNTIME_DEVICE_API_H_

src/runtime/device_api_gpu.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/*!
22
* Copyright (c) 2016 by Contributors
3-
* \file ctxice_api_gpu.h
3+
* \file device_api_gpu.h
44
* \brief GPU specific API
55
*/
66
#ifndef TVM_RUNTIME_DEVICE_API_GPU_H_
@@ -14,15 +14,6 @@
1414

1515
namespace tvm {
1616
namespace runtime {
17-
/*!
18-
* \brief Check CUDA error.
19-
* \param msg Message to print if an error occured.
20-
*/
21-
#define CHECK_CUDA_ERROR(msg) \
22-
{ \
23-
cudaError_t e = cudaGetLastError(); \
24-
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
25-
}
2617

2718
/*!
2819
* \brief Protected CUDA call.

0 commit comments

Comments
 (0)