sgemm_tcore: Replace hardcoded NUM_LANES with NUM_THREADS
This commit is contained in:
@@ -7,7 +7,7 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
#define GEMMINI_DMA 1
|
#define GEMMINI_DMA 0
|
||||||
#if SMEM_SIZE == 0x4000
|
#if SMEM_SIZE == 0x4000
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||||
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
||||||
@@ -273,10 +273,10 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
|
|
||||||
// no double-buffering
|
// no double-buffering
|
||||||
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
||||||
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_LANES;
|
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
|
||||||
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
||||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
|
||||||
volatile float *local_a = sharedmem_per_threadblock;
|
volatile float *local_a = sharedmem_per_threadblock;
|
||||||
constexpr size_t local_a_elems = (BM * BK);
|
constexpr size_t local_a_elems = (BM * BK);
|
||||||
|
|||||||
@@ -9,7 +9,7 @@
|
|||||||
|
|
||||||
#define DOUBLE_BUFFER 1
|
#define DOUBLE_BUFFER 1
|
||||||
#undef ELEM_PER_THREAD
|
#undef ELEM_PER_THREAD
|
||||||
#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_LANES) / (DOUBLE_BUFFER ? 2 : 1))
|
#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_THREADS) / (DOUBLE_BUFFER ? 2 : 1))
|
||||||
|
|
||||||
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
||||||
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
||||||
@@ -291,11 +291,11 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
const uint32_t threads_per_warpgroup = threads_per_threadblock / (DOUBLE_BUFFER ? 2 : 1);
|
const uint32_t threads_per_warpgroup = threads_per_threadblock / (DOUBLE_BUFFER ? 2 : 1);
|
||||||
const uint32_t warpgroup_id = tid_in_threadblock / threads_per_warpgroup;
|
const uint32_t warpgroup_id = tid_in_threadblock / threads_per_warpgroup;
|
||||||
const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME
|
const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME
|
||||||
const uint32_t warp_in_warpgroup = tid_in_warpgroup / NUM_LANES;
|
const uint32_t warp_in_warpgroup = tid_in_warpgroup / NUM_THREADS;
|
||||||
// FIXME: warp_row / BN should be warp-specialized?
|
// FIXME: warp_row / BN should be warp-specialized?
|
||||||
const uint32_t warp_row = warp_in_warpgroup / (BN / WN);
|
const uint32_t warp_row = warp_in_warpgroup / (BN / WN);
|
||||||
const uint32_t warp_col = warp_in_warpgroup % (BN / WN);
|
const uint32_t warp_col = warp_in_warpgroup % (BN / WN);
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
|
||||||
volatile float *local_a = sharedmem_per_threadblock;
|
volatile float *local_a = sharedmem_per_threadblock;
|
||||||
// const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
|
// const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
|
||||||
|
|||||||
@@ -8,8 +8,6 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
#define NUM_LANES 8
|
|
||||||
|
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||||
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
||||||
#define SMEM_ADDR_Q2 ((float * const) 0xff002000)
|
#define SMEM_ADDR_Q2 ((float * const) 0xff002000)
|
||||||
@@ -52,7 +50,7 @@
|
|||||||
#define TCK 8
|
#define TCK 8
|
||||||
#define WMITER (WM / TCM)
|
#define WMITER (WM / TCM)
|
||||||
#define WNITER (WN / TCN)
|
#define WNITER (WN / TCN)
|
||||||
#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_LANES) / (DOUBLE_BUFFER ? 2 : 1))
|
#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_THREADS) / (DOUBLE_BUFFER ? 2 : 1))
|
||||||
|
|
||||||
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
||||||
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
||||||
@@ -101,9 +99,9 @@ inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline constexpr void map_operand(const int tid, int &row, int &col) {
|
inline constexpr void map_operand(const int tid, int &row, int &col) {
|
||||||
if constexpr (NUM_LANES == 32) {
|
if constexpr (NUM_THREADS == 32) {
|
||||||
map_operand_32lanes(tid, row, col);
|
map_operand_32lanes(tid, row, col);
|
||||||
} else if constexpr (NUM_LANES == 8) {
|
} else if constexpr (NUM_THREADS == 8) {
|
||||||
map_operand_8lanes(tid, row, col);
|
map_operand_8lanes(tid, row, col);
|
||||||
} else {
|
} else {
|
||||||
// FIXME: not allowed
|
// FIXME: not allowed
|
||||||
@@ -137,9 +135,9 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline constexpr void map_c(const int tid, int &row, int &col) {
|
inline constexpr void map_c(const int tid, int &row, int &col) {
|
||||||
if constexpr (NUM_LANES == 32) {
|
if constexpr (NUM_THREADS == 32) {
|
||||||
map_c_32lanes(tid, row, col);
|
map_c_32lanes(tid, row, col);
|
||||||
} else if constexpr (NUM_LANES == 8) {
|
} else if constexpr (NUM_THREADS == 8) {
|
||||||
map_c_8lanes(tid, row, col);
|
map_c_8lanes(tid, row, col);
|
||||||
} else {
|
} else {
|
||||||
// FIXME: not allowed
|
// FIXME: not allowed
|
||||||
@@ -571,12 +569,12 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
const uint32_t threads_per_warpgroup = threads_per_threadblock / 1;
|
const uint32_t threads_per_warpgroup = threads_per_threadblock / 1;
|
||||||
const uint32_t warpgroup_id = tid_in_threadblock / threads_per_warpgroup;
|
const uint32_t warpgroup_id = tid_in_threadblock / threads_per_warpgroup;
|
||||||
const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME
|
const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME
|
||||||
const uint32_t warp_in_warpgroup = tid_in_warpgroup / NUM_LANES;
|
const uint32_t warp_in_warpgroup = tid_in_warpgroup / NUM_THREADS;
|
||||||
|
|
||||||
// FIXME: warp_row / BN should be warp-specialized?
|
// FIXME: warp_row / BN should be warp-specialized?
|
||||||
const uint32_t warp_row = warp_in_warpgroup / (BN / WN);
|
const uint32_t warp_row = warp_in_warpgroup / (BN / WN);
|
||||||
const uint32_t warp_col = warp_in_warpgroup % (BN / WN);
|
const uint32_t warp_col = warp_in_warpgroup % (BN / WN);
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
|
|
||||||
volatile float *local_a = sharedmem_per_threadblock;
|
volatile float *local_a = sharedmem_per_threadblock;
|
||||||
// const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
|
// const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y;
|
||||||
|
|||||||
@@ -6,8 +6,6 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
#define NUM_LANES 8
|
|
||||||
|
|
||||||
// Constraints on parameters:
|
// Constraints on parameters:
|
||||||
// * Memory:
|
// * Memory:
|
||||||
// (BM + BN) * BK * sizeof(float) <= sharedmem size.
|
// (BM + BN) * BK * sizeof(float) <= sharedmem size.
|
||||||
@@ -30,7 +28,7 @@
|
|||||||
#define TCK 8
|
#define TCK 8
|
||||||
#define WMITER (WM / TCM)
|
#define WMITER (WM / TCM)
|
||||||
#define WNITER (WN / TCN)
|
#define WNITER (WN / TCN)
|
||||||
#define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_LANES)
|
#define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_THREADS)
|
||||||
|
|
||||||
// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM
|
// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM
|
||||||
// scenario
|
// scenario
|
||||||
@@ -91,9 +89,9 @@ inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline constexpr void map_operand(const int tid, int &row, int &col) {
|
inline constexpr void map_operand(const int tid, int &row, int &col) {
|
||||||
if constexpr (NUM_LANES == 32) {
|
if constexpr (NUM_THREADS == 32) {
|
||||||
map_operand_32lanes(tid, row, col);
|
map_operand_32lanes(tid, row, col);
|
||||||
} else if constexpr (NUM_LANES == 8) {
|
} else if constexpr (NUM_THREADS == 8) {
|
||||||
map_operand_8lanes(tid, row, col);
|
map_operand_8lanes(tid, row, col);
|
||||||
} else {
|
} else {
|
||||||
// FIXME: not allowed
|
// FIXME: not allowed
|
||||||
@@ -127,9 +125,9 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline constexpr void map_c(const int tid, int &row, int &col) {
|
inline constexpr void map_c(const int tid, int &row, int &col) {
|
||||||
if constexpr (NUM_LANES == 32) {
|
if constexpr (NUM_THREADS == 32) {
|
||||||
map_c_32lanes(tid, row, col);
|
map_c_32lanes(tid, row, col);
|
||||||
} else if constexpr (NUM_LANES == 8) {
|
} else if constexpr (NUM_THREADS == 8) {
|
||||||
map_c_8lanes(tid, row, col);
|
map_c_8lanes(tid, row, col);
|
||||||
} else {
|
} else {
|
||||||
// FIXME: not allowed
|
// FIXME: not allowed
|
||||||
|
|||||||
Reference in New Issue
Block a user