Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions src/relay/pass/mac_count.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,21 @@ int64_t ConvMacCount(const Call& call_node) {
return 0;
}
Array<Expr> args = call_node->args;
CHECK(args.size() == 2)
CHECK_EQ(args.size(), 2)
<< "The number of input arguments of a CONV 2D node should be 2.";
const auto* conv_2d_attr = call_node->attrs.as<Conv2DAttrs>();
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape;
std::string data_layout = conv_2d_attr->data_layout;
int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
CHECK(C_ind != -1)
CHECK_NE(C_ind, -1)
<< "There is no input channel dimension.";
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
if (c_ind != -1)
input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
Array<IndexExpr> kernel_size = conv_2d_attr->kernel_size;
CHECK(kernel_size.size() == 2)
CHECK_EQ(kernel_size.size(), 2)
<< "The dimension of the kernel in Conv 2D should be 2.";
const auto* expr = call_node->checked_type().as<TensorTypeNode>();
Array<IndexExpr> output_tensor = expr->shape;
Expand All @@ -99,21 +99,21 @@ int64_t Conv2dTransposeMacCount(const Call& call_node) {
return 0;
}
Array<Expr> args = call_node->args;
CHECK(args.size() == 2)
CHECK_EQ(args.size(), 2)
<< "The number of input arguments of a CONV 2D Transpose node should be 2.";
const auto* conv_2d_transpose_attr = call_node->attrs.as<Conv2DTransposeAttrs>();
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
Array<IndexExpr> data_shape = data_type->shape;
std::string data_layout = conv_2d_transpose_attr->data_layout;
int32_t C_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('C'));
int32_t c_ind = Layout(data_layout).IndexOf(LayoutAxis::Get('c'));
CHECK(C_ind != -1)
CHECK_NE(C_ind, -1)
<< "There is no input channel dimension.";
int64_t input_channel = static_cast<int64_t>(data_shape[C_ind].as<IntImm>()->value);
if (c_ind != -1)
input_channel *= static_cast<int64_t>(data_shape[c_ind].as<IntImm>()->value);
Array<IndexExpr> kernel_size = conv_2d_transpose_attr->kernel_size;
CHECK(kernel_size.size() == 2)
CHECK_EQ(kernel_size.size(), 2)
<< "The dimension of the kernel in Conv 2D Transpose should be 2.";
const auto* expr = call_node->checked_type().as<TensorTypeNode>();
Array<IndexExpr> output_tensor = expr->shape;
Expand All @@ -132,7 +132,7 @@ int64_t DenseMacCount(const Call& call_node) {
return 0;
}
Array<Expr> args = call_node->args;
CHECK(args.size() == 2)
CHECK_EQ(args.size(), 2)
<< "The number of input arguments of a Dense node should be 2.";
const auto* data_type = args[0]->checked_type().as<TensorTypeNode>();
const auto* weight_type = args[1]->checked_type().as<TensorTypeNode>();
Expand All @@ -144,12 +144,28 @@ int64_t DenseMacCount(const Call& call_node) {
int64_t d2 = static_cast<int64_t>(data_shape[1].as<IntImm>()->value);
int64_t d3 = static_cast<int64_t>(weight_shape[0].as<IntImm>()->value);
int64_t d4 = static_cast<int64_t>(weight_shape[1].as<IntImm>()->value);
CHECK(d2 == d4)
CHECK_EQ(d2, d4)
<< "The dimensions of input arguments do not match.";
int64_t count = d1 * d2 * d3;
return count;
}

int64_t BatchMatmulMacCount(const Call& call_node) {
if (!call_node->checked_type_.defined()) {
LOG(WARNING) << "The infer type pass should be called before the mac count pass";
return 0;
}
Array<Expr> args = call_node->args;
CHECK_EQ(args.size(), 2);
Array<IndexExpr> x_shape = args[0]->checked_type().as<TensorTypeNode>()->shape;
Array<IndexExpr> y_shape = args[1]->checked_type().as<TensorTypeNode>()->shape;
int64_t batch = x_shape[0].as<IntImm>()->value;
int64_t m = x_shape[1].as<IntImm>()->value;
int64_t k = x_shape[2].as<IntImm>()->value;
int64_t n = y_shape[1].as<IntImm>()->value;
return batch * m * k * n;
}

RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FMacCount>("FMacCount", ConvMacCount);

Expand All @@ -159,14 +175,17 @@ RELAY_REGISTER_OP("nn.conv2d_transpose")
RELAY_REGISTER_OP("nn.dense")
.set_attr<FMacCount>("FMacCount", DenseMacCount);

RELAY_REGISTER_OP("nn.batch_matmul")
.set_attr<FMacCount>("FMacCount", BatchMatmulMacCount);

class MacCounter : private ExprVisitor {
public:
MacCounter() {
count_ = 0;
}
static int64_t GetTotalMacNumber(const Expr& expr) {
LOG(INFO) << "This pass only counts MACs in direct CONV 2D, "
<< "CONV 2D Transpose and Dense ops";
LOG(INFO) << "This pass only counts MACs in direct conv2d, "
<< "conv2d_transpose, dense, and batch_matmul ops";
MacCounter counter;
counter(expr);
return counter.count_;
Expand Down