|
16 | 16 | * specific language governing permissions and limitations |
17 | 17 | * under the License. |
18 | 18 | */ |
19 | | -#include <cmath> |
20 | | -#include <fstream> |
21 | | -#include <map> |
22 | | -#include <sstream> |
23 | | -#include <string> |
24 | | -#include <vector> |
25 | | - |
26 | | -#include "../../../../runtime/file_utils.h" |
27 | | -#include "../../../../target/source/codegen_c.h" |
28 | | -#include "../../../qnn/utils.h" |
| 19 | +#include <tvm/relay/transform.h> |
| 20 | +#include <tvm/runtime/module.h> |
| 21 | +#include <tvm/runtime/registry.h> |
29 | 22 |
|
30 | 23 | namespace tvm { |
31 | | -namespace runtime { |
32 | | - |
33 | | -using namespace tir; |
34 | | - |
35 | | -class CodeGenCMSISNN : public tvm::codegen::CodeGenC { |
36 | | - public: |
37 | | - void Init(bool output_ssa) { |
38 | | - decl_stream << "#include <stdio.h>\n"; |
39 | | - decl_stream << "#include <stdlib.h>\n"; |
40 | | - decl_stream << "#include <dlpack/dlpack.h>\n"; |
41 | | - decl_stream << "#include <tvm/runtime/crt/module.h>\n"; |
42 | | - decl_stream << "#include <arm_nnfunctions.h>\n"; |
43 | | - CodeGenC::Init(output_ssa); |
44 | | - } |
45 | | - |
46 | | - /*! |
47 | | - * \brief Emit code that offloads a subgraph to the Cortex-M |
48 | | - * |
49 | | - * \return string of code that offloads a subgraph to the Cortex-M |
50 | | - */ |
51 | | - void AddFunction(const PrimFunc& prim_func) { |
52 | | - PrintExternCPrefix(stream); |
53 | | - CodeGenC::AddFunction(prim_func); |
54 | | - PrintExternCPostfix(stream); |
55 | | - } |
56 | | - |
57 | | - private: |
58 | | - void VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) |
59 | | - if (!op->op.same_as(builtin::call_extern())) { |
60 | | - return; |
61 | | - } |
62 | | - std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value; |
63 | | - if (cmsis_func_name == "arm_softmax_s8") { |
64 | | - EmitSoftmax(op); |
65 | | - } |
66 | | - return; |
67 | | - } |
68 | | - |
69 | | - /*! * \brief Creates a cplusplus guard prefix for extern "C" printing */ |
70 | | - void PrintExternCPrefix(std::ostringstream& ss) { |
71 | | - PrintIndent(); |
72 | | - ss << "#ifdef __cplusplus\n"; |
73 | | - ss << "extern \"C\" {\n"; |
74 | | - ss << "#endif\n"; |
75 | | - } |
76 | | - |
77 | | - /*! * \brief Creates a cplusplus guard postfix for extern "C" printing */ |
78 | | - void PrintExternCPostfix(std::ostringstream& ss) { |
79 | | - PrintIndent(); |
80 | | - ss << "#ifdef __cplusplus\n"; |
81 | | - ss << "}\n"; |
82 | | - ss << "#endif\n"; |
83 | | - } |
84 | | - |
85 | | - /*! * \brief Emits CMSIS-NN code block for softmax */ |
86 | | - void EmitSoftmax(const CallNode* op) { |
87 | | - // @tir.call_extern("arm_softmax_s8", buffer_0, num_rows, row_size, scale, buffer_1, dtype=int8) |
88 | | - std::string cmsis_func_name = op->args[0].as<StringImmNode>()->value; |
89 | | - int32_t num_rows = op->args[2].as<IntImmNode>()->value; |
90 | | - int32_t row_size = op->args[3].as<IntImmNode>()->value; |
91 | | - float quant_scale = op->args[4].as<FloatImmNode>()->value; |
92 | | - |
93 | | - // calculate multiplier and shift for CMSIS-NN softmax API |
94 | | - // Note: tfl micro assumptions |
95 | | - // TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt8); |
96 | | - // TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); |
97 | | - // TF_LITE_ENSURE(context, output->params.scale == 1.f / 256); |
98 | | - double beta = 1.0; |
99 | | - int32_t input_bits = 5; |
100 | | - double beta_multiplier = (beta * quant_scale * (1 << (31 - input_bits))); |
101 | | - beta_multiplier = std::min<double>(beta_multiplier, (1ll << 31) - 1.0); |
102 | | - auto mult_shift_pair = tvm::relay::qnn::GetFixedPointMultiplierShift(beta_multiplier); |
103 | | - int32_t mult = std::get<0>(mult_shift_pair); |
104 | | - int32_t shift = std::get<1>(mult_shift_pair); |
105 | | - int32_t diff_min = (1 << 5) - 1; |
106 | | - diff_min <<= (31 - 5); |
107 | | - diff_min >>= shift; |
108 | | - diff_min *= -1; |
109 | | - |
110 | | - PrintIndent(); |
111 | | - stream << "int32_t num_rows = " << num_rows << ";\n"; |
112 | | - PrintIndent(); |
113 | | - stream << "int32_t row_size = " << row_size << ";\n"; |
114 | | - PrintIndent(); |
115 | | - stream << "int32_t mult = " << mult << ";\n"; |
116 | | - PrintIndent(); |
117 | | - stream << "int32_t shift = " << shift << ";\n"; |
118 | | - PrintIndent(); |
119 | | - stream << "int32_t diff_min = " << diff_min << ";\n"; |
120 | | - PrintIndent(); |
121 | | - stream << cmsis_func_name << "(buffer,"; |
122 | | - PrintIndent(); |
123 | | - stream << " num_rows, row_size, mult, shift, diff_min, buffer1);\n"; |
124 | | - PrintIndent(); |
125 | | - stream << "return;\n"; |
126 | | - } |
127 | | -}; |
128 | | - |
129 | | -class CMSISNNModuleNode : public runtime::ModuleNode { |
130 | | - public: |
131 | | - CMSISNNModuleNode(const std::string& code, const std::string& fmt, |
132 | | - const Array<String>& func_names) |
133 | | - : code_(code), fmt_(fmt), func_names_(func_names) {} |
134 | | - |
135 | | - std::string GetSource(const std::string& format) final { return code_; } |
136 | | - |
137 | | - const char* type_key() const { return "c"; } |
138 | | - |
139 | | - PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { |
140 | | - if (name == "get_symbol") { |
141 | | - return PackedFunc( |
142 | | - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; }); |
143 | | - } else if (name == "get_func_names") { |
144 | | - return PackedFunc( |
145 | | - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); |
146 | | - } else { |
147 | | - return PackedFunc(nullptr); |
148 | | - } |
149 | | - } |
150 | | - |
151 | | - void SaveToFile(const std::string& file_name, const std::string& format) final { |
152 | | - std::string fmt = GetFileFormat(file_name, format); |
153 | | - std::string meta_file = GetMetaFilePath(file_name); |
154 | | - if (fmt == "c" || fmt == "cu") { |
155 | | - ICHECK_NE(code_.length(), 0); |
156 | | - SaveBinaryToFile(file_name, code_); |
157 | | - } else { |
158 | | - ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; |
159 | | - } |
160 | | - } |
161 | | - |
162 | | - protected: |
163 | | - std::string code_; |
164 | | - std::string fmt_; |
165 | | - Array<String> func_names_; |
166 | | -}; |
167 | | - |
168 | | -class CMSISNNModule : public Module { |
169 | | - public: |
170 | | - CMSISNNModule() {} |
171 | | - explicit CMSISNNModule(ObjectPtr<Object> n) : Module(n) {} |
172 | | - inline CMSISNNModuleNode* operator->(); |
173 | | - inline const CMSISNNModuleNode* operator->() const; |
174 | | -}; |
175 | | - |
176 | | -inline CMSISNNModuleNode* CMSISNNModule::operator->() { |
177 | | - return static_cast<CMSISNNModuleNode*>(get_mutable()); |
178 | | -} |
179 | | - |
180 | | -static Module CMSISNNModuleNodeCreate(IRModule mod) { |
181 | | - bool output_ssa = false; |
182 | | - CodeGenCMSISNN cg; |
183 | | - Array<String> function_names; |
184 | | - cg.Init(output_ssa); |
185 | | - ICHECK(mod->functions.size() == 1) << "Supports modules with single PrimFunc."; |
186 | | - for (auto kv : mod->functions) { |
187 | | - ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc"; |
188 | | - auto f = Downcast<PrimFunc>(kv.second); |
189 | | - auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); |
190 | | - ICHECK(global_symbol.defined()) |
191 | | - << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; |
192 | | - function_names.push_back(global_symbol.value()); |
193 | | - cg.AddFunction(f); |
194 | | - } |
195 | | - std::string code = cg.Finish(); |
196 | | - auto n = make_object<CMSISNNModuleNode>(code, "c", function_names); |
197 | | - return Module(n); |
| 24 | +namespace relay { |
| 25 | +namespace contrib { |
| 26 | +namespace cmsisnn { |
| 27 | + |
| 28 | +transform::Pass RelayToTIR(); |
| 29 | + |
| 30 | +runtime::Module CompileCMSISNN(const ObjectRef& ref) { |
| 31 | + IRModule relay_mod; |
| 32 | + Function relay_func = Downcast<Function>(ref); |
| 33 | + auto func_name = relay_func->GetAttr<String>(tvm::attr::kGlobalSymbol); |
| 34 | + GlobalVar var = GlobalVar(func_name.value()); |
| 35 | + relay_mod->Add(var, relay_func); |
| 36 | + relay_mod = transform::InferType()(relay_mod); |
| 37 | + |
| 38 | + Array<transform::Pass> pass_seqs{transform::InferType(), RelayToTIR()}; |
| 39 | + transform::Sequential seq(pass_seqs); |
| 40 | + IRModule tir_mod = seq(relay_mod); |
| 41 | + |
| 42 | + const auto* pf = runtime::Registry::Get("runtime.CMSISNNModuleNodeCreate"); |
| 43 | + return (*pf)(tir_mod); |
198 | 44 | } |
199 | 45 |
|
200 | | -TVM_REGISTER_GLOBAL("runtime.module.cmsisnn.create").set_body([](TVMArgs args, TVMRetValue* rv) { |
201 | | - *rv = CMSISNNModuleNodeCreate(args[0]); |
202 | | -}); |
| 46 | +TVM_REGISTER_GLOBAL("relay.ext.cmsisnn").set_body_typed(CompileCMSISNN); |
203 | 47 |
|
204 | | -} // namespace runtime |
| 48 | +} // namespace cmsisnn |
| 49 | +} // namespace contrib |
| 50 | +} // namespace relay |
205 | 51 | } // namespace tvm |
0 commit comments