@@ -262,6 +262,12 @@ PrimExpr max_value(const DataType& dtype, Span span) {
262262 }
263263 } else if (dtype.is_bfloat16 ()) {
264264 return FloatImm (dtype, std::numeric_limits<float >::max (), span);
265+ } else if (dtype.is_float8 ()) {
266+ if (dtype.code () == DataType::TypeCode::kE5M2Float ) {
267+ return FloatImm (dtype, 57344.0 , span);
268+ } else if (dtype.code () == DataType::TypeCode::kE4M3Float ) {
269+ return FloatImm (dtype, 448.0 , span);
270+ }
265271 }
266272 LOG (FATAL) << " Cannot decide max_value for type" << dtype;
267273}
@@ -296,6 +302,12 @@ PrimExpr min_value(const DataType& dtype, Span span) {
296302 }
297303 } else if (dtype.is_bfloat16 ()) {
298304 return FloatImm (dtype, std::numeric_limits<float >::lowest (), span);
305+ } else if (dtype.is_float8 ()) {
306+ if (dtype.code () == DataType::TypeCode::kE5M2Float ) {
307+ return FloatImm (dtype, -57344.0 , span);
308+ } else if (dtype.code () == DataType::TypeCode::kE4M3Float ) {
309+ return FloatImm (dtype, -448.0 , span);
310+ }
299311 }
300312 LOG (FATAL) << " Cannot decide min_value for type" << dtype;
301313}
0 commit comments