diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc index a5240aa2b2c5..6ee54e14f37b 100644 --- a/src/runtime/disco/nccl/nccl.cc +++ b/src/runtime/disco/nccl/nccl.cc @@ -114,8 +114,12 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, bool in_group, NDArray recv ShapeTuple shape = send.Shape(); int64_t numel = shape->Product(); deviceStream_t stream = ctx->GetDefaultStream(); + DataType dtype = DataType(send->dtype); + if (dtype == DataType::NVFloat8E4M3() || dtype == DataType::NVFloat8E5M2()) { + LOG(FATAL) << "Float8 data type cannot be allreduced, as nccl does not support this data type."; + } NCCL_CALL(ncclAllReduce(send->data, recv->data, numel, - /*datatype=*/AsNCCLDataType(DataType(send->dtype)), + /*datatype=*/AsNCCLDataType(dtype), /*op=*/AsNCCLRedOp(reduce_kind), in_group ? ctx->group_comm : ctx->global_comm, stream)); } diff --git a/src/runtime/disco/nccl/nccl_context.h b/src/runtime/disco/nccl/nccl_context.h index b874da219fe4..6c1eaf749a67 100644 --- a/src/runtime/disco/nccl/nccl_context.h +++ b/src/runtime/disco/nccl/nccl_context.h @@ -86,7 +86,10 @@ inline ncclDataType_t AsNCCLDataType(runtime::DataType dtype) { if (dtype == DataType::Int(8)) { return ncclInt8; } - if (dtype == DataType::UInt(8)) { + if (dtype == DataType::UInt(8) || dtype == DataType::NVFloat8E4M3() || + dtype == DataType::NVFloat8E5M2()) { + // For float8 data type, pretend to be uint8 in nccl. + // And will throw error when allreduce, as it makes no sense in this case. return ncclUint8; } if (dtype == DataType::Int(32)) {