@@ -2306,9 +2306,252 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23062306 }
23072307}
23082308
2309+ // Gets UR argument struct for a given kernel and device based on the argument
2310+ // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2311+ // the graphs extension (LaunchWithArgs for graphs is planned future work).
2312+ static void GetUrArgsBasedOnType (
2313+ device_image_impl *DeviceImageImpl,
2314+ const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2315+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2316+ std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
2317+ switch (Arg.MType ) {
2318+ case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2319+ break ;
2320+ case kernel_param_kind_t ::kind_work_group_memory:
2321+ break ;
2322+ case kernel_param_kind_t ::kind_stream:
2323+ break ;
2324+ case kernel_param_kind_t ::kind_dynamic_accessor:
2325+ case kernel_param_kind_t ::kind_accessor: {
2326+ Requirement *Req = (Requirement *)(Arg.MPtr );
2327+
2328+ // getMemAllocationFunc is nullptr when there are no requirements. However,
2329+ // we may pass default constructed accessors to a command, which don't add
2330+ // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2331+ // valid case, so we need to properly handle it.
2332+ ur_mem_handle_t MemArg =
2333+ getMemAllocationFunc
2334+ ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2335+ : nullptr ;
2336+ ur_exp_kernel_arg_value_t Value = {};
2337+ Value.memObjTuple = {MemArg, AccessModeToUr (Req->MAccessMode )};
2338+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2339+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2340+ static_cast <uint32_t >(NextTrueIndex), sizeof (MemArg),
2341+ Value});
2342+ break ;
2343+ }
2344+ case kernel_param_kind_t ::kind_std_layout: {
2345+ ur_exp_kernel_arg_type_t Type;
2346+ if (Arg.MPtr ) {
2347+ Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
2348+ } else {
2349+ Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
2350+ }
2351+ ur_exp_kernel_arg_value_t Value = {};
2352+ Value.value = {Arg.MPtr };
2353+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2354+ Type, static_cast <uint32_t >(NextTrueIndex),
2355+ static_cast <size_t >(Arg.MSize ), Value});
2356+
2357+ break ;
2358+ }
2359+ case kernel_param_kind_t ::kind_sampler: {
2360+ sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2361+ ur_exp_kernel_arg_value_t Value = {};
2362+ Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2363+ ->getOrCreateSampler (ContextImpl);
2364+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2365+ UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2366+ static_cast <uint32_t >(NextTrueIndex),
2367+ sizeof (ur_sampler_handle_t ), Value});
2368+ break ;
2369+ }
2370+ case kernel_param_kind_t ::kind_pointer: {
2371+ ur_exp_kernel_arg_value_t Value = {};
2372+ // We need to de-rerence to get the actual USM allocation - that's the
2373+ // pointer UR is expecting.
2374+ Value.pointer = *static_cast <void *const *>(Arg.MPtr );
2375+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2376+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2377+ static_cast <uint32_t >(NextTrueIndex), sizeof (Arg.MPtr ),
2378+ Value});
2379+ break ;
2380+ }
2381+ case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2382+ assert (DeviceImageImpl != nullptr );
2383+ ur_mem_handle_t SpecConstsBuffer =
2384+ DeviceImageImpl->get_spec_const_buffer_ref ();
2385+ ur_exp_kernel_arg_value_t Value = {};
2386+ Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2387+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2388+ UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2389+ static_cast <uint32_t >(NextTrueIndex),
2390+ sizeof (SpecConstsBuffer), Value});
2391+ break ;
2392+ }
2393+ case kernel_param_kind_t ::kind_invalid:
2394+ throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2395+ " Invalid kernel param kind " +
2396+ codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2397+ break ;
2398+ }
2399+ }
2400+
2401+ // Added by Revert
2402+ static ur_result_t SetKernelParamsAndLaunch (
2403+ queue_impl &Queue, std::vector<ArgDesc> &Args,
2404+ device_image_impl *DeviceImageImpl, ur_kernel_handle_t Kernel,
2405+ NDRDescT &NDRDesc, std::vector<ur_event_handle_t > &RawEvents,
2406+ detail::event_impl *OutEventImpl, const KernelArgMask *EliminatedArgMask,
2407+ const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2408+ bool IsCooperative, bool KernelUsesClusterLaunch,
2409+ uint32_t WorkGroupMemorySize, const RTDeviceBinaryImage *BinImage,
2410+ KernelNameStrRefT KernelName,
2411+ void *KernelFuncPtr = nullptr, int KernelNumArgs = 0,
2412+ detail::kernel_param_desc_t (*KernelParamDescGetter)(int ) = nullptr,
2413+ bool KernelHasSpecialCaptures = true) {
2414+ adapter_impl &Adapter = Queue.getAdapter ();
2415+
2416+ if (SYCLConfig<SYCL_JIT_AMDGCN_PTX_KERNELS>::get ()) {
2417+ std::vector<unsigned char > Empty;
2418+ Kernel = Scheduler::getInstance ().completeSpecConstMaterialization (
2419+ Queue, BinImage, KernelName,
2420+ DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref () : Empty);
2421+ }
2422+
2423+ std::vector<ur_exp_kernel_arg_properties_t > UrArgs;
2424+ UrArgs.reserve (Args.size ());
2425+
2426+ if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2427+ auto setFunc = [&UrArgs,
2428+ KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
2429+ size_t NextTrueIndex) {
2430+ const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
2431+ switch (ParamDesc.kind ) {
2432+ case kernel_param_kind_t ::kind_std_layout: {
2433+ int Size = ParamDesc.info ;
2434+ ur_exp_kernel_arg_value_t Value = {};
2435+ Value.value = ArgPtr;
2436+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2437+ UR_EXP_KERNEL_ARG_TYPE_VALUE,
2438+ static_cast <uint32_t >(NextTrueIndex),
2439+ static_cast <size_t >(Size), Value});
2440+ break ;
2441+ }
2442+ case kernel_param_kind_t ::kind_pointer: {
2443+ ur_exp_kernel_arg_value_t Value = {};
2444+ Value.pointer = *static_cast <const void *const *>(ArgPtr);
2445+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2446+ UR_EXP_KERNEL_ARG_TYPE_POINTER,
2447+ static_cast <uint32_t >(NextTrueIndex),
2448+ sizeof (Value.pointer ), Value});
2449+ break ;
2450+ }
2451+ default :
2452+ throw std::runtime_error (" Direct kernel argument copy failed." );
2453+ }
2454+ };
2455+ applyFuncOnFilteredArgs (EliminatedArgMask, KernelNumArgs,
2456+ KernelParamDescGetter, setFunc);
2457+ } else {
2458+ auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc, &Queue,
2459+ &UrArgs](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2460+ GetUrArgsBasedOnType (DeviceImageImpl, getMemAllocationFunc,
2461+ Queue.getContextImpl (), Arg, NextTrueIndex, UrArgs);
2462+ };
2463+ applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
2464+ }
2465+
2466+ std::optional<int > ImplicitLocalArg =
2467+ ProgramManager::getInstance ().kernelImplicitLocalArgPos (
2468+ KernelName);
2469+ // Set the implicit local memory buffer to support
2470+ // get_work_group_scratch_memory. This is for backend not supporting
2471+ // CUDA-style local memory setting. Note that we may have -1 as a position,
2472+ // this indicates the buffer is actually unused and was elided.
2473+ if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2474+ UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2475+ nullptr ,
2476+ UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2477+ static_cast <uint32_t >(ImplicitLocalArg.value ()),
2478+ WorkGroupMemorySize,
2479+ {nullptr }});
2480+ }
2481+
2482+ adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
2483+
2484+ // Remember this information before the range dimensions are reversed
2485+ const bool HasLocalSize = (NDRDesc.LocalSize [0 ] != 0 );
2486+
2487+ ReverseRangeDimensionsForKernel (NDRDesc);
2488+
2489+ size_t RequiredWGSize[3 ] = {0 , 0 , 0 };
2490+ size_t *LocalSize = nullptr ;
2491+
2492+ if (HasLocalSize)
2493+ LocalSize = &NDRDesc.LocalSize [0 ];
2494+ else {
2495+ Adapter.call <UrApiKind::urKernelGetGroupInfo>(
2496+ Kernel, Queue.getDeviceImpl ().getHandleRef (),
2497+ UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE, sizeof (RequiredWGSize),
2498+ RequiredWGSize,
2499+ /* pPropSizeRet = */ nullptr );
2500+
2501+ const bool EnforcedLocalSize =
2502+ (RequiredWGSize[0 ] != 0 || RequiredWGSize[1 ] != 0 ||
2503+ RequiredWGSize[2 ] != 0 );
2504+ if (EnforcedLocalSize)
2505+ LocalSize = RequiredWGSize;
2506+ }
2507+ const bool HasOffset = NDRDesc.GlobalOffset [0 ] != 0 ||
2508+ NDRDesc.GlobalOffset [1 ] != 0 ||
2509+ NDRDesc.GlobalOffset [2 ] != 0 ;
2510+
2511+ std::vector<ur_kernel_launch_property_t > property_list;
2512+
2513+ if (KernelUsesClusterLaunch) {
2514+ ur_kernel_launch_property_value_t launch_property_value_cluster_range;
2515+ launch_property_value_cluster_range.clusterDim [0 ] =
2516+ NDRDesc.ClusterDimensions [0 ];
2517+ launch_property_value_cluster_range.clusterDim [1 ] =
2518+ NDRDesc.ClusterDimensions [1 ];
2519+ launch_property_value_cluster_range.clusterDim [2 ] =
2520+ NDRDesc.ClusterDimensions [2 ];
2521+
2522+ property_list.push_back ({UR_KERNEL_LAUNCH_PROPERTY_ID_CLUSTER_DIMENSION,
2523+ launch_property_value_cluster_range});
2524+ }
2525+ if (IsCooperative) {
2526+ ur_kernel_launch_property_value_t launch_property_value_cooperative;
2527+ launch_property_value_cooperative.cooperative = 1 ;
2528+ property_list.push_back ({UR_KERNEL_LAUNCH_PROPERTY_ID_COOPERATIVE,
2529+ launch_property_value_cooperative});
2530+ }
2531+ // If there is no implicit arg, let the driver handle it via a property
2532+ if (WorkGroupMemorySize && !ImplicitLocalArg.has_value ()) {
2533+ property_list.push_back ({UR_KERNEL_LAUNCH_PROPERTY_ID_WORK_GROUP_MEMORY,
2534+ {{WorkGroupMemorySize}}});
2535+ }
2536+ ur_event_handle_t UREvent = nullptr ;
2537+ ur_result_t Error =
2538+ Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2539+ Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2540+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2541+ &NDRDesc.GlobalSize [0 ], LocalSize, UrArgs.size (), UrArgs.data (),
2542+ property_list.size (),
2543+ property_list.empty () ? nullptr : property_list.data (),
2544+ RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2545+ OutEventImpl ? &UREvent : nullptr );
2546+ if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
2547+ OutEventImpl->setHandle (UREvent);
2548+ }
2549+
2550+ return Error;
2551+ }
2552+
23092553// Sets arguments for a given kernel and device based on the argument type.
2310- // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2311- // extension.
2554+ // This is a legacy path which the graphs extension still uses.
23122555static void SetArgBasedOnType (
23132556 adapter_impl &Adapter, ur_kernel_handle_t Kernel,
23142557 device_image_impl *DeviceImageImpl,
@@ -2389,6 +2632,8 @@ static void SetArgBasedOnType(
23892632 }
23902633}
23912634
2635+ // Before Revert
2636+ /*
23922637static ur_result_t SetKernelParamsAndLaunch(
23932638 queue_impl &Queue, std::vector<ArgDesc> &Args,
23942639 device_image_impl *DeviceImageImpl, ur_kernel_handle_t Kernel,
@@ -2471,7 +2716,7 @@ static ur_result_t SetKernelParamsAndLaunch(
24712716 Kernel, Queue.getDeviceImpl().getHandleRef(),
24722717 UR_KERNEL_GROUP_INFO_COMPILE_WORK_GROUP_SIZE, sizeof(RequiredWGSize),
24732718 RequiredWGSize,
2474- /* pPropSizeRet = */ nullptr );
2719+ nullptr); // pPropSizeRet == nullptr
24752720
24762721 const bool EnforcedLocalSize =
24772722 (RequiredWGSize[0] != 0 || RequiredWGSize[1] != 0 ||
@@ -2522,6 +2767,7 @@ static ur_result_t SetKernelParamsAndLaunch(
25222767
25232768 return Error;
25242769}
2770+ */
25252771
25262772static std::tuple<ur_kernel_handle_t , device_image_impl *,
25272773 const KernelArgMask *>
@@ -2757,7 +3003,7 @@ void enqueueImpKernel(
27573003 Queue, Args, DeviceImageImpl, Kernel, NDRDesc, EventsWaitList,
27583004 OutEventImpl, EliminatedArgMask, getMemAllocationFunc,
27593005 KernelIsCooperative, KernelUsesClusterLaunch, WorkGroupMemorySize,
2760- BinImage, KernelName, DeviceKernelInfo, KernelFuncPtr, KernelNumArgs,
3006+ BinImage, KernelName, KernelFuncPtr, KernelNumArgs,
27613007 KernelParamDescGetter, KernelHasSpecialCaptures);
27623008 }
27633009 if (UR_RESULT_SUCCESS != Error) {
0 commit comments