Skip to content

Commit d3600f5

Browse files
author
Masahiro Masuda
committed
show supported data type in err msg
1 parent c1a9a6d commit d3600f5

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

src/runtime/contrib/rocthrust/thrust.cc

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices
119119
} else if (out_dtype == "float64") {
120120
thrust_sort<float, double>(input, values_out, indices_out, is_ascend, sort_len);
121121
} else {
122-
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
122+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
123+
<< ". Supported output dtypes are int32, int64, float32 and float64";
123124
}
124125
} else if (data_dtype == "float64") {
125126
if (out_dtype == "int32") {
@@ -131,7 +132,8 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices
131132
} else if (out_dtype == "float64") {
132133
thrust_sort<double, double>(input, values_out, indices_out, is_ascend, sort_len);
133134
} else {
134-
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
135+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
136+
<< ". Supported output dtypes are int32, int64, float32 and float64";
135137
}
136138
} else if (data_dtype == "int32") {
137139
if (out_dtype == "int32") {
@@ -143,7 +145,8 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices
143145
} else if (out_dtype == "float64") {
144146
thrust_sort<int32_t, double>(input, values_out, indices_out, is_ascend, sort_len);
145147
} else {
146-
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
148+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
149+
<< ". Supported output dtypes are int32, int64, float32 and float64";
147150
}
148151
} else if (data_dtype == "int64") {
149152
if (out_dtype == "int32") {
@@ -155,10 +158,12 @@ void thrust_sort_common(DLTensor* input, DLTensor* values_out, DLTensor* indices
155158
} else if (out_dtype == "float64") {
156159
thrust_sort<int64_t, double>(input, values_out, indices_out, is_ascend, sort_len);
157160
} else {
158-
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
161+
LOG(FATAL) << "Unsupported output dtype: " << out_dtype
162+
<< ". Supported output dtypes are int32, int64, float32 and float64";
159163
}
160164
} else {
161-
LOG(FATAL) << "Unsupported input dtype: " << data_dtype;
165+
LOG(FATAL) << "Unsupported input dtype: " << data_dtype
166+
<< ". Supported input dtypes are int32, int64, and float32 and float64.";
162167
}
163168
}
164169

@@ -221,7 +226,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
221226
thrust_stable_sort_by_key<int, float>(keys_in, values_in, keys_out, values_out,
222227
for_scatter);
223228
} else {
224-
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
229+
LOG(FATAL) << "Unsupported value dtype: " << value_dtype
230+
<< ". Supported value dtypes are int32, int64 and float32";
225231
}
226232
} else if (key_dtype == "int64") {
227233
if (value_dtype == "int32") {
@@ -234,7 +240,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
234240
thrust_stable_sort_by_key<int64_t, float>(keys_in, values_in, keys_out, values_out,
235241
for_scatter);
236242
} else {
237-
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
243+
LOG(FATAL) << "Unsupported value dtype: " << value_dtype
244+
<< ". Supported value dtypes are int32, int64 and float32";
238245
}
239246
} else if (key_dtype == "float32") {
240247
if (value_dtype == "int32") {
@@ -247,10 +254,12 @@ TVM_REGISTER_GLOBAL("tvm.contrib.thrust.stable_sort_by_key")
247254
thrust_stable_sort_by_key<float, float>(keys_in, values_in, keys_out, values_out,
248255
for_scatter);
249256
} else {
250-
LOG(FATAL) << "Unsupported value dtype: " << value_dtype;
257+
LOG(FATAL) << "Unsupported value dtype: " << value_dtype
258+
<< ". Supported value dtypes are int32, int64 and float32";
251259
}
252260
} else {
253-
LOG(FATAL) << "Unsupported key dtype: " << key_dtype;
261+
LOG(FATAL) << "Unsupported key dtype: " << key_dtype
262+
<< ". Supported key dtypes are int32, int64, and float32.";
254263
}
255264
});
256265

0 commit comments

Comments
 (0)