Skip to content

Commit 931efc7

Browse files
authored
[Disco] Enable float8 data type in disco (#17398)
This PR enables the float8 data type in disco, except all reduce operation. Since in this PR, we pretend float8 to be uint8.
1 parent 85f2cc3 commit 931efc7

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/runtime/disco/nccl/nccl.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,12 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv
114114
ShapeTuple shape = send.Shape();
115115
int64_t numel = shape->Product();
116116
deviceStream_t stream = ctx->GetDefaultStream();
117+
DataType dtype = DataType(send->dtype);
118+
if (dtype == DataType::NVFloat8E4M3() || dtype == DataType::NVFloat8E5M2()) {
119+
LOG(FATAL) << "Float8 data type cannot be allreduced, as nccl does not support this data type.";
120+
}
117121
NCCL_CALL(ncclAllReduce(send->data, recv->data, numel,
118-
/*datatype=*/AsNCCLDataType(DataType(send->dtype)),
122+
/*datatype=*/AsNCCLDataType(dtype),
119123
/*op=*/AsNCCLRedOp(reduce_kind),
120124
in_group ? ctx->group_comm : ctx->global_comm, stream));
121125
}

src/runtime/disco/nccl/nccl_context.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) {
8686
if (dtype == DataType::Int(8)) {
8787
return ncclInt8;
8888
}
89-
if (dtype == DataType::UInt(8)) {
89+
if (dtype == DataType::UInt(8) || dtype == DataType::NVFloat8E4M3() ||
90+
dtype == DataType::NVFloat8E5M2()) {
91+
// For float8 data type, pretend to be uint8 in nccl.
92+
// And will throw error when allreduce, as it makes no sense in this case.
9093
return ncclUint8;
9194
}
9295
if (dtype == DataType::Int(32)) {

0 commit comments

Comments
 (0)