Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,21 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { // NO
return;
}

if (op->dtype.is_float8()) {
int lanes = op->dtype.lanes();
ICHECK(lanes == 1 || lanes == 2 || lanes == 4);
std::string v = PrintExpr(op->value);
// Implicit conversion from float back to fp8
PrintType(op->dtype, os);
os << "(make_float" << lanes << "(";
for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << "static_cast<float>(" << v << ")";
}
os << "))";
return;
}

if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 4) {
bool fail = false;
const int64_t* p = as_const_int(op->value);
Expand Down Expand Up @@ -1359,6 +1374,12 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p)
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kE5M2Float or kE4M4Float
if (op->dtype.is_float8()) {
p->PrintType(op->dtype, os);
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
// Type code is kFloat
switch (op->dtype.bits()) {
case 64:
Expand Down
15 changes: 15 additions & 0 deletions tests/python/codegen/test_target_codegen_cuda_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,5 +799,20 @@ def test_main(self, weight_shape, model_dtype, target_str, compiled_functions):
tvm.testing.assert_allclose(weight_np, dequant_weight_np, atol=10, rtol=5e-2)


@tvm.testing.requires_cuda_compute_version(9)
@pytest.mark.parametrize("dtype", ["e5m2_float8", "e4m3_float8"])
def test_const(dtype):
@T.prim_func
def func(A: T.Buffer((4,), dtype)) -> None:
A_local = T.alloc_buffer((4,), dtype=dtype, scope="local")
for tx in T.thread_binding(0, 4, "threadIdx.x"):
for i in T.vectorized(4):
A_local[i] = T.float32(1.0).astype(dtype)
A[tx] = A_local[tx]

mod = tvm.IRModule({"main": func})
tvm.build(mod, target="cuda")


if __name__ == "__main__":
tvm.testing.main()