Skip to content

Commit f30dc38

Browse files
authored
Merge pull request #13 from ROCm/enable_matmul
Enable matmul function
2 parents fc9bf4d + 90bbdc6 commit f30dc38

File tree

6 files changed

+85
-37
lines changed

6 files changed

+85
-37
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,8 @@ def backward(ctx, grad_output):
224224

225225
def supports_igemmlt(device: torch.device) -> bool:
226226
"""check if this device supports the optimized int8 kernel"""
227-
"""Important: Could I use igemmlt on ROCm? """
228227
if torch.version.hip:
229-
#Well, lets currently disable it
230-
return False
228+
return True
231229
if torch.cuda.get_device_capability(device=device) < (7, 5):
232230
return False
233231
device_name = torch.cuda.get_device_name(device=device)

bitsandbytes/functional.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -452,13 +452,13 @@ def get_transform_buffer(
452452
rows = shape[0] * shape[1]
453453
cols = shape[-1]
454454

455-
state = (shape, to_order)
456455
if transpose:
457456
# swap dims
458457
tmp = rows
459458
rows = cols
460459
cols = tmp
461-
state = (shape[::-1], to_order)
460+
shape = shape[::-1]
461+
state = (shape, to_order)
462462

463463
if to_order == "row" or to_order == "col":
464464
return init_func(shape, dtype=dtype, device=device), state
@@ -1980,6 +1980,8 @@ def mm_dequant(
19801980
new_col_stats=None,
19811981
bias=None
19821982
):
1983+
if HIP_ENVIRONMENT:
1984+
A, quant_state = nvidia_transform(A, "row", state = quant_state)
19831985
assert A.dtype == torch.int32
19841986
if bias is not None: assert bias.dtype == torch.float16
19851987
out_shape = quant_state[0]
@@ -2552,7 +2554,10 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
25522554
def extract_outliers(A, SA, idx):
25532555
shapeA = SA[0]
25542556
formatA = SA[1]
2555-
assert formatA in ["col_turing", "col_ampere"]
2557+
if not HIP_ENVIRONMENT:
2558+
assert formatA in ["col_turing", "col_ampere"]
2559+
else:
2560+
assert formatA in ["col"]
25562561
assert A.device.type == "cuda"
25572562

25582563
out = torch.zeros(
@@ -2567,7 +2572,7 @@ def extract_outliers(A, SA, idx):
25672572
ptrOut = get_ptr(out)
25682573

25692574
prev_device = pre_call(A.device)
2570-
if formatA == 'col_turing':
2575+
if formatA == 'col_turing' or HIP_ENVIRONMENT:
25712576
lib.cextractOutliers_turing(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)
25722577
elif formatA == "col_ampere":
25732578
lib.cextractOutliers_ampere(ptrA, ptrIdx, ptrOut, idx_size, rows, cols)

csrc/kernels.hip

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2300,13 +2300,16 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
23002300

23012301
const int n_out = numRows*numCols;
23022302

2303-
int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
2303+
//int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
23042304
// we have tiles of size numRows*32, thus col only increases every numRows
23052305
// num_row_tiles is the tiles after which the column increases by 32
23062306
// blockIdx.x is the index of the current tile
2307-
int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
2307+
//int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
23082308
// base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
2309-
int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
2309+
//int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
2310+
2311+
int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD;
2312+
int thread_offset = threadIdx.x * ITEMS_PER_THREAD;
23102313

23112314
// SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
23122315
// subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
@@ -2321,33 +2324,59 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
23212324

23222325
int local_values[ITEMS_PER_THREAD];
23232326
half local_output[ITEMS_PER_THREAD];
2324-
float local_rowStats[ITEMS_PER_THREAD];
2325-
__shared__ float smem_rowStats[SUBTILE_ROWS];
2327+
//float local_rowStats[ITEMS_PER_THREAD];
2328+
//__shared__ float smem_rowStats[SUBTILE_ROWS];
23262329

23272330
typedef hipcub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, hipcub::BLOCK_LOAD_DIRECT> LoadInt32;
2328-
typedef hipcub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
2331+
//typedef hipcub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
23292332
__shared__ typename LoadInt32::TempStorage loadint32;
2330-
__shared__ typename ExchangeInt32::TempStorage exchangeint32;
2333+
//__shared__ typename ExchangeInt32::TempStorage exchangeint32;
23312334

23322335

23332336
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
2334-
float colStat = col >= numCols ? 0.0f : colStats[col];
2335-
float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);
2337+
//float colStat = col >= numCols ? 0.0f : colStats[col];
2338+
//float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);
2339+
int row_idx, col_idx;
2340+
float colStat[ITEMS_PER_THREAD];
2341+
float local_biasValue[ITEMS_PER_THREAD];
2342+
float rowStat[ITEMS_PER_THREAD];
2343+
#pragma unroll ITEMS_PER_THREAD
2344+
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2345+
{
2346+
row_idx = (block_offset + thread_offset + j) / numCols;
2347+
col_idx = (block_offset + thread_offset + j) % numCols;
2348+
colStat[j] = col_idx >= numCols ? 0.0f : colStats[col_idx];
2349+
local_biasValue[j] = ((bias == NULL) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]);
2350+
rowStat[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx];
2351+
}
23362352
// no block loads for rows for now -- keep it simple
2337-
for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
2353+
/*for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
23382354
{
23392355
// todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
23402356
int row = (base_row+j) % numRows; // wrap around
23412357
// each warp accesses the same element, for four consequitive elements
23422358
// todo: update description about striped shared memory, it is not needed
23432359
// rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
23442360
smem_rowStats[j] = rowStats[row];
2345-
}
2361+
}*/
23462362
__syncthreads();
23472363

2364+
int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset;
2365+
LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0);
23482366

2367+
#pragma unroll ITEMS_PER_THREAD
2368+
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2369+
local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*rowStat[j]*colStat[j]) + local_biasValue[j]);
2370+
23492371
// each block processes SUBTILE_ROWS*32 elements
2350-
const int items_per_load = THREADS*ITEMS_PER_THREAD;
2372+
#pragma unroll ITEMS_PER_THREAD
2373+
for(int j = 0; j < ITEMS_PER_THREAD; j++)
2374+
{
2375+
int outIdx = block_offset + thread_offset + j;
2376+
if(outIdx< n_out)
2377+
out[outIdx] = local_output[j];
2378+
}
2379+
/*const int items_per_load = THREADS*ITEMS_PER_THREAD;
23512380
const int rows_per_load = items_per_load/32;
23522381
23532382
int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
@@ -2368,7 +2397,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
23682397
#pragma unroll ITEMS_PER_THREAD
23692398
for(int j = 0; j < ITEMS_PER_THREAD; j++)
23702399
local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
2371-
2400+
23722401
#pragma unroll ITEMS_PER_THREAD
23732402
for(int j = 0; j < ITEMS_PER_THREAD; j++)
23742403
local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue);
@@ -2388,7 +2417,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
23882417
}
23892418
23902419
row_offset += rows_per_load;
2391-
}
2420+
}*/
23922421
}
23932422

23942423

@@ -2974,7 +3003,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
29743003
{
29753004
int local_colidx = idx[blockIdx.x];
29763005

2977-
if(FORMAT==COL_TURING)
3006+
/*if(FORMAT==COL_TURING)
29783007
{
29793008
// TURING FORMAT:
29803009
// 8*32 tiles with 4*4 subtiles
@@ -3030,6 +3059,17 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
30303059
int out_idx = (row*idx_size) + blockIdx.x;
30313060
out[out_idx] = val;
30323061
}
3062+
}*/
3063+
3064+
//Only col format is used on ROCm
3065+
for(int row = threadIdx.x; row < rowsA; row+= blockDim.x)
3066+
{
3067+
//col-major offset
3068+
int offset = local_colidx * rowsA + row;
3069+
3070+
char val = A[offset];
3071+
int out_idx = (row*idx_size) + blockIdx.x;
3072+
out[out_idx] = val;
30333073
}
30343074
}
30353075

csrc/ops.hip

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,9 @@ template <int FORMATB, int DTYPE_OUT, int SCALE_ROWS> int igemmlt(hipblasLtHandl
587587
}
588588
else
589589
{
590-
has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_32F));
590+
has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, HIP_R_8I));
591+
hipblasOperation_t opA = HIPBLAS_OP_N;
592+
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opA, sizeof(opA)));
591593
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSB, &opT, sizeof(opT)));
592594
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&Cdesc, HIP_R_8I, m, n, ldc));
593595
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutSetAttribute(Cdesc, HIPBLASLT_MATRIX_LAYOUT_ORDER, &col32, sizeof(col32)));
@@ -651,14 +653,18 @@ int fill_up_to_nearest_multiple(int value, int multiple)
651653
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, float* newRowStats, float* newcolStats, half *bias, int numRows, int numCols)
652654
{
653655
int threads = 512;
654-
int tileCols = fill_up_to_nearest_multiple(numCols, 32);
655-
int n = numRows*tileCols;
656-
int subtile_rows = 128;
657-
int tilesize = 32*subtile_rows;
658-
int num_blocks = numRows/subtile_rows;
659-
num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
660-
num_blocks = num_blocks*(tileCols/32);
661-
assert(threads <= tilesize);
656+
//int tileCols = fill_up_to_nearest_multiple(numCols, 32);
657+
//int n = numRows*tileCols;
658+
int tileCols = numCols;
659+
int n = numRows*numCols;
660+
//int subtile_rows = 128;
661+
//int tilesize = 32*subtile_rows;
662+
//int num_blocks = numRows/subtile_rows;
663+
//num_blocks += (numRows % subtile_rows == 0) ? 0 : 1;
664+
//num_blocks = num_blocks*(tileCols/32);
665+
//assert(threads <= tilesize);
666+
int num_blocks = numRows * numCols / (threads * 4);
667+
num_blocks += (numRows * numCols) % (threads * 4) == 0 ? 0 : 1;
662668

663669
hipLaunchKernelGGL(( kdequant_mm_int32_fp16<4, 128, 512>), dim3(num_blocks), dim3(threads), 0, 0, A, rowStats, colStats, out, newRowStats, newcolStats, bias, numRows, numCols, tileCols, n);
664670
CUDA_CHECK_RETURN(hipPeekAtLastError());
@@ -820,14 +826,17 @@ template <int FORMAT> void extractOutliers(char * A, int *idx, char *out, int id
820826

821827
int num_blocks = idx_size;
822828

823-
if(FORMAT == COL_TURING)
829+
/*if(FORMAT == COL_TURING)
824830
{
825831
tiledRows = fill_up_to_nearest_multiple(rows, 8);
826832
}
827833
else if(FORMAT == COL_AMPERE)
828834
{
829835
tiledRows = fill_up_to_nearest_multiple(rows, 32);
830-
}
836+
}*/
837+
838+
//for col format on ROCm
839+
tiledRows = rows;
831840

832841
hipLaunchKernelGGL(( kExtractOutliers<FORMAT>), dim3(num_blocks), dim3(threads), 0, 0, A, idx, out, idx_size, rows, cols, tiledRows, tiledCols);
833842
CUDA_CHECK_RETURN(hipPeekAtLastError());

tests/test_autograd.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,6 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose):
288288
)
289289
names = ["dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_decomp_{}_has_fp16_weights_{}_has_bias_{}".format(*vals) for vals in str_values]
290290

291-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
292291
@pytest.mark.parametrize(
293292
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, decomp, has_fp16_weights, has_bias",
294293
values,

tests/test_functional.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,6 @@ def test_bench_8bit_training(batch, seq, model, hidden):
970970
values = list(product(dim1, dim4, dims, formatB, has_bias))
971971
names = ["dim1_{}_dim4_{}_dims_{}_formatB_{}_has_bias_{}".format(*vals) for vals in values]
972972

973-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
974973
@pytest.mark.parametrize("dim1, dim4, dims, formatB, has_bias", values, ids=names)
975974
def test_dequant_mm(dim1, dim4, dims, formatB, has_bias):
976975
inner = torch.randint(1, 128, size=(1,)).item()
@@ -1312,7 +1311,6 @@ def test_row_scale_bench(dim1, dim4, inner):
13121311
for vals in values
13131312
]
13141313

1315-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
13161314
@pytest.mark.parametrize(
13171315
"dim1, dim2, dim3, dims, dtype, orderA, orderOut, transpose",
13181316
values,
@@ -2045,7 +2043,6 @@ def quant_zp(x):
20452043
print(err1, err2, err3, err4, err5, err6)
20462044

20472045

2048-
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
20492046
def test_extract_outliers():
20502047
for i in range(k):
20512048
shapeA = (4096, 4096 * 4)

0 commit comments

Comments
 (0)