diff --git a/cpp/kernels/xqa/barriers.cuh b/cpp/kernels/xqa/barriers.cuh index 3c0318be465..d21ba7e8674 100644 --- a/cpp/kernels/xqa/barriers.cuh +++ b/cpp/kernels/xqa/barriers.cuh @@ -68,7 +68,7 @@ public: template __device__ inline mha::conditional_t arrive(uint32_t update = 1) { - ArrivalToken token; + ArrivalToken token{}; #if __CUDA_ARCH__ >= 900 if constexpr (scope == Scope::CTA) { @@ -128,9 +128,9 @@ public: __device__ inline bool isLocal() const { - uint32_t addrCtaRank; + uint32_t addrCtaRank{}; asm("getctarank.u64 %0, %1;\n" : "=r"(addrCtaRank) : "l"(addr())); - uint32_t ctaRank; + uint32_t ctaRank{}; asm("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(ctaRank)); return addrCtaRank == ctaRank; } @@ -154,7 +154,7 @@ public: #if __CUDA_ARCH__ >= 900 if constexpr (scope == Scope::CTA) { - ArrivalToken token; + ArrivalToken token{}; asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.b64 %0, [%1], %2;\n" : "=l"(token) : "l"(addr()), "r"(txCount) @@ -181,7 +181,7 @@ public: { if constexpr (scope == Scope::CTA) { - ArrivalToken token; + ArrivalToken token{}; switch (order) { case ArriveOrder::RELEASE: @@ -239,7 +239,7 @@ public: template __device__ inline bool test_wait(ArrivalToken&& token) { - uint32_t ready; + uint32_t ready{}; if constexpr (scope == Scope::CGA) { asm volatile( @@ -271,7 +271,7 @@ public: template __device__ inline bool test_wait_parity(bool parity) { - uint32_t ready; + uint32_t ready{}; if constexpr (scope == Scope::CGA) { asm volatile( @@ -303,7 +303,7 @@ public: template __device__ inline bool try_wait(ArrivalToken&& token) { - uint32_t ready; + uint32_t ready{}; if constexpr (scope == Scope::CGA) { asm volatile( @@ -334,7 +334,7 @@ public: template __device__ inline bool try_wait_parity(bool parity) { - uint32_t ready; + uint32_t ready{}; if constexpr (scope == Scope::CGA) { asm volatile( diff --git a/cpp/kernels/xqa/mha_components.cuh b/cpp/kernels/xqa/mha_components.cuh index a2b8619a0cb..4f4006ff77f 100644 --- a/cpp/kernels/xqa/mha_components.cuh +++ b/cpp/kernels/xqa/mha_components.cuh @@ -59,7 +59,7 @@ template __device__ inline QuadRegRowMaxT replicateForQuad(Warp const& warp, Vec const& src) { assertWarpConverged(); - QuadRegRowMaxT dst; + QuadRegRowMaxT dst{}; #pragma unroll for (uint32_t i = 0; i < src.size; i++) { @@ -82,7 +82,7 @@ __device__ inline ThrdRegRowMaxT dedupFromQuad(Warp assert(src[i] == __shfl_sync(~0U, src[i], laneId() / 4 * 4)); } #endif - ThrdRegRowMaxT dst; + ThrdRegRowMaxT dst{}; uint32_t const lane = laneId(); uint32_t const idxMat = lane / 8; uint32_t const idxRow = lane % 8; diff --git a/cpp/kernels/xqa/mha_sm90.cu b/cpp/kernels/xqa/mha_sm90.cu index ee78d67b651..1d40662da71 100644 --- a/cpp/kernels/xqa/mha_sm90.cu +++ b/cpp/kernels/xqa/mha_sm90.cu @@ -1616,7 +1616,7 @@ CUBIN_EXPORT __global__ { if (warpElectSync()) { - tma::load1DAsync(&smem.tokens[idxBuf], &scratchMem.tokens()[idxChunk], + tma::loadLinearAsync(&smem.tokens[idxBuf], &scratchMem.tokens()[idxChunk], sizeof(smem.tokens[idxBuf]), bar.produced); arrive_tx(bar.produced, sizeof(smem.tokens[idxBuf]), 1); } diff --git a/cpp/kernels/xqa/mla_sm120.cu b/cpp/kernels/xqa/mla_sm120.cu index 4ae50d5b5ba..74877512a7d 100644 --- a/cpp/kernels/xqa/mla_sm120.cu +++ b/cpp/kernels/xqa/mla_sm120.cu @@ -79,9 +79,9 @@ struct KVTilePartLoader // if greater than 1, then we need unrolling for the loading loop. Seems 1 is fine for latency. static inline constexpr uint32_t nbPageBuffers = 1; #if USE_PAGED_KV_CACHE - uint32_t const nbPages; // for bound check + uint32_t const nbPages; // for bound check Vec pageBuffers[nbPageBuffers]; - uint32_t idxTileRef; // idxTile used to load the pages + uint32_t idxTileRef = ~0U; // idxTile used to load the pages #endif uint32_t const baseOffset; @@ -117,6 +117,11 @@ __device__ inline KVTilePartLoader::KVTilePartLoader( , baseOffset{(idxReq * beamWidth) * 2} #endif { +#pragma unroll + for (auto& pageBuffer : pageBuffers) + { + pageBuffer.fill(kBAD_PAGE_INDEX); + } } // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache @@ -240,6 +245,7 @@ struct CgaXBuffer { XBuffer x; Vec rowSum; + Vec rowMaxLog2e; }; struct PingPongMutex @@ -247,6 +253,7 @@ struct PingPongMutex using ShmStorage = CtaBarrier[2]; ShmStorage& barriers; uint32_t const idxGrp; + bool skipWait = false; static __device__ inline void initStorage(ShmStorage& barriers, uint32_t thrdsPerGrp) { @@ -261,14 +268,23 @@ struct PingPongMutex { } + __device__ inline void test_lock(uint32_t iter) + { + skipWait = barriers[idxGrp].test_wait_parity(toParity<1>(iter)); + } + __device__ inline void lock(uint32_t iter) { - barriers[idxGrp].wait_parity(toParity<1>(iter)); + if (!skipWait) + { + barriers[idxGrp].wait_parity(toParity<1>(iter)); + } } __device__ inline void unlock() { barriers[idxGrp ^ 1U].arrive(); + skipWait = false; } }; @@ -299,10 +315,9 @@ constexpr bool useRegQ = USE_REG_Q; struct SharedMemA { - static inline constexpr uint32_t nbKBufs = 4; - static inline constexpr uint32_t nbXBufs = 1; + static inline constexpr uint32_t nbKBufs = 12; - static inline constexpr uint32_t regQParts = (useRegQ ? 1 : 0); + static inline constexpr uint32_t regQParts = (useRegQ ? 4 : 0); static inline constexpr uint32_t shmQParts = nbQParts - regQParts; using ShmQPart = Array2D; @@ -310,10 +325,9 @@ struct SharedMemA Vec q; ShmKPart k[nbKBufs]; - XBuffer x[nbXBufs]; - Vec rowSum[nbXBufs]; - Vec drain; // data does not matter. Used to help avoid fence. + // single buffer reused by two groups. sendX() warp will arbitrate the order of x buffer access via two xBars. + CgaXBuffer x; // scaled by log2e. Write by last CGA iteration (from the other producer CTA) and read by current producer CTA. Vec rowMaxLog2e; @@ -324,21 +338,19 @@ struct SharedMemA PingPongMutex::ShmStorage tensorCoreMutex; CtaBarrierPair kBars[nbKBufs]; - CtaBarrierPair xBars[nbXBufs]; + static inline constexpr uint32_t nbXBars = nbMathGrpsA; + CtaBarrierPair xBars[nbXBars]; #if USE_REG_Q CtaBarrierPair regQBar; #endif CtaBarrier shmQBar; - CgaBarrier cgaXBufConsumed; // for X - - PingPongMutex::ShmStorage rowMaxTransferMutex; // protect the order of rowMax transfer to consumers - CgaBarrier consumerRowMaxConsumedBar; // arrive by consumer CTAs. + CgaBarrier cgaXBufConsumed; // for X CtaBarrierPair multiBlockBars[nbMultiBlockBufs]; __device__ inline void invalidateBarriers(uint32_t thrdIdx) { - constexpr uint32_t nbBars = (useRegQ ? 15 : 13) + 2 * (nbKBufs + nbXBufs); + constexpr uint32_t nbBars = (useRegQ ? 12 : 10) + 2 * (nbKBufs + nbXBars); #ifndef __CUDACC_RTC__ constexpr uint32_t nbBarsRef = exactDiv(offsetof(SharedMemA, qkScaleLog2e) - offsetof(SharedMemA, rowMaxLog2eBar), 8); @@ -375,16 +387,16 @@ struct SharedMemB // in the future. struct XVBuffer { - XBuffer x; VBuffer v; - XBuffer pad; // for output swizzling + CgaXBuffer x; + uint8_t pad[headGrpSize * 128 * 2 - sizeof(VBuffer) - sizeof(CgaXBuffer)]; // for output swizzling }; XVBuffer xv[nbXVBufs]; __device__ inline XBuffer& x(uint32_t idx) { - return xv[idx].x; + return xv[idx].x.x; } __device__ inline VBuffer& v(uint32_t idx) @@ -392,15 +404,20 @@ struct SharedMemB return xv[idx].v; } - Vec xRowSum[nbXBufs]; + __device__ inline Vec& xRowSum(uint32_t idx) + { + return xv[idx].x.rowSum; + } + + __device__ inline Vec& xRowMaxLog2e(uint32_t idx) + { + return xv[idx].x.rowMaxLog2e; + } static inline constexpr uint32_t nbAccRowMaxSumCopies = 2; Vec accRowMaxLog2e[nbAccRowMaxSumCopies]; Vec accRowSum[nbAccRowMaxSumCopies]; - Vec xRowMaxLog2e[nbProducerCtasPerCga]; - CgaBarrier xRowMaxLog2eProducedBar[nbProducerCtasPerCga]; - CtaBarrierPair xBars[nbXBufs]; CtaBarrierPair vBars[nbVBufs]; @@ -411,23 +428,21 @@ struct SharedMemB __device__ inline void invalidateBarriers(uint32_t thrdIdx) { - constexpr uint32_t nbBars = 17; + constexpr uint32_t nbBars = 15; #ifndef __CUDACC_RTC__ - constexpr uint32_t nbBarsRef - = exactDiv(offsetof(SharedMemB, isLastSubSeq) - offsetof(SharedMemB, xRowMaxLog2eProducedBar), 8); + constexpr uint32_t nbBarsRef = exactDiv(offsetof(SharedMemB, isLastSubSeq) - offsetof(SharedMemB, xBars), 8); static_assert(nbBars == nbBarsRef); #endif if (thrdIdx < nbBars) { - reinterpret_cast(&xRowMaxLog2eProducedBar[0])[thrdIdx].~CtaBarrier(); + reinterpret_cast(&xBars[0])[thrdIdx].~CtaBarrier(); } } __device__ inline Vec& getMultiBlockBufs() { #ifndef __CUDACC_RTC__ - static_assert( - sizeof(Vec) < offsetof(SharedMemB, xRowMaxLog2eProducedBar)); + static_assert(sizeof(Vec) < offsetof(SharedMemB, xBars)); #endif return *reinterpret_cast*>(this); } @@ -518,11 +533,10 @@ struct Producer b.initialize(1, thrdsPerGrp); b.consumed.arrive(thrdsPerGrp); } - if (warpRank < SharedMemA::nbXBufs) + if (warpRank < SharedMemA::nbXBars) { auto& b = smem.xBars[warpRank]; b.initialize(thrdsPerGrp, 1); - b.consumed.arrive(1); } #if USE_REG_Q if (warpRank == 0) @@ -546,10 +560,6 @@ struct Producer init(&smem.cgaXBufConsumed, 1 * nbVSplit); smem.cgaXBufConsumed.arrive(1 * nbVSplit); PingPongMutex::initStorage(smem.tensorCoreMutex, thrdsPerGrp); - PingPongMutex::initStorage(smem.rowMaxTransferMutex, thrdsPerGrp); - init(&smem.consumerRowMaxConsumedBar, warp_size * nbComputeWarpsB * nbVSplit); - smem.consumerRowMaxConsumedBar.arrive( - warp_size * nbComputeWarpsB * nbVSplit); } if (nbSubSeq > 1 && warpRank < nbMultiBlockBufs) { @@ -674,7 +684,6 @@ private: uint32_t const grpIdx = warpIdx.y; uint32_t const tileBaseRow = warpTile.y * warpIdx.x; PingPongMutex tensorCoreMutex{smem.tensorCoreMutex, grpIdx}; - PingPongMutex rowMaxTransferMutex{smem.rowMaxTransferMutex, grpIdx}; constexpr uint32_t partNbInstK = exactDiv(partElemsK, qmmaShape.k); using AtomA = Vec; // for 16x32 data, working as mat A of QMMA.16832 @@ -726,7 +735,6 @@ private: } smem.regQBar.consumed.arrive(); #endif - smem.shmQBar.wait_parity(false); // main loop #pragma unroll 1 for (uint32_t grpIter = 0; true; grpIter++) @@ -824,6 +832,10 @@ private: } } #endif + if (ctaIter == 0) + { + smem.shmQBar.wait_parity(false); + } #pragma unroll for (uint32_t idxPart = SharedMemA::regQParts; idxPart < nbQParts; idxPart++) { @@ -891,7 +903,11 @@ private: { applyMask(this_warp(), acc, 0, validTokens); } - WarpAcc const xF32 = scaleAndSoftmax(acc, grpIdx, grpIter, tileBaseRow, rowMaxTransferMutex); + ThrdRegRowMax rowMaxLog2e; + WarpAcc const xF32 = scaleAndSoftmax(rowMaxLog2e, acc, grpIdx, grpIter, tileBaseRow); + + auto& xBar = smem.xBars[grpIdx]; + bool const skipXBarWait = xBar.consumed.test_wait_parity(toParity<1>(grpIter)); // convert to fp8 WarpAcc const xF32Quant = xF32 * rcpXScale; // 0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15 @@ -917,17 +933,19 @@ private: : computeRowSumF32(this_warp(), xF32); // store xF8 and rowSum into L2 scratch buffer - uint32_t const idxXBuf = ctaIter % SharedMemA::nbXBufs; - auto& xBar = smem.xBars[idxXBuf]; - xBar.consumed.wait_parity(toParity(ctaIter)); - storeRowMax(smem.rowSum[idxXBuf], rowSum, tileBaseRow, lane); - storeOrderedXToShm(smem.x[idxXBuf], xF8, tileBaseRow, lane); + if (!skipXBarWait) + { + xBar.consumed.wait_parity(toParity<1>(grpIter)); + } + storeRowMax(smem.x.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane); + storeRowMax(smem.x.rowSum, rowSum, tileBaseRow, lane); + storeOrderedXToShm(smem.x.x, xF8, tileBaseRow, lane); xBar.produced.arrive(); } } - __device__ inline WarpAcc scaleAndSoftmax(WarpAcc const& acc, uint32_t grpIdx, uint32_t grpIter, - uint32_t tileBaseRow, PingPongMutex& rowMaxTransferMutex); + __device__ inline WarpAcc scaleAndSoftmax( + ThrdRegRowMax& rowMaxLog2e, WarpAcc const& acc, uint32_t grpIdx, uint32_t grpIter, uint32_t tileBaseRow); __device__ inline void storeOrderedXToShm(XBuffer& dst, Array2D, WarpAcc::rows, exactDiv(WarpAcc::cols, 2)> const& src, @@ -973,6 +991,11 @@ __device__ inline void Producer::loadK() __device__ inline void Producer::sendX() { + // let group 0 to produce first. + if (warpElectSync()) + { + smem.xBars[0].consumed.arrive(); + } for (uint32_t iter = 0; true; iter++) { uint32_t const idxTile = idxTileBeg() + iterStride() * iter; @@ -980,18 +1003,20 @@ __device__ inline void Producer::sendX() { break; } - uint32_t const idxBuf = iter % SharedMemA::nbXBufs; - auto& xBar = smem.xBars[idxBuf]; - xBar.produced.wait_parity(toParity(iter)); + uint32_t const idxBar = iter % SharedMemA::nbXBars; + auto& xBar = smem.xBars[idxBar]; + xBar.produced.wait_parity(toParity(iter)); smem.cgaXBufConsumed.wait_parity(toParity<1>(iter)); if (warpElectSync()) { auto& dst = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][ctaRank]; - tma::store1DAsync(&dst.x, &smem.x[idxBuf], sizeof(XBuffer)); - tma::store1DAsync(&dst.rowSum, &smem.rowSum[idxBuf], sizeof(smem.rowSum[0])); + tma::store1DAsync(&dst, &smem.x, sizeof(CgaXBuffer)); tma::commitGroup(); tma::waitGroup<0>(); - xBar.consumed.arrive(); + // it's turn for the other math group to produce. + uint32_t const idxBarNext = (iter + 1) % SharedMemA::nbXBars; + auto& xBarNext = smem.xBars[idxBarNext]; + xBarNext.consumed.arrive(); asm volatile("fence.release.cluster;\n"); #pragma unroll for (uint32_t i = 0; i < nbVSplit; i++) @@ -1004,7 +1029,7 @@ __device__ inline void Producer::sendX() } __device__ inline Producer::WarpAcc Producer::scaleAndSoftmax( - WarpAcc const& acc, uint32_t grpIdx, uint32_t grpIter, uint32_t tileBaseRow, PingPongMutex& rowMaxTransferMutex) + ThrdRegRowMax& rowMaxLog2e, WarpAcc const& acc, uint32_t grpIdx, uint32_t grpIter, uint32_t tileBaseRow) { uint32_t const ctaIter = grpIdx + grpIter * nbMathGrps; uint32_t const cgaIter = ctaRank + ctaIter * nbProducerCtasPerCga; @@ -1013,9 +1038,9 @@ __device__ inline Producer::WarpAcc Producer::scaleAndSoftmax( uint32_t const idxProducer = ctaRank; assert(ctaRank < nbProducerCtasPerCga); - auto const accLog2e = acc * smem.qkScaleLog2e; + float const qkScaleLog2e = smem.qkScaleLog2e; bool const skipWaitLastShmRowMax = smem.rowMaxLog2eBar[grpIdx].test_wait_parity(toParity<1>(grpIter)); - QuadRegRowMax const tileRowMaxLog2e = computeRowMax(accLog2e); + QuadRegRowMax const tileRowMaxLog2e = computeRowMax(acc) * qkScaleLog2e; // get max with previous CTA's rowMax if (!skipWaitLastShmRowMax) { @@ -1029,22 +1054,10 @@ __device__ inline Producer::WarpAcc Producer::scaleAndSoftmax( SharedMemA& smemNext = mapa(smem, ctaRank ^ 1U); CgaBarrier& nextRowMaxLog2eBar = smemNext.rowMaxLog2eBar[(cgaIter + 1) % (nbMathGrps * nbProducerCtasPerCga) / nbMathGrps]; - ThrdRegRowMax const rowMaxLog2e = dedupFromQuad(warp, quadRowMaxLog2e); + rowMaxLog2e = dedupFromQuad(warp, quadRowMaxLog2e); storeRowMaxAsync(nextRowMaxLog2eBar, smemNext.rowMaxLog2e, rowMaxLog2e, tileBaseRow, lane); nextRowMaxLog2eBar.arrive_tx_relaxed(sizeof(rowMaxLog2e)); // notify that the next CTA can read rowMax now. - // transfer rowMax to consumers. - rowMaxTransferMutex.lock(grpIter); // @fixme: use test_wait_parity() early to avoid latency. - smem.consumerRowMaxConsumedBar.wait_parity(checkedVal(grpIdx, toParity<1>(ctaIter))); - for (uint32_t idxConsumer = 0; idxConsumer < nbVSplit; idxConsumer++) - { - auto& smemB = getConsumerShm(idxConsumer); - storeRowMaxAsync(smemB.xRowMaxLog2eProducedBar[idxProducer], smemB.xRowMaxLog2e[idxProducer], - rowMaxLog2e, tileBaseRow, lane); - smemB.xRowMaxLog2eProducedBar[idxProducer].arrive_tx_relaxed(sizeof(rowMaxLog2e)); - } - rowMaxTransferMutex.unlock(); - WarpAcc x; // apply softmax #pragma unroll @@ -1060,9 +1073,9 @@ __device__ inline Producer::WarpAcc Producer::scaleAndSoftmax( #pragma unroll for (uint32_t j = 0; j < InstAcc::cols; j++) { - float elem = accLog2e(m, n)(i, j); - assert(maxVal >= elem); - x(m, n)(i, j) = exp2f(elem - maxVal); + float elem = acc(m, n)(i, j); + assert(maxVal >= elem * qkScaleLog2e); + x(m, n)(i, j) = exp2f(elem * qkScaleLog2e - maxVal); } } } @@ -1194,7 +1207,6 @@ struct Consumer { if (warpRank < nbProducerCtasPerCga) { - init(&smem.xRowMaxLog2eProducedBar[warpRank], Producer::thrdsPerGrp); init(&smem.cgaXBufProduced[warpRank], 1); } if (warpRank < SharedMemB::nbXBufs) @@ -1297,8 +1309,7 @@ __device__ inline void Consumer::compute() uint32_t const cB = 0; WarpAcc acc{}; - uint32_t idxXVBufLast; - bool skipWait_xRowMaxLog2eProducedBar = false; + uint32_t idxXVBufLast{}; for (uint32_t iter = 0; true; iter++) { uint32_t const idxTile = iterToTile(iter); @@ -1310,20 +1321,19 @@ __device__ inline void Consumer::compute() ThrdRegRowMax accRowMaxLog2e = loadShmRowMax(smem.accRowMaxLog2e[tileIdx.x], tileBase.y, lane); ThrdRegRowMax accRowSum = loadShmRowMax(smem.accRowSum[tileIdx.x], tileBase.y, lane); - uint32_t const idxProducer = iter % nbProducerCtasPerCga; - if (!skipWait_xRowMaxLog2eProducedBar) - { - smem.xRowMaxLog2eProducedBar[idxProducer].wait_parity(toParity(iter)); - } - ThrdRegRowMax const xRowMaxLog2e = loadShmRowMax(smem.xRowMaxLog2e[idxProducer], tileBase.y, lane); - auto& prodSmem = getProducerShm(idxProducer); - uint32_t const drainData = hashRegData(xRowMaxLog2e); - tma::storeAsync(&prodSmem.drain[lane], drainData, prodSmem.consumerRowMaxConsumedBar); - prodSmem.consumerRowMaxConsumedBar.template arrive_tx(sizeof(drainData)); + uint32_t const idxXBuf = iter % SharedMemB::nbXBufs; + uint32_t const idxVBuf = iter % SharedMemB::nbVBufs; + auto& xBar = smem.xBars[idxXBuf]; + auto& vBar = smem.vBars[idxVBuf]; + // @fixme: merge these two barriers and use test_wait_parity() early to avoid latency. + bool const skipVBarWait = vBar.produced.test_wait_parity(toParity(iter)); + xBar.produced.wait_parity(toParity(iter)); + + ThrdRegRowMax const xRowMaxLog2e = loadShmRowMax(smem.xRowMaxLog2e(idxXBuf), tileBase.y, lane); assert(all(accRowMaxLog2e <= xRowMaxLog2e)); auto const needRescaleVec = (xRowMaxLog2e > accRowMaxLog2e); - UniformNeedRescaleMask rescaleMask; + UniformNeedRescaleMask rescaleMask{}; #pragma unroll for (uint32_t i = 0; i < rescaleMask.size; i++) { @@ -1360,17 +1370,13 @@ __device__ inline void Consumer::compute() } accRowMaxLog2e = xRowMaxLog2e; storeRowMax(smem.accRowMaxLog2e[tileIdx.x], accRowMaxLog2e, tileBase.y, lane); - - uint32_t const idxXBuf = iter % SharedMemB::nbXBufs; - uint32_t const idxVBuf = iter % SharedMemB::nbVBufs; - auto& xBar = smem.xBars[idxXBuf]; - auto& vBar = smem.vBars[idxVBuf]; - // @fixme: merge these two barriers and use test_wait_parity() early to avoid latency. - vBar.produced.wait_parity(toParity(iter)); - xBar.produced.wait_parity(toParity(iter)); + if (!skipVBarWait) + { + vBar.produced.wait_parity(toParity(iter)); + } auto const& xBuf = smem.x(idxXBuf); auto const& vBuf = smem.v(idxVBuf)[tileIdx.x]; - auto const xRowSum = loadShmRowMax(smem.xRowSum[idxXBuf], tileBase.y, lane); + auto const xRowSum = loadShmRowMax(smem.xRowSum(idxXBuf), tileBase.y, lane); accRowSum = accRowSum + xRowSum; storeRowMax(smem.accRowSum[tileIdx.x], accRowSum, tileBase.y, lane); @@ -1386,13 +1392,6 @@ __device__ inline void Consumer::compute() auto const data = ldmatrix_16x16_trans<2>(&vBuf.template at(qmmaShape.k * idxInstK + rB, idxAtomBx2 + cB)); AtomB const v[2] = {data[0], data[2], data[1], data[3]}; - if (idxInstK == tileNbInstK - 1 && idxAtomBx2 == warpTileNbAtomBx2 - 2) - { - uint32_t const iterNext = iter + 1; - skipWait_xRowMaxLog2eProducedBar - = smem.xRowMaxLog2eProducedBar[iterNext % nbProducerCtasPerCga].test_wait_parity( - toParity(iterNext)); - } #pragma unroll for (uint32_t i = 0; i < WarpAcc::rows; i++) { @@ -1479,11 +1478,9 @@ __device__ inline void Consumer::loadX() if (warpElectSync()) { auto& src = args.cgaXBuf[nbSubSeq * idxInputTokenGlobal + idxSubSeq][idxScratchXBuf]; - auto& dstX = smem.x(idxXBuf); - auto& dstRowSum = smem.xRowSum[idxXBuf]; - tma::load1DAsync(&dstX, &src.x, sizeof(smem.x(0)), xBar.produced); - tma::load1DAsync(&dstRowSum, &src.rowSum, sizeof(smem.xRowSum[0]), xBar.produced); - xBar.produced.arrive_tx(sizeof(smem.x(0)) + sizeof(smem.xRowSum[0])); + auto& dst = smem.xv[idxXBuf].x; + tma::loadLinearAsync(&dst, &src.x, sizeof(CgaXBuffer), xBar.produced); + xBar.produced.arrive_tx(sizeof(CgaXBuffer)); xBar.produced.wait_parity(toParity(iter)); uint32_t const idxProducer = idxScratchXBuf; // @fixme: check if this works. If it doesn't, randomly pick some data from dstX and dstRowSum and use @@ -1663,7 +1660,7 @@ __device__ inline void mergePartialOutputs(uint32_t& semaphore, Vec(idxSubSeq)); if (warpElectSync()) { - tma::load1DAsync(&shmBufs[idxBuf], &reqPartialResults[idxSubSeq].chunks[ctaRank], + tma::loadLinearAsync(&shmBufs[idxBuf], &reqPartialResults[idxSubSeq].chunks[ctaRank], sizeof(PartialResult::Chunk), bar.produced); bar.produced.arrive_tx(sizeof(PartialResult::Chunk)); } diff --git a/cpp/kernels/xqa/mla_sm120.cuh b/cpp/kernels/xqa/mla_sm120.cuh index d5c790c8ca5..5087721dea3 100644 --- a/cpp/kernels/xqa/mla_sm120.cuh +++ b/cpp/kernels/xqa/mla_sm120.cuh @@ -8,7 +8,7 @@ template __device__ inline ThrdRegRowMaxT loadShmRowMax( Vec const& shm, uint32_t tileBaseRow, uint32_t lane = laneId()) { - ThrdRegRowMaxT result; + ThrdRegRowMaxT result{}; #pragma unroll for (uint32_t i = 0; i < result.size; i++) { @@ -42,7 +42,7 @@ __device__ inline void storeRowMaxAsync(CgaBarrier& bar, Vec template __device__ inline QuadRegRowMaxT computeRowMax(WarpAccT const& acc) { - QuadRegRowMaxT rowMaxLog2e; + QuadRegRowMaxT rowMaxLog2e{}; // compute per-thread row max #pragma unroll for (uint32_t n = 0; n < acc.cols; n++) diff --git a/cpp/kernels/xqa/test/test.cpp b/cpp/kernels/xqa/test/test.cpp index 8737fae6f73..57bf86b20ed 100644 --- a/cpp/kernels/xqa/test/test.cpp +++ b/cpp/kernels/xqa/test/test.cpp @@ -678,7 +678,6 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, } if (isTracing) { - runKernel(); printf("Tracing is enabled\n"); } checkCuda(cudaEventRecord(tic, stream)); diff --git a/cpp/kernels/xqa/tma.h b/cpp/kernels/xqa/tma.h index a5614cf2fdf..38d7e439286 100644 --- a/cpp/kernels/xqa/tma.h +++ b/cpp/kernels/xqa/tma.h @@ -45,12 +45,19 @@ typedef struct CUtensorMap_st namespace tma { -__device__ inline void load1DAsync(void* dst, void const* src, uint32_t nbBytes, CtaBarrier& bar) +__device__ inline void loadLinearAsync(void* dst, void const* src, uint32_t nbBytes, CtaBarrier& bar) { asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" : : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), - "l"(__cvta_generic_to_shared(&bar))); + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); +} + +__device__ inline void prefetchLinear(void const* src, uint32_t nbBytes) +{ + asm volatile("cp.async.bulk.prefetch.L2.global [%0], %1;\n" ::"l"(reinterpret_cast(src)), "r"(nbBytes) + : "memory"); } // dsr and &bar must be remote address generated by mapa and src must be local address @@ -59,7 +66,8 @@ __device__ inline void sm2smCopyAsync(void* dst, void const* src, uint32_t nbByt asm volatile("cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" : : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), - "l"(__cvta_generic_to_shared(&bar))); + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); } template @@ -247,14 +255,14 @@ __device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) __device__ inline void commitGroup() { - asm volatile("cp.async.bulk.commit_group;\n"); + asm volatile("cp.async.bulk.commit_group;\n" : : : "memory"); } // wait until only targetNbInFlightGroups groups are still in-flight. template __device__ inline void waitGroup() { - asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(targetNbInFlightGroups)); + asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(targetNbInFlightGroups) : "memory"); } __device__ inline void prefetchTensorMap(CUtensorMap const& tensorMap, StateSpace loc = StateSpace::kGENERIC) @@ -284,20 +292,23 @@ __device__ inline void storeAsync(void* dst, T const& src, CgaBarrier& bar) { asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];\n" ::"l"( __cvta_generic_to_shared(dst)), - "r"(srcVec[0]), "l"(__cvta_generic_to_shared(&bar))); + "r"(srcVec[0]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); } else if constexpr (nbWords == 2) { asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.u32 [%0], {%1, %2}, [%3];\n" ::"l"( __cvta_generic_to_shared(dst)), - "r"(srcVec[0]), "r"(srcVec[1]), "l"(__cvta_generic_to_shared(&bar))); + "r"(srcVec[0]), "r"(srcVec[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); } else if constexpr (nbWords == 4) { asm volatile( "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v4.u32 [%0], {%1, %2, %3, %4}, [%5];\n" ::"l"( __cvta_generic_to_shared(dst)), - "r"(srcVec[0]), "r"(srcVec[1]), "r"(srcVec[2]), "r"(srcVec[3]), "l"(__cvta_generic_to_shared(&bar))); + "r"(srcVec[0]), "r"(srcVec[1]), "r"(srcVec[2]), "r"(srcVec[3]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); } else {