diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 93dc3cc9..47e21c70 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -11,8 +11,8 @@ #define ROW_REMAINDER_LOGIC constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool WARP_SPECIALIZED = false; -constexpr bool TENSOR_CORE = false; +constexpr bool WARP_SPECIALIZED = true; +constexpr bool TENSOR_CORE = true; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 1d88b4de..3c2d463c 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -11,6 +11,9 @@ constexpr bool DEBUG = false; constexpr bool Q_IS_K_MAJOR = true; +// temporary safety stop +static_assert(TENSOR_CORE && WARP_SPECIALIZED); + void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same // across the threadblock @@ -90,80 +93,78 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { "flashattention kernel assumes 1 threadblock occupancy per cluster"); uint8_t *smem_per_threadblock = reinterpret_cast( DEV_SMEM_START_ADDR); - float *smem_cursor = reinterpret_cast(smem_per_threadblock); - // constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; - // float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); - float *smem_Q0 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_Q1 = smem_cursor; - smem_cursor += smem_Q_size; - float *smem_K0 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_K1 = smem_cursor; - smem_cursor += smem_K_size; - float *smem_V0 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_V1 = smem_cursor; - smem_cursor += smem_V_size; - float *smem_S0 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_S1 = smem_cursor; - smem_cursor += smem_QK_size; - float *smem_P0 = smem_S0; // in-place update - float *smem_P1 = smem_S1; // in-place update - float *smem_O0 = smem_cursor; - smem_cursor += smem_O_size; - float *smem_O1 = smem_cursor; - smem_cursor += smem_O_size; + constexpr uint32_t smem_start = DEV_SMEM_START_ADDR; + constexpr uint32_t smem_octet0 = 0 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet1 = 1 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet2 = 2 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet3 = 3 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet4 = 4 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet5 = 5 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet6 = 6 * (SMEM_SIZE / 8); + constexpr uint32_t smem_octet7 = 7 * (SMEM_SIZE / 8); - // NOTE: this has to match with smem_* - static_assert(sizeof(elem_t) == sizeof(float)); - constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); - constexpr uint32_t spad_addr_Q0 = 0; - constexpr uint32_t spad_addr_Q1 = - spad_addr_Q0 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K0 = - spad_addr_Q1 + (smem_Q_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_K1 = - spad_addr_K0 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V0 = - spad_addr_K1 + (smem_K_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_V1 = - spad_addr_V0 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S0 = - spad_addr_V1 + (smem_V_size * sizeof(float) / spad_addr_factor); - constexpr uint32_t spad_addr_S1 = - spad_addr_S0 + (smem_QK_size * sizeof(float) / spad_addr_factor); + // allocation strategy: since the two warpgroups only access *0 and *1 + // buffers each, allocate *0 in the first half of SMEM, and *1 in the latter + // half + // at the same time, make sure Q and K are in different banks so that they + // can be accessed in parallel for GEMM; same for P and V + constexpr uint32_t smem_Q0_offset = smem_octet0; + constexpr uint32_t smem_Q1_offset = smem_octet4; + constexpr uint32_t smem_K0_offset = smem_octet1; + constexpr uint32_t smem_K1_offset = smem_octet5; + constexpr uint32_t smem_V0_offset = smem_K0_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_V1_offset = smem_K1_offset + smem_K_size * sizeof(float); + constexpr uint32_t smem_S0_offset = smem_octet2; + constexpr uint32_t smem_S1_offset = smem_octet6; + constexpr uint32_t smem_P0_offset = smem_Q0_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_P1_offset = smem_Q1_offset + smem_Q_size * sizeof(float); + constexpr uint32_t smem_O0_offset = smem_octet3; + constexpr uint32_t smem_O1_offset = smem_octet7; + + float *smem_Q0 = reinterpret_cast(smem_start + smem_Q0_offset); + float *smem_Q1 = reinterpret_cast(smem_start + smem_Q1_offset); + float *smem_K0 = reinterpret_cast(smem_start + smem_K0_offset); + float *smem_K1 = reinterpret_cast(smem_start + smem_K1_offset); + float *smem_V0 = reinterpret_cast(smem_start + smem_V0_offset); + float *smem_V1 = reinterpret_cast(smem_start + smem_V1_offset); + float *smem_S0 = reinterpret_cast(smem_start + smem_S0_offset); + float *smem_S1 = reinterpret_cast(smem_start + smem_S1_offset); + float *smem_P0 = reinterpret_cast(smem_start + smem_P0_offset); + float *smem_P1 = reinterpret_cast(smem_start + smem_P1_offset); + float *smem_O0 = reinterpret_cast(smem_start + smem_O0_offset); + float *smem_O1 = reinterpret_cast(smem_start + smem_O1_offset); // allocate rowmax/rowsum storage at the end of the sharedmem address space constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; constexpr uint32_t smem_O_row_scale_size = B_ROW; - // FIXME: dangerous - smem_cursor = reinterpret_cast(0xff038000); - float *smem_rowmax_0 = smem_cursor; - smem_cursor += smem_rowmax_size; - float *smem_rowmax_1 = smem_cursor; - smem_cursor += smem_rowmax_size; - float *smem_rowsum_0 = smem_cursor; - smem_cursor += smem_rowsum_size; - float *smem_rowsum_1 = smem_cursor; - smem_cursor += smem_rowsum_size; - float *smem_O_row_scale_0 = smem_cursor; - smem_cursor += smem_O_row_scale_size; - float *smem_O_row_scale_1 = smem_cursor; - smem_cursor += smem_O_row_scale_size; + float *smem_cursor_0 = smem_O0 + smem_O_size; + float *smem_cursor_1 = smem_O1 + smem_O_size; + // // FIXME: dangerous + // smem_cursor = reinterpret_cast(0xff038000); + float *smem_rowmax_0 = smem_cursor_0; + smem_cursor_0 += smem_rowmax_size; + float *smem_rowmax_1 = smem_cursor_1; + smem_cursor_1 += smem_rowmax_size; + float *smem_rowsum_0 = smem_cursor_0; + smem_cursor_0 += smem_rowsum_size; + float *smem_rowsum_1 = smem_cursor_1; + smem_cursor_1 += smem_rowsum_size; + float *smem_O_row_scale_0 = smem_cursor_0; + smem_cursor_0 += smem_O_row_scale_size; + float *smem_O_row_scale_1 = smem_cursor_1; + smem_cursor_1 += smem_O_row_scale_size; // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked constexpr uint32_t smem_scratchpad_size = threads_per_warpgroup * 2 /*arbitrary slack*/; - float *smem_scratchpad_0 = smem_cursor; - smem_cursor += smem_scratchpad_size; - float *smem_scratchpad_1 = smem_cursor; - smem_cursor += smem_scratchpad_size; + float *smem_scratchpad_0 = smem_cursor_0; + smem_cursor_0 += smem_scratchpad_size; + float *smem_scratchpad_1 = smem_cursor_1; + smem_cursor_1 += smem_scratchpad_size; // select the correct buffer by warpgroup float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0; @@ -179,6 +180,21 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad = (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; + static_assert(sizeof(elem_t) == sizeof(float)); + constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); + constexpr uint32_t spad_addr_Q0 = smem_Q0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_Q1 = smem_Q1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K0 = smem_K0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_K1 = smem_K1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V0 = smem_V0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_V1 = smem_V1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S0 = smem_S0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_S1 = smem_S1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P0 = smem_P0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_P1 = smem_P1_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor; + constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor; + const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0; const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0; const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0;