33#include < ATen/xpu/XPUContext.h>
44
55#include < comm/Runtime.h>
6- #include < iostream>
76
87namespace xpu {
98namespace sycl {
109
1110template <class KernelClass >
1211static int64_t syclMaxWorkGroupSize (
13- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue()) {
14- auto q = c10::xpu::getCurrentXPUStream (dev_id).queue ();
15- auto ctx = q.get_context ();
16- auto dev = q.get_device ();
12+ at::DeviceIndex dev_id = at::xpu::current_device()) {
13+ auto & ctx = c10::xpu::get_device_context ();
14+ auto & dev = c10::xpu::get_raw_device (dev_id);
1715
1816 auto kid = ::sycl::get_kernel_id<KernelClass>();
1917 // The kernel won't be built for devices except for the first device.
@@ -30,73 +28,69 @@ static int64_t syclMaxWorkGroupSize(
3028
3129template <class KernelClass >
3230static int64_t syclMaxWorkGroupSize (
33- KernelClass /* kfn*/ ,
34- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
31+ const KernelClass& /* kfn*/ ,
32+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
3533 return syclMaxWorkGroupSize<KernelClass>(dev_id);
3634}
3735
3836static inline int64_t syclDeviceMaxWorkGroupSize (
39- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
37+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
4038 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
4139 return dev_prop->max_work_group_size ;
4240}
4341
4442static inline int64_t syclMaxSubGroupSize (
45- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
43+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
4644 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
47- auto subgroup_sizes = dev_prop->sub_group_sizes ;
48- uint64_t max_val = 0 ;
49- for (auto i : subgroup_sizes) {
50- if (i > max_val)
51- max_val = i;
52- }
53- return max_val;
45+ const auto & subgroup_sizes = dev_prop->sub_group_sizes ;
46+ TORCH_CHECK (
47+ !subgroup_sizes.empty (),
48+ " The device subgroup sizes is empty, please check the device status." );
49+ return *std::max_element (subgroup_sizes.begin (), subgroup_sizes.end ());
5450}
5551
5652static inline int64_t syclMinSubGroupSize (
57- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
53+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
5854 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
59- auto subgroup_sizes = dev_prop->sub_group_sizes ;
60- uint64_t min_val = dev_prop->max_work_group_size ;
61- for (auto i : subgroup_sizes) {
62- if (i < min_val)
63- min_val = i;
64- }
65- return min_val;
55+ const auto & subgroup_sizes = dev_prop->sub_group_sizes ;
56+ TORCH_CHECK (
57+ !subgroup_sizes.empty (),
58+ " The device subgroup sizes is empty, please check the device status." );
59+ return *std::min_element (subgroup_sizes.begin (), subgroup_sizes.end ());
6660}
6761
6862static inline int64_t syclMaxComputeUnitSize (
69- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
63+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
7064 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
7165 return dev_prop->max_compute_units ;
7266}
7367
7468static inline int64_t syclGpuEuCount (
75- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
69+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
7670 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
7771 return dev_prop->gpu_eu_count ;
7872}
7973
8074static inline int64_t syclGpuEuSimdWidth (
81- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
75+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
8276 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
8377 return dev_prop->gpu_eu_simd_width ;
8478}
8579
8680static inline int64_t syclGpuHWThreadsPerEU (
87- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
81+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
8882 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
8983 return dev_prop->gpu_hw_threads_per_eu ;
9084}
9185
9286static inline int64_t syclGpuEUCountPerSubslice (
93- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
87+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
9488 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
9589 return dev_prop->gpu_eu_count_per_subslice ;
9690}
9791
9892static inline int64_t syclMaxWorkItemsPerTile (
99- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
93+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
10094 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
10195 int64_t eu_cnt = dev_prop->gpu_eu_count ;
10296 int64_t simd_width = syclMaxSubGroupSize (dev_id);
@@ -105,110 +99,92 @@ static inline int64_t syclMaxWorkItemsPerTile(
10599}
106100
107101static inline int64_t syclMaxWorkItemsPerSubSlice (
108- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
102+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
109103 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
110104 int64_t simd_width = syclMaxSubGroupSize (dev_id);
111105 int64_t eu_count = dev_prop->gpu_eu_count_per_subslice ;
112106 return simd_width * eu_count;
113107}
114108
115109static inline int64_t syclMaxWorkItemsPerEU (
116- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
110+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
117111 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
118112 int64_t simd_width = syclMaxSubGroupSize (dev_id);
119113 int64_t hw_threads = dev_prop->gpu_hw_threads_per_eu ;
120114 return simd_width * hw_threads;
121115}
122116
123117static inline int64_t syclMaxNumSubGroups (
124- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
118+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
125119 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
126120 return dev_prop->max_num_sub_groups ;
127121}
128122
129123static inline int64_t syclMaxDSSNum (
130- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
124+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
131125 int64_t dss_num =
132126 syclMaxComputeUnitSize (dev_id) / syclGpuEUCountPerSubslice (dev_id);
133127 return dss_num;
134128}
135129
136130static inline size_t syclGlobalMemSize (
137- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
131+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
138132 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
139133 return dev_prop->global_mem_size ;
140134}
141135
142136static inline int64_t syclLocalMemSize (
143- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
137+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
144138 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
145139 return dev_prop->local_mem_size ;
146140}
147141
148142template <typename T>
149143uint32_t syclPrefVectorWidth (
150- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
144+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
151145 (void )dev_id; // Suppress unused variable warning
152146
153147 // Hot fix. This is the preferred vector width for GPUs up to LNL/BMG.
154- uint32_t vec_width = 16 ;
148+ constexpr uint32_t vec_width = 16 ;
155149
156- if (std::is_same<T, char >::value) {
157- return vec_width / sizeof (char );
158- }
159- if (std::is_same<T, short >::value) {
160- return vec_width / sizeof (short );
161- }
162- if (std::is_same<T, int >::value) {
163- return vec_width / sizeof (int );
164- }
165- if (std::is_same<T, int64_t >::value) {
166- return vec_width / sizeof (int64_t );
167- }
168- if (std::is_same<T, float >::value) {
169- return vec_width / sizeof (float );
170- }
171- if (std::is_same<T, double >::value) {
172- return vec_width / sizeof (double );
150+ if constexpr (
151+ std::is_same_v<T, char > || std::is_same_v<T, short > ||
152+ std::is_same_v<T, int > || std::is_same_v<T, int64_t > ||
153+ std::is_same_v<T, float > || std::is_same_v<T, double > ||
154+ std::is_same_v<T, ::sycl::half>) {
155+ return vec_width / sizeof (T);
156+ } else {
157+ throw std::invalid_argument (
158+ " Invalid data type to fetch preferred vector width!" );
173159 }
174- if (std::is_same<T, ::sycl::half>::value) {
175- return vec_width / sizeof (::sycl::half);
176- }
177- throw std::invalid_argument (
178- " Invalid data type to fetch preferred vector width!" );
179160}
180161
181162template <typename T>
182163uint32_t syclNativeVectorWidth (
183- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
164+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
184165 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
185- if (std::is_same <T, char >::value ) {
166+ if constexpr (std::is_same_v <T, char >) {
186167 return dev_prop->native_vector_width_char ;
187- }
188- if (std::is_same<T, short >::value) {
168+ } else if constexpr (std::is_same_v<T, short >) {
189169 return dev_prop->native_vector_width_short ;
190- }
191- if (std::is_same<T, int >::value) {
170+ } else if constexpr (std::is_same_v<T, int >) {
192171 return dev_prop->native_vector_width_int ;
193- }
194- if (std::is_same<T, int64_t >::value) {
172+ } else if constexpr (std::is_same_v<T, int64_t >) {
195173 return dev_prop->native_vector_width_long ;
196- }
197- if (std::is_same<T, float >::value) {
174+ } else if constexpr (std::is_same_v<T, float >) {
198175 return dev_prop->native_vector_width_float ;
199- }
200- if (std::is_same<T, double >::value) {
176+ } else if constexpr (std::is_same_v<T, double >) {
201177 return dev_prop->native_vector_width_double ;
202- }
203- if (std::is_same<T, ::sycl::half>::value) {
178+ } else if constexpr (std::is_same_v<T, ::sycl::half>) {
204179 return dev_prop->native_vector_width_half ;
180+ } else {
181+ throw std::invalid_argument (
182+ " Invalid data type to fetch native vector width!" );
205183 }
206- throw std::invalid_argument (
207- " Invalid data type to fetch native vector width!" );
208184}
209185
210186static inline bool syclHasFloat64 (
211- at::DeviceIndex dev_id = at::xpu::getDeviceIndexOfCurrentQueue ()) {
187+ at::DeviceIndex dev_id = at::xpu::current_device ()) {
212188 auto * dev_prop = at::xpu::getDeviceProperties (dev_id);
213189 return dev_prop->has_fp64 ;
214190}
0 commit comments