flash: Do GEMM II in Gemmini; verify 1st iteration
This commit is contained in:
@@ -95,7 +95,7 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
|
||||
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
|
||||
}
|
||||
|
||||
template <uint32_t dim_row, uint32_t dim_col>
|
||||
template <uint32_t dim_row, uint32_t dim_col, bool block_row_major = false>
|
||||
inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threads_per_threadblock,
|
||||
@@ -113,14 +113,18 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
for (int row_offset = 0; row_offset < dim_row;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
const uint32_t first_thread_offset = dim_col * row;
|
||||
|
||||
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
dest[thread_offset] = src[thread_offset];
|
||||
thread_offset += NUM_THREADS;
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, B_COL>(row, col);
|
||||
const uint32_t smem_offset = B_COL * smem_row + smem_col;
|
||||
const uint32_t gmem_offset = B_COL * row + col;
|
||||
|
||||
dest[gmem_offset] = src[smem_offset];
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
@@ -415,6 +419,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
asm volatile("thread_block_online_softmax_finish_%=:" ::);
|
||||
}
|
||||
|
||||
template <bool block_row_major = false>
|
||||
__attribute__((always_inline)) inline void thread_block_O_rescale(
|
||||
const float *smem_O_in, float *smem_O_out, const float *smem_O_row_scale,
|
||||
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||
@@ -431,19 +436,21 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
constexpr uint32_t per_row_iter = B_COL / NUM_THREADS;
|
||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||
|
||||
// Oi rescale
|
||||
//
|
||||
#pragma GCC unroll
|
||||
for (int i = 0; i < per_row_iter; i++) {
|
||||
const float o = smem_O_in[thread_offset];
|
||||
const float scale = smem_O_row_scale[row];
|
||||
smem_O_out[thread_offset] = (o * scale);
|
||||
const uint32_t col_offset = NUM_THREADS * i;
|
||||
const uint32_t col = col_offset + tid_in_warp;
|
||||
const auto [smem_row, smem_col] =
|
||||
remap_to_gemmini_dma_layout<block_row_major, HEADDIM>(row, col);
|
||||
|
||||
thread_offset += NUM_THREADS;
|
||||
const uint32_t offset = HEADDIM * smem_row + smem_col;
|
||||
const float o = smem_O_in[offset];
|
||||
const float scale = smem_O_row_scale[row];
|
||||
smem_O_out[offset] = (o * scale);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user