Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions include/tvm/meta_schedule/postproc.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,9 @@ class Postproc : public runtime::ObjectRef {
TVM_DLL static Postproc DisallowDynamicLoop();
/*!
* \brief Create a postprocessor that checks if all async mem copies are not strided.
* \param merge_async_commit_queue_scope Whether or not to merge async commit queue scope.
* \return The postprocessor created
*/
TVM_DLL static Postproc DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope = true);
TVM_DLL static Postproc DisallowAsyncStridedMemCopy();
/*!
* \brief Create a postprocessor that rewrites the cooperative fetch annotation to
* actual vectorized cooperative fetching in loop bindings.
Expand Down
37 changes: 36 additions & 1 deletion include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -727,14 +727,49 @@ TVM_DLL const Op& texture2d_load();

/*!
* \brief Initiate a non-blocking DMA copy from source to destination
*
* The copy is launched immediately.
*
* If a `dma_start_group()` call is active, the copy will be added
* to the current group for tracking of in-flight group counts.
*
* If no `dma_start_group()` call is active, the copy will be tracked
* individually i.e. as a group with size 1.
*/
TVM_DLL const Op& dma_copy();

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
* \brief Wait until the number of DMA groups in flight is less than
* or equal to some maximum
*
* Calling `dma_wait()` while a group is active is unsupported.
*/
TVM_DLL const Op& dma_wait();

/*!
* \brief Start a group of DMA copies
*
* Any call to `dma_copy()` that occurs after `dma_start_group()` will
* be added to the current group for tracking of in-flight group counts.
*
* Only one DMA group may be active at a given time. Calling
* `dma_start_group()` while a group is active is unsupported.
*/
TVM_DLL const Op& dma_start_group();

/*!
* \brief End a group of DMA copies
*
* Track all calls to `dma_copy()` that occurred since the preceding
* `dma_start_group()` as a single group in-flight.
*
* Calling `dma_end_group()` without an active group is unsupported.
*
* Note: A group of DMA calls may be empty, and will still contribute
* to the count of in-flight groups used by `dma_wait()`.
*/
TVM_DLL const Op& dma_end_group();

/*!
* \brief Provide a true statement that can be used for simplifications
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,9 @@

@register_object("meta_schedule.DisallowAsyncStridedMemCopy")
class DisallowAsyncStridedMemCopy(Postproc):
"""A postprocessor that disallows schedules that use async strided mem copies.
"""A postprocessor that disallows schedules that use async strided mem copies."""

Parameters
----------
merge_async_commit_queue_scope : bool
Whether or not to merge the async commit queue scope.
"""

def __init__(self, merge_async_commit_queue_scope=True) -> None:
def __init__(self) -> None:
self.__init_handle_by_constructor__(
_ffi_api.PostprocDisallowAsyncStridedMemCopy, # type: ignore # pylint: disable=no-member
merge_async_commit_queue_scope,
)
35 changes: 21 additions & 14 deletions python/tvm/tir/tensor_intrin/hexagon.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,27 @@ def sync_dma_load_impl(a: T.handle, c: T.handle) -> None:
T.writes(C[0:size])
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_copy",
-1, # Use QueueId of -1 to not interfere with async copies.
T.address_of(C[0], dtype="handle"),
T.address_of(A[0], dtype="handle"),
size,
0, # Do not use experimental bypass mode.
dtype="int32",
)
)
T.evaluate(
T.tvm_call_packed(
"device_api.hexagon.dma_wait",
-1,
0, # Wait for the sync queue (-1) to have 0 messages.
"device_api.hexagon.dma_copy_dltensor",
T.tvm_stack_make_array(
T.address_of(C[0], dtype="handle"),
T.tvm_stack_make_shape(size, dtype="handle"),
0,
1,
C.dtype,
0,
dtype="handle",
),
T.tvm_stack_make_array(
T.address_of(A[0], dtype="handle"),
T.tvm_stack_make_shape(size, dtype="handle"),
0,
1,
A.dtype,
0,
dtype="handle",
),
T.cast(size, dtype="int"),
False, # Do not use experimental bypass mode.
dtype="int32",
)
)
Expand Down
1 change: 0 additions & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.ptx_ldg32", Bool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,6 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
pass_list.push_back(tir::transform::InjectDoubleBuffer());
pass_list.push_back(tir::transform::VectorizeLoop(true));
pass_list.push_back(tir::transform::StorageRewrite());
transform::PassContext pass_ctx = transform::PassContext::Current();
pass_ctx->config.Set("tir.merge_async_commit_queue_scope",
Bool(merge_async_commit_queue_scope));
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({{GlobalVar(g_var->name_hint), f}}));
Expand All @@ -169,15 +166,12 @@ class DisallowAsyncStridedMemCopyNode : public PostprocNode {
return Postproc(n);
}

bool merge_async_commit_queue_scope = true;

static constexpr const char* _type_key = "meta_schedule.DisallowAsyncStridedMemCopy";
TVM_DECLARE_FINAL_OBJECT_INFO(DisallowAsyncStridedMemCopyNode, PostprocNode);
};

Postproc Postproc::DisallowAsyncStridedMemCopy(bool merge_async_commit_queue_scope) {
Postproc Postproc::DisallowAsyncStridedMemCopy() {
ObjectPtr<DisallowAsyncStridedMemCopyNode> n = make_object<DisallowAsyncStridedMemCopyNode>();
n->merge_async_commit_queue_scope = merge_async_commit_queue_scope;
return Postproc(n);
}

Expand Down
19 changes: 16 additions & 3 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,10 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor")
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
uint32_t queue_id = static_cast<int>(args[0]);
void* dst = args[1];
void* src = args[2];
int size = args[3];
uint32_t size = static_cast<int>(args[3]);
ICHECK(size > 0);
bool bypass_cache = args[4];

Expand All @@ -226,13 +226,26 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVM
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
uint32_t queue_id = static_cast<int>(args[0]);
int inflight = args[1];
ICHECK(inflight >= 0);
HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_start_group")
.set_body([](TVMArgs args, TVMRetValue* rv) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: When the input/output types are fixed, the .set_body_typed() method can be used to avoid needing manual argument wrangling.

.set_body_typed([](int queue_id) -> int32_t {
      return HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id);
    });

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chose not to implement this for this iteration as it seems like we could / should redo the entire Hexagon Device API with this change.

uint32_t queue_id = static_cast<int>(args[0]);
HexagonDeviceAPI::Global()->UserDMA()->StartGroup(queue_id);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_end_group").set_body([](TVMArgs args, TVMRetValue* rv) {
uint32_t queue_id = static_cast<int>(args[0]);
HexagonDeviceAPI::Global()->UserDMA()->EndGroup(queue_id);
*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) {
int32_t device_type = args[0];
int32_t device_id = args[1];
Expand Down
14 changes: 8 additions & 6 deletions src/runtime/hexagon/hexagon_user_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ unsigned int HexagonUserDMA::Init() {
return status;
}

int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache) {
int HexagonUserDMA::Copy(uint32_t queue_id, void* dst, void* src, uint32_t length,
bool bypass_cache) {
// length limited to 24 bits
if (length > DESC_LENGTH_MASK) {
return DMA_FAILURE;
Expand Down Expand Up @@ -103,15 +104,15 @@ int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bo
return DMA_SUCCESS;
}

void HexagonUserDMA::Wait(int queue_id, uint32_t max_dmas_in_flight) {
void HexagonUserDMA::Wait(uint32_t queue_id, uint32_t max_dmas_in_flight) {
// wait (forever) until max DMAs in flight <= actual DMAs in flight
while (DMAsInFlight(queue_id) > max_dmas_in_flight) {
while (DMAGroupsInFlight(queue_id) > max_dmas_in_flight) {
}
}

uint32_t HexagonUserDMA::Poll(int queue_id) { return DMAsInFlight(queue_id); }
uint32_t HexagonUserDMA::Poll(uint32_t queue_id) { return DMAGroupsInFlight(queue_id); }

uint32_t HexagonUserDMA::DMAsInFlight(int queue_id) {
uint32_t HexagonUserDMA::DMAGroupsInFlight(uint32_t queue_id) {
dmpoll(); // update DMA engine status
return descriptors_->InFlight(queue_id);
}
Expand All @@ -125,7 +126,8 @@ HexagonUserDMA::HexagonUserDMA() {
unsigned int done = dma_desc_get_done(dma_desc);
return (done != DESC_DONE_COMPLETE);
};
descriptors_ = new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_DESCRIPTORS, desc_in_flight);
descriptors_ =
new QueuedRingBuffer<dma_desc_2d_t>(MAX_DMA_QUEUES, MAX_DMA_DESCRIPTORS, desc_in_flight);
}

HexagonUserDMA::~HexagonUserDMA() {
Expand Down
31 changes: 25 additions & 6 deletions src/runtime/hexagon/hexagon_user_dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ namespace hexagon {
#define DMA_FAILURE -1
#define DMA_RETRY 1
#define MAX_DMA_DESCRIPTORS 100
#define SYNC_DMA_QUEUE -1
#define MAX_DMA_QUEUES 10
#define SYNC_DMA_QUEUE MAX_DMA_QUEUES - 1

class HexagonUserDMA {
public:
Expand All @@ -47,32 +48,50 @@ class HexagonUserDMA {

/*!
* \brief Initiate DMA to copy memory from source to destination address
* \param queue_id The virtual DMA queue
* \param dst Destination address
* \param src Source address
* \param length Length in bytes to copy
* \returns Status: DMA_SUCCESS or DMA_FAILURE
*/
int Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);
int Copy(uint32_t queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
* \param queue_id The virtual DMA queue
* \param max_dmas_in_flight Maximum number of DMAs allowed to be in flight
* to satisfy the `Wait` e.g. use `Wait(0)` to wait on "all" outstanding DMAs to complete
*/
void Wait(int queue_id, uint32_t max_dmas_in_flight);
void Wait(uint32_t queue_id, uint32_t max_dmas_in_flight);

/*!
* \brief Poll the number of DMAs in flight
* \param queue_id The virtual DMA queue
* \returns Number of DMAs in flight
*/
uint32_t Poll(int queue_id);
uint32_t Poll(uint32_t queue_id);

/*!
* \brief Start a group of DMA copies
* \param queue_id The virtual DMA queue
*/
void StartGroup(uint32_t queue_id) { descriptors_->StartGroup(queue_id); }

/*!
* \brief End a group of DMA copies
* \param queue_id The virtual DMA queue
*/
void EndGroup(uint32_t queue_id) { descriptors_->EndGroup(queue_id); }

private:
//! \brief Initializes the Hexagon User DMA engine
unsigned int Init();

//! \brief Calculates and returns the number of DMAs in flight
uint32_t DMAsInFlight(int queue_id);
/*!
* \brief Calculates and returns the number of DMAs in flight
* \param queue_id The virtual DMA queue
*/
uint32_t DMAGroupsInFlight(uint32_t queue_id);

//! \brief Tracks whether the very first DMA has been executed
bool first_dma_ = true;
Expand Down
Loading