Skip to content

Commit 1a78643

Browse files
committed
Migrate C Interface API Generation to C++
Using the new name transformations added in apache#9088, the C interface API is now generated in C++ rather than in Python. Follow up PRs will clean up any remaining name transformation inconsistencies. Fixes apache#8792
1 parent cca7176 commit 1a78643

File tree

11 files changed

+385
-116
lines changed

11 files changed

+385
-116
lines changed

python/tvm/micro/interface_api.py

Lines changed: 0 additions & 101 deletions
This file was deleted.

python/tvm/micro/model_library_format.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
import tarfile
2626
import typing
2727

28+
import tvm
2829
from tvm.ir.type import TupleType
2930
from .._ffi import get_global_func
30-
from .interface_api import generate_c_interface_header
3131
from ..contrib import utils
3232
from ..driver import build_module
3333
from ..runtime import ndarray as _nd
3434
from ..relay.backend import executor_factory
35+
from ..relay.backend.name_transforms import to_c_variable_style, prefix_generated_name
3536
from ..relay import param_dict
3637
from ..tir import expr
3738

@@ -43,6 +44,18 @@ class UnsupportedInModelLibraryFormatError(Exception):
4344
"""Raised when export_model_library_format does not support the given Module tree."""
4445

4546

47+
def generate_c_interface_header(module_name, inputs, outputs, include_path):
48+
"""Generate C Interface header to be included in MLF"""
49+
mangled_name = to_c_variable_style(prefix_generated_name(module_name))
50+
metadata_header = os.path.join(include_path, f"{mangled_name}.h")
51+
52+
interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate")
53+
interface_c_module = interface_c_create(module_name, inputs, outputs)
54+
55+
with open(metadata_header, "w") as header_file:
56+
header_file.write(interface_c_module.get_source())
57+
58+
4659
def _populate_codegen_dir(mod, codegen_dir: str, module_name: str = None):
4760
"""Populate the codegen sub-directory as part of a Model Library Format export.
4861

python/tvm/relay/backend/name_transforms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,18 @@ def to_c_variable_style(original_name: str):
4747
return _backend.ToCVariableStyle(original_name)
4848

4949

50+
def to_c_constant_style(original_name: str):
51+
"""Transform a name to the C constant style assuming it is
52+
appropriately constructed using the prefixing functions
53+
54+
Parameters
55+
----------
56+
original_name : str
57+
Original name to transform
58+
"""
59+
return _backend.ToCConstantStyle(original_name)
60+
61+
5062
def prefix_name(names: Union[List[str], str]):
5163
"""Apply TVM-specific prefix to a function name
5264

src/relay/backend/name_transforms.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ std::string ToCVariableStyle(const std::string& original_name) {
6060
return variable_name;
6161
}
6262

63+
std::string ToCConstantStyle(const std::string& original_name) {
64+
std::string constant_name = ToCVariableStyle(original_name);
65+
66+
std::transform(constant_name.begin(), constant_name.end(), constant_name.begin(), ::toupper);
67+
return constant_name;
68+
}
69+
6370
std::string CombineNames(const Array<String>& names) {
6471
std::stringstream combine_stream;
6572
ICHECK(!names.empty());
@@ -77,22 +84,16 @@ std::string CombineNames(const Array<String>& names) {
7784
std::string SanitiseName(const std::string& name) {
7885
ICHECK(!name.empty());
7986

80-
auto multipleSeparators = [](char before, char after) {
81-
return before == '_' && before == after;
82-
};
8387
auto isNotAlnum = [](char c) { return !std::isalnum(c); };
8488
std::string sanitised_input = name;
8589
std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_');
8690

87-
sanitised_input.erase(
88-
std::unique(sanitised_input.begin(), sanitised_input.end(), multipleSeparators),
89-
sanitised_input.end());
90-
9191
return sanitised_input;
9292
}
9393

9494
TVM_REGISTER_GLOBAL("relay.backend.ToCFunctionStyle").set_body_typed(ToCFunctionStyle);
9595
TVM_REGISTER_GLOBAL("relay.backend.ToCVariableStyle").set_body_typed(ToCVariableStyle);
96+
TVM_REGISTER_GLOBAL("relay.backend.ToCConstantStyle").set_body_typed(ToCConstantStyle);
9697
TVM_REGISTER_GLOBAL("relay.backend.PrefixName").set_body_typed(PrefixName);
9798
TVM_REGISTER_GLOBAL("relay.backend.PrefixGeneratedName").set_body_typed(PrefixGeneratedName);
9899
TVM_REGISTER_GLOBAL("relay.backend.SanitiseName").set_body_typed(SanitiseName);

src/relay/backend/name_transforms.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@
3535
* ToCVariableStyle(PrefixGeneratedName(CombineNames({"model", "Devices"})))
3636
* // tvmgen_model_devices
3737
*
38+
* ToCConstantStyle(PrefixGeneratedName(CombineNames({"model", "Devices"})))
39+
* // TVMGEN_MODEL_DEVICES
40+
*
3841
*/
3942

4043
#include <tvm/runtime/container/array.h>
@@ -68,6 +71,14 @@ std::string ToCFunctionStyle(const std::string& original_name);
6871
*/
6972
std::string ToCVariableStyle(const std::string& original_name);
7073

74+
/*!
75+
* \brief Transform a name to the C constant style assuming it is
76+
* appropriately constructed using the prefixing functions
77+
* \param name Original name
78+
* \return Transformed function in the C constant style
79+
*/
80+
std::string ToCConstantStyle(const std::string& original_name);
81+
7182
/*!
7283
* \brief Combine names together for use as a generated name
7384
* \param names Vector of strings to combine

src/target/source/interface_c.cc

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file interface_c.cc
22+
* \brief Generates a C interface header for a given modules inputs and outputs
23+
*/
24+
25+
#include <tvm/runtime/container/array.h>
26+
#include <tvm/runtime/container/string.h>
27+
#include <tvm/runtime/module.h>
28+
#include <tvm/runtime/packed_func.h>
29+
#include <tvm/runtime/registry.h>
30+
31+
#include <string>
32+
33+
#include "../../relay/backend/name_transforms.h"
34+
35+
namespace tvm {
36+
namespace codegen {
37+
38+
using runtime::PackedFunc;
39+
using namespace tvm::relay::backend;
40+
41+
class InterfaceCNode : public runtime::ModuleNode {
42+
public:
43+
InterfaceCNode(std::string module_name, Array<String> inputs, Array<String> outputs)
44+
: module_name_(module_name), inputs_(inputs), outputs_(outputs) {}
45+
const char* type_key() const { return "h"; }
46+
47+
std::string GetSource(const std::string& format) final {
48+
std::stringstream code;
49+
std::string mangled_module_name = ToCVariableStyle(PrefixGeneratedName({module_name_}));
50+
std::string header_guard_name = ToCConstantStyle(PrefixGeneratedName({module_name_}));
51+
52+
EmitUpperHeaderGuard(code, header_guard_name);
53+
EmitBrief(code, "Input tensor pointers");
54+
EmitStruct(code, mangled_module_name, "inputs", inputs_);
55+
EmitBrief(code, "Output tensor pointers");
56+
EmitStruct(code, mangled_module_name, "outputs", outputs_);
57+
EmitRunFunction(code, mangled_module_name);
58+
EmitLowerHeaderGuard(code, header_guard_name);
59+
60+
return code.str();
61+
}
62+
63+
PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
64+
return PackedFunc(nullptr);
65+
}
66+
67+
private:
68+
void EmitUpperHeaderGuard(std::stringstream& code_stream, const std::string& header_guard_name) {
69+
code_stream << "#ifndef " << header_guard_name << "_H_\n"
70+
<< "#define " << header_guard_name << "_H_\n"
71+
<< "#include <stdint.h>\n\n"
72+
<< "#ifdef __cplusplus\n"
73+
<< "extern \"C\" {\n"
74+
<< "#endif\n\n";
75+
}
76+
77+
void EmitLowerHeaderGuard(std::stringstream& code_stream, const std::string& header_guard_name) {
78+
code_stream << "\n#ifdef __cplusplus\n"
79+
<< "}\n"
80+
<< "#endif\n\n"
81+
<< "#endif // " << header_guard_name << "_H_\n";
82+
}
83+
84+
void EmitBrief(std::stringstream& code_stream, const std::string& description) {
85+
code_stream << "/*!\n"
86+
<< " * \\brief " << description << " for TVM module \"" << module_name_ << "\" \n"
87+
<< " */\n";
88+
}
89+
90+
void EmitStruct(std::stringstream& code_stream, const std::string& mangled_module_name,
91+
const std::string& suffix, Array<String> properties) {
92+
code_stream << "struct " << mangled_module_name << "_" << suffix << " {\n";
93+
94+
std::vector<std::string> sanitised_properties;
95+
for (const String& property : properties) {
96+
std::string sanitised_property = SanitiseName(property);
97+
ICHECK(std::find(sanitised_properties.begin(), sanitised_properties.end(),
98+
sanitised_property) == sanitised_properties.end())
99+
<< "Sanitized input tensor name clash" << sanitised_property;
100+
code_stream << " void* " << sanitised_property << ";\n";
101+
sanitised_properties.push_back(sanitised_property);
102+
}
103+
code_stream << "};\n\n";
104+
}
105+
106+
void EmitRunFunction(std::stringstream& code_stream, const std::string& mangled_module_name) {
107+
code_stream << "/*!\n"
108+
<< " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n"
109+
<< " * \\param inputs Input tensors for the module \n"
110+
<< " * \\param outputs Output tensors for the module \n"
111+
<< " */\n"
112+
<< "int32_t " << mangled_module_name << "_run(\n"
113+
<< " struct " << mangled_module_name << "_inputs* inputs,\n"
114+
<< " struct " << mangled_module_name << "_outputs* outputs\n"
115+
<< ");\n";
116+
}
117+
118+
std::string module_name_;
119+
Array<String> inputs_;
120+
Array<String> outputs_;
121+
};
122+
123+
runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
124+
Array<String> outputs) {
125+
auto n = make_object<InterfaceCNode>(module_name, inputs, outputs);
126+
return runtime::Module(n);
127+
}
128+
129+
TVM_REGISTER_GLOBAL("runtime.InterfaceCCreate").set_body_typed(InterfaceCCreate);
130+
131+
} // namespace codegen
132+
} // namespace tvm

tests/cpp/name_transforms_test.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ TEST(NameTransforms, ToCVariableStyle) {
4040
EXPECT_THROW(ToCVariableStyle(""), InternalError);
4141
}
4242

43+
TEST(NameTransforms, ToCConstantStyle) {
44+
ASSERT_EQ(ToCConstantStyle("TVM_Woof"), "TVM_WOOF");
45+
ASSERT_EQ(ToCConstantStyle("TVM_woof"), "TVM_WOOF");
46+
ASSERT_EQ(ToCConstantStyle("TVM_woof_Woof"), "TVM_WOOF_WOOF");
47+
EXPECT_THROW(ToCConstantStyle(""), InternalError);
48+
}
49+
4350
TEST(NameTransforms, PrefixName) {
4451
ASSERT_EQ(PrefixName({"Woof"}), "TVM_Woof");
4552
ASSERT_EQ(PrefixName({"woof"}), "TVM_woof");
@@ -69,10 +76,10 @@ TEST(NameTransforms, CombineNames) {
6976
}
7077

7178
TEST(NameTransforms, SanitiseName) {
72-
ASSERT_EQ(SanitiseName("+_+ "), "_");
79+
ASSERT_EQ(SanitiseName("+_+ "), "____");
7380
ASSERT_EQ(SanitiseName("input+"), "input_");
7481
ASSERT_EQ(SanitiseName("input-"), "input_");
75-
ASSERT_EQ(SanitiseName("input++"), "input_");
82+
ASSERT_EQ(SanitiseName("input++"), "input__");
7683
ASSERT_EQ(SanitiseName("woof:1"), "woof_1");
7784
EXPECT_THROW(SanitiseName(""), InternalError);
7885
}

0 commit comments

Comments
 (0)