Skip to content

Commit dd28f09

Browse files
authored
fix issue when build with hipblasLt on rocm6.1 (#22553)
### Description <!-- Describe your changes. --> hipblasLt library is released with rocm6.x, and current onnxruntime's code need some modifications to match new hipblasLt API. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 7ad7873 commit dd28f09

File tree

3 files changed

+16
-17
lines changed

3 files changed

+16
-17
lines changed

onnxruntime/core/providers/rocm/rocm_call.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,4 @@ template Status RocmCall<ncclResult_t, false>(ncclResult_t retCode, const char*
170170
template void RocmCall<ncclResult_t, true>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
171171
#endif
172172

173-
#ifdef USE_HIPBLASLT
174-
template Status RocmCall<hipblasStatus_t, false>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
175-
template void RocmCall<hipblasStatus_t, true>(hipblasStatus_t retCode, const char* exprString, const char* libName, hipblasStatus_t successCode, const char* msg, const char* file, const int line);
176-
#endif
177-
178173
} // namespace onnxruntime

onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,26 +37,26 @@ enum ActivationType {
3737
};
3838

3939
template <typename T>
40-
constexpr hipblasltDatatype_t HipBlasDataTypeFor();
40+
constexpr hipDataType HipBlasDataTypeFor();
4141

4242
template <>
43-
constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
44-
return HIPBLASLT_R_32F;
43+
constexpr hipDataType HipBlasDataTypeFor<float>() {
44+
return HIP_R_32F;
4545
}
4646

4747
template <>
48-
constexpr hipblasltDatatype_t HipBlasDataTypeFor<half>() {
49-
return HIPBLASLT_R_16F;
48+
constexpr hipDataType HipBlasDataTypeFor<half>() {
49+
return HIP_R_16F;
5050
}
5151

5252
template <>
53-
constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
54-
return HIPBLASLT_R_16B;
53+
constexpr hipDataType HipBlasDataTypeFor<BFloat16>() {
54+
return HIP_R_16BF;
5555
}
5656

5757
template <>
58-
constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
59-
return HIPBLASLT_R_64F;
58+
constexpr hipDataType HipBlasDataTypeFor<double>() {
59+
return HIP_R_64F;
6060
}
6161

6262
template <BlasOp Op>
@@ -108,7 +108,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
108108

109109
hipblasOperation_t trans_a = MapBlasOpToHipBlasLt<OpB>();
110110
hipblasOperation_t trans_b = MapBlasOpToHipBlasLt<OpA>();
111-
hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor<T>();
111+
hipDataType in_out_datatype = HipBlasDataTypeFor<T>();
112112
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;
113113

114114
HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle,
@@ -119,7 +119,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
119119
in_out_datatype,
120120
in_out_datatype,
121121
in_out_datatype,
122-
HIPBLASLT_COMPUTE_F32,
122+
HIPBLAS_COMPUTE_32F,
123123
heuristic_result));
124124
HIPBLASLT_CALL_THROW(hipblasLtDestroy(handle));
125125

@@ -161,7 +161,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
161161
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
162162
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
163163
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
164-
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLASLT_R_32F));
164+
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
165165

166166
int batch = GetBatchCountFromParams<T>(params);
167167
if (batch > 1) {

tools/ci_build/amd_hipify.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path):
2121
s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn")
2222
s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut")
2323
s = s.replace("kTotalCudaStreams", "kTotalHipStreams")
24+
25+
# in rocm 6.0, hipify-perl, the -roc option also maps __half -> rocblas_half which we don't want
26+
s = s.replace("rocblas_half", "__half")
27+
2428
# these should be "hip" but it's easier to just use rocm to avoid complicated file renaming
2529
s = s.replace("CudaGraph", "RocmGraph")
2630
s = s.replace("CUDAGraph", "ROCMGraph")

0 commit comments

Comments
 (0)