flash: Do GEMM II in Gemmini; verify 1st iteration

This commit is contained in:
Hansung Kim
2024-09-08 16:09:06 -07:00
parent 3f50ac57ee
commit cdb8377b62
2 changed files with 113 additions and 170 deletions

View File

@@ -95,7 +95,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
}
template <uint32_t dim_row, uint32_t dim_col>
template <uint32_t dim_row, uint32_t dim_col, bool block_row_major = false>
inline void thread_block_copy_tile(const float *src, float *dest,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock,
@@ -113,14 +113,18 @@ inline void thread_block_copy_tile(const float *src, float *dest,
for (int row_offset = 0; row_offset < dim_row;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
const uint32_t first_thread_offset = dim_col * row;
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
dest[thread_offset] = src[thread_offset];
thread_offset += NUM_THREADS;
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
const uint32_t smem_offset = B_COL * smem_row + smem_col;
const uint32_t gmem_offset = B_COL * row + col;
dest[gmem_offset] = src[smem_offset];
}
threadblock_barrier(threadblock_id_in_cluster,
@@ -415,6 +419,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("thread_block_online_softmax_finish_%=:" ::);
}
template <bool block_row_major = false>
__attribute__((always_inline)) inline void thread_block_O_rescale(
const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale,
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
@@ -431,19 +436,21 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
for (int row_offset = 0; row_offset < B_ROW;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
const uint32_t first_thread_offset = B_COL * row;
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
uint32_t thread_offset = first_thread_offset + tid_in_warp;
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
// Oi rescale
//
#pragma GCC unroll
for (int i = 0; i < per_row_iter; i++) {
const float o = smem_O_in[thread_offset];
const float scale = smem_O_row_scale[row];
smem_O_out[thread_offset] = (o * scale);
const uint32_t col_offset = NUM_THREADS * i;
const uint32_t col = col_offset + tid_in_warp;
const auto [smem_row, smem_col] =
remap_to_gemmini_dma_layout<block_row_major, HEADDIM>(row, col);
thread_offset += NUM_THREADS;
const uint32_t offset = HEADDIM * smem_row + smem_col;
const float o = smem_O_in[offset];
const float scale = smem_O_row_scale[row];
smem_O_out[offset] = (o * scale);
}
}