Skip to content

Commit 2a992c5

Browse files
toyxufengyuan14
andauthored
Add tensor compare ops (#80)
e.g. where, clamp, clamp_min, clamp_max Co-authored-by: Feng Yuan <[email protected]>
1 parent 40cff1f commit 2a992c5

File tree

5 files changed

+288
-0
lines changed

5 files changed

+288
-0
lines changed

src/aten/TensorCompare.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
#include <ATen/ScalarOps.h>
2+
#include <ATen/TensorIndexing.h>
3+
#include <ATen/XPUNativeFunctions.h>
4+
#include <ATen/core/Tensor.h>
5+
#include <ATen/native/TensorCompare.h>
6+
#include <ATen/native/TensorIterator.h>
7+
#include <ATen/native/TypeProperties.h>
8+
#include <aten/sycl/TensorCompare.h>
9+
10+
namespace at {
11+
12+
template <typename... Args>
13+
Device out_device(Args&... inps) {
14+
for (const auto& i : {inps...}) {
15+
if (i.device() != at::kCPU) {
16+
return i.device();
17+
}
18+
}
19+
return at::kCPU;
20+
}
21+
22+
Tensor& where_self_out(
23+
const Tensor& condition,
24+
const Tensor& self,
25+
const Tensor& other,
26+
Tensor& out) {
27+
const auto result_type = at::native::result_type(self, other);
28+
TORCH_CHECK(
29+
out.scalar_type() == result_type,
30+
"Expected out type to be ",
31+
result_type,
32+
" but got ",
33+
out.scalar_type());
34+
35+
auto self_ = self.scalar_type() != result_type ? self.to(result_type) : self;
36+
auto other_ =
37+
other.scalar_type() != result_type ? other.to(result_type) : other;
38+
auto condition_ = condition;
39+
auto device = out_device(condition, self_, other_);
40+
if (device != at::kCPU) { // allow CPU scalars on non-cpu device
41+
if (condition.device() != device && condition.ndimension() == 0) {
42+
condition_ = condition.to(device);
43+
}
44+
if (self_.device() != device && self_.ndimension() == 0) {
45+
self_ = self_.to(device);
46+
}
47+
if (other_.device() != device && other_.ndimension() == 0) {
48+
other_ = other_.to(device);
49+
}
50+
}
51+
if (condition_.scalar_type() == ScalarType::Byte) {
52+
TORCH_WARN_ONCE(
53+
"where received a uint8 condition tensor. This behavior is deprecated and will be removed in a future version of PyTorch. Use a boolean condition instead.");
54+
condition_ = condition_.to(kBool);
55+
}
56+
TORCH_CHECK(
57+
condition_.scalar_type() == kBool,
58+
"where expected condition to be a boolean tensor, but got a tensor with dtype ",
59+
condition_.scalar_type());
60+
// if there's still a device mismatch, let tensoriterator error out with it
61+
auto iter = at::TensorIteratorConfig()
62+
.check_all_same_dtype(false)
63+
.add_output(out)
64+
.add_const_input(condition_)
65+
.add_const_input(self_)
66+
.add_const_input(other_)
67+
.build();
68+
native::xpu::where_kernel(iter);
69+
return out;
70+
}
71+
72+
Tensor& XPUNativeFunctions::where_out(
73+
const Tensor& condition,
74+
const Tensor& self,
75+
const Tensor& other,
76+
Tensor& out) {
77+
return where_self_out(condition, self, other, out);
78+
}
79+
80+
Tensor XPUNativeFunctions::where(
81+
const Tensor& condition,
82+
const Tensor& self,
83+
const Tensor& other) {
84+
auto device = out_device(condition, self, other);
85+
auto result_type = at::native::result_type(self, other);
86+
Tensor ret = at::empty({0}, self.options().dtype(result_type).device(device));
87+
where_self_out(condition, self, other, ret);
88+
return ret;
89+
}
90+
91+
Tensor& XPUNativeFunctions::clamp_out(
92+
const Tensor& self,
93+
const c10::optional<Scalar>& min,
94+
const c10::optional<Scalar>& max,
95+
Tensor& result) {
96+
using at::native::detail::ClampLimits;
97+
if (min && max) {
98+
if ((*min).toDouble() != (*min).toDouble() ||
99+
(*max).toDouble() != (*max).toDouble()) {
100+
at::fill_(
101+
const_cast<Tensor&>(result),
102+
std::numeric_limits<double>::quiet_NaN());
103+
} else {
104+
auto iter = TensorIterator::unary_op(result, self);
105+
native::xpu::clamp_scalar_kernel(iter, *min, *max);
106+
}
107+
} else if (max) {
108+
auto iter = TensorIterator::unary_op(result, self);
109+
native::xpu::clamp_max_scalar_kernel(iter, *max);
110+
} else if (min) {
111+
auto iter = TensorIterator::unary_op(result, self);
112+
native::xpu::clamp_min_scalar_kernel(iter, *min);
113+
}
114+
return result;
115+
}
116+
117+
Tensor& XPUNativeFunctions::clamp_min_out(
118+
const Tensor& self,
119+
const Scalar& min,
120+
Tensor& result) {
121+
if (min.toDouble() != min.toDouble()) {
122+
at::fill_(const_cast<Tensor&>(result), min);
123+
} else {
124+
auto iter = TensorIterator::unary_op(result, self);
125+
native::xpu::clamp_min_scalar_kernel(iter, min);
126+
}
127+
return result;
128+
}
129+
130+
Tensor& XPUNativeFunctions::clamp_max_out(
131+
const Tensor& self,
132+
const Scalar& max,
133+
Tensor& result) {
134+
if (max.toDouble() != max.toDouble()) {
135+
// TODO this is not great, building TI again is expensive, but I can't use
136+
// fill_stub because fill is not structured
137+
// this is a corner case anyway
138+
at::fill_(const_cast<Tensor&>(result), native::wrapped_scalar_tensor(max));
139+
} else {
140+
auto iter = TensorIterator::unary_op(result, self);
141+
native::xpu::clamp_max_scalar_kernel(iter, max);
142+
}
143+
return result;
144+
}
145+
146+
} // namespace at

src/aten/sycl/TensorCompare.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Dispatch.h>
3+
#include <ATen/NumericUtils.h>
4+
#include <ATen/native/TensorCompare.h>
5+
#include <ATen/native/TensorIterator.h>
6+
7+
#include <aten/sycl/Loops.h>
8+
9+
namespace at {
10+
namespace native {
11+
namespace xpu {
12+
13+
template <typename scalar_t>
14+
struct WhereFunctor {
15+
scalar_t operator()(bool cond_val, scalar_t self_val, scalar_t other_val)
16+
const {
17+
return cond_val ? self_val : other_val;
18+
}
19+
};
20+
21+
template <typename scalar_t>
22+
struct ClampFunctor {
23+
scalar_t operator()(scalar_t v, scalar_t lower, scalar_t upper) const {
24+
if (at::_isnan(v)) {
25+
return v;
26+
}
27+
if (at::_isnan(lower)) {
28+
return lower;
29+
}
30+
if (at::_isnan(upper)) {
31+
return upper;
32+
} else {
33+
return std::min(std::max(v, lower), upper);
34+
}
35+
}
36+
};
37+
38+
template <typename scalar_t>
39+
struct ClampScalarFunctor {
40+
using opmath_t = at::opmath_type<scalar_t>;
41+
scalar_t operator()(scalar_t v) const {
42+
if (_isnan(static_cast<opmath_t>(v))) {
43+
return v;
44+
} else if (minmax_ == at::native::detail::ClampLimits::Min) {
45+
return std::max(static_cast<opmath_t>(v), lim0_val_);
46+
} else if (minmax_ == at::native::detail::ClampLimits::Max) {
47+
return std::min(static_cast<opmath_t>(v), lim0_val_);
48+
} else {
49+
return std::min(std::max(static_cast<opmath_t>(v), lim0_val_), lim1_val_);
50+
}
51+
}
52+
ClampScalarFunctor(
53+
opmath_t lim0_val,
54+
opmath_t lim1_val,
55+
at::native::detail::ClampLimits minmax)
56+
: lim0_val_(lim0_val), lim1_val_(lim1_val), minmax_(minmax) {}
57+
58+
private:
59+
opmath_t lim0_val_;
60+
opmath_t lim1_val_;
61+
at::native::detail::ClampLimits minmax_;
62+
};
63+
64+
void where_kernel(TensorIterator& iter) {
65+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
66+
kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_xpu", [&] {
67+
gpu_kernel(iter, WhereFunctor<scalar_t>());
68+
});
69+
}
70+
71+
void clamp_kernel(TensorIteratorBase& iter) {
72+
AT_DISPATCH_ALL_TYPES_AND2(
73+
kHalf, kBFloat16, iter.common_dtype(), "clamp_xpu", [&] {
74+
gpu_kernel(iter, ClampFunctor<scalar_t>());
75+
});
76+
}
77+
78+
void inline launch_clamp_scalar(
79+
TensorIteratorBase& iter,
80+
Scalar lim0,
81+
Scalar lim1,
82+
at::native::detail::ClampLimits minmax) {
83+
AT_DISPATCH_ALL_TYPES_AND2(
84+
kHalf, kBFloat16, iter.common_dtype(), "clamp_scalar_xpu", [&] {
85+
using opmath_t = at::opmath_type<scalar_t>;
86+
auto lim0_val = lim0.to<opmath_t>();
87+
auto lim1_val = lim1.to<opmath_t>();
88+
gpu_kernel(
89+
iter, ClampScalarFunctor<scalar_t>(lim0_val, lim1_val, minmax));
90+
});
91+
}
92+
93+
void clamp_scalar_kernel(
94+
TensorIteratorBase& iter,
95+
const Scalar& min,
96+
const Scalar& max) {
97+
launch_clamp_scalar(iter, min, max, at::native::detail::ClampLimits::MinMax);
98+
}
99+
100+
void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min) {
101+
launch_clamp_scalar(iter, min, min, at::native::detail::ClampLimits::Min);
102+
}
103+
104+
void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max) {
105+
launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max);
106+
}
107+
108+
} // namespace xpu
109+
} // namespace native
110+
} // namespace at

src/aten/sycl/TensorCompare.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
#include <ATen/native/TensorIterator.h>
4+
5+
namespace at {
6+
namespace native {
7+
namespace xpu {
8+
9+
void where_kernel(TensorIterator& iter);
10+
11+
void clamp_kernel(TensorIteratorBase& iter);
12+
13+
void clamp_scalar_kernel(
14+
TensorIteratorBase& iter,
15+
const Scalar& min,
16+
const Scalar& max);
17+
18+
void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min);
19+
20+
void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max);
21+
22+
} // namespace xpu
23+
} // namespace native
24+
} // namespace at

test/xpu/test_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,9 @@
9797
"bitwise_or",
9898
"bitwise_xor",
9999
"bitwise_not",
100+
"where",
101+
"clamp_min",
102+
"clamp_max",
100103
"clamp",
101104
]
102105
_xpu_tensor_factory_op_list = [

yaml/xpu_functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,8 @@ supported:
156156
- bitwise_or.Tensor_out
157157
- bitwise_xor.Tensor_out
158158
- bitwise_not.out
159+
- where.self_out
160+
- where.self
161+
- clamp.out
162+
- clamp_min.out
163+
- clamp_max.out

0 commit comments

Comments
 (0)