sgemm_tcore: Move global_dmem_load back to kernel.cpp
This commit is contained in:
@@ -7,6 +7,214 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
|
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||||
|
const uint32_t k, const float *A, const float *B,
|
||||||
|
volatile float *local_a, volatile float *local_b,
|
||||||
|
const uint32_t tid_in_threadblock,
|
||||||
|
const uint32_t threadblock_id_x,
|
||||||
|
const uint32_t threadblock_id_y) {
|
||||||
|
const uint32_t local_a_row = tid_in_threadblock / BK;
|
||||||
|
const uint32_t local_a_col = tid_in_threadblock % BK;
|
||||||
|
const uint32_t local_as_row = tid_in_threadblock / BM;
|
||||||
|
const uint32_t local_as_col = tid_in_threadblock % BM;
|
||||||
|
const uint32_t local_b_row = tid_in_threadblock / BN;
|
||||||
|
const uint32_t local_b_col = tid_in_threadblock % BN;
|
||||||
|
|
||||||
|
constexpr uint32_t threads_in_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||||
|
|
||||||
|
// Data move from GMEM to SMEM
|
||||||
|
//
|
||||||
|
// Make sure global offset values for A and B are contiguous between
|
||||||
|
// neighboring threads to ensure GMEM coalescing.
|
||||||
|
//
|
||||||
|
// TODO: Sharedmem swizzling is important here
|
||||||
|
if constexpr (!TRANSPOSE_AS) {
|
||||||
|
// FIXME: !TRANSPOSE_AS code is old
|
||||||
|
|
||||||
|
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
||||||
|
// number of rows a full TB can read at a time
|
||||||
|
constexpr uint32_t row_stride_a = threads_in_threadblock / BK;
|
||||||
|
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
|
||||||
|
volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col;
|
||||||
|
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (uint32_t local_row_offset = 0; local_row_offset < BM;
|
||||||
|
local_row_offset += row_stride_a) {
|
||||||
|
// const uint32_t global_a_offset =
|
||||||
|
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
|
||||||
|
// local_a[BK * (local_a_row + local_row_offset) + local_a_col] =
|
||||||
|
// A[global_a_offset];
|
||||||
|
*local_a_tmp = *global_a;
|
||||||
|
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
local_a_tmp += BK * row_stride_a;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if constexpr (!GMEM_COALESCED_A) {
|
||||||
|
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
||||||
|
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
|
||||||
|
const float *global_a = A + dim_k * global_a_row + (k + local_as_row);
|
||||||
|
// FIXME experimenting with global coalescing
|
||||||
|
// const uint32_t global_a_row = BM * threadblock_id_y + local_as_row;
|
||||||
|
// const float *global_a = A + dim_k * global_a_row + (k + local_as_col);
|
||||||
|
volatile float *local_a_tmp = local_a + BM * local_as_row + local_as_col;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
row_stride_as * 8 <= BK,
|
||||||
|
"manual loop unrolling condition not met; consider increasing BK");
|
||||||
|
static_assert(
|
||||||
|
(BK % (row_stride_as * 8)) == 0,
|
||||||
|
"manual loop unrolling condition not met; BK should be power-of-two");
|
||||||
|
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (uint32_t local_row_offset = 0; local_row_offset < BK;
|
||||||
|
local_row_offset += row_stride_as * 8) {
|
||||||
|
// @perf: bank conflicts here
|
||||||
|
// const uint32_t global_a_offset =
|
||||||
|
// dim_k * (global_a_row) + (k + local_as_row + local_row_offset);
|
||||||
|
// FIXME experimenting with global coalescing
|
||||||
|
// const uint32_t global_a_offset =
|
||||||
|
// dim_k * (global_a_row + local_row_offset) + (k + local_as_col);
|
||||||
|
// local_a[BM * (local_as_row + local_row_offset) + local_as_col] =
|
||||||
|
// A[global_a_offset];
|
||||||
|
|
||||||
|
// *local_a_tmp = *global_a;
|
||||||
|
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
|
||||||
|
global_a += row_stride_as;
|
||||||
|
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
|
||||||
|
global_a += row_stride_as;
|
||||||
|
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
|
||||||
|
global_a += row_stride_as;
|
||||||
|
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
|
||||||
|
global_a += row_stride_as;
|
||||||
|
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
|
||||||
|
global_a += row_stride_as;
|
||||||
|
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
|
||||||
|
global_a += row_stride_as;
|
||||||
|
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
|
||||||
|
global_a += row_stride_as;
|
||||||
|
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
||||||
|
global_a += row_stride_as;
|
||||||
|
|
||||||
|
asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
local_a_tmp += BM * row_stride_as * 8;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
constexpr uint32_t row_stride_a = threads_in_threadblock / BK;
|
||||||
|
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
||||||
|
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
|
||||||
|
// NOTE that SMEM writes are transposed
|
||||||
|
volatile float *local_a_tmp = local_a + BM * local_a_col + local_a_row;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
row_stride_a * 8 <= BM,
|
||||||
|
"manual loop unrolling condition not met; consider increasing BM");
|
||||||
|
static_assert(
|
||||||
|
(BM % (row_stride_a * 8)) == 0,
|
||||||
|
"manual loop unrolling condition not met; BM should be power-of-two");
|
||||||
|
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (uint32_t local_row_offset = 0; local_row_offset < BM;
|
||||||
|
local_row_offset += row_stride_a * 8) {
|
||||||
|
// const uint32_t global_a_offset =
|
||||||
|
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
|
||||||
|
// NOTE that SMEM writes are transposed
|
||||||
|
// local_a[BM * (local_a_col) + local_a_row + local_row_offset] =
|
||||||
|
// A[global_a_offset];
|
||||||
|
|
||||||
|
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
|
||||||
|
// stride along columns
|
||||||
|
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
local_a_tmp += row_stride_a * 8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr uint32_t row_stride_b = threads_in_threadblock / BN;
|
||||||
|
const uint32_t global_b_col = BN * threadblock_id_x + local_b_col;
|
||||||
|
const float *global_b = B + dim_n * (k + local_b_row) + global_b_col;
|
||||||
|
volatile float *local_b_tmp = local_b + BN * local_b_row + local_b_col;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
row_stride_b * 8 <= BK,
|
||||||
|
"manual loop unrolling condition not met; consider increasing BK");
|
||||||
|
static_assert(
|
||||||
|
(BK % (row_stride_b * 8)) == 0,
|
||||||
|
"manual loop unrolling condition not met; BK should be power-of-two");
|
||||||
|
|
||||||
|
#pragma GCC unroll 1
|
||||||
|
for (uint32_t load_offset = 0; load_offset < BK;
|
||||||
|
load_offset += row_stride_b * 8) {
|
||||||
|
// const uint32_t global_b_offset =
|
||||||
|
// dim_n * (k + local_b_row + load_offset) + global_b_col;
|
||||||
|
// local_b[BN * (local_b_row + load_offset) + local_b_col] =
|
||||||
|
// B[global_b_offset];
|
||||||
|
|
||||||
|
// *local_b_tmp = *global_b;
|
||||||
|
|
||||||
|
// global_b += dim_n * row_stride_b;
|
||||||
|
// local_b_tmp += BN * row_stride_b;
|
||||||
|
|
||||||
|
asm volatile ("flw ft0, (%0)" :: "r"(global_b));
|
||||||
|
global_b += dim_n * row_stride_b;
|
||||||
|
asm volatile ("flw ft1, (%0)" :: "r"(global_b));
|
||||||
|
global_b += dim_n * row_stride_b;
|
||||||
|
asm volatile ("flw ft2, (%0)" :: "r"(global_b));
|
||||||
|
global_b += dim_n * row_stride_b;
|
||||||
|
asm volatile ("flw ft3, (%0)" :: "r"(global_b));
|
||||||
|
global_b += dim_n * row_stride_b;
|
||||||
|
asm volatile ("flw ft4, (%0)" :: "r"(global_b));
|
||||||
|
global_b += dim_n * row_stride_b;
|
||||||
|
asm volatile ("flw ft5, (%0)" :: "r"(global_b));
|
||||||
|
global_b += dim_n * row_stride_b;
|
||||||
|
asm volatile ("flw ft6, (%0)" :: "r"(global_b));
|
||||||
|
global_b += dim_n * row_stride_b;
|
||||||
|
asm volatile ("flw ft7, (%0)" :: "r"(global_b));
|
||||||
|
global_b += dim_n * row_stride_b;
|
||||||
|
|
||||||
|
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||||
|
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||||
|
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
|
||||||
|
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
|
||||||
|
local_b_tmp += BN * row_stride_b * 4;
|
||||||
|
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||||
|
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||||
|
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
|
||||||
|
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
|
||||||
|
local_b_tmp += BN * row_stride_b * 4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||||
const uint32_t tid_in_threadblock,
|
const uint32_t tid_in_threadblock,
|
||||||
const uint32_t threads_per_threadblock,
|
const uint32_t threads_per_threadblock,
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#ifndef _UTIL_H_
|
#ifndef _UTIL_H_
|
||||||
#define _UTIL_H_
|
#define _UTIL_H_
|
||||||
|
|
||||||
|
#include <vx_intrinsics.h>
|
||||||
#include <vx_spawn.h>
|
#include <vx_spawn.h>
|
||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
@@ -335,212 +336,4 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
|
|||||||
vx_barrier(barrier_id, count);
|
vx_barrier(barrier_id, count);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
|
||||||
const uint32_t k, const float *A, const float *B,
|
|
||||||
volatile float *local_a, volatile float *local_b,
|
|
||||||
const uint32_t tid_in_threadblock,
|
|
||||||
const uint32_t threadblock_id_x,
|
|
||||||
const uint32_t threadblock_id_y) {
|
|
||||||
const uint32_t local_a_row = tid_in_threadblock / BK;
|
|
||||||
const uint32_t local_a_col = tid_in_threadblock % BK;
|
|
||||||
const uint32_t local_as_row = tid_in_threadblock / BM;
|
|
||||||
const uint32_t local_as_col = tid_in_threadblock % BM;
|
|
||||||
const uint32_t local_b_row = tid_in_threadblock / BN;
|
|
||||||
const uint32_t local_b_col = tid_in_threadblock % BN;
|
|
||||||
|
|
||||||
constexpr uint32_t threads_in_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
|
||||||
|
|
||||||
// Data move from GMEM to SMEM
|
|
||||||
//
|
|
||||||
// Make sure global offset values for A and B are contiguous between
|
|
||||||
// neighboring threads to ensure GMEM coalescing.
|
|
||||||
//
|
|
||||||
// TODO: Sharedmem swizzling is important here
|
|
||||||
if constexpr (!TRANSPOSE_AS) {
|
|
||||||
// FIXME: !TRANSPOSE_AS code is old
|
|
||||||
|
|
||||||
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
|
||||||
// number of rows a full TB can read at a time
|
|
||||||
constexpr uint32_t row_stride_a = threads_in_threadblock / BK;
|
|
||||||
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
|
|
||||||
volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col;
|
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
|
||||||
for (uint32_t local_row_offset = 0; local_row_offset < BM;
|
|
||||||
local_row_offset += row_stride_a) {
|
|
||||||
// const uint32_t global_a_offset =
|
|
||||||
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
|
|
||||||
// local_a[BK * (local_a_row + local_row_offset) + local_a_col] =
|
|
||||||
// A[global_a_offset];
|
|
||||||
*local_a_tmp = *global_a;
|
|
||||||
|
|
||||||
global_a += dim_k * row_stride_a;
|
|
||||||
local_a_tmp += BK * row_stride_a;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if constexpr (!GMEM_COALESCED_A) {
|
|
||||||
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
|
|
||||||
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
|
|
||||||
const float *global_a = A + dim_k * global_a_row + (k + local_as_row);
|
|
||||||
// FIXME experimenting with global coalescing
|
|
||||||
// const uint32_t global_a_row = BM * threadblock_id_y + local_as_row;
|
|
||||||
// const float *global_a = A + dim_k * global_a_row + (k + local_as_col);
|
|
||||||
volatile float *local_a_tmp = local_a + BM * local_as_row + local_as_col;
|
|
||||||
|
|
||||||
static_assert(
|
|
||||||
row_stride_as * 8 <= BK,
|
|
||||||
"manual loop unrolling condition not met; consider increasing BK");
|
|
||||||
static_assert(
|
|
||||||
(BK % (row_stride_as * 8)) == 0,
|
|
||||||
"manual loop unrolling condition not met; BK should be power-of-two");
|
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
|
||||||
for (uint32_t local_row_offset = 0; local_row_offset < BK;
|
|
||||||
local_row_offset += row_stride_as * 8) {
|
|
||||||
// @perf: bank conflicts here
|
|
||||||
// const uint32_t global_a_offset =
|
|
||||||
// dim_k * (global_a_row) + (k + local_as_row + local_row_offset);
|
|
||||||
// FIXME experimenting with global coalescing
|
|
||||||
// const uint32_t global_a_offset =
|
|
||||||
// dim_k * (global_a_row + local_row_offset) + (k + local_as_col);
|
|
||||||
// local_a[BM * (local_as_row + local_row_offset) + local_as_col] =
|
|
||||||
// A[global_a_offset];
|
|
||||||
|
|
||||||
// *local_a_tmp = *global_a;
|
|
||||||
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
|
|
||||||
global_a += row_stride_as;
|
|
||||||
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
|
|
||||||
global_a += row_stride_as;
|
|
||||||
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
|
|
||||||
global_a += row_stride_as;
|
|
||||||
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
|
|
||||||
global_a += row_stride_as;
|
|
||||||
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
|
|
||||||
global_a += row_stride_as;
|
|
||||||
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
|
|
||||||
global_a += row_stride_as;
|
|
||||||
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
|
|
||||||
global_a += row_stride_as;
|
|
||||||
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
|
||||||
global_a += row_stride_as;
|
|
||||||
|
|
||||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
local_a_tmp += BM * row_stride_as * 8;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
constexpr uint32_t row_stride_a = threads_in_threadblock / BK;
|
|
||||||
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
|
||||||
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
|
|
||||||
// NOTE that SMEM writes are transposed
|
|
||||||
volatile float *local_a_tmp = local_a + BM * local_a_col + local_a_row;
|
|
||||||
|
|
||||||
static_assert(
|
|
||||||
row_stride_a * 8 <= BM,
|
|
||||||
"manual loop unrolling condition not met; consider increasing BM");
|
|
||||||
static_assert(
|
|
||||||
(BM % (row_stride_a * 8)) == 0,
|
|
||||||
"manual loop unrolling condition not met; BM should be power-of-two");
|
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
|
||||||
for (uint32_t local_row_offset = 0; local_row_offset < BM;
|
|
||||||
local_row_offset += row_stride_a * 8) {
|
|
||||||
// const uint32_t global_a_offset =
|
|
||||||
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
|
|
||||||
// NOTE that SMEM writes are transposed
|
|
||||||
// local_a[BM * (local_a_col) + local_a_row + local_row_offset] =
|
|
||||||
// A[global_a_offset];
|
|
||||||
|
|
||||||
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
|
|
||||||
global_a += dim_k * row_stride_a;
|
|
||||||
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
|
|
||||||
global_a += dim_k * row_stride_a;
|
|
||||||
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
|
|
||||||
global_a += dim_k * row_stride_a;
|
|
||||||
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
|
|
||||||
global_a += dim_k * row_stride_a;
|
|
||||||
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
|
|
||||||
global_a += dim_k * row_stride_a;
|
|
||||||
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
|
|
||||||
global_a += dim_k * row_stride_a;
|
|
||||||
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
|
|
||||||
global_a += dim_k * row_stride_a;
|
|
||||||
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
|
||||||
global_a += dim_k * row_stride_a;
|
|
||||||
|
|
||||||
// stride along columns
|
|
||||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
|
|
||||||
local_a_tmp += row_stride_a * 8;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
constexpr uint32_t row_stride_b = threads_in_threadblock / BN;
|
|
||||||
const uint32_t global_b_col = BN * threadblock_id_x + local_b_col;
|
|
||||||
const float *global_b = B + dim_n * (k + local_b_row) + global_b_col;
|
|
||||||
volatile float *local_b_tmp = local_b + BN * local_b_row + local_b_col;
|
|
||||||
|
|
||||||
static_assert(
|
|
||||||
row_stride_b * 8 <= BK,
|
|
||||||
"manual loop unrolling condition not met; consider increasing BK");
|
|
||||||
static_assert(
|
|
||||||
(BK % (row_stride_b * 8)) == 0,
|
|
||||||
"manual loop unrolling condition not met; BK should be power-of-two");
|
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
|
||||||
for (uint32_t load_offset = 0; load_offset < BK;
|
|
||||||
load_offset += row_stride_b * 8) {
|
|
||||||
// const uint32_t global_b_offset =
|
|
||||||
// dim_n * (k + local_b_row + load_offset) + global_b_col;
|
|
||||||
// local_b[BN * (local_b_row + load_offset) + local_b_col] =
|
|
||||||
// B[global_b_offset];
|
|
||||||
|
|
||||||
// *local_b_tmp = *global_b;
|
|
||||||
|
|
||||||
// global_b += dim_n * row_stride_b;
|
|
||||||
// local_b_tmp += BN * row_stride_b;
|
|
||||||
|
|
||||||
asm volatile ("flw ft0, (%0)" :: "r"(global_b));
|
|
||||||
global_b += dim_n * row_stride_b;
|
|
||||||
asm volatile ("flw ft1, (%0)" :: "r"(global_b));
|
|
||||||
global_b += dim_n * row_stride_b;
|
|
||||||
asm volatile ("flw ft2, (%0)" :: "r"(global_b));
|
|
||||||
global_b += dim_n * row_stride_b;
|
|
||||||
asm volatile ("flw ft3, (%0)" :: "r"(global_b));
|
|
||||||
global_b += dim_n * row_stride_b;
|
|
||||||
asm volatile ("flw ft4, (%0)" :: "r"(global_b));
|
|
||||||
global_b += dim_n * row_stride_b;
|
|
||||||
asm volatile ("flw ft5, (%0)" :: "r"(global_b));
|
|
||||||
global_b += dim_n * row_stride_b;
|
|
||||||
asm volatile ("flw ft6, (%0)" :: "r"(global_b));
|
|
||||||
global_b += dim_n * row_stride_b;
|
|
||||||
asm volatile ("flw ft7, (%0)" :: "r"(global_b));
|
|
||||||
global_b += dim_n * row_stride_b;
|
|
||||||
|
|
||||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
|
||||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
|
||||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
|
|
||||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
|
|
||||||
local_b_tmp += BN * row_stride_b * 4;
|
|
||||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
|
||||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
|
||||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
|
|
||||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
|
|
||||||
local_b_tmp += BN * row_stride_b * 4;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Reference in New Issue
Block a user