807 B
807 B
case20_flash_bwd_fused
FlashAttention backward-style fused pipeline smoke test.
The tensor warp performs one score MMA, then waits for the scalar warp to run softmax plus dsoftmax on the TMEM C row. The scalar warp writes the dS row back to TMEM A using signed fp16 values. The tensor warp then performs four more MMA steps, for five MMA operations total in this case.
This case verifies:
- tensor warp MMA sequencing around a scalar TMEM handoff;
- scalar-only
FEXP.Suse for stable softmax; - dsoftmax shape
dS = P * (dP - sum(P * dP)); - signed scalar TMEM stores feeding later tensor MMA operations.
The input score row is uniform, so P = 1/32. The synthetic upstream gradient
uses two buckets, producing exact dS values -1/32 for row entries 0..15 and
+1/32 for row entries 16..31.