diff --git a/csrc/ops.hip b/csrc/ops.hip index 4dcbbecfd..972209cfa 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -509,6 +509,7 @@ static std::string hipError_to_string(const hipError_t ret) } } + template int igemmlt(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc) { #ifdef NO_HIPBLASLT @@ -524,28 +525,101 @@ template int igemmlt(hipblasLtHandl hipblasLtOrder_t col_ampere = HIPBLASLT_ORDER_COL; has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Adesc, HIP_R_8I, m, k, lda)); + if(has_error != 0) + { + std::cout<<"failed to run hipblasLtMatrixLayoutCreate for Adesc:"< int igemmlt(hipblasLtHandl } else { - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + void* d_workspace=nullptr; + uint64_t workspace_size = 0; + for(int i = 0; i < returnedAlgoCount; i++) + workspace_size = max(workspace_size, heuristicResult[i].workspaceSize); + hipMalloc(&d_workspace, workspace_size); + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int32_t*)C, Cdesc, (int32_t*)C, Cdesc, &heuristicResult[0].algo, d_workspace, workspace_size, 0)); + hipFree(d_workspace); + if(has_error != 0) + { + std::cout<<"failed to run hipblasLtMatmul"< int igemmlt(hipblasLtHandl } else { + uint64_t workspace_size = 0; + for(int i = 0; i < returnedAlgoCount; i++) + workspace_size = max(workspace_size, heuristicResult[i].workspaceSize); + void* d_workspace=nullptr; + hipMalloc(&d_workspace, workspace_size); if(!SCALE_ROWS) { float alpha = 1.0f, beta = 0.0f; - - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc,&alpha, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, workspace_size, 0)); } else { float beta = 0.0f; - - has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, 0, 0)); + has_error |= checkHipblasStatus(hipblasLtMatmul(ltHandle, matmulDesc, row_scale, A, Adesc, B, Bdesc, &beta, (int8_t*)C, Cdesc, (int8_t*)C, Cdesc, &heuristicResult[0].algo, nullptr, workspace_size, 0)); + } + hipFree(d_workspace); + if(has_error != 0) + { + std::cout<<"failed to run hipblasLtMatmul with int8"< int igemmlt(hipblasLtHandl #endif // NO_HIPBLASLT } + int fill_up_to_nearest_multiple(int value, int multiple) { return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple))); diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh index e57cbb3b5..49e02ade6 100644 --- a/csrc/ops_hip.cuh +++ b/csrc/ops_hip.cuh @@ -101,17 +101,20 @@ typedef enum Funcs_t class Context { public: - rocblas_handle m_handle; + hipblasLtHandle_t m_handle; + //rocblas_handle m_handle; Context() { - rocblas_handle handle; - rocblas_create_handle(&handle); + //rocblas_handle handle; + //rocblas_create_handle(&handle); + hipblasLtHandle_t handle; + hipblasLtCreate(&handle); m_handle = handle; } - }; +/* class ContextLt { public: @@ -124,6 +127,7 @@ class ContextLt m_handle = handle; } }; +*/ class ContextHipsparse {