feat: add flash pipeline kernel cases
This commit is contained in:
19
kernels/wu_arch_cases/case20_flash_bwd_fused/README.md
Normal file
19
kernels/wu_arch_cases/case20_flash_bwd_fused/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# case20_flash_bwd_fused
|
||||
|
||||
FlashAttention backward-style fused pipeline smoke test.
|
||||
|
||||
The tensor warp performs one score MMA, then waits for the scalar warp to run
|
||||
softmax plus dsoftmax on the TMEM C row. The scalar warp writes the dS row back
|
||||
to TMEM A using signed fp16 values. The tensor warp then performs four more
|
||||
MMA steps, for five MMA operations total in this case.
|
||||
|
||||
This case verifies:
|
||||
|
||||
- tensor warp MMA sequencing around a scalar TMEM handoff;
|
||||
- scalar-only `FEXP.S` use for stable softmax;
|
||||
- dsoftmax shape `dS = P * (dP - sum(P * dP))`;
|
||||
- signed scalar TMEM stores feeding later tensor MMA operations.
|
||||
|
||||
The input score row is uniform, so `P = 1/32`. The synthetic upstream gradient
|
||||
uses two buckets, producing exact dS values `-1/32` for row entries 0..15 and
|
||||
`+1/32` for row entries 16..31.
|
||||
Reference in New Issue
Block a user