Skip to content

Commit 92d77a8

Browse files
author
Ashutosh Parkhi
committed
[CMSIS-NN] Moved TIR Generation to C++
1 parent 01aeeb1 commit 92d77a8

File tree

7 files changed

+346
-335
lines changed

7 files changed

+346
-335
lines changed

python/tvm/relay/backend/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,3 @@
1616
# under the License.
1717
"""Backend codegen modules for relay."""
1818
from . import compile_engine
19-
from .contrib import cmsisnn

python/tvm/relay/backend/contrib/cmsisnn/__init__.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

python/tvm/relay/backend/contrib/cmsisnn/codegen.py

Lines changed: 0 additions & 134 deletions
This file was deleted.

python/tvm/relay/op/contrib/cmsisnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,5 +80,5 @@ def check_quantized_softmax(extract):
8080
)
8181

8282
return [
83-
("cmsisnn.qnn_softmax", softmax_pattern(), check_quantized_softmax),
83+
("cmsisnn.quantized_softmax", softmax_pattern(), check_quantized_softmax),
8484
]

src/relay/backend/contrib/cmsisnn/codegen_cmsisnn.cc

Lines changed: 27 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -16,190 +16,36 @@
1616
* specific language governing permissions and limitations
1717
* under the License.
1818
*/
19-
#include <cmath>
20-
#include <fstream>
21-
#include <map>
22-
#include <sstream>
23-
#include <string>
24-
#include <vector>
25-
26-
#include "../../../../runtime/file_utils.h"
27-
#include "../../../../target/source/codegen_c.h"
28-
#include "../../../qnn/utils.h"
19+
#include <tvm/relay/transform.h>
20+
#include <tvm/runtime/module.h>
21+
#include <tvm/runtime/registry.h>
2922

3023
namespace tvm {
31-
namespace runtime {
32-
33-
using namespace tir;
34-
35-
class CodeGenCMSISNN : public tvm::codegen::CodeGenC {
36-
public:
37-
void Init(bool output_ssa) {
38-
decl_stream << "#include <stdio.h>\n";
39-
decl_stream << "#include <stdlib.h>\n";
40-
decl_stream << "#include <dlpack/dlpack.h>\n";
41-
decl_stream << "#include <tvm/runtime/crt/module.h>\n";
42-
decl_stream << "#include <arm_nnfunctions.h>\n";
43-
CodeGenC::Init(output_ssa);
44-
}
45-
46-
/*!
47-
* \brief Emit code that offloads a subgraph to the Cortex-M
48-
*
49-
* \return string of code that offloads a subgraph to the Cortex-M
50-
*/
51-
void AddFunction(const PrimFunc& prim_func) {
52-
PrintExternCPrefix(stream);
53-
CodeGenC::AddFunction(prim_func);
54-
PrintExternCPostfix(stream);
55-
}
56-
57-
private:
58-
void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
59-
if (!op->op.same_as(builtin::call_extern())) {
60-
return;
61-
}
62-
std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;
63-
if (cmsis_func_name == "arm_softmax_s8") {
64-
EmitSoftmax(op);
65-
}
66-
return;
67-
}
68-
69-
/*! * \brief Creates a cplusplus guard prefix for extern "C" printing */
70-
void PrintExternCPrefix(std::ostringstream& ss) {
71-
PrintIndent();
72-
ss << "#ifdef __cplusplus\n";
73-
ss << "extern \"C\" {\n";
74-
ss << "#endif\n";
75-
}
76-
77-
/*! * \brief Creates a cplusplus guard postfix for extern "C" printing */
78-
void PrintExternCPostfix(std::ostringstream& ss) {
79-
PrintIndent();
80-
ss << "#ifdef __cplusplus\n";
81-
ss << "}\n";
82-
ss << "#endif\n";
83-
}
84-
85-
/*! * \brief Emits CMSIS-NN code block for softmax */
86-
void EmitSoftmax(const CallNode* op) {
87-
// @tir.call_extern("arm_softmax_s8", buffer_0, num_rows, row_size, scale, buffer_1, dtype=int8)
88-
std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value;
89-
int32_t num_rows = op->args[2].as<IntImmNode>()->value;
90-
int32_t row_size = op->args[3].as<IntImmNode>()->value;
91-
float quant_scale = op->args[4].as<FloatImmNode>()->value;
92-
93-
// calculate multiplier and shift for CMSIS-NN softmax API
94-
// Note: tfl micro assumptions
95-
// TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8);
96-
// TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
97-
// TF_LITE_ENSURE(context, output->params.scale == 1.f / 256);
98-
double beta = 1.0;
99-
int32_t input_bits = 5;
100-
double beta_multiplier = (beta * quant_scale * (1 << (31 - input_bits)));
101-
beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0);
102-
auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier);
103-
int32_t mult = std::get<0>(mult_shift_pair);
104-
int32_t shift = std::get<1>(mult_shift_pair);
105-
int32_t diff_min = (1 << 5) - 1;
106-
diff_min <<= (31 - 5);
107-
diff_min >>= shift;
108-
diff_min *= -1;
109-
110-
PrintIndent();
111-
stream << "int32_t num_rows = " << num_rows << ";\n";
112-
PrintIndent();
113-
stream << "int32_t row_size = " << row_size << ";\n";
114-
PrintIndent();
115-
stream << "int32_t mult = " << mult << ";\n";
116-
PrintIndent();
117-
stream << "int32_t shift = " << shift << ";\n";
118-
PrintIndent();
119-
stream << "int32_t diff_min = " << diff_min << ";\n";
120-
PrintIndent();
121-
stream << cmsis_func_name << "(buffer,";
122-
PrintIndent();
123-
stream << " num_rows, row_size, mult, shift, diff_min, buffer1);\n";
124-
PrintIndent();
125-
stream << "return;\n";
126-
}
127-
};
128-
129-
class CMSISNNModuleNode : public runtime::ModuleNode {
130-
public:
131-
CMSISNNModuleNode(const std::string& code, const std::string& fmt,
132-
const Array<String>& func_names)
133-
: code_(code), fmt_(fmt), func_names_(func_names) {}
134-
135-
std::string GetSource(const std::string& format) final { return code_; }
136-
137-
const char* type_key() const { return "c"; }
138-
139-
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
140-
if (name == "get_symbol") {
141-
return PackedFunc(
142-
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; });
143-
} else if (name == "get_func_names") {
144-
return PackedFunc(
145-
[sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; });
146-
} else {
147-
return PackedFunc(nullptr);
148-
}
149-
}
150-
151-
void SaveToFile(const std::string& file_name, const std::string& format) final {
152-
std::string fmt = GetFileFormat(file_name, format);
153-
std::string meta_file = GetMetaFilePath(file_name);
154-
if (fmt == "c" || fmt == "cu") {
155-
ICHECK_NE(code_.length(), 0);
156-
SaveBinaryToFile(file_name, code_);
157-
} else {
158-
ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
159-
}
160-
}
161-
162-
protected:
163-
std::string code_;
164-
std::string fmt_;
165-
Array<String> func_names_;
166-
};
167-
168-
class CMSISNNModule : public Module {
169-
public:
170-
CMSISNNModule() {}
171-
explicit CMSISNNModule(ObjectPtr<Object> n) : Module(n) {}
172-
inline CMSISNNModuleNode* operator->();
173-
inline const CMSISNNModuleNode* operator->() const;
174-
};
175-
176-
inline CMSISNNModuleNode* CMSISNNModule::operator->() {
177-
return static_cast<CMSISNNModuleNode*>(get_mutable());
178-
}
179-
180-
static Module CMSISNNModuleNodeCreate(IRModule mod) {
181-
bool output_ssa = false;
182-
CodeGenCMSISNN cg;
183-
Array<String> function_names;
184-
cg.Init(output_ssa);
185-
ICHECK(mod->functions.size() == 1) << "Supports modules with single PrimFunc.";
186-
for (auto kv : mod->functions) {
187-
ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc";
188-
auto f = Downcast<PrimFunc>(kv.second);
189-
auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
190-
ICHECK(global_symbol.defined())
191-
<< "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
192-
function_names.push_back(global_symbol.value());
193-
cg.AddFunction(f);
194-
}
195-
std::string code = cg.Finish();
196-
auto n = make_object<CMSISNNModuleNode>(code, "c", function_names);
197-
return Module(n);
24+
namespace relay {
25+
namespace contrib {
26+
namespace cmsisnn {
27+
28+
transform::Pass RelayToTIR();
29+
30+
runtime::Module CompileCMSISNN(const ObjectRef& ref) {
31+
IRModule relay_mod;
32+
Function relay_func = Downcast<Function>(ref);
33+
auto func_name = relay_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
34+
GlobalVar var = GlobalVar(func_name.value());
35+
relay_mod->Add(var, relay_func);
36+
relay_mod = transform::InferType()(relay_mod);
37+
38+
Array<transform::Pass> pass_seqs{transform::InferType(), RelayToTIR()};
39+
transform::Sequential seq(pass_seqs);
40+
IRModule tir_mod = seq(relay_mod);
41+
42+
const auto* pf = runtime::Registry::Get("runtime.CMSISNNModuleNodeCreate");
43+
return (*pf)(tir_mod);
19844
}
19945

200-
TVM_REGISTER_GLOBAL("runtime.module.cmsisnn.create").set_body([](TVMArgs args, TVMRetValue* rv) {
201-
*rv = CMSISNNModuleNodeCreate(args[0]);
202-
});
46+
TVM_REGISTER_GLOBAL("relay.ext.cmsisnn").set_body_typed(CompileCMSISNN);
20347

204-
} // namespace runtime
48+
} // namespace cmsisnn
49+
} // namespace contrib
50+
} // namespace relay
20551
} // namespace tvm

0 commit comments

Comments
 (0)