Skip to content

Commit cc24d4b

Browse files
committed
U3, WIP
Change-Id: Ibc088f19ad1dc9466fc368f8523baa30ee88b7d0
1 parent c80da03 commit cc24d4b

File tree

7 files changed

+311
-45
lines changed

7 files changed

+311
-45
lines changed

include/tvm/ir/memory_pools.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,14 @@ class PoolInfoProperties : public ObjectRef {
220220

221221
/* \brief Represents RW memory area */
222222
struct WorkspacePoolInfoNode : public PoolInfoNode {
223+
void VisitAttrs(tvm::AttrVisitor* v) { PoolInfoNode::VisitAttrs(v); }
224+
225+
bool SEqualReduce(const WorkspacePoolInfoNode* other, SEqualReducer equal) const {
226+
return PoolInfoNode::SEqualReduce(other, equal);
227+
}
228+
229+
void SHashReduce(SHashReducer hash_reduce) const { PoolInfoNode::SHashReduce(hash_reduce); }
230+
223231
static constexpr const char* _type_key = "ir.WorkspacePoolInfo";
224232
TVM_DECLARE_FINAL_OBJECT_INFO(WorkspacePoolInfoNode, PoolInfoNode);
225233
};
@@ -275,6 +283,22 @@ class ConstantInfo : public ObjectRef {
275283
* data from constant_info_array */
276284
struct ConstantPoolInfoNode : public PoolInfoNode {
277285
Array<ConstantInfo> constant_info_array;
286+
287+
void VisitAttrs(tvm::AttrVisitor* v) {
288+
PoolInfoNode::VisitAttrs(v);
289+
v->Visit("constant_info_array", &constant_info_array);
290+
}
291+
292+
bool SEqualReduce(const ConstantPoolInfoNode* other, SEqualReducer equal) const {
293+
return PoolInfoNode::SEqualReduce(other, equal) &&
294+
equal(constant_info_array, other->constant_info_array);
295+
}
296+
297+
void SHashReduce(SHashReducer hash_reduce) const {
298+
PoolInfoNode::SHashReduce(hash_reduce);
299+
hash_reduce(constant_info_array);
300+
}
301+
278302
static constexpr const char* _type_key = "ir.ConstantPoolInfo";
279303
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantPoolInfoNode, PoolInfoNode);
280304
};

python/tvm/testing/aot.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -286,8 +286,11 @@ def _emit_main_workspace_pool_structs(main_file, workspace_pool_names, mod_name)
286286
f"struct {_mangle_name(mod_name, 'workspace_pools')} "
287287
f"{_mangle_name(mod_name, 'workspace_pools')} = {{"
288288
)
289-
for workspace_pool_name in workspace_pool_names:
290-
main_file.write(f"\t.{workspace_pool_name} = {workspace_pool_name},\n")
289+
for workspace_pool_name in workspace_pool_names.keys():
290+
main_file.write(
291+
f"\t.{workspace_pool_name} = {workspace_pool_names[workspace_pool_name]}"
292+
f"{workspace_pool_name},\n"
293+
)
291294
main_file.write("};\n")
292295

293296

@@ -507,19 +510,27 @@ def _create_main(
507510
compiled_model.executor_factory.executor_codegen_metadata
508511
)
509512
devices = compiled_model.executor_factory.get_devices()
510-
workspace_pool_names = None
513+
workspace_pool_names = {}
511514
if executor_codegen_metadata.pool_inputs:
512-
workspace_pool_names = [
513-
allocated_pool.pool_info.pool_name
515+
workspace_pool_names = {
516+
allocated_pool.pool_info.pool_name: "&"
517+
if isinstance(
518+
allocated_pool.pool_info, tvm.ir.memory_pools.ConstantPoolInfo
519+
)
520+
else ""
514521
for allocated_pool in dict(executor_codegen_metadata.pool_inputs).values()
515522
if not allocated_pool.pool_info.is_internal
516-
]
523+
}
517524
_emit_main_device_structs(main_file, devices, model.name)
518525
if not use_workspace_io:
519526
_emit_main_workspace_pool_structs(main_file, workspace_pool_names, model.name)
520527
_emit_main_data_structs(main_file, model.inputs, model.outputs, model.name)
521528
_emit_main_c_interface_call(
522-
main_file, devices, workspace_pool_names, model.name, use_workspace_io
529+
main_file,
530+
devices,
531+
list(workspace_pool_names.keys()),
532+
model.name,
533+
use_workspace_io,
523534
)
524535
else:
525536
_emit_main_fake_packed_values(main_file)

src/target/source/codegen_params.cc

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ static int ComputeNumElementsPerRow(int one_element_size_bytes, int indent_chars
5353
}
5454

5555
template <typename T, typename Enable = std::enable_if<std::is_integral<T>::value>>
56-
void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std::ostream& os) {
56+
void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std::ostream& os,
57+
const std::string& eol) {
5758
int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */);
5859
if (std::is_signed<T>::value) {
5960
one_element_size_bytes += 1; // sign character
@@ -97,17 +98,18 @@ void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std::
9798
os << ", ";
9899
}
99100
if ((i % elements_per_row) == elements_per_row - 1) {
100-
os << "\n";
101+
os << eol;
101102
}
102103
}
103104

104105
if ((num_elements % elements_per_row) != 0) {
105-
os << "\n";
106+
os << eol;
106107
}
107108
}
108109

109110
template <typename T, typename Enable = std::enable_if<std::is_floating_point<T>::value>>
110-
void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, std::ostream& os) {
111+
void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, std::ostream& os,
112+
const std::string& eol) {
111113
// Floats and doubles are printed as hex but casted.
112114
int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */) + 1 /* sign */ +
113115
1 /* decimal point */ + 1 /* exponent sign */;
@@ -149,16 +151,17 @@ void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars,
149151
os << ", ";
150152
}
151153
if ((i % elements_per_row) == elements_per_row - 1) {
152-
os << "\n";
154+
os << eol;
153155
}
154156
}
155157

156158
if ((num_elements % elements_per_row) != 0) {
157-
os << "\n";
159+
os << eol;
158160
}
159161
}
160162

161-
void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os) {
163+
void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os,
164+
const std::string& eol) {
162165
auto arr_type = arr.DataType();
163166
CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw "
164167
<< arr_type.lanes();
@@ -180,13 +183,13 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream&
180183
<< "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw "
181184
<< arr_type.bits() << "-bit array";
182185
if (arr_type.bits() == 8) {
183-
PrintIntegralArray<int8_t>(arr->data, num_elements, indent_chars, os);
186+
PrintIntegralArray<int8_t>(arr->data, num_elements, indent_chars, os, eol);
184187
} else if (arr_type.bits() == 16) {
185-
PrintIntegralArray<int16_t>(arr->data, num_elements, indent_chars, os);
188+
PrintIntegralArray<int16_t>(arr->data, num_elements, indent_chars, os, eol);
186189
} else if (arr_type.bits() == 32) {
187-
PrintIntegralArray<int32_t>(arr->data, num_elements, indent_chars, os);
190+
PrintIntegralArray<int32_t>(arr->data, num_elements, indent_chars, os, eol);
188191
} else if (arr_type.bits() == 64) {
189-
PrintIntegralArray<int64_t>(arr->data, num_elements, indent_chars, os);
192+
PrintIntegralArray<int64_t>(arr->data, num_elements, indent_chars, os, eol);
190193
} else {
191194
CHECK(false) << "should not get here";
192195
}
@@ -199,13 +202,13 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream&
199202
<< arr_type.bits() << "-bit array";
200203

201204
if (arr_type.bits() == 8) {
202-
PrintIntegralArray<uint8_t>(arr->data, num_elements, indent_chars, os);
205+
PrintIntegralArray<uint8_t>(arr->data, num_elements, indent_chars, os, eol);
203206
} else if (arr_type.bits() == 16) {
204-
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os);
207+
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol);
205208
} else if (arr_type.bits() == 32) {
206-
PrintIntegralArray<uint32_t>(arr->data, num_elements, indent_chars, os);
209+
PrintIntegralArray<uint32_t>(arr->data, num_elements, indent_chars, os, eol);
207210
} else if (arr_type.bits() == 64) {
208-
PrintIntegralArray<uint64_t>(arr->data, num_elements, indent_chars, os);
211+
PrintIntegralArray<uint64_t>(arr->data, num_elements, indent_chars, os, eol);
209212
} else {
210213
CHECK(false) << "should not get here";
211214
}
@@ -216,11 +219,11 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream&
216219
os.setf(std::ios::left, std::ios::adjustfield);
217220
if (arr_type.bits() == 16) {
218221
// NOTE: print types not widely supported by C as uint16_t.
219-
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os);
222+
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol);
220223
} else if (arr_type.bits() == 32) {
221-
PrintFloatingPointArray<float>(arr->data, num_elements, indent_chars, os);
224+
PrintFloatingPointArray<float>(arr->data, num_elements, indent_chars, os, eol);
222225
} else if (arr_type.bits() == 64) {
223-
PrintFloatingPointArray<double>(arr->data, num_elements, indent_chars, os);
226+
PrintFloatingPointArray<double>(arr->data, num_elements, indent_chars, os, eol);
224227
} else {
225228
CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw "
226229
<< arr_type.bits() << "-bit array";
@@ -233,7 +236,7 @@ void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream&
233236
CHECK(arr_type.bits() == 16)
234237
<< "CodegenParams: only support generating 16-bit bfloat params; saw " << arr_type.bits()
235238
<< "-bit array";
236-
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os);
239+
PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol);
237240
break;
238241
}
239242

src/target/source/codegen_params.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include <tvm/runtime/ndarray.h>
2828

2929
#include <iostream>
30+
#include <string>
3031

3132
namespace tvm {
3233
namespace codegen {
@@ -44,7 +45,8 @@ namespace codegen {
4445
* \param indent_chars Number of chars to indent
4546
* \param os Output stream where the array data should be written.
4647
*/
47-
void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os);
48+
void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os,
49+
const std::string& eol = "\n");
4850

4951
} // namespace codegen
5052
} // namespace tvm

src/target/source/interface_c.cc

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,11 @@
2929
#include <tvm/runtime/registry.h>
3030
#include <tvm/tir/usmp/utils.h>
3131

32+
#include <numeric>
3233
#include <string>
3334

3435
#include "../../relay/backend/name_transforms.h"
36+
#include "codegen_params.h"
3537

3638
namespace tvm {
3739
namespace codegen {
@@ -90,8 +92,12 @@ class InterfaceCNode : public runtime::ModuleNode {
9092
for (const tir::usmp::AllocatedPoolInfo pool : pools_) {
9193
String pool_name = pool->pool_info->pool_name;
9294
Integer pool_size = pool->allocated_size;
93-
EmitIntegerValueMacro(code, SanitizeName(pool_name) + " size",
94-
SanitizeName(pool_name) + "_WORKSPACE_POOL_SIZE", pool_size->value);
95+
if (const auto* pool_info = pool->pool_info.as<ConstantPoolInfoNode>()) {
96+
EmitConstantPool(code, SanitizeName(pool_name) + " initialization data", pool_info);
97+
} else {
98+
EmitIntegerValueMacro(code, SanitizeName(pool_name) + " size",
99+
SanitizeName(pool_name) + _macro_pool_size_postfix, pool_size->value);
100+
}
95101
}
96102
EmitLowerHeaderGuard(code);
97103

@@ -103,6 +109,9 @@ class InterfaceCNode : public runtime::ModuleNode {
103109
}
104110

105111
private:
112+
constexpr static const char* _macro_pool_size_postfix = "_POOL_SIZE_BYTES";
113+
constexpr static const char* _macro_pool_data_postfix = "_POOL_DATA";
114+
106115
void EmitUpperHeaderGuard(std::stringstream& code_stream) {
107116
std::string header_guard_name = ToCConstantStyle(PrefixGeneratedName({module_name_, "H"}));
108117
code_stream << "#ifndef " << header_guard_name << "_\n"
@@ -152,6 +161,43 @@ class InterfaceCNode : public runtime::ModuleNode {
152161
code_stream << "#define " << macro_name_prefixed << " " << macro_value << "\n";
153162
}
154163

164+
void EmitConstantPool(std::stringstream& code_, const std::string& brief_description,
165+
const ConstantPoolInfoNode* pool_info) {
166+
EmitBrief(code_, brief_description);
167+
std::string name_prefixed =
168+
ToCConstantStyle(PrefixGeneratedName({module_name_, SanitizeName(pool_info->pool_name)}));
169+
170+
if (pool_info->constant_info_array.size() > 0) {
171+
std::vector<ConstantInfo> const_info_vec(pool_info->constant_info_array.begin(),
172+
pool_info->constant_info_array.end());
173+
std::sort(const_info_vec.begin(), const_info_vec.end(),
174+
[](const ConstantInfo& a, const ConstantInfo& b) {
175+
return a->byte_offset->value < b->byte_offset->value;
176+
});
177+
int64_t accumulated_pool_len =
178+
const_info_vec.back()->byte_offset +
179+
runtime::GetDataSize(*const_info_vec.back()->data.operator->());
180+
const auto& accumulated_pool = runtime::NDArray::Empty(
181+
{accumulated_pool_len}, DataType::UInt(8), const_info_vec.back()->data->device);
182+
for (const auto& const_info : const_info_vec) {
183+
const auto& data = const_info->data;
184+
const auto& offs = const_info->byte_offset;
185+
data.CopyToBytes(static_cast<uint8_t*>(accumulated_pool->data) + offs,
186+
runtime::GetDataSize(*data.operator->()));
187+
}
188+
189+
code_ << "#define " << name_prefixed << _macro_pool_size_postfix << " "
190+
<< accumulated_pool_len << "\n";
191+
code_ << "#define " << name_prefixed << _macro_pool_data_postfix << " \\\n";
192+
codegen::NDArrayDataToC(accumulated_pool, 4, code_, "\\\n");
193+
code_ << '\n';
194+
195+
} else {
196+
LOG(FATAL) << "No constant data in constant pool found "
197+
<< PrettyPrint(GetRef<ObjectRef>(pool_info));
198+
}
199+
}
200+
155201
void EmitRunFunction(std::stringstream& code_stream) {
156202
std::string run_function = ToCVariableStyle(PrefixGeneratedName({module_name_, "run"}));
157203
std::string inputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "inputs"}));

tests/cpp/target/source/interface_c_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSingle) {
409409

410410
ASSERT_THAT(
411411
header_source,
412-
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_WORKSPACE_POOL_SIZE 100000"));
412+
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_POOL_SIZE_BYTES 100000"));
413413
}
414414

415415
TEST(InterfaceAPI, ContainsWorkspacePoolStructMany) {
@@ -443,14 +443,14 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructMany) {
443443

444444
ASSERT_THAT(
445445
header_source,
446-
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_1_WORKSPACE_POOL_SIZE 100000"));
446+
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_1_POOL_SIZE_BYTES 100000"));
447447

448448
ASSERT_THAT(header_source,
449449
HasSubstr("* \\brief my_memory_pool_2 size for TVM module \"ultimate_cat_spotter\""));
450450

451451
ASSERT_THAT(
452452
header_source,
453-
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_2_WORKSPACE_POOL_SIZE 200000"));
453+
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_2_POOL_SIZE_BYTES 200000"));
454454
}
455455

456456
TEST(InterfaceAPI, ContainsWorkspacePoolStructSanitized) {
@@ -479,7 +479,7 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSanitized) {
479479

480480
ASSERT_THAT(
481481
header_source,
482-
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_1_WORKSPACE_POOL_SIZE 100000"));
482+
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_1_POOL_SIZE_BYTES 100000"));
483483
}
484484

485485
TEST(InterfaceAPI, ContainsWorkspacePoolStructClash) {

0 commit comments

Comments
 (0)