Skip to content

Commit c6c89c3

Browse files
authored
[Hexagon] Add concept of DMA groups (#14254)
* [Hexagon] Add concept of groups to QueuedRingBuffer * fix failing RingBuffer unit test * add StartGroup and EndGroup builtins and lowering * use max queue number to increase dma copy perf * increase max dma queues to fix test failures * use DMA groups in LowerAsyncDMA pass Co-authored-by: Noah Verke <[email protected]> Co-authored-by: Eric Lunderberg <[email protected]> * elide merge_async_commit_queue_scope * LowerAsyncDMA bug fix + disallow non-contig copy; comments, tests pass * format and lint * add comments to Hex User DMA header * use unsigned queue ID; fix test fails * format and lint * use dma_copy_dltensor in TIR tenor intrin; fix test fails * address feedback: comments, types, names
1 parent fc2a9e5 commit c6c89c3

File tree

19 files changed

+488
-199
lines changed

19 files changed

+488
-199
lines changed

include/tvm/meta_schedule/postproc.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,9 @@ class Postproc : public runtime::ObjectRef {
111111
TVM_DLL static Postproc DisallowDynamicLoop();
112112
/*!
113113
* \brief Create a postprocessor that checks if all async mem copies are not strided.
114-
* \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope.
115114
* \return The postprocessor created
116115
*/
117-
TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true);
116+
TVM_DLL static Postproc DisallowAsyncStridedMemCopy();
118117
/*!
119118
* \brief Create a postprocessor that rewrites the cooperative fetch annotation to
120119
* actual vectorized cooperative fetching in loop bindings.

include/tvm/tir/builtin.h

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -727,14 +727,49 @@ TVM_DLL const Op& texture2d_load();
727727

728728
/*!
729729
* \brief Initiate a non-blocking DMA copy from source to destination
730+
*
731+
* The copy is launched immediately.
732+
*
733+
* If a `dma_start_group()` call is active, the copy will be added
734+
* to the current group for tracking of in-flight group counts.
735+
*
736+
* If no `dma_start_group()` call is active, the copy will be tracked
737+
* individually i.e. as a group with size 1.
730738
*/
731739
TVM_DLL const Op& dma_copy();
732740

733741
/*!
734-
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
742+
* \brief Wait until the number of DMA groups in flight is less than
743+
* or equal to some maximum
744+
*
745+
* Calling `dma_wait()` while a group is active is unsupported.
735746
*/
736747
TVM_DLL const Op& dma_wait();
737748

749+
/*!
750+
* \brief Start a group of DMA copies
751+
*
752+
* Any call to `dma_copy()` that occurs after `dma_start_group()` will
753+
* be added to the current group for tracking of in-flight group counts.
754+
*
755+
* Only one DMA group may be active at a given time. Calling
756+
* `dma_start_group()` while a group is active is unsupported.
757+
*/
758+
TVM_DLL const Op& dma_start_group();
759+
760+
/*!
761+
* \brief End a group of DMA copies
762+
*
763+
* Track all calls to `dma_copy()` that occurred since the preceding
764+
* `dma_start_group()` as a single group in-flight.
765+
*
766+
* Calling `dma_end_group()` without an active group is unsupported.
767+
*
768+
* Note: A group of DMA calls may be empty, and will still contribute
769+
* to the count of in-flight groups used by `dma_wait()`.
770+
*/
771+
TVM_DLL const Op& dma_end_group();
772+
738773
/*!
739774
* \brief Provide a true statement that can be used for simplifications
740775
*

python/tvm/meta_schedule/postproc/disallow_async_strided_mem_copy.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,9 @@
2323

2424
@register_object("meta_schedule.DisallowAsyncStridedMemCopy")
2525
class DisallowAsyncStridedMemCopy(Postproc):
26-
"""A postprocessor that disallows schedules that use async strided mem copies.
26+
"""A postprocessor that disallows schedules that use async strided mem copies."""
2727

28-
Parameters
29-
----------
30-
merge_async_commit_queue_scope : bool
31-
Whether or not to merge the async commit queue scope.
32-
"""
33-
34-
def __init__(self, merge_async_commit_queue_scope=True) -> None:
28+
def __init__(self) -> None:
3529
self.__init_handle_by_constructor__(
3630
_ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member
37-
merge_async_commit_queue_scope,
3831
)

python/tvm/tir/tensor_intrin/hexagon.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,20 +47,27 @@ def sync_dma_load_impl(a: T.handle, c: T.handle) -> None:
4747
T.writes(C[0:size])
4848
T.evaluate(
4949
T.tvm_call_packed(
50-
"device_api.hexagon.dma_copy",
51-
-1, # Use QueueId of -1 to not interfere with async copies.
52-
T.address_of(C[0], dtype="handle"),
53-
T.address_of(A[0], dtype="handle"),
54-
size,
55-
0, # Do not use experimental bypass mode.
56-
dtype="int32",
57-
)
58-
)
59-
T.evaluate(
60-
T.tvm_call_packed(
61-
"device_api.hexagon.dma_wait",
62-
-1,
63-
0, # Wait for the sync queue (-1) to have 0 messages.
50+
"device_api.hexagon.dma_copy_dltensor",
51+
T.tvm_stack_make_array(
52+
T.address_of(C[0], dtype="handle"),
53+
T.tvm_stack_make_shape(size, dtype="handle"),
54+
0,
55+
1,
56+
C.dtype,
57+
0,
58+
dtype="handle",
59+
),
60+
T.tvm_stack_make_array(
61+
T.address_of(A[0], dtype="handle"),
62+
T.tvm_stack_make_shape(size, dtype="handle"),
63+
0,
64+
1,
65+
A.dtype,
66+
0,
67+
dtype="handle",
68+
),
69+
T.cast(size, dtype="int"),
70+
False, # Do not use experimental bypass mode.
6471
dtype="int32",
6572
)
6673
)

src/driver/driver_api.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
5252
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
5353
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
5454
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
55-
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
5655
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
5756
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
5857
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);

src/meta_schedule/postproc/disallow_async_strided_mem_copy.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,6 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
145145
pass_list.push_back(tir::transform::InjectDoubleBuffer());
146146
pass_list.push_back(tir::transform::VectorizeLoop(true));
147147
pass_list.push_back(tir::transform::StorageRewrite());
148-
transform::PassContext pass_ctx = transform::PassContext::Current();
149-
pass_ctx->config.Set("tir.merge_async_commit_queue_scope",
150-
Bool(merge_async_commit_queue_scope));
151148
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
152149
runtime::String(g_var->name_hint));
153150
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
@@ -169,15 +166,12 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
169166
return Postproc(n);
170167
}
171168

172-
bool merge_async_commit_queue_scope = true;
173-
174169
static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
175170
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
176171
};
177172

178-
Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) {
173+
Postproc Postproc::DisallowAsyncStridedMemCopy() {
179174
ObjectPtr<DisallowAsyncStridedMemCopyNode> n = make_object<DisallowAsyncStridedMemCopyNode>();
180-
n->merge_async_commit_queue_scope = merge_async_commit_queue_scope;
181175
return Postproc(n);
182176
}
183177

src/runtime/hexagon/hexagon_device_api.cc

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,10 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor")
210210
});
211211

212212
TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
213-
int queue_id = args[0];
213+
uint32_t queue_id = static_cast<int>(args[0]);
214214
void* dst = args[1];
215215
void* src = args[2];
216-
int size = args[3];
216+
uint32_t size = static_cast<int>(args[3]);
217217
ICHECK(size > 0);
218218
bool bypass_cache = args[4];
219219

@@ -226,13 +226,26 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVM
226226
});
227227

228228
TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) {
229-
int queue_id = args[0];
229+
uint32_t queue_id = static_cast<int>(args[0]);
230230
int inflight = args[1];
231231
ICHECK(inflight >= 0);
232232
HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight);
233233
*rv = static_cast<int32_t>(0);
234234
});
235235

236+
TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group")
237+
.set_body([](TVMArgs args, TVMRetValue* rv) {
238+
uint32_t queue_id = static_cast<int>(args[0]);
239+
HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id);
240+
*rv = static_cast<int32_t>(0);
241+
});
242+
243+
TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group").set_body([](TVMArgs args, TVMRetValue* rv) {
244+
uint32_t queue_id = static_cast<int>(args[0]);
245+
HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id);
246+
*rv = static_cast<int32_t>(0);
247+
});
248+
236249
TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
237250
int32_t device_type = args[0];
238251
int32_t device_id = args[1];

src/runtime/hexagon/hexagon_user_dma.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ unsigned int HexagonUserDMA::Init() {
3232
return status;
3333
}
3434

35-
int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache) {
35+
int HexagonUserDMA::Copy(uint32_t queue_id, void* dst, void* src, uint32_t length,
36+
bool bypass_cache) {
3637
// length limited to 24 bits
3738
if (length > DESC_LENGTH_MASK) {
3839
return DMA_FAILURE;
@@ -103,15 +104,15 @@ int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bo
103104
return DMA_SUCCESS;
104105
}
105106

106-
void HexagonUserDMA::Wait(int queue_id, uint32_t max_dmas_in_flight) {
107+
void HexagonUserDMA::Wait(uint32_t queue_id, uint32_t max_dmas_in_flight) {
107108
// wait (forever) until max DMAs in flight <= actual DMAs in flight
108-
while (DMAsInFlight(queue_id) > max_dmas_in_flight) {
109+
while (DMAGroupsInFlight(queue_id) > max_dmas_in_flight) {
109110
}
110111
}
111112

112-
uint32_t HexagonUserDMA::Poll(int queue_id) { return DMAsInFlight(queue_id); }
113+
uint32_t HexagonUserDMA::Poll(uint32_t queue_id) { return DMAGroupsInFlight(queue_id); }
113114

114-
uint32_t HexagonUserDMA::DMAsInFlight(int queue_id) {
115+
uint32_t HexagonUserDMA::DMAGroupsInFlight(uint32_t queue_id) {
115116
dmpoll(); // update DMA engine status
116117
return descriptors_->InFlight(queue_id);
117118
}
@@ -125,7 +126,8 @@ HexagonUserDMA::HexagonUserDMA() {
125126
unsigned int done = dma_desc_get_done(dma_desc);
126127
return (done != DESC_DONE_COMPLETE);
127128
};
128-
descriptors_ = new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_DESCRIPTORS, desc_in_flight);
129+
descriptors_ =
130+
new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_QUEUES, MAX_DMA_DESCRIPTORS, desc_in_flight);
129131
}
130132

131133
HexagonUserDMA::~HexagonUserDMA() {

src/runtime/hexagon/hexagon_user_dma.h

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ namespace hexagon {
3434
#define DMA_FAILURE -1
3535
#define DMA_RETRY 1
3636
#define MAX_DMA_DESCRIPTORS 100
37-
#define SYNC_DMA_QUEUE -1
37+
#define MAX_DMA_QUEUES 10
38+
#define SYNC_DMA_QUEUE MAX_DMA_QUEUES - 1
3839

3940
class HexagonUserDMA {
4041
public:
@@ -47,32 +48,50 @@ class HexagonUserDMA {
4748

4849
/*!
4950
* \brief Initiate DMA to copy memory from source to destination address
51+
* \param queue_id The virtual DMA queue
5052
* \param dst Destination address
5153
* \param src Source address
5254
* \param length Length in bytes to copy
5355
* \returns Status: DMA_SUCCESS or DMA_FAILURE
5456
*/
55-
int Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);
57+
int Copy(uint32_t queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);
5658

5759
/*!
5860
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
61+
* \param queue_id The virtual DMA queue
5962
* \param max_dmas_in_flight Maximum number of DMAs allowed to be in flight
6063
* to satisfy the `Wait` e.g. use `Wait(0)` to wait on "all" outstanding DMAs to complete
6164
*/
62-
void Wait(int queue_id, uint32_t max_dmas_in_flight);
65+
void Wait(uint32_t queue_id, uint32_t max_dmas_in_flight);
6366

6467
/*!
6568
* \brief Poll the number of DMAs in flight
69+
* \param queue_id The virtual DMA queue
6670
* \returns Number of DMAs in flight
6771
*/
68-
uint32_t Poll(int queue_id);
72+
uint32_t Poll(uint32_t queue_id);
73+
74+
/*!
75+
* \brief Start a group of DMA copies
76+
* \param queue_id The virtual DMA queue
77+
*/
78+
void StartGroup(uint32_t queue_id) { descriptors_->StartGroup(queue_id); }
79+
80+
/*!
81+
* \brief End a group of DMA copies
82+
* \param queue_id The virtual DMA queue
83+
*/
84+
void EndGroup(uint32_t queue_id) { descriptors_->EndGroup(queue_id); }
6985

7086
private:
7187
//! \brief Initializes the Hexagon User DMA engine
7288
unsigned int Init();
7389

74-
//! \brief Calculates and returns the number of DMAs in flight
75-
uint32_t DMAsInFlight(int queue_id);
90+
/*!
91+
* \brief Calculates and returns the number of DMAs in flight
92+
* \param queue_id The virtual DMA queue
93+
*/
94+
uint32_t DMAGroupsInFlight(uint32_t queue_id);
7695

7796
//! \brief Tracks whether the very first DMA has been executed
7897
bool first_dma_ = true;

0 commit comments

Comments
 (0)