Skip to content

Commit eb36bdd

Browse files
tianleiwuankitm3k
authored andcommitted
[CUDA/ROCm/Migraphx] consolidate gpu data transfer (microsoft#22609)
### Description Consolidate the gpu data transfer in CUDA, ROCm and Migraphx EP. (1) Remove some redundant stream synchronize on default stream according to spec of cudaMemcpy (2) consolidate CUDA, ROCm and MigrphaX to try use same logic. ### Motivation This is a follow up on reviewing microsoft#22589. ### Context https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior ##### cudaMemcpy() * For transfers from pageable host memory to device memory, a stream sync is performed before the copy is initiated. The function will return once the pageable buffer has been copied to the staging memory for DMA transfer to device memory, **but the DMA to final destination may not have completed**. * For transfers from pinned host memory to device memory, the function is synchronous with respect to the host. * For transfers from device to either pageable or pinned host memory, the function returns only once the copy has completed. * For transfers from device memory to device memory, **no host-side synchronization is performed**. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. #### cudaMemcpyAsync * For transfers between device memory and pageable host memory, the function might be synchronous with respect to host. * For transfers from any host memory to any host memory, the function is fully synchronous with respect to the host. * If pageable memory must first be staged to pinned memory, the driver may synchronize with the stream and stage the copy into pinned memory. * For all other transfers, the function should be fully asynchronous. https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory.html ##### hipMemcpyAsync() If host or dest are not pinned, the memory copy will be performed synchronously. For best performance, use hipHostMalloc to allocate host memory that is transferred asynchronously. on HCC hipMemcpyAsync does not support overlapped H2D and D2H copies. For hipMemcpy, the copy is always performed by the device associated with the specified stream. ##### hipMemcpy() For hipMemcpy, the copy is always performed by the current device (set by hipSetDevice). https://github.com/ROCm/ROCm/blob/roc-5.7.x/tools/autotag/templates/rocm_changes/5.6.1.md ROCm 5.6.1 release note: hipMemcpy device-to-device (intra device) is now asynchronous with respect to the host
1 parent f187022 commit eb36bdd

File tree

1 file changed

+0
-3
lines changed

1 file changed

+0
-3
lines changed

onnxruntime/core/providers/migraphx/gpu_data_transfer.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst,
6767
} else if (src_device.Type() == OrtDevice::GPU) {
6868
// copying between GPU, this is non-blocking
6969
HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast<hipStream_t>(stream.GetHandle())));
70-
} else {
71-
// copy from other CPU memory to GPU, this is blocking
72-
HIP_CALL_THROW(hipMemcpyWithStream(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast<hipStream_t>(stream.GetHandle())));
7370
}
7471
} else if (src_device.Type() == OrtDevice::GPU) {
7572
// If dest are not pinned, the memory copy will be performed synchronously.

0 commit comments

Comments
 (0)