Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -476,13 +476,51 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
hipblasLtPointerMode_t pointerMode = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;

has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&aDesc, HIP_R_8I, m, k, lda));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutCreate for Adesc:"<<m<<" "<<k<<" "<<lda<<std::endl;
return has_error;
}

has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&bDesc, HIP_R_8I, m, n, ldb));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutCreate for Bdesc:"<<m<<" "<<n<<" "<<ldb<<std::endl;
hipblasLtMatrixLayoutDestroy(aDesc);
return has_error;
}

has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatrixLayoutCreate for Cdesc"<<k<<" "<<n<<" "<<ldc<<std::endl;
hipblasLtMatrixLayoutDestroy(aDesc);
hipblasLtMatrixLayoutDestroy(bDesc);
return has_error;
}

// Default layout order is col major

has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, scaleType));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescCreate"<<std::endl;
hipblasLtMatrixLayoutDestroy(aDesc);
hipblasLtMatrixLayoutDestroy(bDesc);
hipblasLtMatrixLayoutDestroy(cDesc);
return has_error;
}

has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT)));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescSetAttribute HIPBLASLT_MATMUL_DESC_TRANSA"<<std::endl;
hipblasLtMatrixLayoutDestroy(aDesc);
hipblasLtMatrixLayoutDestroy(bDesc);
hipblasLtMatrixLayoutDestroy(cDesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}

if (DTYPE_OUT == 32) {

Expand Down Expand Up @@ -524,6 +562,15 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
(int32_t*)C, cDesc,
&heuristicResult[0].algo, NULL, 0, stream
));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmul"<<std::endl;
hipblasLtMatrixLayoutDestroy(aDesc);
hipblasLtMatrixLayoutDestroy(bDesc);
hipblasLtMatrixLayoutDestroy(cDesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
}
} else {
// This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.
Expand All @@ -538,6 +585,15 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmul with int8"<<std::endl;
hipblasLtMatrixLayoutDestroy(aDesc);
hipblasLtMatrixLayoutDestroy(bDesc);
hipblasLtMatrixLayoutDestroy(cDesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
} else {
hipblasLtPointerMode_t alphaVec = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;
float beta = 0.0f;
Expand All @@ -547,6 +603,16 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
&pointerMode,
sizeof(alphaVec)
));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmulDescSetAttribute HIPBLASLT_MATMUL_DESC_POINTER_MODE for int8"<<std::endl;
hipblasLtMatrixLayoutDestroy(aDesc);
hipblasLtMatrixLayoutDestroy(bDesc);
hipblasLtMatrixLayoutDestroy(cDesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}

has_error |= checkHipblasStatus(hipblasLtMatmul(
ltHandle, matmulDesc,
row_scale, A, aDesc,
Expand All @@ -555,6 +621,15 @@ template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
));
if(has_error != 0)
{
std::cout<<"failed to run hipblasLtMatmul with int8"<<std::endl;
hipblasLtMatrixLayoutDestroy(aDesc);
hipblasLtMatrixLayoutDestroy(bDesc);
hipblasLtMatrixLayoutDestroy(cDesc);
hipblasLtMatmulDescDestroy(matmulDesc);
return has_error;
}
}
}

Expand Down
20 changes: 5 additions & 15 deletions csrc/ops_hip.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -99,28 +99,18 @@ typedef enum Funcs_t
class Context
{
public:
rocblas_handle m_handle;
hipblasLtHandle_t m_handle;
//rocblas_handle m_handle;

Context()
{
rocblas_handle handle;
rocblas_create_handle(&handle);
m_handle = handle;
}

};

class ContextLt
{
public:
hipblasLtHandle_t m_handle;

ContextLt()
{
//rocblas_handle handle;
//rocblas_create_handle(&handle);
hipblasLtHandle_t handle;
hipblasLtCreate(&handle);
m_handle = handle;
}

};

class ContextHipsparse
Expand Down