Skip to content

Commit 9ec7249

Browse files
authored
[TIR] Implement max/min_value for fp8 data types (#16723)
1 parent 939b8b9 commit 9ec7249

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/tir/op/op.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,12 @@ PrimExpr max_value(const DataType& dtype, Span span) {
262262
}
263263
} else if (dtype.is_bfloat16()) {
264264
return FloatImm(dtype, std::numeric_limits<float>::max(), span);
265+
} else if (dtype.is_float8()) {
266+
if (dtype.code() == DataType::TypeCode::kE5M2Float) {
267+
return FloatImm(dtype, 57344.0, span);
268+
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
269+
return FloatImm(dtype, 448.0, span);
270+
}
265271
}
266272
LOG(FATAL) << "Cannot decide max_value for type" << dtype;
267273
}
@@ -296,6 +302,12 @@ PrimExpr min_value(const DataType& dtype, Span span) {
296302
}
297303
} else if (dtype.is_bfloat16()) {
298304
return FloatImm(dtype, std::numeric_limits<float>::lowest(), span);
305+
} else if (dtype.is_float8()) {
306+
if (dtype.code() == DataType::TypeCode::kE5M2Float) {
307+
return FloatImm(dtype, -57344.0, span);
308+
} else if (dtype.code() == DataType::TypeCode::kE4M3Float) {
309+
return FloatImm(dtype, -448.0, span);
310+
}
299311
}
300312
LOG(FATAL) << "Cannot decide min_value for type" << dtype;
301313
}

0 commit comments

Comments
 (0)