@@ -2300,13 +2300,16 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
23002300
23012301 const int n_out = numRows*numCols;
23022302
2303- int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1 );
2303+ // int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
23042304 // we have tiles of size numRows*32, thus col only increases every numRows
23052305 // num_row_tiles is the tiles after which the column increases by 32
23062306 // blockIdx.x is the index of the current tile
2307- int col = ((threadIdx.x % 32 ) + ((blockIdx.x /num_row_tiles)*32 ));
2307+ // int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
23082308 // base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
2309- int base_row = (blockIdx.x *SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
2309+ // int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
2310+
2311+ int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD;
2312+ int thread_offset = threadIdx.x * ITEMS_PER_THREAD;
23102313
23112314 // SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
23122315 // subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
@@ -2321,33 +2324,59 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
23212324
23222325 int local_values[ITEMS_PER_THREAD];
23232326 half local_output[ITEMS_PER_THREAD];
2324- float local_rowStats[ITEMS_PER_THREAD];
2325- __shared__ float smem_rowStats[SUBTILE_ROWS];
2327+ // float local_rowStats[ITEMS_PER_THREAD];
2328+ // __shared__ float smem_rowStats[SUBTILE_ROWS];
23262329
23272330 typedef hipcub::BlockLoad<int , THREADS, ITEMS_PER_THREAD, hipcub::BLOCK_LOAD_DIRECT> LoadInt32;
2328- typedef hipcub::BlockExchange<int , THREADS, ITEMS_PER_THREAD> ExchangeInt32;
2331+ // typedef hipcub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
23292332 __shared__ typename LoadInt32::TempStorage loadint32;
2330- __shared__ typename ExchangeInt32::TempStorage exchangeint32;
2333+ // __shared__ typename ExchangeInt32::TempStorage exchangeint32;
23312334
23322335
23332336 // L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
2334- float colStat = col >= numCols ? 0 .0f : colStats[col];
2335- float local_biasValue = ((bias == NULL ) || (col >= numCols)) ? 0 .0f : __half2float (bias[col]);
2337+ // float colStat = col >= numCols ? 0.0f : colStats[col];
2338+ // float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);
2339+ int row_idx, col_idx;
2340+ float colStat[ITEMS_PER_THREAD];
2341+ float local_biasValue[ITEMS_PER_THREAD];
2342+ float rowStat[ITEMS_PER_THREAD];
2343+ #pragma unroll ITEMS_PER_THREAD
2344+ for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2345+ {
2346+ row_idx = (block_offset + thread_offset + j) / numCols;
2347+ col_idx = (block_offset + thread_offset + j) % numCols;
2348+ colStat[j] = col_idx >= numCols ? 0 .0f : colStats[col_idx];
2349+ local_biasValue[j] = ((bias == NULL ) || (col_idx >= numCols)) ? 0 .0f : __half2float (bias[col_idx]);
2350+ rowStat[j] = row_idx >= numRows ? 0 .0f : rowStats[row_idx];
2351+ }
23362352 // no block loads for rows for now -- keep it simple
2337- for (int j = threadIdx.x ; j < SUBTILE_ROWS; j+=blockDim.x )
2353+ /* for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
23382354 {
23392355 // todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
23402356 int row = (base_row+j) % numRows; // wrap around
23412357 // each warp accesses the same element, for four consequitive elements
23422358 // todo: update description about striped shared memory, it is not needed
23432359 // rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
23442360 smem_rowStats[j] = rowStats[row];
2345- }
2361+ }*/
23462362 __syncthreads ();
23472363
2364+ int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out ? THREADS * ITEMS_PER_THREAD : n_out - block_offset;
2365+ LoadInt32 (loadint32).Load (&(A[block_offset]), local_values, valid_items, 0 );
23482366
2367+ #pragma unroll ITEMS_PER_THREAD
2368+ for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2369+ local_output[j] = __float2half ((local_values[j]*MM_DEQUANT_CONST*rowStat[j]*colStat[j]) + local_biasValue[j]);
2370+
23492371 // each block processes SUBTILE_ROWS*32 elements
2350- const int items_per_load = THREADS*ITEMS_PER_THREAD;
2372+ #pragma unroll ITEMS_PER_THREAD
2373+ for (int j = 0 ; j < ITEMS_PER_THREAD; j++)
2374+ {
2375+ int outIdx = block_offset + thread_offset + j;
2376+ if (outIdx< n_out)
2377+ out[outIdx] = local_output[j];
2378+ }
2379+ /* const int items_per_load = THREADS*ITEMS_PER_THREAD;
23512380 const int rows_per_load = items_per_load/32;
23522381
23532382 int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
@@ -2368,7 +2397,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
23682397 #pragma unroll ITEMS_PER_THREAD
23692398 for(int j = 0; j < ITEMS_PER_THREAD; j++)
23702399 local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
2371-
2400+
23722401 #pragma unroll ITEMS_PER_THREAD
23732402 for(int j = 0; j < ITEMS_PER_THREAD; j++)
23742403 local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue);
@@ -2388,7 +2417,7 @@ template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kd
23882417 }
23892418
23902419 row_offset += rows_per_load;
2391- }
2420+ }*/
23922421}
23932422
23942423
@@ -2974,7 +3003,7 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
29743003{
29753004 int local_colidx = idx[blockIdx.x ];
29763005
2977- if (FORMAT==COL_TURING)
3006+ /* if(FORMAT==COL_TURING)
29783007 {
29793008 // TURING FORMAT:
29803009 // 8*32 tiles with 4*4 subtiles
@@ -3030,6 +3059,17 @@ template <int FORMAT> __global__ void kExtractOutliers(char *A, int *idx, char *
30303059 int out_idx = (row*idx_size) + blockIdx.x;
30313060 out[out_idx] = val;
30323061 }
3062+ }*/
3063+
3064+ // Only col format is used on ROCm
3065+ for (int row = threadIdx.x ; row < rowsA; row+= blockDim.x )
3066+ {
3067+ // col-major offset
3068+ int offset = local_colidx * rowsA + row;
3069+
3070+ char val = A[offset];
3071+ int out_idx = (row*idx_size) + blockIdx.x ;
3072+ out[out_idx] = val;
30333073 }
30343074}
30353075
0 commit comments