diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp b/sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp index f504bb01ce0bf..7d4fe0c26a424 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/context.hpp @@ -7,6 +7,8 @@ //===-----------------------------------------------------------------===// #pragma once +#include + #include "common.hpp" #include "device.hpp" #include "platform.hpp" @@ -93,9 +95,61 @@ struct ur_context_handle_t_ { uint32_t getReferenceCount() const noexcept { return RefCount; } + /// We need to keep track of USM mappings in AMD HIP, as certain extra + /// synchronization *is* actually required for correctness. + /// During kernel enqueue we must dispatch a prefetch for each kernel argument + /// that points to a USM mapping to ensure the mapping is correctly + /// populated on the device (https://github.com/intel/llvm/issues/7252). Thus, + /// we keep track of mappings in the context, and then check against them just + /// before the kernel is launched. The stream against which the kernel is + /// launched is not known until enqueue time, but the USM mappings can happen + /// at any time. Thus, they are tracked on the context used for the urUSM* + /// mapping. + /// + /// The three utility function are simple wrappers around a mapping from a + /// pointer to a size. + void addUSMMapping(void *Ptr, size_t Size) { + std::lock_guard Guard(Mutex); + assert(USMMappings.find(Ptr) == USMMappings.end() && + "mapping already exists"); + USMMappings[Ptr] = Size; + } + + void removeUSMMapping(const void *Ptr) { + std::lock_guard guard(Mutex); + auto It = USMMappings.find(Ptr); + if (It != USMMappings.end()) + USMMappings.erase(It); + } + + std::pair getUSMMapping(const void *Ptr) { + std::lock_guard Guard(Mutex); + auto It = USMMappings.find(Ptr); + // The simple case is the fast case... + if (It != USMMappings.end()) + return *It; + + // ... but in the failure case we have to fall back to a full scan to search + // for "offset" pointers in case the user passes in the middle of an + // allocation. We have to do some not-so-ordained-by-the-standard ordered + // comparisons of pointers here, but it'll work on all platforms we support. + uintptr_t PtrVal = (uintptr_t)Ptr; + for (std::pair Pair : USMMappings) { + uintptr_t BaseAddr = (uintptr_t)Pair.first; + uintptr_t EndAddr = BaseAddr + Pair.second; + if (PtrVal > BaseAddr && PtrVal < EndAddr) { + // If we've found something now, offset *must* be nonzero + assert(Pair.second); + return Pair; + } + } + return {nullptr, 0}; + } + private: std::mutex Mutex; std::vector ExtendedDeleters; + std::unordered_map USMMappings; }; namespace { diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp index 1b0b2acc2a3f8..7b36c1863fcc8 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/enqueue.cpp @@ -252,7 +252,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( std::unique_ptr RetImplEvent{nullptr}; try { - ScopedContext Active(hQueue->getDevice()); + ur_device_handle_t Dev = hQueue->getDevice(); + ScopedContext Active(Dev); + ur_context_handle_t Ctx = hQueue->getContext(); uint32_t StreamToken; ur_stream_quard Guard; @@ -260,6 +262,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( numEventsInWaitList, phEventWaitList, Guard, &StreamToken); hipFunction_t HIPFunc = hKernel->get(); + hipDevice_t HIPDev = Dev->get(); + for (const void *P : hKernel->getPtrArgs()) { + auto [Addr, Size] = Ctx->getUSMMapping(P); + if (!Addr) + continue; + if (hipMemPrefetchAsync(Addr, Size, HIPDev, HIPStream) != hipSuccess) + return UR_RESULT_ERROR_INVALID_KERNEL_ARGS; + } Result = enqueueEventsWait(hQueue, HIPStream, numEventsInWaitList, phEventWaitList); @@ -301,7 +311,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( int DeviceMaxLocalMem = 0; Result = UR_CHECK_ERROR(hipDeviceGetAttribute( &DeviceMaxLocalMem, hipDeviceAttributeMaxSharedMemoryPerBlock, - hQueue->getDevice()->get())); + HIPDev)); static const int EnvVal = std::atoi(LocalMemSzPtr); if (EnvVal <= 0 || EnvVal > DeviceMaxLocalMem) { diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp index 8da2d969c2c55..93d431989617c 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.cpp @@ -256,7 +256,7 @@ urKernelGetSubGroupInfo(ur_kernel_handle_t hKernel, ur_device_handle_t hDevice, UR_APIEXPORT ur_result_t UR_APICALL urKernelSetArgPointer( ur_kernel_handle_t hKernel, uint32_t argIndex, const ur_kernel_arg_pointer_properties_t *, const void *pArgValue) { - hKernel->setKernelArg(argIndex, sizeof(pArgValue), pArgValue); + hKernel->setKernelPtrArg(argIndex, sizeof(pArgValue), pArgValue); return UR_RESULT_SUCCESS; } diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp b/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp index 0e4f3c0ea8bd0..1e2bd03a0f1ea 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/kernel.hpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "program.hpp" @@ -55,6 +56,7 @@ struct ur_kernel_handle_t_ { args_size_t ParamSizes; args_index_t Indices; args_size_t OffsetPerIndex; + std::set PtrArgs; std::uint32_t ImplicitOffsetArgs[3] = {0, 0, 0}; @@ -175,6 +177,19 @@ struct ur_kernel_handle_t_ { Args.addArg(Index, Size, Arg); } + /// We track all pointer arguments to be able to issue prefetches at enqueue + /// time + void setKernelPtrArg(int Index, size_t Size, const void *PtrArg) { + Args.PtrArgs.insert(*static_cast(PtrArg)); + setKernelArg(Index, Size, PtrArg); + } + + bool isPtrArg(const void *ptr) { + return Args.PtrArgs.find(ptr) != Args.PtrArgs.end(); + } + + std::set &getPtrArgs() { return Args.PtrArgs; } + void setKernelLocalArg(int Index, size_t Size) { Args.addLocalArg(Index, Size); } diff --git a/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp b/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp index 296954268a818..f7699441143d3 100644 --- a/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp +++ b/sycl/plugins/unified_runtime/ur/adapters/hip/usm.cpp @@ -28,14 +28,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMHostAlloc( ScopedContext Active(hContext->getDevice()); Result = UR_CHECK_ERROR(hipHostMalloc(ppMem, size)); } catch (ur_result_t Error) { - Result = Error; + return Error; } if (Result == UR_RESULT_SUCCESS) { assert((!pUSMDesc || pUSMDesc->align == 0 || reinterpret_cast(*ppMem) % pUSMDesc->align == 0)); + hContext->addUSMMapping(*ppMem, size); } - return Result; } @@ -53,14 +53,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMDeviceAlloc( ScopedContext Active(hContext->getDevice()); Result = UR_CHECK_ERROR(hipMalloc(ppMem, size)); } catch (ur_result_t Error) { - Result = Error; + return Error; } if (Result == UR_RESULT_SUCCESS) { assert((!pUSMDesc || pUSMDesc->align == 0 || reinterpret_cast(*ppMem) % pUSMDesc->align == 0)); + hContext->addUSMMapping(*ppMem, size); } - return Result; } @@ -84,8 +84,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMSharedAlloc( if (Result == UR_RESULT_SUCCESS) { assert((!pUSMDesc || pUSMDesc->align == 0 || reinterpret_cast(*ppMem) % pUSMDesc->align == 0)); + hContext->addUSMMapping(*ppMem, size); } - return Result; } @@ -109,8 +109,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMFree(ur_context_handle_t hContext, Result = UR_CHECK_ERROR(hipFreeHost(pMem)); } } catch (ur_result_t Error) { - Result = Error; + return Error; } + hContext->removeUSMMapping(pMem); return Result; }