flash: Change timing for QKV move
Verified with warp_specialized false; true remains to be fixed.
This commit is contained in:
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
constexpr uint32_t ROWMAX_SETS = 3;
|
constexpr uint32_t ROWMAX_SETS = 3;
|
||||||
constexpr bool DEBUG = true;
|
constexpr bool DEBUG = true;
|
||||||
constexpr bool WARP_SPECIALIZED = true;
|
constexpr bool WARP_SPECIALIZED = false;
|
||||||
|
|
||||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||||
|
|
||||||
@@ -490,8 +490,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threads_per_threadblock / NUM_THREADS;
|
threads_per_threadblock / NUM_THREADS;
|
||||||
|
|
||||||
// warpgroup context
|
// warpgroup context
|
||||||
constexpr uint32_t threads_per_warpgroup = threads_per_threadblock / 2;
|
constexpr uint32_t threads_per_warpgroup =
|
||||||
constexpr uint32_t warpgroups_per_cluster = threadblocks_per_cluster * 2;
|
threads_per_threadblock / (WARP_SPECIALIZED ? 2 : 1);
|
||||||
|
constexpr uint32_t warpgroups_per_cluster =
|
||||||
|
threadblocks_per_cluster * (WARP_SPECIALIZED ? 2 : 1);
|
||||||
const uint32_t warps_per_warpgroup_per_core =
|
const uint32_t warps_per_warpgroup_per_core =
|
||||||
NUM_WARPS / warpgroups_per_cluster;
|
NUM_WARPS / warpgroups_per_cluster;
|
||||||
const uint32_t warpgroup_id = task_id / threads_per_warpgroup;
|
const uint32_t warpgroup_id = task_id / threads_per_warpgroup;
|
||||||
@@ -507,6 +509,25 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const uint32_t dim_seqlen = arg->dim_seqlen;
|
const uint32_t dim_seqlen = arg->dim_seqlen;
|
||||||
const uint32_t dim_headdim = arg->dim_headdim;
|
const uint32_t dim_headdim = arg->dim_headdim;
|
||||||
|
|
||||||
|
// get global memory addresses from kernel arguments
|
||||||
|
const float *gmem_Q = reinterpret_cast<float *>(arg->addr_q);
|
||||||
|
const float *gmem_K = reinterpret_cast<float *>(arg->addr_k);
|
||||||
|
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
|
||||||
|
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
|
||||||
|
|
||||||
|
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
||||||
|
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
||||||
|
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
||||||
|
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
||||||
|
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
|
||||||
|
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
|
||||||
|
float *gmem_tmp_d6 = reinterpret_cast<float *>(0xd6000000UL);
|
||||||
|
float *gmem_tmp_d7 = reinterpret_cast<float *>(0xd7000000UL);
|
||||||
|
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
|
||||||
|
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
|
||||||
|
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
|
||||||
|
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
||||||
|
|
||||||
// static shared memory allocation
|
// static shared memory allocation
|
||||||
constexpr uint32_t smem_Q_size = B_ROW * HEADDIM;
|
constexpr uint32_t smem_Q_size = B_ROW * HEADDIM;
|
||||||
constexpr uint32_t smem_K_size = B_COL * HEADDIM;
|
constexpr uint32_t smem_K_size = B_COL * HEADDIM;
|
||||||
@@ -572,32 +593,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
smem_cursor -= smem_scratchpad_size;
|
smem_cursor -= smem_scratchpad_size;
|
||||||
float *smem_scratchpad_1 = smem_cursor;
|
float *smem_scratchpad_1 = smem_cursor;
|
||||||
|
|
||||||
|
// select the correct buffer by warpgroup
|
||||||
|
float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0;
|
||||||
|
float *smem_K = (warpgroup_id % 2) ? smem_K1 : smem_K0;
|
||||||
|
float *smem_V = (warpgroup_id % 2) ? smem_V1 : smem_V0;
|
||||||
|
float *smem_S = (warpgroup_id % 2) ? smem_S1 : smem_S0;
|
||||||
|
float *smem_O = (warpgroup_id % 2) ? smem_O1 : smem_O0;
|
||||||
|
float *smem_P = smem_S;
|
||||||
|
float *smem_O_row_scale =
|
||||||
|
(warpgroup_id % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
||||||
|
float *smem_rowmax = (warpgroup_id % 2) ? smem_rowmax_1 : smem_rowmax_0;
|
||||||
|
float *smem_rowsum = (warpgroup_id % 2) ? smem_rowsum_1 : smem_rowsum_0;
|
||||||
|
float *smem_scratchpad =
|
||||||
|
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
||||||
|
|
||||||
// initialize rowmax/rowsum values in sharedmem
|
// initialize rowmax/rowsum values in sharedmem
|
||||||
if (warpgroup_id == 0) {
|
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O,
|
||||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0,
|
smem_rowmax, smem_rowsum, smem_O_row_scale);
|
||||||
smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0);
|
|
||||||
} else {
|
|
||||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1,
|
|
||||||
smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1);
|
|
||||||
}
|
|
||||||
|
|
||||||
const float *gmem_Q = reinterpret_cast<float *>(arg->addr_q);
|
|
||||||
const float *gmem_K = reinterpret_cast<float *>(arg->addr_k);
|
|
||||||
const float *gmem_V = reinterpret_cast<float *>(arg->addr_v);
|
|
||||||
float *gmem_O = reinterpret_cast<float *>(arg->addr_o);
|
|
||||||
|
|
||||||
float *gmem_tmp_d0 = reinterpret_cast<float *>(0xd0000000UL);
|
|
||||||
float *gmem_tmp_d1 = reinterpret_cast<float *>(0xd1000000UL);
|
|
||||||
float *gmem_tmp_d2 = reinterpret_cast<float *>(0xd2000000UL);
|
|
||||||
float *gmem_tmp_d3 = reinterpret_cast<float *>(0xd3000000UL);
|
|
||||||
float *gmem_tmp_d4 = reinterpret_cast<float *>(0xd4000000UL);
|
|
||||||
float *gmem_tmp_d5 = reinterpret_cast<float *>(0xd5000000UL);
|
|
||||||
float *gmem_tmp_d6 = reinterpret_cast<float *>(0xd6000000UL);
|
|
||||||
float *gmem_tmp_d7 = reinterpret_cast<float *>(0xd7000000UL);
|
|
||||||
float *gmem_tmp_e0 = reinterpret_cast<float *>(0xe0000000UL);
|
|
||||||
float *gmem_tmp_e1 = reinterpret_cast<float *>(0xe1000000UL);
|
|
||||||
float *gmem_tmp_e2 = reinterpret_cast<float *>(0xe2000000UL);
|
|
||||||
float *gmem_tmp_e3 = reinterpret_cast<float *>(0xe3000000UL);
|
|
||||||
|
|
||||||
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
||||||
|
|
||||||
@@ -606,13 +618,38 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// read Q and K into SMEM before the loop starts
|
||||||
|
//
|
||||||
|
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||||
|
|
||||||
|
// load Q; this stays in SMEM for the entire loop
|
||||||
|
if constexpr (!WARP_SPECIALIZED) {
|
||||||
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||||
|
HEADDIM, threads_per_warpgroup>(
|
||||||
|
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||||
|
tid_in_warpgroup);
|
||||||
|
} else {
|
||||||
|
// FIXME: transpose to K-major in SMEM for correctness
|
||||||
|
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
|
||||||
|
HEADDIM, threads_per_warpgroup>(
|
||||||
|
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||||
|
tid_in_warpgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
// load K
|
||||||
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||||
|
HEADDIM, threads_per_warpgroup>(
|
||||||
|
dim_seqlen, /*tile_k=*/0, 0 /* dim_k == headdim */, gmem_K, smem_K,
|
||||||
|
tid_in_warpgroup);
|
||||||
|
|
||||||
|
// protect write to SMEM
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
asm volatile ("tile_loop_start_%=:" :: );
|
asm volatile ("tile_loop_start_%=:" :: );
|
||||||
|
|
||||||
// "inner loop" along the columns of K^T
|
// "inner loop" along the columns of K^T
|
||||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||||
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
|
for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) {
|
||||||
asm volatile ("buf_select_start_%=:" :: );
|
|
||||||
|
|
||||||
// float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1;
|
// float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1;
|
||||||
// float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0;
|
// float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0;
|
||||||
// float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1;
|
// float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1;
|
||||||
@@ -622,67 +659,87 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// float *smem_O_row_scale_consume =
|
// float *smem_O_row_scale_consume =
|
||||||
// (tile_k % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
// (tile_k % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
||||||
|
|
||||||
float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0;
|
constexpr bool skip_gemm_qk = false;
|
||||||
float *smem_K = (warpgroup_id % 2) ? smem_K1 : smem_K0;
|
|
||||||
float *smem_V = (warpgroup_id % 2) ? smem_V1 : smem_V0;
|
|
||||||
float *smem_S = (warpgroup_id % 2) ? smem_S1 : smem_S0;
|
|
||||||
float *smem_O = (warpgroup_id % 2) ? smem_O1 : smem_O0;
|
|
||||||
float *smem_P = smem_S;
|
|
||||||
float *smem_O_row_scale =
|
|
||||||
(warpgroup_id % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0;
|
|
||||||
float *smem_rowmax = (warpgroup_id % 2) ? smem_rowmax_1 : smem_rowmax_0;
|
|
||||||
float *smem_rowsum = (warpgroup_id % 2) ? smem_rowsum_1 : smem_rowsum_0;
|
|
||||||
float *smem_scratchpad =
|
|
||||||
(warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0;
|
|
||||||
|
|
||||||
asm volatile ("buf_select_finish_%=:" :: );
|
|
||||||
|
|
||||||
const uint32_t tile_k_ = tile_k;
|
|
||||||
|
|
||||||
constexpr bool skip_gemm_qk = true;
|
|
||||||
if constexpr (!skip_gemm_qk) {
|
if constexpr (!skip_gemm_qk) {
|
||||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
|
||||||
|
|
||||||
// load Q
|
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
|
||||||
HEADDIM, threads_per_warpgroup>(
|
|
||||||
dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/,
|
|
||||||
0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q,
|
|
||||||
tid_in_warpgroup);
|
|
||||||
|
|
||||||
// load K
|
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
|
||||||
HEADDIM, threads_per_warpgroup>(
|
|
||||||
dim_seqlen, tile_k_, 0 /* always 0 because dim_k == headdim */,
|
|
||||||
gmem_K, smem_K, tid_in_warpgroup);
|
|
||||||
|
|
||||||
// GMEM->SMEM and compute barrier
|
|
||||||
threadblock_barrier(warpgroup_id_in_cluster,
|
|
||||||
warps_per_warpgroup_per_core);
|
|
||||||
|
|
||||||
// clear out accumulators before GEMM
|
|
||||||
initialize_accum_regs<0>();
|
|
||||||
initialize_accum_regs<1>();
|
|
||||||
|
|
||||||
// GEMM I: S = Q*K
|
// GEMM I: S = Q*K
|
||||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
//
|
||||||
MemLayout::MN_major, B_ROW, B_COL, HEADDIM,
|
// FIXME: deduplicate this between GEMM II
|
||||||
/*load_accum=*/false,
|
if constexpr (!WARP_SPECIALIZED) {
|
||||||
/*write_to_smem=*/true>(
|
// clear out accumulators before GEMM
|
||||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
|
initialize_accum_regs<0>();
|
||||||
threads_per_warpgroup, warpgroups_per_cluster,
|
initialize_accum_regs<1>();
|
||||||
warpgroup_id_in_cluster);
|
|
||||||
|
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
|
MemLayout::MN_major, B_ROW, B_COL,
|
||||||
|
HEADDIM,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
|
||||||
|
threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
// when warp-specialized, there's only enough warps to do 64x32 tile
|
||||||
|
// size so we need to do 2 GEMM calls
|
||||||
|
static_assert(B_ROW / 2 == 32,
|
||||||
|
"tile size assumption for warp-specialization not met");
|
||||||
|
|
||||||
|
// assumes smem_Q is K-major
|
||||||
|
// FIXME: fix this to MN-major
|
||||||
|
float *smem_Q_half0 = smem_Q;
|
||||||
|
float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM;
|
||||||
|
float *smem_S_half0 = smem_S;
|
||||||
|
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
||||||
|
|
||||||
|
// clear out accumulators before GEMM
|
||||||
|
initialize_accum_regs<0>();
|
||||||
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
|
// split by rows into 2 chunks
|
||||||
|
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
|
||||||
|
MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
|
HEADDIM,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
|
||||||
|
MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
|
HEADDIM,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// load Q*K
|
// load Q*K
|
||||||
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_COL,
|
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_COL,
|
||||||
HEADDIM, threads_per_warpgroup>(
|
HEADDIM, threads_per_warpgroup>(
|
||||||
dim_seqlen, warpgroup_id /* parallelize across rows */, tile_k_,
|
dim_seqlen, warpgroup_id /* parallelize across rows */, tile_k,
|
||||||
gmem_Q /*=gmem_S*/, smem_S, tid_in_warpgroup);
|
gmem_Q /*contains S*/, smem_S, tid_in_warpgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
// protect GEMM result writes (smem_S) before softmax
|
// protect write to SMEM (smem_S) before softmax
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
if (warpgroup_id == 0) {
|
||||||
|
if (tile_k == 0) {
|
||||||
|
thread_block_copy_tile(smem_S, gmem_tmp_d0,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else if (tile_k == 1) {
|
||||||
|
thread_block_copy_tile(smem_S, gmem_tmp_d1,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
|
warps_per_warpgroup_per_core);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// inter-warpgroup barrier before online softmax
|
// inter-warpgroup barrier before online softmax
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
|
|
||||||
@@ -693,32 +750,36 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
smem_scratchpad, smem_rowmax, smem_rowsum,
|
smem_scratchpad, smem_rowmax, smem_rowsum,
|
||||||
smem_O_row_scale);
|
smem_O_row_scale);
|
||||||
|
|
||||||
// TODO: put the data movement for QKV here for inter-warpgroup
|
// data movement for K and V
|
||||||
//
|
//
|
||||||
|
// Q stays in SMEM for the entire loop
|
||||||
|
//
|
||||||
|
// load K for the next iteration
|
||||||
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||||
|
HEADDIM, threads_per_warpgroup>(
|
||||||
|
dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K,
|
||||||
|
tid_in_warpgroup);
|
||||||
|
|
||||||
|
// load V for the current iteration
|
||||||
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||||
HEADDIM, threads_per_warpgroup>(
|
HEADDIM, threads_per_warpgroup>(
|
||||||
HEADDIM, 0 /* full N-dimension */, tile_k_, gmem_V, smem_V,
|
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
|
||||||
tid_in_warpgroup);
|
tid_in_warpgroup);
|
||||||
|
|
||||||
|
// protect write to SMEM
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
if (tile_k_ == 0) {
|
if (tile_k == 0) {
|
||||||
// thread_block_copy_tile(smem_P, gmem_tmp_d0,
|
|
||||||
// tid_in_warpgroup, threads_per_warpgroup,
|
|
||||||
// warpgroup_id_in_cluster);
|
|
||||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
|
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
|
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
} else if (tile_k_ == 1) {
|
} else if (tile_k == 1) {
|
||||||
// thread_block_copy_tile(smem_P, gmem_tmp_d1,
|
|
||||||
// tid_in_warpgroup, threads_per_warpgroup,
|
|
||||||
// warpgroup_id_in_cluster);
|
|
||||||
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
|
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
@@ -748,24 +809,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
// O before PV
|
// O before PV
|
||||||
if (tile_k_ == 0) {
|
if (tile_k == 0) {
|
||||||
thread_block_copy_tile(smem_P, gmem_tmp_d0, tid_in_warpgroup,
|
thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
thread_block_copy_tile(smem_V, gmem_tmp_d6, tid_in_warpgroup,
|
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_warpgroup,
|
} else if (tile_k == 1) {
|
||||||
|
thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
} else if (tile_k_ == 1) {
|
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
||||||
thread_block_copy_tile(smem_P, gmem_tmp_d1, tid_in_warpgroup,
|
|
||||||
threads_per_warpgroup,
|
|
||||||
warpgroup_id_in_cluster);
|
|
||||||
thread_block_copy_tile(smem_V, gmem_tmp_d7, tid_in_warpgroup,
|
|
||||||
threads_per_warpgroup,
|
|
||||||
warpgroup_id_in_cluster);
|
|
||||||
thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_warpgroup,
|
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
}
|
}
|
||||||
@@ -838,12 +893,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
// O after PV
|
// O after PV
|
||||||
if (tile_k_ == 0) {
|
if (tile_k == 0) {
|
||||||
thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup,
|
thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
} else if (tile_k_ == 1) {
|
} else if (tile_k == 1) {
|
||||||
thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup,
|
thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup,
|
||||||
threads_per_warpgroup,
|
threads_per_warpgroup,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user