Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions src/relay/backend/name_transforms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,27 +29,34 @@ namespace tvm {
namespace relay {
namespace backend {

std::string ToCFunctionStyle(const std::string& original_name) {
ICHECK(!original_name.empty()) << "Function name is empty";
ICHECK_EQ(original_name.find("TVM"), 0) << "Function not TVM prefixed";

int tvm_prefix_length = 3;
std::string function_name("TVM");
std::string ToCamel(const std::string& original_name) {
std::string camel_name;
camel_name.reserve(original_name.size());

bool new_block = true;
for (const char& symbol : original_name.substr(tvm_prefix_length)) {
for (const char& symbol : original_name) {
if (std::isalpha(symbol)) {
if (new_block) {
function_name.push_back(std::toupper(symbol));
camel_name.push_back(std::toupper(symbol));
new_block = false;
} else {
function_name.push_back(std::tolower(symbol));
camel_name.push_back(std::tolower(symbol));
}
} else if (symbol == '_') {
new_block = true;
}
}
return function_name;
return camel_name;
}

std::string ToCFunctionStyle(const std::string& original_name) {
ICHECK(!original_name.empty()) << "Function name is empty";
ICHECK_EQ(original_name.find("TVM"), 0) << "Function not TVM prefixed";

int tvm_prefix_length = 3;
std::string function_prefix("TVM");

return function_prefix + ToCamel(original_name.substr(tvm_prefix_length));
}

std::string ToCVariableStyle(const std::string& original_name) {
Expand All @@ -71,6 +78,30 @@ std::string ToCConstantStyle(const std::string& original_name) {
return constant_name;
}

std::string ToRustStructStyle(const std::string& original_name) {
ICHECK(!original_name.empty()) << "Struct name is empty";
return ToCamel(original_name);
}

std::string ToRustMacroStyle(const std::string& original_name) {
ICHECK(!original_name.empty()) << "Macro name is empty";

std::string macro_name;
macro_name.resize(original_name.size());

std::transform(original_name.begin(), original_name.end(), macro_name.begin(), ::tolower);
return macro_name;
}

std::string ToRustConstantStyle(const std::string& original_name) {
ICHECK(!original_name.empty()) << "Constant name is empty";
std::string constant_name;
constant_name.resize(original_name.size());

std::transform(original_name.begin(), original_name.end(), constant_name.begin(), ::toupper);
return constant_name;
}

std::string CombineNames(const Array<String>& names) {
std::stringstream combine_stream;
ICHECK(!names.empty()) << "Name segments empty";
Expand Down
24 changes: 24 additions & 0 deletions src/relay/backend/name_transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@ std::string ToCVariableStyle(const std::string& original_name);
*/
std::string ToCConstantStyle(const std::string& original_name);

/*!
* \brief Transform a name to the Rust struct style assuming it is
* appropriately constructed using the combining functions
* \param name Original name
* \return Transformed function in the Rust struct style
*/
std::string ToRustStructStyle(const std::string& original_name);

/*!
* \brief Transform a name to the Rust macro style assuming it is
* appropriately constructed using the combining functions
* \param name Original name
* \return Transformed function in the Rust macro style
*/
std::string ToRustMacroStyle(const std::string& original_name);

/*!
* \brief Transform a name to the Rust constant style assuming it is
* appropriately constructed using the combining functions
* \param name Original name
* \return Transformed function in the Rust constant style
*/
std::string ToRustConstantStyle(const std::string& original_name);

/*!
* \brief Combine names together for use as a generated name
* \param names Vector of strings to combine
Expand Down
50 changes: 48 additions & 2 deletions tests/cpp/name_transforms_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,20 @@
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/name_transforms.h>

using namespace tvm::relay::backend;
namespace tvm {
namespace relay {
namespace backend {

using namespace tvm::runtime;

std::string ToCamel(const std::string& original_name);

TEST(NameTransforms, ToCFunctionStyle) {
ASSERT_EQ(ToCFunctionStyle("TVM_Woof"), "TVMWoof");
ASSERT_EQ(ToCFunctionStyle("TVM_woof"), "TVMWoof");
ASSERT_EQ(ToCFunctionStyle("TVM_woof_woof"), "TVMWoofWoof");
ASSERT_EQ(ToCFunctionStyle("TVMGen_woof_woof"), "TVMGenWoofWoof");
EXPECT_THROW(ToCVariableStyle("Cake_Bakery"), InternalError); // Incorrect prefix
EXPECT_THROW(ToCFunctionStyle("Cake_Bakery"), InternalError); // Incorrect prefix
EXPECT_THROW(ToCFunctionStyle(""), InternalError);
}

Expand All @@ -51,6 +56,27 @@ TEST(NameTransforms, ToCConstantStyle) {
EXPECT_THROW(ToCConstantStyle(""), InternalError);
}

TEST(NameTransforms, ToRustStructStyle) {
ASSERT_EQ(ToRustStructStyle("Woof"), "Woof");
ASSERT_EQ(ToRustStructStyle("woof"), "Woof");
ASSERT_EQ(ToRustStructStyle("woof_woof"), "WoofWoof");
EXPECT_THROW(ToRustStructStyle(""), InternalError);
}

TEST(NameTransforms, ToRustMacroStyle) {
ASSERT_EQ(ToRustMacroStyle("Woof"), "woof");
ASSERT_EQ(ToRustMacroStyle("woof"), "woof");
ASSERT_EQ(ToRustMacroStyle("woof_Woof"), "woof_woof");
EXPECT_THROW(ToRustMacroStyle(""), InternalError);
}

TEST(NameTransforms, ToRustConstantStyle) {
ASSERT_EQ(ToRustConstantStyle("Woof"), "WOOF");
ASSERT_EQ(ToRustConstantStyle("woof"), "WOOF");
ASSERT_EQ(ToRustConstantStyle("woof_Woof"), "WOOF_WOOF");
EXPECT_THROW(ToRustConstantStyle(""), InternalError);
}

TEST(NameTransforms, PrefixName) {
ASSERT_EQ(PrefixName({"Woof"}), "TVM_Woof");
ASSERT_EQ(PrefixName({"woof"}), "TVM_woof");
Expand Down Expand Up @@ -94,3 +120,23 @@ TEST(NameTransforms, CombinedLogic) {
ASSERT_EQ(ToCVariableStyle(PrefixName({"Device", "target", "t"})), "tvm_device_target_t");
ASSERT_EQ(ToCVariableStyle(PrefixGeneratedName({"model", "Devices"})), "tvmgen_model_devices");
}

TEST(NameTransforms, Internal_ToCamel) {
ASSERT_EQ(ToCamel("Woof"), "Woof");
ASSERT_EQ(ToCamel("woof"), "Woof");
ASSERT_EQ(ToCamel("woof_woof"), "WoofWoof");
}

TEST(NameTransforms, Internal_ToCamel_Allocation) {
std::string woof = "Woof_woof_woof_woof";
std::string camel = ToCamel(woof);
std::string check;
check.reserve(woof.size());

// Check that the pre-allocation happens
ASSERT_EQ(camel.capacity(), check.capacity());
}

} // namespace backend
} // namespace relay
} // namespace tvm