@@ -2310,14 +2310,14 @@ ur_mem_flags_t AccessModeToUr(access::mode AccessorMode) {
23102310 }
23112311}
23122312
2313- // Gets UR argument struct for a given kernel and device based on the argument
2314- // type. Refactored from SetKernelParamsAndLaunch to allow it to be used in
2315- // the graphs extension (LaunchWithArgs for graphs is planned future work).
2316- static void GetUrArgsBasedOnType (
2313+ // Sets arguments for a given kernel and device based on the argument type.
2314+ // Refactored from SetKernelParamsAndLaunch to allow it to be used in the graphs
2315+ // extension.
2316+ static void SetArgBasedOnType (
2317+ adapter_impl &Adapter, ur_kernel_handle_t Kernel,
23172318 device_image_impl *DeviceImageImpl,
23182319 const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2319- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex,
2320- std::vector<ur_exp_kernel_arg_properties_t> &UrArgs) {
2320+ context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
23212321 switch (Arg.MType ) {
23222322 case kernel_param_kind_t ::kind_dynamic_work_group_memory:
23232323 break ;
@@ -2337,61 +2337,52 @@ static void GetUrArgsBasedOnType(
23372337 getMemAllocationFunc
23382338 ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
23392339 : nullptr ;
2340- ur_exp_kernel_arg_value_t Value = {};
2341- Value.memObjTuple = {MemArg, AccessModeToUr (Req->MAccessMode )};
2342- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2343- UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2344- static_cast <uint32_t >(NextTrueIndex), sizeof (MemArg),
2345- Value});
2340+ ur_kernel_arg_mem_obj_properties_t MemObjData{};
2341+ MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2342+ MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2343+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2344+ &MemObjData, MemArg);
23462345 break ;
23472346 }
23482347 case kernel_param_kind_t ::kind_std_layout: {
2349- ur_exp_kernel_arg_type_t Type;
23502348 if (Arg.MPtr ) {
2351- Type = UR_EXP_KERNEL_ARG_TYPE_VALUE;
2349+ Adapter.call <UrApiKind::urKernelSetArgValue>(
2350+ Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
23522351 } else {
2353- Type = UR_EXP_KERNEL_ARG_TYPE_LOCAL;
2352+ Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2353+ Arg.MSize , nullptr );
23542354 }
2355- ur_exp_kernel_arg_value_t Value = {};
2356- Value.value = {Arg.MPtr };
2357- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2358- Type, static_cast <uint32_t >(NextTrueIndex),
2359- static_cast <size_t >(Arg.MSize ), Value});
23602355
23612356 break ;
23622357 }
23632358 case kernel_param_kind_t ::kind_sampler: {
23642359 sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2365- ur_exp_kernel_arg_value_t Value = {};
2366- Value.sampler = (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2367- ->getOrCreateSampler (ContextImpl);
2368- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2369- UR_EXP_KERNEL_ARG_TYPE_SAMPLER,
2370- static_cast <uint32_t >(NextTrueIndex),
2371- sizeof (ur_sampler_handle_t ), Value});
2360+ ur_sampler_handle_t Sampler =
2361+ (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2362+ ->getOrCreateSampler (ContextImpl);
2363+ Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2364+ nullptr , Sampler);
23722365 break ;
23732366 }
23742367 case kernel_param_kind_t ::kind_pointer: {
2375- ur_exp_kernel_arg_value_t Value = {};
2376- // We need to de-rerence to get the actual USM allocation - that's the
2368+ // We need to de-rerence this to get the actual USM allocation - that's the
23772369 // pointer UR is expecting.
2378- Value.pointer = *static_cast <void *const *>(Arg.MPtr );
2379- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2380- UR_EXP_KERNEL_ARG_TYPE_POINTER,
2381- static_cast <uint32_t >(NextTrueIndex), sizeof (Arg.MPtr ),
2382- Value});
2370+ const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2371+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2372+ nullptr , Ptr);
23832373 break ;
23842374 }
23852375 case kernel_param_kind_t ::kind_specialization_constants_buffer: {
23862376 assert (DeviceImageImpl != nullptr );
23872377 ur_mem_handle_t SpecConstsBuffer =
23882378 DeviceImageImpl->get_spec_const_buffer_ref ();
2389- ur_exp_kernel_arg_value_t Value = {};
2390- Value.memObjTuple = {SpecConstsBuffer, UR_MEM_FLAG_READ_ONLY};
2391- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2392- UR_EXP_KERNEL_ARG_TYPE_MEM_OBJ,
2393- static_cast <uint32_t >(NextTrueIndex),
2394- sizeof (SpecConstsBuffer), Value});
2379+
2380+ ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2381+ MemObjProps.pNext = nullptr ;
2382+ MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2383+ MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2384+ Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2385+ Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
23952386 break ;
23962387 }
23972388 case kernel_param_kind_t ::kind_invalid:
@@ -2424,32 +2415,22 @@ static ur_result_t SetKernelParamsAndLaunch(
24242415 DeviceImageImpl ? DeviceImageImpl->get_spec_const_blob_ref () : Empty);
24252416 }
24262417
2427- std::vector<ur_exp_kernel_arg_properties_t > UrArgs;
2428- UrArgs.reserve (Args.size ());
2429-
24302418 if (KernelFuncPtr && !KernelHasSpecialCaptures) {
2431- auto setFunc = [&UrArgs ,
2419+ auto setFunc = [&Adapter, Kernel ,
24322420 KernelFuncPtr](const detail::kernel_param_desc_t &ParamDesc,
24332421 size_t NextTrueIndex) {
24342422 const void *ArgPtr = (const char *)KernelFuncPtr + ParamDesc.offset ;
24352423 switch (ParamDesc.kind ) {
24362424 case kernel_param_kind_t ::kind_std_layout: {
24372425 int Size = ParamDesc.info ;
2438- ur_exp_kernel_arg_value_t Value = {};
2439- Value.value = ArgPtr;
2440- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2441- UR_EXP_KERNEL_ARG_TYPE_VALUE,
2442- static_cast <uint32_t >(NextTrueIndex),
2443- static_cast <size_t >(Size), Value});
2426+ Adapter.call <UrApiKind::urKernelSetArgValue>(Kernel, NextTrueIndex,
2427+ Size, nullptr , ArgPtr);
24442428 break ;
24452429 }
24462430 case kernel_param_kind_t ::kind_pointer: {
2447- ur_exp_kernel_arg_value_t Value = {};
2448- Value.pointer = *static_cast <const void *const *>(ArgPtr);
2449- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES, nullptr ,
2450- UR_EXP_KERNEL_ARG_TYPE_POINTER,
2451- static_cast <uint32_t >(NextTrueIndex),
2452- sizeof (Value.pointer ), Value});
2431+ const void *Ptr = *static_cast <const void *const *>(ArgPtr);
2432+ Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2433+ nullptr , Ptr);
24532434 break ;
24542435 }
24552436 default :
@@ -2459,10 +2440,10 @@ static ur_result_t SetKernelParamsAndLaunch(
24592440 applyFuncOnFilteredArgs (EliminatedArgMask, KernelNumArgs,
24602441 KernelParamDescGetter, setFunc);
24612442 } else {
2462- auto setFunc = [&DeviceImageImpl, &getMemAllocationFunc , &Queue ,
2463- &UrArgs ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2464- GetUrArgsBasedOnType ( DeviceImageImpl, getMemAllocationFunc,
2465- Queue.getContextImpl (), Arg, NextTrueIndex, UrArgs );
2443+ auto setFunc = [&Adapter, Kernel, &DeviceImageImpl , &getMemAllocationFunc ,
2444+ &Queue ](detail::ArgDesc &Arg, size_t NextTrueIndex) {
2445+ SetArgBasedOnType (Adapter, Kernel, DeviceImageImpl, getMemAllocationFunc,
2446+ Queue.getContextImpl (), Arg, NextTrueIndex);
24662447 };
24672448 applyFuncOnFilteredArgs (EliminatedArgMask, Args, setFunc);
24682449 }
@@ -2475,12 +2456,8 @@ static ur_result_t SetKernelParamsAndLaunch(
24752456 // CUDA-style local memory setting. Note that we may have -1 as a position,
24762457 // this indicates the buffer is actually unused and was elided.
24772458 if (ImplicitLocalArg.has_value () && ImplicitLocalArg.value () != -1 ) {
2478- UrArgs.push_back ({UR_STRUCTURE_TYPE_EXP_KERNEL_ARG_PROPERTIES,
2479- nullptr ,
2480- UR_EXP_KERNEL_ARG_TYPE_LOCAL,
2481- static_cast <uint32_t >(ImplicitLocalArg.value ()),
2482- WorkGroupMemorySize,
2483- {nullptr }});
2459+ Adapter.call <UrApiKind::urKernelSetArgLocal>(
2460+ Kernel, ImplicitLocalArg.value (), WorkGroupMemorySize, nullptr );
24842461 }
24852462
24862463 adjustNDRangePerKernel (NDRDesc, Kernel, Queue.getDeviceImpl ());
@@ -2538,104 +2515,20 @@ static ur_result_t SetKernelParamsAndLaunch(
25382515 {{WorkGroupMemorySize}}});
25392516 }
25402517 ur_event_handle_t UREvent = nullptr ;
2541- ur_result_t Error =
2542- Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunchWithArgsExp>(
2543- Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2544- HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr ,
2545- &NDRDesc.GlobalSize [0 ], LocalSize, UrArgs.size (), UrArgs.data (),
2546- property_list.size (),
2547- property_list.empty () ? nullptr : property_list.data (),
2548- RawEvents.size (), RawEvents.empty () ? nullptr : &RawEvents[0 ],
2549- OutEventImpl ? &UREvent : nullptr );
2518+ ur_result_t Error = Adapter.call_nocheck <UrApiKind::urEnqueueKernelLaunch>(
2519+ Queue.getHandleRef (), Kernel, NDRDesc.Dims ,
2520+ HasOffset ? &NDRDesc.GlobalOffset [0 ] : nullptr , &NDRDesc.GlobalSize [0 ],
2521+ LocalSize, property_list.size (),
2522+ property_list.empty () ? nullptr : property_list.data (), RawEvents.size (),
2523+ RawEvents.empty () ? nullptr : &RawEvents[0 ],
2524+ OutEventImpl ? &UREvent : nullptr );
25502525 if (Error == UR_RESULT_SUCCESS && OutEventImpl) {
25512526 OutEventImpl->setHandle (UREvent);
25522527 }
25532528
25542529 return Error;
25552530}
25562531
2557- // Sets arguments for a given kernel and device based on the argument type.
2558- // This is a legacy path which the graphs extension still uses.
2559- static void SetArgBasedOnType (
2560- adapter_impl &Adapter, ur_kernel_handle_t Kernel,
2561- device_image_impl *DeviceImageImpl,
2562- const std::function<void *(Requirement *Req)> &getMemAllocationFunc,
2563- context_impl &ContextImpl, detail::ArgDesc &Arg, size_t NextTrueIndex) {
2564- switch (Arg.MType ) {
2565- case kernel_param_kind_t ::kind_dynamic_work_group_memory:
2566- break ;
2567- case kernel_param_kind_t ::kind_work_group_memory:
2568- break ;
2569- case kernel_param_kind_t ::kind_stream:
2570- break ;
2571- case kernel_param_kind_t ::kind_dynamic_accessor:
2572- case kernel_param_kind_t ::kind_accessor: {
2573- Requirement *Req = (Requirement *)(Arg.MPtr );
2574-
2575- // getMemAllocationFunc is nullptr when there are no requirements. However,
2576- // we may pass default constructed accessors to a command, which don't add
2577- // requirements. In such case, getMemAllocationFunc is nullptr, but it's a
2578- // valid case, so we need to properly handle it.
2579- ur_mem_handle_t MemArg =
2580- getMemAllocationFunc
2581- ? reinterpret_cast <ur_mem_handle_t >(getMemAllocationFunc (Req))
2582- : nullptr ;
2583- ur_kernel_arg_mem_obj_properties_t MemObjData{};
2584- MemObjData.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2585- MemObjData.memoryAccess = AccessModeToUr (Req->MAccessMode );
2586- Adapter.call <UrApiKind::urKernelSetArgMemObj>(Kernel, NextTrueIndex,
2587- &MemObjData, MemArg);
2588- break ;
2589- }
2590- case kernel_param_kind_t ::kind_std_layout: {
2591- if (Arg.MPtr ) {
2592- Adapter.call <UrApiKind::urKernelSetArgValue>(
2593- Kernel, NextTrueIndex, Arg.MSize , nullptr , Arg.MPtr );
2594- } else {
2595- Adapter.call <UrApiKind::urKernelSetArgLocal>(Kernel, NextTrueIndex,
2596- Arg.MSize , nullptr );
2597- }
2598-
2599- break ;
2600- }
2601- case kernel_param_kind_t ::kind_sampler: {
2602- sampler *SamplerPtr = (sampler *)Arg.MPtr ;
2603- ur_sampler_handle_t Sampler =
2604- (ur_sampler_handle_t )detail::getSyclObjImpl (*SamplerPtr)
2605- ->getOrCreateSampler (ContextImpl);
2606- Adapter.call <UrApiKind::urKernelSetArgSampler>(Kernel, NextTrueIndex,
2607- nullptr , Sampler);
2608- break ;
2609- }
2610- case kernel_param_kind_t ::kind_pointer: {
2611- // We need to de-rerence this to get the actual USM allocation - that's the
2612- // pointer UR is expecting.
2613- const void *Ptr = *static_cast <const void *const *>(Arg.MPtr );
2614- Adapter.call <UrApiKind::urKernelSetArgPointer>(Kernel, NextTrueIndex,
2615- nullptr , Ptr);
2616- break ;
2617- }
2618- case kernel_param_kind_t ::kind_specialization_constants_buffer: {
2619- assert (DeviceImageImpl != nullptr );
2620- ur_mem_handle_t SpecConstsBuffer =
2621- DeviceImageImpl->get_spec_const_buffer_ref ();
2622-
2623- ur_kernel_arg_mem_obj_properties_t MemObjProps{};
2624- MemObjProps.pNext = nullptr ;
2625- MemObjProps.stype = UR_STRUCTURE_TYPE_KERNEL_ARG_MEM_OBJ_PROPERTIES;
2626- MemObjProps.memoryAccess = UR_MEM_FLAG_READ_ONLY;
2627- Adapter.call <UrApiKind::urKernelSetArgMemObj>(
2628- Kernel, NextTrueIndex, &MemObjProps, SpecConstsBuffer);
2629- break ;
2630- }
2631- case kernel_param_kind_t ::kind_invalid:
2632- throw sycl::exception (sycl::make_error_code (sycl::errc::runtime),
2633- " Invalid kernel param kind " +
2634- codeToString (UR_RESULT_ERROR_INVALID_VALUE));
2635- break ;
2636- }
2637- }
2638-
26392532static std::tuple<ur_kernel_handle_t , device_image_impl *,
26402533 const KernelArgMask *>
26412534getCGKernelInfo (const CGExecKernel &CommandGroup, context_impl &ContextImpl,
0 commit comments