tensor: Test with multiple accumulators
This commit is contained in:
@@ -11,6 +11,10 @@ inline void vx_wmma() {
|
|||||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void vx_wmma_new() {
|
||||||
|
asm volatile (".insn r %0, 0, 0, x1, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
|
}
|
||||||
|
|
||||||
#include "test_data.h"
|
#include "test_data.h"
|
||||||
|
|
||||||
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||||
@@ -122,6 +126,14 @@ void vx_wmma_load() {
|
|||||||
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
|
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
|
||||||
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
|
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
|
||||||
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
|
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
|
||||||
|
asm volatile ("flw f24, %0" :: "m"(C[row+0][col+0]));
|
||||||
|
asm volatile ("flw f25, %0" :: "m"(C[row+0][col+1]));
|
||||||
|
asm volatile ("flw f26, %0" :: "m"(C[row+2][col+0]));
|
||||||
|
asm volatile ("flw f27, %0" :: "m"(C[row+2][col+1]));
|
||||||
|
asm volatile ("flw f28, %0" :: "m"(C[row+0][col+4]));
|
||||||
|
asm volatile ("flw f29, %0" :: "m"(C[row+0][col+5]));
|
||||||
|
asm volatile ("flw f30, %0" :: "m"(C[row+2][col+4]));
|
||||||
|
asm volatile ("flw f31, %0" :: "m"(C[row+2][col+5]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// float results[32*8];
|
// float results[32*8];
|
||||||
@@ -149,14 +161,22 @@ void store_wmma_result() {
|
|||||||
|
|
||||||
float *const results_wid = results + (DIM_M * DIM_M * wid);
|
float *const results_wid = results + (DIM_M * DIM_M * wid);
|
||||||
|
|
||||||
asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)]));
|
// asm volatile("fsw f16, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)]));
|
||||||
asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)]));
|
// asm volatile("fsw f17, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)]));
|
||||||
asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)]));
|
// asm volatile("fsw f18, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)]));
|
||||||
asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)]));
|
// asm volatile("fsw f19, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)]));
|
||||||
asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)]));
|
// asm volatile("fsw f20, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)]));
|
||||||
asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)]));
|
// asm volatile("fsw f21, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)]));
|
||||||
asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)]));
|
// asm volatile("fsw f22, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)]));
|
||||||
asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)]));
|
// asm volatile("fsw f23, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)]));
|
||||||
|
asm volatile("fsw f24, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 0)]));
|
||||||
|
asm volatile("fsw f25, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 1)]));
|
||||||
|
asm volatile("fsw f26, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 0)]));
|
||||||
|
asm volatile("fsw f27, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 1)]));
|
||||||
|
asm volatile("fsw f28, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 4)]));
|
||||||
|
asm volatile("fsw f29, %0" ::"m"(results_wid[DIM_M * (row + 0) + (col + 5)]));
|
||||||
|
asm volatile("fsw f30, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 4)]));
|
||||||
|
asm volatile("fsw f31, %0" ::"m"(results_wid[DIM_M * (row + 2) + (col + 5)]));
|
||||||
}
|
}
|
||||||
|
|
||||||
void print_wmma_result() {
|
void print_wmma_result() {
|
||||||
@@ -184,7 +204,7 @@ void wmma() {
|
|||||||
// for (int i = 0; i < 100; i++) {
|
// for (int i = 0; i < 100; i++) {
|
||||||
// vx_wmma();
|
// vx_wmma();
|
||||||
// }
|
// }
|
||||||
vx_wmma();
|
vx_wmma_new();
|
||||||
|
|
||||||
store_wmma_result();
|
store_wmma_result();
|
||||||
// print_wmma_result();
|
// print_wmma_result();
|
||||||
|
|||||||
Reference in New Issue
Block a user