Skip to content
This repository was archived by the owner on Aug 5, 2022. It is now read-only.

Commit fba2c64

Browse files
ankalininptbuilder
authored andcommitted
task-mkldnn-integration: MKLDNN to Caffe integration: remove NULLs MKLDNN to Caffe integration. initial stage: scoring only.
2 parents 531cd5d + 2825c19 commit fba2c64

22 files changed

+3690
-9
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ caffe_option(CPU_ONLY "Build Caffe without CUDA support" OFF) # TODO: rename to
3030
caffe_option(USE_OPENMP "Build Caffe with OpenMP support" ON )
3131
caffe_option(USE_CUDNN "Build Caffe with cuDNN library support" ON IF NOT CPU_ONLY)
3232
caffe_option(USE_MKL2017_AS_DEFAULT_ENGINE "Use MKL2017 primitives for supported layers" OFF)
33+
caffe_option(USE_MKLDNN_AS_DEFAULT_ENGINE "Use MKL-DNN primitives for supported layers" OFF)
3334
caffe_option(BUILD_SHARED_LIBS "Build shared libraries" ON)
3435
caffe_option(BUILD_python "Build Python wrapper" ON)
3536
set(python_version "2" CACHE STRING "Specify which Python version to use")

Makefile

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -403,10 +403,23 @@ ifeq ($(WITH_PYTHON_LAYER), 1)
403403
LIBRARIES += $(PYTHON_LIBRARIES)
404404
endif
405405

406+
# MKLDNN configuration
407+
# detect support for mkl-dnn primitives
408+
MKLDNN_LDFLAGS=
409+
MKLDNN_INCLUDE ?= $(MKLDNNROOT)/include
410+
ifneq ("$(wildcard $(MKLDNN_INCLUDE)/mkldnn.hpp)","")
411+
CXXFLAGS += -DMKLDNN_SUPPORTED -std=c++11
412+
ifeq ($(USE_MKLDNN_AS_DEFAULT_ENGINE), 1)
413+
CXXFLAGS += -DUSE_MKLDNN_AS_DEFAULT_ENGINE
414+
endif
415+
LIBRARIES += mkldnn
416+
MKLDNN_LDFLAGS+=-L$(MKLDNNROOT)/lib -Wl,-rpath,$(MKLDNNROOT)/lib
417+
endif
418+
406419
# BLAS configuration (default = MKL)
420+
MKL_LDFLAGS=
407421
MKL_EXTERNAL := 0
408422
BLAS ?= mkl
409-
MKL_LDFLAGS=
410423
ifeq ($(BLAS), mkl)
411424
# MKL
412425
ICC_ON=0
@@ -463,6 +476,9 @@ endif
463476
INCLUDE_DIRS += $(BLAS_INCLUDE)
464477
LIBRARY_DIRS += $(BLAS_LIB)
465478

479+
INCLUDE_DIRS += $(MKLDNN_INCLUDE)
480+
LIBRARY_DIRS += $(MKLDNN_LIB)
481+
466482
LIBRARY_DIRS += $(LIB_BUILD_DIR)
467483

468484
# Automatic dependency generation (nvcc is handled separately)
@@ -672,7 +688,7 @@ $(ALL_BUILD_DIRS): | $(BUILD_DIR_LINK)
672688

673689
$(DYNAMIC_NAME): $(OBJS) | $(LIB_BUILD_DIR)
674690
@ echo LD -o $@
675-
$(Q)$(CXX) -shared -o $@ $(OBJS) $(VERSIONFLAGS) $(LINKFLAGS) $(MKL_LDFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_SHARED_HARDENING_FLAGS) $(LDFLAGS)
691+
$(Q)$(CXX) -shared -o $@ $(OBJS) $(VERSIONFLAGS) $(LINKFLAGS) $(MKL_LDFLAGS) $(MKLDNN_LDFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_SHARED_HARDENING_FLAGS) $(LDFLAGS)
676692
@ cd $(BUILD_DIR)/lib; rm -f $(DYNAMIC_NAME_SHORT); ln -s $(DYNAMIC_VERSIONED_NAME_SHORT) $(DYNAMIC_NAME_SHORT)
677693

678694
$(STATIC_NAME): $(OBJS) | $(LIB_BUILD_DIR)
@@ -704,7 +720,7 @@ $(TEST_ALL_BIN): $(TEST_MAIN_SRC) $(TEST_OBJS) $(GTEST_OBJS) \
704720
| $(DYNAMIC_NAME) $(TEST_BIN_DIR)
705721
@ echo CXX/LD -o $@ $<
706722
$(Q)$(CXX) $(TEST_MAIN_SRC) $(TEST_OBJS) $(GTEST_OBJS) \
707-
-o $@ $(LINKFLAGS) $(MKL_LDFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_EXEC_HARDENING_FLAGS) $(LDFLAGS) -l$(LIBRARY_NAME) -Wl,-rpath,$(ORIGIN)/../lib
723+
-o $@ $(LINKFLAGS) $(MKL_LDFLAGS) $(MKLDNN_LDFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_EXEC_HARDENING_FLAGS) $(LDFLAGS) -l$(LIBRARY_NAME) -Wl,-rpath,$(ORIGIN)/../lib
708724

709725
$(TEST_CU_BINS): $(TEST_BIN_DIR)/%.testbin: $(TEST_CU_BUILD_DIR)/%.o \
710726
$(GTEST_OBJS) | $(DYNAMIC_NAME) $(TEST_BIN_DIR)
@@ -716,7 +732,7 @@ $(TEST_CXX_BINS): $(TEST_BIN_DIR)/%.testbin: $(TEST_CXX_BUILD_DIR)/%.o \
716732
$(GTEST_OBJS) | $(DYNAMIC_NAME) $(TEST_BIN_DIR)
717733
@ echo LD $<
718734
$(Q)$(CXX) $(TEST_MAIN_SRC) $< $(GTEST_OBJS) \
719-
-o $@ $(LINKFLAGS) $(MKL_LDFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_EXEC_HARDENING_FLAGS) $(LDFLAGS) -l$(LIBRARY_NAME) -Wl,-rpath,$(ORIGIN)/../lib
735+
-o $@ $(LINKFLAGS) $(MKL_LDFLAGS) $(MKLDNN_LDFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_EXEC_HARDENING_FLAGS) $(LDFLAGS) -l$(LIBRARY_NAME) -Wl,-rpath,$(ORIGIN)/../lib
720736

721737
# Target for extension-less symlinks to tool binaries with extension '*.bin'.
722738
$(TOOL_BUILD_DIR)/%: $(TOOL_BUILD_DIR)/%.bin | $(TOOL_BUILD_DIR)
@@ -725,12 +741,12 @@ $(TOOL_BUILD_DIR)/%: $(TOOL_BUILD_DIR)/%.bin | $(TOOL_BUILD_DIR)
725741

726742
$(TOOL_BINS): %.bin : %.o | $(DYNAMIC_NAME)
727743
@ echo CXX/LD -o $@
728-
$(Q)$(CXX) $< -o $@ $(LINKFLAGS) $(MKL_LDFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_EXEC_HARDENING_FLAGS) -l$(LIBRARY_NAME) $(LDFLAGS) \
744+
$(Q)$(CXX) $< -o $@ $(LINKFLAGS) $(MKL_LDFLAGS) $(MKLDNN_LDFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_EXEC_HARDENING_FLAGS) -l$(LIBRARY_NAME) $(LDFLAGS) \
729745
-Wl,-rpath,$(ORIGIN)/../lib
730746

731747
$(EXAMPLE_BINS): %.bin : %.o | $(DYNAMIC_NAME)
732748
@ echo CXX/LD -o $@
733-
$(Q)$(CXX) $< -o $@ $(LINKFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_EXEC_HARDENING_FLAGS) -l$(LIBRARY_NAME) $(LDFLAGS) \
749+
$(Q)$(CXX) $< -o $@ $(LINKFLAGS) $(MKL_LDFLAGS) $(MKLDNN_LDFLAGS) $(CXX_HARDENING_FLAGS) $(LINKER_EXEC_HARDENING_FLAGS) -l$(LIBRARY_NAME) $(LDFLAGS) \
734750
-Wl,-rpath,$(ORIGIN)/../../lib
735751

736752
proto: $(PROTO_GEN_CC) $(PROTO_GEN_HEADER)

Makefile.config.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
# Uncomment to use optimized MKL2017 primitives by default for supported layers
1111
# USE_MKL2017_AS_DEFAULT_ENGINE := 1
12+
# USE_MKLDNN_AS_DEFAULT_ENGINE := 1
1213

1314
# uncomment to disable IO dependencies and corresponding data layers
1415
# USE_OPENCV := 0

cmake/Dependencies.cmake

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,26 @@ if(BLAS STREQUAL "MKL" OR BLAS STREQUAL "mkl")
184184
endif()
185185
endif()
186186

187+
# ---[ MKLDNN
188+
set(MKLDNN_INCLUDE_DIR "$ENV{MKLDNNROOT}/include/")
189+
if(EXISTS ${MKLDNN_INCLUDE_DIR}/mkldnn.hpp)
190+
message(STATUS "Found MKLDNN")
191+
set(MKLDNN_SUPPORTED ON)
192+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DMKLDNN_SUPPORTED -std=c++11")
193+
if(USE_MKLDNN_AS_DEFAULT_ENGINE)
194+
message(STATUS "MKLDNN engine will be used as a default engine")
195+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_MKLDNN_AS_DEFAULT_ENGINE")
196+
endif()
197+
list(APPEND Caffe_LINKER_LIBS "$ENV{MKLDNNROOT}/lib/libmkldnn.so")
198+
include_directories(SYSTEM ${MKLDNN_INCLUDE_DIR})
199+
else()
200+
message(STATUS "MKLDNN not found. MKLDNN_INCLUDE_DIR = ${MKLDNN_INCLUDE_DIR}")
201+
set(MKLDNN_SUPPORTED OFF)
202+
if(USE_MKLDNN_AS_DEFAULT_ENGINE)
203+
message(WARNING "Flag USE_MKLDNN_AS_DEFAULT_ENGINE was set, but MKLDNN not found")
204+
endif()
205+
endif()
206+
187207
# ---[ Python
188208
if(BUILD_python)
189209
if(NOT "${python_version}" VERSION_LESS "3.0.0")

cmake/Summary.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ function(caffe_print_configuration_summary)
126126
if(BLAS STREQUAL "MKL" OR BLAS STREQUAL "mkl")
127127
caffe_status(" MKL2017_SUPPORTED : " MKL2017_SUPPORTED AND USE_MKL2017_AS_DEFAULT_ENGINE THEN "ON, is a default engine" ELSE " ${MKL2017_SUPPORTED}")
128128
endif()
129+
caffe_status(" MKLDNN_SUPPORTED : " MKLDNN_SUPPORTED AND USE_MKLDNN_AS_DEFAULT_ENGINE THEN "ON, is a default engine" ELSE " ${MKLDNN_SUPPORTED}")
129130
caffe_status(" Boost : Yes (ver. ${Boost_MAJOR_VERSION}.${Boost_MINOR_VERSION})")
130131
caffe_status(" glog : Yes")
131132
caffe_status(" gflags : Yes")
Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,250 @@
1+
#ifndef CAFFE_MKLDNN_LAYERS_HPP_
2+
#define CAFFE_MKLDNN_LAYERS_HPP_
3+
4+
#include <string>
5+
#include <vector>
6+
7+
#include "boost/enable_shared_from_this.hpp"
8+
#include "caffe/blob.hpp"
9+
#include "caffe/common.hpp"
10+
#include "caffe/layers/base_conv_layer.hpp"
11+
#include "caffe/layers/conv_layer.hpp"
12+
#include "caffe/layers/inner_product_layer.hpp"
13+
#include "caffe/layers/neuron_layer.hpp"
14+
#include "caffe/proto/caffe.pb.h"
15+
#include "caffe/mkldnn_memory.hpp"
16+
#include "mkldnn.hpp"
17+
18+
using namespace mkldnn;
19+
20+
namespace caffe {
21+
22+
// ===== CpuEngine =======================================
23+
// cpu_engine singleton
24+
class CpuEngine
25+
{
26+
public:
27+
static CpuEngine & Instance()
28+
{
29+
// I's thread-safe in C++11.
30+
static CpuEngine myInstance;
31+
return myInstance;
32+
}
33+
CpuEngine(CpuEngine const&) = delete; // Copy construct
34+
CpuEngine(CpuEngine&&) = delete; // Move construct
35+
CpuEngine& operator=(CpuEngine const&) = delete; // Copy assign
36+
CpuEngine& operator=(CpuEngine &&) = delete; // Move assign
37+
38+
engine & get_engine() { return _cpu_engine; }
39+
protected:
40+
CpuEngine() : _cpu_engine(engine::cpu, 0) {}
41+
~CpuEngine() {}
42+
private:
43+
engine _cpu_engine;
44+
};
45+
46+
// ===== MKLDNNConvolutionLayer =======================================
47+
template <typename Dtype>
48+
class MKLDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
49+
public:
50+
explicit MKLDNNConvolutionLayer(const LayerParameter& param);
51+
virtual ~MKLDNNConvolutionLayer() {}
52+
protected:
53+
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
54+
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
55+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down
56+
, const vector<Blob<Dtype>*>& bottom);
57+
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down
58+
, const vector<Blob<Dtype>*>& bottom);
59+
// Customized methods
60+
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
61+
void Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
62+
private:
63+
virtual void compute_output_shape();
64+
virtual void init_properties(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
65+
void InitConvolution(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
66+
67+
shared_ptr<MKLDNNData<Dtype> > fwd_bottom_data, fwd_top_data, fwd_weights_data, fwd_bias_data;
68+
shared_ptr<convolution::primitive_desc> convFwd_pd;
69+
70+
shared_ptr<convolution> convFwd;
71+
shared_ptr<memory> input_memory, weights_memory, bias_memory, output_memory;
72+
73+
uint32_t width_, height_, width_out_, height_out_, kernel_w_, kernel_h_, stride_w_, stride_h_;
74+
int pad_w_, pad_h_;
75+
};
76+
77+
// ===== MKLDNNInnerProductLayer =======================================
78+
template <typename Dtype>
79+
class MKLDNNInnerProductLayer : public InnerProductLayer<Dtype> {
80+
public:
81+
explicit MKLDNNInnerProductLayer(const LayerParameter& param);
82+
virtual ~MKLDNNInnerProductLayer();
83+
protected:
84+
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
85+
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
86+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down
87+
, const vector<Blob<Dtype>*>& bottom);
88+
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down
89+
, const vector<Blob<Dtype>*>& bottom);
90+
// Customized methods
91+
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
92+
void Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
93+
private:
94+
void InitInnerProduct(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
95+
96+
shared_ptr<MKLDNNData<Dtype> > fwd_bottom_data, fwd_top_data, fwd_weights_data, fwd_bias_data;
97+
shared_ptr<inner_product::primitive_desc> ipFwd_pd;
98+
99+
shared_ptr<inner_product> ipFwd;
100+
shared_ptr<memory> input_memory, weights_memory, bias_memory, output_memory;
101+
102+
uint32_t w_, h_;
103+
};
104+
105+
106+
/**
107+
* @brief Normalize the input in a local region across feature maps.
108+
*/
109+
110+
// ===== MKLDNNLRNLayer =======================================
111+
template <typename Dtype>
112+
class MKLDNNLRNLayer : public Layer<Dtype> {
113+
public:
114+
explicit MKLDNNLRNLayer(const LayerParameter& param)
115+
: Layer<Dtype>(param)
116+
, fwd_top_data(NULL)
117+
, fwd_bottom_data(NULL)
118+
, lrnFwd_pd(NULL)
119+
, lrnFwd(NULL)
120+
, input_memory(NULL)
121+
, output_memory(NULL)
122+
, scratch_(NULL) {}
123+
virtual ~MKLDNNLRNLayer() {}
124+
protected:
125+
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
126+
virtual void Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
127+
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
128+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down
129+
, const vector<Blob<Dtype>*>& bottom);
130+
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
131+
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down
132+
, const vector<Blob<Dtype>*>& bottom);
133+
134+
virtual inline const char* type() const { return "LRN"; }
135+
virtual inline int ExactNumBottomBlobs() const { return 1; }
136+
virtual inline int ExactNumTopBlobs() const { return 1; }
137+
private:
138+
void InitLRN(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
139+
140+
Dtype alpha_, beta_, k_;
141+
int size_, num_, width_, height_, channels_;
142+
143+
shared_ptr<MKLDNNData<Dtype> > fwd_top_data, fwd_bottom_data;
144+
shared_ptr<lrn::primitive_desc> lrnFwd_pd;
145+
146+
shared_ptr<lrn> lrnFwd;
147+
shared_ptr<memory> input_memory, output_memory;
148+
149+
shared_ptr<memory> scratch_;
150+
};
151+
152+
// ===== MKLDNNPoolingLayer =======================================
153+
template <typename Dtype>
154+
class MKLDNNPoolingLayer : public Layer<Dtype> {
155+
public:
156+
explicit MKLDNNPoolingLayer(const LayerParameter& param)
157+
: Layer<Dtype>(param)
158+
, fwd_top_data(NULL)
159+
, fwd_bottom_data(NULL)
160+
, poolingFwd_pd(NULL)
161+
, poolingFwd(NULL)
162+
, indices_memory(NULL)
163+
, input_memory(NULL)
164+
, output_memory(NULL)
165+
, indices_pd(NULL)
166+
{}
167+
~MKLDNNPoolingLayer() {}
168+
protected:
169+
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
170+
virtual void Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
171+
172+
virtual inline const char* type() const { return "Pooling"; }
173+
virtual inline int ExactNumBottomBlobs() const { return 1; }
174+
virtual inline int MinTopBlobs() const { return 1; }
175+
// MAX POOL layers can output an extra top blob for the mask;
176+
// others can only output the pooled inputs.
177+
virtual inline int MaxTopBlobs() const {
178+
return (this->layer_param_.pooling_param().pool() == PoolingParameter_PoolMethod_MAX) ? 2 : 1;
179+
}
180+
protected:
181+
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
182+
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
183+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,const vector<bool>& propagate_down
184+
,const vector<Blob<Dtype>*>& bottom);
185+
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down
186+
,const vector<Blob<Dtype>*>& bottom);
187+
188+
private:
189+
void InitPooling(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
190+
191+
uint32_t num_, channels_, width_, height_, width_out_, height_out_;
192+
uint32_t kernel_w_, kernel_h_;
193+
uint32_t stride_w_, stride_h_;
194+
int32_t pad_w_, pad_h_;
195+
196+
Blob<uint32_t> max_idx_;
197+
bool global_pooling_;
198+
199+
shared_ptr<MKLDNNData<Dtype> > fwd_top_data, fwd_bottom_data;
200+
shared_ptr<pooling::primitive_desc> poolingFwd_pd;
201+
shared_ptr<pooling> poolingFwd;
202+
shared_ptr<memory> indices_memory, input_memory, output_memory;
203+
shared_ptr<memory::primitive_desc> indices_pd;
204+
205+
};
206+
207+
// ===== MKLDNNReLULayer =======================================
208+
template <typename Dtype>
209+
class MKLDNNReLULayer : public NeuronLayer<Dtype> {
210+
public:
211+
/**
212+
* @param param provides ReLUParameter relu_param,
213+
* with ReLULayer options:
214+
* - negative_slope (\b optional, default 0).
215+
* the value @f$ \nu @f$ by which negative values are multiplied.
216+
*/
217+
explicit MKLDNNReLULayer(const LayerParameter& param)
218+
: NeuronLayer<Dtype>(param)
219+
, fwd_top_data (NULL)
220+
, fwd_bottom_data (NULL)
221+
, reluFwd_pd(NULL)
222+
, reluFwd(NULL)
223+
, input_memory(NULL)
224+
, output_memory(NULL)
225+
{}
226+
227+
~MKLDNNReLULayer() {}
228+
protected:
229+
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
230+
virtual void Reshape(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
231+
virtual inline const char* type() const { return "ReLU"; }
232+
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
233+
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
234+
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down
235+
, const vector<Blob<Dtype>*>& bottom);
236+
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down
237+
, const vector<Blob<Dtype>*>& bottom);
238+
private:
239+
void InitReLU(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top);
240+
shared_ptr<MKLDNNData<Dtype> > fwd_top_data, fwd_bottom_data;
241+
shared_ptr<relu::primitive_desc> reluFwd_pd;
242+
243+
shared_ptr<relu> reluFwd;
244+
shared_ptr<memory> input_memory, output_memory;
245+
246+
uint32_t num_, width_, height_, channels_;
247+
};
248+
249+
} // namespace caffe
250+
#endif // #ifndef CAFFE_MKLDNN_LAYERS_HPP_

0 commit comments

Comments
 (0)