Skip to content

Commit 340d988

Browse files
committed
Overlap TMA store
1 parent 4499c4c commit 340d988

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

deep_gemm/include/deep_gemm/fp8_gemm.cuh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,10 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
364364
DG_STATIC_ASSERT(static_cast<int>(kSwizzleDMode > 0) + static_cast<int>(BLOCK_N_PADDING > 0) <= 1,
365365
"Swizzling and padding are not compatible");
366366

367+
// Wait last TMA store to be finished
368+
if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N)
369+
cute::tma_store_wait<0>();
370+
367371
// Write back to shared memory using STSM and issue TMA stores
368372
DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization");
369373
#pragma unroll
@@ -424,10 +428,7 @@ fp8_gemm_kernel(__nv_bfloat16* gmem_d, float* scales_b, int* grouped_layout,
424428
cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr,
425429
n_block_idx * BLOCK_N + in_block_n_offset,
426430
scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx));
427-
428-
// Wait TMA to be finished
429431
cute::tma_store_arrive();
430-
cute::tma_store_wait<0>();
431432
}
432433
__syncwarp();
433434
}

0 commit comments

Comments
 (0)