Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/tvm/ir/memory_pools.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,14 @@ class PoolInfoProperties : public ObjectRef {

/* \brief Represents RW memory area */
struct WorkspacePoolInfoNode : public PoolInfoNode {
void VisitAttrs(tvm::AttrVisitor* v) { PoolInfoNode::VisitAttrs(v); }

bool SEqualReduce(const WorkspacePoolInfoNode* other, SEqualReducer equal) const {
return PoolInfoNode::SEqualReduce(other, equal);
}

void SHashReduce(SHashReducer hash_reduce) const { PoolInfoNode::SHashReduce(hash_reduce); }

static constexpr const char* _type_key = "ir.WorkspacePoolInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(WorkspacePoolInfoNode, PoolInfoNode);
};
Expand Down Expand Up @@ -275,6 +283,22 @@ class ConstantInfo : public ObjectRef {
* data from constant_info_array */
struct ConstantPoolInfoNode : public PoolInfoNode {
Array<ConstantInfo> constant_info_array;

void VisitAttrs(tvm::AttrVisitor* v) {
PoolInfoNode::VisitAttrs(v);
v->Visit("constant_info_array", &constant_info_array);
}

bool SEqualReduce(const ConstantPoolInfoNode* other, SEqualReducer equal) const {
return PoolInfoNode::SEqualReduce(other, equal) &&
equal(constant_info_array, other->constant_info_array);
}

void SHashReduce(SHashReducer hash_reduce) const {
PoolInfoNode::SHashReduce(hash_reduce);
hash_reduce(constant_info_array);
}

static constexpr const char* _type_key = "ir.ConstantPoolInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPoolInfoNode, PoolInfoNode);
};
Expand Down
25 changes: 18 additions & 7 deletions python/tvm/testing/aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,11 @@ def _emit_main_workspace_pool_structs(main_file, workspace_pool_names, mod_name)
f"struct {_mangle_name(mod_name, 'workspace_pools')} "
f"{_mangle_name(mod_name, 'workspace_pools')} = {{"
)
for workspace_pool_name in workspace_pool_names:
main_file.write(f"\t.{workspace_pool_name} = {workspace_pool_name},\n")
for workspace_pool_name in workspace_pool_names.keys():
main_file.write(
f"\t.{workspace_pool_name} = {workspace_pool_names[workspace_pool_name]}"
f"{workspace_pool_name},\n"
)
main_file.write("};\n")


Expand Down Expand Up @@ -507,19 +510,27 @@ def _create_main(
compiled_model.executor_factory.executor_codegen_metadata
)
devices = compiled_model.executor_factory.get_devices()
workspace_pool_names = None
workspace_pool_names = {}
if executor_codegen_metadata.pool_inputs:
workspace_pool_names = [
allocated_pool.pool_info.pool_name
workspace_pool_names = {
allocated_pool.pool_info.pool_name: "&"
if isinstance(
allocated_pool.pool_info, tvm.ir.memory_pools.ConstantPoolInfo
)
else ""
for allocated_pool in dict(executor_codegen_metadata.pool_inputs).values()
if not allocated_pool.pool_info.is_internal
]
}
_emit_main_device_structs(main_file, devices, model.name)
if not use_workspace_io:
_emit_main_workspace_pool_structs(main_file, workspace_pool_names, model.name)
_emit_main_data_structs(main_file, model.inputs, model.outputs, model.name)
_emit_main_c_interface_call(
main_file, devices, workspace_pool_names, model.name, use_workspace_io
main_file,
devices,
list(workspace_pool_names.keys()),
model.name,
use_workspace_io,
)
else:
_emit_main_fake_packed_values(main_file)
Expand Down
2 changes: 0 additions & 2 deletions src/runtime/crt/microtvm_rpc_server/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ class MicroRPCServer {
rpc_server_{&io_},
is_running_{true} {}

void* operator new(size_t count, void* ptr) { return ptr; }

void Initialize() {
uint8_t initial_session_nonce = Session::kInvalidNonce;
tvm_crt_error_t error =
Expand Down
41 changes: 22 additions & 19 deletions src/target/source/codegen_params.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ static int ComputeNumElementsPerRow(int one_element_size_bytes, int indent_chars
}

template <typename T, typename Enable = std::enable_if<std::is_integral<T>::value>>
void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std::ostream& os) {
void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std::ostream& os,
const std::string& eol) {
int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */);
if (std::is_signed<T>::value) {
one_element_size_bytes += 1; // sign character
Expand Down Expand Up @@ -97,17 +98,18 @@ void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std::
os << ", ";
}
if ((i % elements_per_row) == elements_per_row - 1) {
os << "\n";
os << eol;
}
}

if ((num_elements % elements_per_row) != 0) {
os << "\n";
os << eol;
}
}

template <typename T, typename Enable = std::enable_if<std::is_floating_point<T>::value>>
void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, std::ostream& os) {
void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, std::ostream& os,
const std::string& eol) {
// Floats and doubles are printed as hex but casted.
int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */) + 1 /* sign */ +
1 /* decimal point */ + 1 /* exponent sign */;
Expand Down Expand Up @@ -149,16 +151,17 @@ void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars,
os << ", ";
}
if ((i % elements_per_row) == elements_per_row - 1) {
os << "\n";
os << eol;
}
}

if ((num_elements % elements_per_row) != 0) {
os << "\n";
os << eol;
}
}

void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os) {
void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os,
const std::string& eol) {
auto arr_type = arr.DataType();
CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw "
<< arr_type.lanes();
Expand All @@ -180,13 +183,13 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream&
<< "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw "
<< arr_type.bits() << "-bit array";
if (arr_type.bits() == 8) {
PrintIntegralArray<int8_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<int8_t>(arr->data, num_elements, indent_chars, os, eol);
} else if (arr_type.bits() == 16) {
PrintIntegralArray<int16_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<int16_t>(arr->data, num_elements, indent_chars, os, eol);
} else if (arr_type.bits() == 32) {
PrintIntegralArray<int32_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<int32_t>(arr->data, num_elements, indent_chars, os, eol);
} else if (arr_type.bits() == 64) {
PrintIntegralArray<int64_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<int64_t>(arr->data, num_elements, indent_chars, os, eol);
} else {
CHECK(false) << "should not get here";
}
Expand All @@ -199,13 +202,13 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream&
<< arr_type.bits() << "-bit array";

if (arr_type.bits() == 8) {
PrintIntegralArray<uint8_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<uint8_t>(arr->data, num_elements, indent_chars, os, eol);
} else if (arr_type.bits() == 16) {
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol);
} else if (arr_type.bits() == 32) {
PrintIntegralArray<uint32_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<uint32_t>(arr->data, num_elements, indent_chars, os, eol);
} else if (arr_type.bits() == 64) {
PrintIntegralArray<uint64_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<uint64_t>(arr->data, num_elements, indent_chars, os, eol);
} else {
CHECK(false) << "should not get here";
}
Expand All @@ -216,11 +219,11 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream&
os.setf(std::ios::left, std::ios::adjustfield);
if (arr_type.bits() == 16) {
// NOTE: print types not widely supported by C as uint16_t.
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol);
} else if (arr_type.bits() == 32) {
PrintFloatingPointArray<float>(arr->data, num_elements, indent_chars, os);
PrintFloatingPointArray<float>(arr->data, num_elements, indent_chars, os, eol);
} else if (arr_type.bits() == 64) {
PrintFloatingPointArray<double>(arr->data, num_elements, indent_chars, os);
PrintFloatingPointArray<double>(arr->data, num_elements, indent_chars, os, eol);
} else {
CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw "
<< arr_type.bits() << "-bit array";
Expand All @@ -233,7 +236,7 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream&
CHECK(arr_type.bits() == 16)
<< "CodegenParams: only support generating 16-bit bfloat params; saw " << arr_type.bits()
<< "-bit array";
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os);
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol);
break;
}

Expand Down
4 changes: 3 additions & 1 deletion src/target/source/codegen_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/ndarray.h>

#include <iostream>
#include <string>

namespace tvm {
namespace codegen {
Expand All @@ -44,7 +45,8 @@ namespace codegen {
* \param indent_chars Number of chars to indent
* \param os Output stream where the array data should be written.
*/
void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os);
void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os,
const std::string& eol = "\n");

} // namespace codegen
} // namespace tvm
Expand Down
52 changes: 50 additions & 2 deletions src/target/source/interface_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
#include <tvm/runtime/registry.h>
#include <tvm/tir/usmp/utils.h>

#include <numeric>
#include <string>

#include "../../relay/backend/name_transforms.h"
#include "codegen_params.h"

namespace tvm {
namespace codegen {
Expand Down Expand Up @@ -90,8 +92,13 @@ class InterfaceCNode : public runtime::ModuleNode {
for (const tir::usmp::AllocatedPoolInfo pool : pools_) {
String pool_name = pool->pool_info->pool_name;
Integer pool_size = pool->allocated_size;
EmitIntegerValueMacro(code, SanitizeName(pool_name) + " size",
SanitizeName(pool_name) + "_WORKSPACE_POOL_SIZE", pool_size->value);
if (const auto* pool_info = pool->pool_info.as<ConstantPoolInfoNode>()) {
EmitConstantPool(code, SanitizeName(pool_name) + " initialization data", pool_info);
} else {
EmitIntegerValueMacro(code, SanitizeName(pool_name) + " size",
SanitizeName(pool_name) + _macro_workspace_pool_size_postfix,
pool_size->value);
}
}
EmitLowerHeaderGuard(code);

Expand All @@ -103,6 +110,10 @@ class InterfaceCNode : public runtime::ModuleNode {
}

private:
constexpr static const char* _macro_workspace_pool_size_postfix = "_WORKSPACE_POOL_SIZE";
constexpr static const char* _macro_constant_pool_size_postfix = "_CONSTANT_POOL_SIZE";
constexpr static const char* _macro_constant_pool_data_postfix = "_CONSTANT_POOL_DATA";

void EmitUpperHeaderGuard(std::stringstream& code_stream) {
std::string header_guard_name = ToCConstantStyle(PrefixGeneratedName({module_name_, "H"}));
code_stream << "#ifndef " << header_guard_name << "_\n"
Expand Down Expand Up @@ -152,6 +163,43 @@ class InterfaceCNode : public runtime::ModuleNode {
code_stream << "#define " << macro_name_prefixed << " " << macro_value << "\n";
}

void EmitConstantPool(std::stringstream& code_, const std::string& brief_description,
const ConstantPoolInfoNode* pool_info) {
EmitBrief(code_, brief_description);
std::string name_prefixed =
ToCConstantStyle(PrefixGeneratedName({module_name_, SanitizeName(pool_info->pool_name)}));

if (pool_info->constant_info_array.size() > 0) {
std::vector<ConstantInfo> const_info_vec(pool_info->constant_info_array.begin(),
pool_info->constant_info_array.end());
std::sort(const_info_vec.begin(), const_info_vec.end(),
[](const ConstantInfo& a, const ConstantInfo& b) {
return a->byte_offset->value < b->byte_offset->value;
});
int64_t accumulated_pool_len =
const_info_vec.back()->byte_offset +
runtime::GetDataSize(*const_info_vec.back()->data.operator->());
const auto& accumulated_pool = runtime::NDArray::Empty(
{accumulated_pool_len}, DataType::UInt(8), const_info_vec.back()->data->device);
for (const auto& const_info : const_info_vec) {
const auto& data = const_info->data;
const auto& offs = const_info->byte_offset;
data.CopyToBytes(static_cast<uint8_t*>(accumulated_pool->data) + offs,
runtime::GetDataSize(*data.operator->()));
}

code_ << "#define " << name_prefixed << _macro_constant_pool_size_postfix << " "
<< accumulated_pool_len << "\n";
code_ << "#define " << name_prefixed << _macro_constant_pool_data_postfix << " \\\n";
codegen::NDArrayDataToC(accumulated_pool, 4, code_, "\\\n");
code_ << '\n';

} else {
LOG(FATAL) << "No constant data in constant pool found "
<< PrettyPrint(GetRef<ObjectRef>(pool_info));
}
}

void EmitRunFunction(std::stringstream& code_stream) {
std::string run_function = ToCVariableStyle(PrefixGeneratedName({module_name_, "run"}));
std::string inputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "inputs"}));
Expand Down
8 changes: 4 additions & 4 deletions src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
}

void GenerateConstantBuffer(const ConstantPoolInfoNode* pool_info, size_t allocated_size) {
size_t offset = 0;
size_t ord = 0;
if (pool_info->constant_info_array.size() > 0) {
// Pool is RO, form an initialized struct
code_ << "__attribute__((section(\".rodata.tvm\"), ";
Expand All @@ -312,8 +312,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
std::multiplies<int64_t>());
code_ << " ";
codegen_c_base_.PrintType(data.DataType(), code_);
code_ << " " << const_info->name_hint << "[" << num_elements
<< "] __attribute__((packed, aligned(" << metadata_->constant_alignment << ")));";
code_ << " " << const_info->name_hint << "[" << num_elements << "] __attribute__(("
<< (ord++ ? "packed, " : "") << "aligned(" << metadata_->constant_alignment << ")));";
code_ << " // " << num_elements * data.DataType().bytes()
<< " bytes, aligned offset: " << offs << "\n";
}
Expand All @@ -326,7 +326,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
code_ << " },\n";
}
code_ << "};";
code_ << "// of total size " << allocated_size << " bytes, aligned: " << offset << " bytes\n";
code_ << "// of total size " << allocated_size << " bytes\n";
} else {
LOG(FATAL) << "No constant data in constant pool found "
<< PrettyPrint(GetRef<ObjectRef>(pool_info));
Expand Down
43 changes: 43 additions & 0 deletions tests/cpp/target/source/interface_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/runtime/module.h>
#include <tvm/tir/usmp/utils.h>

using ::testing::ContainsRegex;
using ::testing::HasSubstr;

namespace tvm {
Expand Down Expand Up @@ -126,6 +127,48 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePools) {
ASSERT_THAT(header_source, HasSubstr(run_function.str()));
}

TEST(InterfaceAPI, ContainsRunFunctionWithWorkspaceAndConstantPools) {
std::stringstream run_function;

run_function << "/*!\n"
<< " * \\brief entrypoint function for TVM module \"ultimate_cat_spotter\"\n"
<< " * \\param inputs Input tensors for the module \n"
<< " * \\param outputs Output tensors for the module \n"
<< " * \\param workspace_pools Workspace memory pool pointers for the module \n"
<< " */\n"
<< "int32_t tvmgen_ultimate_cat_spotter_run(\n"
<< " struct tvmgen_ultimate_cat_spotter_inputs* inputs,\n"
<< " struct tvmgen_ultimate_cat_spotter_outputs* outputs,\n"
<< " struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n"
<< ");\n";

PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool", {});
PoolInfo const_info = ConstantPoolInfo(
"my_constant_pool", {},
{{"const1", 0, runtime::NDArray::Empty({1}, DataType::Int(32), {kDLCPU, 0})},
{"const2", 16, runtime::NDArray::Empty({1}, DataType::Float(64), {kDLCPU, 0})}});
tir::usmp::AllocatedPoolInfo allocated_pool_info =
tir::usmp::AllocatedPoolInfo(pool_info, 100000);
tir::usmp::AllocatedPoolInfo allocated_const_info =
tir::usmp::AllocatedPoolInfo(const_info, 100000);
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
{allocated_pool_info, allocated_const_info}, {}, {}, 0);
std::string header_source = test_module->GetSource();
ASSERT_THAT(header_source, HasSubstr(run_function.str()));
ASSERT_THAT(
header_source,
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_CONSTANT_POOL_CONSTANT_POOL_SIZE 24"));
ASSERT_THAT(
header_source,
ContainsRegex(
"#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_CONSTANT_POOL_CONSTANT_POOL_DATA \\\\\\\n "
"0x\\w\\w, 0x\\w\\w, 0x\\w\\w, 0x\\w\\w, 0x\\w\\w, 0x\\w\\w, 0x\\w\\w, 0x\\w\\w, "
"0x\\w\\w, 0x\\w\\w, 0x\\w\\w, 0x\\w\\w, 0x\\w\\w, "
"0x\\w\\w, 0x\\w\\w, 0x\\w\\w, \\\\\\\n 0x\\w\\w, 0x\\w\\w, 0x\\w\\w, 0x\\w\\w, "
"0x\\w\\w, 0x\\w\\w, 0x\\w\\w, 0x\\w\\w\\\\\\\n"));
}

TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePoolsAndDevices) {
std::stringstream run_function;

Expand Down
Loading