@@ -29,23 +29,24 @@ Tensor& sub_out(
2929 InvalidArgument,
3030 out);
3131
32- ET_KERNEL_CHECK (ctx, tensor_is_realhb_type (out), InvalidArgument, out);
32+ ET_KERNEL_CHECK (ctx, tensor_is_realh_type (out), InvalidArgument, out);
3333
3434 ScalarType a_type = a.scalar_type ();
3535 ScalarType b_type = b.scalar_type ();
3636 ScalarType alpha_type = utils::get_scalar_dtype (alpha);
3737 ScalarType common_type = promoteTypes (a_type, b_type, /* half_to_float*/ true );
3838 ScalarType out_type = out.scalar_type ();
3939
40+ ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
4041 ET_KERNEL_CHECK (
4142 ctx, check_alpha_type (alpha_type, common_type), InvalidArgument, out);
42- ET_KERNEL_CHECK (ctx, canCast (common_type, out_type), InvalidArgument, out);
43- ET_KERNEL_CHECK (ctx, tensor_is_realh_type (out), InvalidArgument, out);
4443
45- ET_SWITCH_REALH_TYPES (a_type, ctx, " sub.out" , CTYPE_A, [&]() {
46- ET_SWITCH_REALH_TYPES (b_type, ctx, " sub.out" , CTYPE_B, [&]() {
47- ET_SWITCH_REAL_TYPES (common_type, ctx, " sub.out" , CTYPE_IN, [&]() {
48- ET_SWITCH_REALH_TYPES (out_type, ctx, " sub.out" , CTYPE_OUT, [&]() {
44+ constexpr auto name = " sub.out" ;
45+
46+ ET_SWITCH_REALH_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
47+ ET_SWITCH_REALH_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
48+ ET_SWITCH_REAL_TYPES (common_type, ctx, name, CTYPE_IN, [&]() {
49+ ET_SWITCH_REALH_TYPES (out_type, ctx, name, CTYPE_OUT, [&]() {
4950 CTYPE_IN alpha_val;
5051 utils::extract_scalar (alpha, &alpha_val);
5152
@@ -84,11 +85,11 @@ Tensor& sub_scalar_out(
8485 out,
8586 " Failed to resize output tensor." );
8687
87- ET_KERNEL_CHECK (ctx, tensor_is_realhb_type (out), InvalidArgument, out);
88+ ET_KERNEL_CHECK (ctx, tensor_is_realh_type (out), InvalidArgument, out);
8889
8990 ScalarType a_type = a.scalar_type ();
9091 ScalarType b_type = utils::get_scalar_dtype (b);
91- ScalarType alpha_type = utils::get_scalar_dtype (b );
92+ ScalarType alpha_type = utils::get_scalar_dtype (alpha );
9293 ScalarType common_type =
9394 utils::promote_type_with_scalar (a_type, b, /* half_to_float*/ false );
9495 ScalarType out_type = out.scalar_type ();
@@ -100,31 +101,30 @@ Tensor& sub_scalar_out(
100101 common_type = ScalarType::Float;
101102 }
102103
103- ET_SWITCH_REALH_TYPES (a_type, ctx, " sub.Scalar_out" , CTYPE_A, [&]() {
104- ET_SWITCH_SCALAR_OBJ_REAL_TYPES (
105- b_type, ctx, " sub.Scalar_out" , CTYPE_B, [&]() {
106- ET_SWITCH_REAL_TYPES (
107- common_type, ctx, " sub.Scalar_out" , CTYPE_IN, [&]() {
108- ET_SWITCH_REALH_TYPES (
109- out_type, ctx, " sub.Scalar_out" , CTYPE_OUT, [&]() {
110- CTYPE_B b_val;
111- utils::extract_scalar (b, &b_val);
112- CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
113- CTYPE_IN alpha_val;
114- utils::extract_scalar (alpha, &alpha_val);
115-
116- apply_unary_map_fn (
117- [b_casted, alpha_val](const CTYPE_A val_a) {
118- CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
119- CTYPE_IN value = a_casted - alpha_val * b_casted;
120- return static_cast <CTYPE_OUT>(value);
121- },
122- a.const_data_ptr <CTYPE_A>(),
123- out.mutable_data_ptr <CTYPE_OUT>(),
124- out.numel ());
125- });
126- });
104+ constexpr auto name = " sub.Scalar_out" ;
105+
106+ ET_SWITCH_REALH_TYPES (a_type, ctx, name, CTYPE_A, [&]() {
107+ ET_SWITCH_SCALAR_OBJ_REAL_TYPES (b_type, ctx, name, CTYPE_B, [&]() {
108+ ET_SWITCH_REAL_TYPES (common_type, ctx, name, CTYPE_IN, [&]() {
109+ ET_SWITCH_REALH_TYPES (out_type, ctx, name, CTYPE_OUT, [&]() {
110+ CTYPE_B b_val;
111+ utils::extract_scalar (b, &b_val);
112+ CTYPE_IN b_casted = static_cast <CTYPE_IN>(b_val);
113+ CTYPE_IN alpha_val;
114+ utils::extract_scalar (alpha, &alpha_val);
115+
116+ apply_unary_map_fn (
117+ [b_casted, alpha_val](const CTYPE_A val_a) {
118+ CTYPE_IN a_casted = static_cast <CTYPE_IN>(val_a);
119+ CTYPE_IN value = a_casted - alpha_val * b_casted;
120+ return static_cast <CTYPE_OUT>(value);
121+ },
122+ a.const_data_ptr <CTYPE_A>(),
123+ out.mutable_data_ptr <CTYPE_OUT>(),
124+ out.numel ());
127125 });
126+ });
127+ });
128128 });
129129
130130 return out;
0 commit comments