Skip to content

Commit 4f52f85

Browse files
committed
Improve error handling for clamp.
1 parent c77852e commit 4f52f85

File tree

5 files changed

+260
-39
lines changed

5 files changed

+260
-39
lines changed

test/test_ops_error_message.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,19 @@ def test():
179179
callable=test,
180180
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
181181
)
182+
183+
def test_clamp_scalar_raises_error_on_no_min_and_max(self):
184+
device = torch_xla.device()
185+
a = torch.rand(2, 5, device=device)
186+
187+
def test():
188+
# Dispatch to `clamp()` overload explicitly.
189+
# Otherwise, it's dispatched to `clamp.Tensor()`, which doesn't have
190+
# this check.
191+
return torch.ops.aten.clamp.default(a)
192+
193+
self.assertExpectedRaisesInline(
194+
exc_type=RuntimeError,
195+
callable=test,
196+
expect="""clamp(): expected at least one of `min` or `max` arguments to be specified."""
197+
)

test/test_ops_error_message.py.bak

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
import expecttest
2+
import os
3+
import torch
4+
import torch_xla
5+
import unittest
6+
7+
8+
def onlyOnCPU(fn):
9+
accelerator = os.environ.get("PJRT_DEVICE").lower()
10+
return unittest.skipIf(accelerator != "cpu", "PJRT_DEVICE=CPU required")(fn)
11+
12+
13+
class TestOpsErrorMessage(expecttest.TestCase):
14+
15+
def test_add_broadcast_error(self):
16+
a = torch.rand(2, 2, 4, 4, device="xla")
17+
b = torch.rand(2, 2, device="xla")
18+
19+
def test():
20+
return torch.add(a, b)
21+
22+
self.assertExpectedRaisesInline(
23+
exc_type=RuntimeError,
24+
callable=test,
25+
expect="""Shapes are not compatible for broadcasting: f32[2,2,4,4] vs. f32[2,2]. Expected dimension 2 of shape f32[2,2,4,4] (4) to match dimension 0 of shape f32[2,2] (2). Either that or that any of them is either 1 or unbounded. Try reshaping one of the tensors to match the other."""
26+
)
27+
28+
@onlyOnCPU
29+
def test_construct_large_tensor_raises_error(self):
30+
31+
def test():
32+
# When eager-mode is enabled, OOM is triggered here.
33+
a = torch.rand(1024, 1024, 1024, 1024, 1024, device=torch_xla.device())
34+
b = a.sum()
35+
# OOM is raised when we try to bring data from the device.
36+
return b.cpu()
37+
38+
self.assertExpectedRaisesInline(
39+
exc_type=RuntimeError,
40+
callable=test,
41+
expect="""Error preparing computation: Out of memory allocating 4503599761588224 bytes."""
42+
)
43+
44+
def test_cat_raises_error_on_incompatible_shapes(self):
45+
a = torch.rand(2, 2, device=torch_xla.device())
46+
b = torch.rand(5, 1, device=torch_xla.device())
47+
48+
def test():
49+
return torch.cat([a, b])
50+
51+
self.assertExpectedRaisesInline(
52+
exc_type=RuntimeError,
53+
callable=test,
54+
expect="""cat(): cannot concatenate tensors of shape f32[2,2] with f32[5,1] at dimension 0. Expected shapes to be equal (except at dimension 0) or that either of them was a 1D empty tensor of size (0,)."""
55+
)
56+
57+
def test_div_raises_error_on_invalid_rounding_mode(self):
58+
a = torch.rand(2, 2, device=torch_xla.device())
59+
60+
def test():
61+
return torch.div(a, 2, rounding_mode="bad")
62+
63+
self.assertExpectedRaisesInline(
64+
exc_type=RuntimeError,
65+
callable=test,
66+
expect="""div(): invalid rounding mode `bad`. Expected it to be either 'trunc', 'floor', or be left unspecified."""
67+
)
68+
69+
def test_flip_raises_error_on_duplicated_dims(self):
70+
a = torch.rand(2, 2, 2, 2, device=torch_xla.device())
71+
dims = [0, 0, 0, 1, 2, 3, -1]
72+
73+
def test():
74+
return torch.flip(a, dims=dims)
75+
76+
self.assertExpectedRaisesInline(
77+
exc_type=RuntimeError,
78+
callable=test,
79+
expect="""flip(): expected each dimension to appear at most once. Found dimensions: 0 (3 times), 3 (2 times). Consider changing dims from [0, 0, 0, 1, 2, 3, -1] to [0, 1, 2, 3]."""
80+
)
81+
82+
def test_full_raises_error_on_negative_size(self):
83+
shape = [2, -2, 2]
84+
85+
def test():
86+
return torch.full(shape, 1.5, device="xla")
87+
88+
self.assertExpectedRaisesInline(
89+
exc_type=RuntimeError,
90+
callable=test,
91+
expect="""full(): expected concrete sizes (i.e. non-symbolic) to be positive values. However found negative ones: [2, -2, 2]."""
92+
)
93+
94+
def test_gather_raises_error_on_rank_mismatch(self):
95+
S = 2
96+
97+
input = torch.arange(4, device=torch_xla.device()).view(S, S)
98+
index = torch.randint(0, S, (S, S, S), device=torch_xla.device())
99+
dim = 1
100+
101+
def test():
102+
return torch.gather(input, dim, index)
103+
104+
self.assertExpectedRaisesInline(
105+
exc_type=RuntimeError,
106+
callable=test,
107+
expect="""gather(): expected rank of input (2) and index (3) tensors to be the same."""
108+
)
109+
110+
def test_gather_raises_error_on_invalid_index_size(self):
111+
S = 2
112+
X = S + 2
113+
114+
input = torch.arange(16, device=torch_xla.device()).view(S, S, S, S)
115+
index = torch.randint(0, S, (X, S, X, S), device=torch_xla.device())
116+
dim = 1
117+
118+
def test():
119+
return torch.gather(input, dim, index)
120+
121+
self.assertExpectedRaisesInline(
122+
exc_type=RuntimeError,
123+
callable=test,
124+
expect="""gather(): expected sizes of index [4, 2, 4, 2] to be smaller or equal those of input [2, 2, 2, 2] on all dimensions, except on dimension 1. However, that's not true on dimensions [0, 2]."""
125+
)
126+
127+
def test_random__raises_error_on_empty_interval(self):
128+
a = torch.empty(10, device=torch_xla.device())
129+
from_ = 3
130+
to_ = 1
131+
132+
def test():
133+
return a.random_(from_, to_)
134+
135+
self.assertExpectedRaisesInline(
136+
exc_type=RuntimeError,
137+
callable=test,
138+
expect="""random_(): expected `from` (3) to be smaller than `to` (1)."""
139+
)
140+
141+
def test_random__raises_error_on_value_out_of_type_value_range(self):
142+
a = torch.empty(10, device=torch_xla.device(), dtype=torch.float16)
143+
from_ = 3
144+
to_ = 65_504 + 2
145+
146+
def test():
147+
return a.random_(from_, to_)
148+
149+
self.assertExpectedRaisesInline(
150+
exc_type=RuntimeError,
151+
callable=test,
152+
expect="""random_(): expected `to` to be within the range [-65504, 65504]. However got value 65505, which is greater than the upper bound."""
153+
)
154+
155+
def test_mm_raises_error_on_non_matrix_input(self):
156+
device = torch_xla.device()
157+
a = torch.rand(2, 2, 2, device=device)
158+
b = torch.rand(2, 2, device=device)
159+
160+
def test():
161+
torch.mm(a, b)
162+
163+
self.assertExpectedRaisesInline(
164+
exc_type=RuntimeError,
165+
callable=test,
166+
expect="""mm(): expected the first input tensor f32[2,2,2] to be a matrix (i.e. a 2D tensor)."""
167+
)
168+
169+
def test_mm_raises_error_on_incompatible_shapes(self):
170+
device = torch_xla.device()
171+
a = torch.rand(2, 5, device=device)
172+
b = torch.rand(8, 2, device=device)
173+
174+
def test():
175+
torch.mm(a, b)
176+
177+
self.assertExpectedRaisesInline(
178+
exc_type=RuntimeError,
179+
callable=test,
180+
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
181+
)
182+
183+
def test_clamp_raises_error_on_no_min_and_max(self):
184+
device = torch_xla.device()
185+
a = torch.rand(2, 5, device=device)
186+
187+
def test():
188+
return torch.ops.aten.clamp.default(a)
189+
190+
self.assertExpectedRaisesInline(
191+
exc_type=RuntimeError,
192+
callable=test,
193+
expect="""mm(): cannot matrix-multiply tensors f32[2,5] and f32[8,2]. Expected the size of dimension 1 of the first input tensor (5) to be equal the size of dimension 0 of the second input tensor (8)."""
194+
)

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,24 +1372,31 @@ at::Tensor XLANativeFunctions::clamp(const at::Tensor& self,
13721372
const std::optional<at::Scalar>& min,
13731373
const std::optional<at::Scalar>& max) {
13741374
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1375-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1376-
return bridge::AtenFromXlaTensor(tensor_methods::clamp(xla_self, min, max));
1375+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1376+
bridge::GetXlaTensor(self));
1377+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1378+
tensor_methods::clamp(xla_self, min, max));
1379+
return bridge::AtenFromXlaTensor(std::move(output));
13771380
}
13781381

13791382
at::Tensor XLANativeFunctions::clamp_max(const at::Tensor& self,
13801383
const at::Scalar& max) {
13811384
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1382-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1383-
return bridge::AtenFromXlaTensor(
1384-
tensor_methods::clamp(xla_self, std::nullopt, max));
1385+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1386+
bridge::GetXlaTensor(self));
1387+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1388+
tensor_methods::clamp(xla_self, std::nullopt, max));
1389+
return bridge::AtenFromXlaTensor(std::move(output));
13851390
}
13861391

13871392
at::Tensor XLANativeFunctions::clamp_min(const at::Tensor& self,
13881393
const at::Scalar& min) {
13891394
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1390-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1391-
return bridge::AtenFromXlaTensor(
1392-
tensor_methods::clamp(xla_self, min, std::nullopt));
1395+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1396+
bridge::GetXlaTensor(self));
1397+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1398+
tensor_methods::clamp(xla_self, min, std::nullopt));
1399+
return bridge::AtenFromXlaTensor(std::move(output));
13931400
}
13941401

13951402
at::Tensor XLANativeFunctions::clone(
@@ -1947,9 +1954,11 @@ at::Tensor XLANativeFunctions::hardtanh(const at::Tensor& self,
19471954
const at::Scalar& min_val,
19481955
const at::Scalar& max_val) {
19491956
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
1950-
XLA_ASSIGN_OR_THROW(XLATensorPtr xla_self, bridge::GetXlaTensor(self));
1951-
return bridge::AtenFromXlaTensor(
1952-
tensor_methods::clamp(xla_self, min_val, max_val));
1957+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr xla_self,
1958+
bridge::GetXlaTensor(self));
1959+
XLA_ASSIGN_OR_THROW(absl_nonnull XLATensorPtr output,
1960+
tensor_methods::clamp(xla_self, min_val, max_val));
1961+
return bridge::AtenFromXlaTensor(std::move(output));
19531962
}
19541963

19551964
at::Tensor XLANativeFunctions::hardtanh_backward(const at::Tensor& grad_output,

torch_xla/csrc/tensor_methods.cpp

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -177,22 +177,6 @@ torch::lazy::Value MaybeExpand(const torch::lazy::Value& input,
177177
input, torch::lazy::ToVector<int64_t>(target_shape.dimensions()));
178178
}
179179

180-
MinMaxValues GetMinMaxValues(const XLATensorPtr& tensor,
181-
const std::optional<at::Scalar>& min,
182-
const std::optional<at::Scalar>& max) {
183-
XLA_CHECK(min || max)
184-
<< "At least one of \'min\' or \'max\' must not be None";
185-
xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(tensor->dtype());
186-
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type);
187-
auto shape = tensor->shape();
188-
return {XLAGraphExecutor::Get()->GetIrValueForScalar(
189-
min ? *min : min_max.min, shape.get().element_type(),
190-
tensor->GetDevice()),
191-
XLAGraphExecutor::Get()->GetIrValueForScalar(
192-
max ? *max : min_max.max, shape.get().element_type(),
193-
tensor->GetDevice())};
194-
}
195-
196180
void CheckRank(const XLATensorPtr& t, int64_t expected_rank,
197181
const std::string& tag, const std::string& arg_name,
198182
int arg_number) {
@@ -506,6 +490,16 @@ absl::Status CheckMMMatrixSizesAreCompatible(const XLATensorPtr& mat1,
506490
return absl::OkStatus();
507491
}
508492

493+
absl::Status CheckClampMinOrMax(const std::optional<at::Scalar>& min,
494+
const std::optional<at::Scalar>& max) {
495+
if (!min.has_value() && !max.has_value()) {
496+
return XLA_ERROR_WITH_LOCATION(
497+
absl::InvalidArgumentError("clamp(): expected at least one of `min` or "
498+
"`max` arguments to be specified."));
499+
}
500+
return absl::OkStatus();
501+
}
502+
509503
} // namespace
510504

511505
//////////////////////////////////////////////////////////////////////////////
@@ -1357,12 +1351,23 @@ void celu_(XLATensorPtr& input, const at::Scalar& alpha) {
13571351
input->SetInPlaceIrValue(Celu(input->GetIrValue(), alpha));
13581352
}
13591353

1360-
XLATensorPtr clamp(const XLATensorPtr& input,
1361-
const std::optional<at::Scalar>& min,
1362-
const std::optional<at::Scalar>& max) {
1363-
MinMaxValues min_max = GetMinMaxValues(input, min, max);
1364-
return input->CreateFrom(
1365-
Clamp(input->GetIrValue(), min_max.min, min_max.max));
1354+
absl::StatusOr<absl_nonnull XLATensorPtr> clamp(
1355+
const XLATensorPtr& input, const std::optional<at::Scalar>& min,
1356+
const std::optional<at::Scalar>& max) {
1357+
XLA_RETURN_IF_ERROR(CheckClampMinOrMax(min, max));
1358+
1359+
xla::Shape shape = input->shape();
1360+
const torch::lazy::BackendDevice& device = input->GetDevice();
1361+
1362+
xla::PrimitiveType raw_element_type = XlaTypeFromTorchType(input->dtype());
1363+
XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(raw_element_type);
1364+
1365+
torch::lazy::Value min_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
1366+
min.value_or(min_max.min), shape.element_type(), device);
1367+
torch::lazy::Value max_value = XLAGraphExecutor::Get()->GetIrValueForScalar(
1368+
max.value_or(min_max.max), shape.element_type(), device);
1369+
1370+
return input->CreateFrom(Clamp(input->GetIrValue(), min_value, max_value));
13661371
}
13671372

13681373
XLATensorPtr clone(const XLATensorPtr& input) {

torch_xla/csrc/tensor_methods.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,9 @@ XLATensorPtr pixel_shuffle(const XLATensorPtr& self, int64_t upscale_factor);
316316
XLATensorPtr celu(const XLATensorPtr& input, const at::Scalar& alpha);
317317
void celu_(XLATensorPtr& input, const at::Scalar& alpha);
318318

319-
XLATensorPtr clamp(const XLATensorPtr& input,
320-
const std::optional<at::Scalar>& min,
321-
const std::optional<at::Scalar>& max);
322-
XLATensorPtr clamp(const XLATensorPtr& input,
323-
const std::optional<at::Tensor>& min,
324-
const std::optional<at::Tensor>& max);
319+
absl::StatusOr<absl_nonnull XLATensorPtr> clamp(
320+
const XLATensorPtr& input, const std::optional<at::Scalar>& min,
321+
const std::optional<at::Scalar>& max);
325322

326323
XLATensorPtr clone(const XLATensorPtr& input);
327324

0 commit comments

Comments
 (0)