1. einsum('b i d, b j d -> b i j') → torch.bmm(q, k.transpose(-1,-2)) — 直接映射 rocBLAS batched GEMM
2. baddbmm 把 scale 融合进 GEMM,少一次 kernel launch 3. 第二个 einsum 同理换torch.bm 每一轮加速1到两秒
This commit is contained in:
@@ -118,4 +118,100 @@ SUMMARY
|
||||
Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak)
|
||||
Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak)
|
||||
GPU peak (BF16): 61.0 TFLOPS
|
||||
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
|
||||
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
|
||||
|
||||
========================================================================
|
||||
TABLE 1: STAGE TIMING
|
||||
========================================================================
|
||||
Stage Mean(ms) Std %
|
||||
------------------------------------------------------------------------
|
||||
1_Image_Embedding 29.5 0.16 0.1%
|
||||
2_VAE_Encode 51.3 0.06 0.1%
|
||||
3_Text_Conditioning 14.7 0.18 0.0%
|
||||
4_Projectors 0.2 0.03 0.0%
|
||||
5_DDIM_Loop 33392.5 3.21 97.3%
|
||||
6_VAE_Decode 808.4 1.00 2.4%
|
||||
7_Post_Process 15.8 0.56 0.0%
|
||||
------------------------------------------------------------------------
|
||||
TOTAL 34312.4
|
||||
|
||||
================================================================================
|
||||
TABLE 2: UNET SUB-MODULE BREAKDOWN
|
||||
================================================================================
|
||||
Module Type Total(ms) Count Per-call %
|
||||
--------------------------------------------------------------------------------
|
||||
ResBlock 10256.3 1100 9.32 23.2%
|
||||
SpatialTransformer 9228.2 800 11.54 20.9%
|
||||
CrossAttention 8105.8 3300 2.46 18.3%
|
||||
ConditionalUnet1D 6409.5 100 64.10 14.5%
|
||||
TemporalTransformer 5847.0 850 6.88 13.2%
|
||||
FeedForward 4338.1 1650 2.63 9.8%
|
||||
UNet.out 73.8 50 1.48 0.2%
|
||||
--------------------------------------------------------------------------------
|
||||
TOTAL (hooked) 44258.7
|
||||
|
||||
==========================================================================================
|
||||
TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)
|
||||
==========================================================================================
|
||||
Block Total(ms) % Breakdown
|
||||
------------------------------------------------------------------------------------------
|
||||
input_blocks.1 3376.2 7.6% SpatialTransformer=1101, CrossAttention=990, ResBlock=543, TemporalTransformer=454, FeedForward=288
|
||||
input_blocks.2 3374.0 7.6% SpatialTransformer=1100, CrossAttention=991, ResBlock=540, TemporalTransformer=455, FeedForward=288
|
||||
input_blocks.4 1592.4 3.6% SpatialTransformer=394, ResBlock=374, CrossAttention=303, TemporalTransformer=272, FeedForward=249
|
||||
input_blocks.5 1642.5 3.7% ResBlock=425, SpatialTransformer=397, CrossAttention=303, TemporalTransformer=271, FeedForward=247
|
||||
input_blocks.7 1469.0 3.3% ResBlock=416, SpatialTransformer=324, FeedForward=251, CrossAttention=240, TemporalTransformer=237
|
||||
input_blocks.8 1543.7 3.5% ResBlock=491, SpatialTransformer=325, FeedForward=250, CrossAttention=240, TemporalTransformer=238
|
||||
input_blocks.10 217.5 0.5% ResBlock=218
|
||||
input_blocks.11 216.8 0.5% ResBlock=217
|
||||
middle_block 848.9 1.9% ResBlock=434, SpatialTransformer=151, CrossAttention=134, TemporalTransformer=69, FeedForward=61
|
||||
output_blocks.0 303.2 0.7% ResBlock=303
|
||||
output_blocks.1 303.1 0.7% ResBlock=303
|
||||
output_blocks.2 302.8 0.7% ResBlock=303
|
||||
output_blocks.3 1734.8 3.9% ResBlock=687, SpatialTransformer=322, FeedForward=249, CrossAttention=239, TemporalTransformer=237
|
||||
output_blocks.4 1739.8 3.9% ResBlock=688, SpatialTransformer=323, FeedForward=251, CrossAttention=239, TemporalTransformer=238
|
||||
output_blocks.5 1622.3 3.7% ResBlock=570, SpatialTransformer=324, FeedForward=251, CrossAttention=239, TemporalTransformer=238
|
||||
output_blocks.6 1881.0 4.3% ResBlock=664, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=250
|
||||
output_blocks.7 1768.0 4.0% ResBlock=554, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
|
||||
output_blocks.8 1688.7 3.8% ResBlock=474, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
|
||||
output_blocks.9 3558.6 8.0% SpatialTransformer=1096, CrossAttention=992, ResBlock=727, TemporalTransformer=454, FeedForward=290
|
||||
output_blocks.10 3492.8 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
|
||||
output_blocks.11 3493.3 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
|
||||
out 73.8 0.2% UNet.out=74
|
||||
action_unet 3212.0 7.3% ConditionalUnet1D=3212
|
||||
state_unet 3197.6 7.2% ConditionalUnet1D=3198
|
||||
other 1606.2 3.6% TemporalTransformer=960, FeedForward=337, CrossAttention=309
|
||||
------------------------------------------------------------------------------------------
|
||||
TOTAL 44258.7
|
||||
|
||||
======================================================================
|
||||
TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)
|
||||
======================================================================
|
||||
Component Total(ms) %
|
||||
----------------------------------------------------------------------
|
||||
CrossAttention 8105.8 65.1%
|
||||
FeedForward 4338.1 34.9%
|
||||
----------------------------------------------------------------------
|
||||
TOTAL (attn+ff) 12443.9
|
||||
|
||||
==================================================
|
||||
TABLE 3: MEMORY SUMMARY
|
||||
==================================================
|
||||
Initial allocated: 11.82 GB
|
||||
Peak allocated: 14.43 GB
|
||||
Delta (pipeline): 2.61 GB
|
||||
|
||||
============================================================
|
||||
TABLE 4: THROUGHPUT
|
||||
============================================================
|
||||
Total pipeline latency: 34312.4 ms
|
||||
DDIM loop latency: 33392.5 ms
|
||||
DDIM steps: 50
|
||||
CFG scale: 1.0 (1x UNet/step)
|
||||
UNet forward calls: 50
|
||||
Per DDIM step: 667.9 ms
|
||||
Per UNet forward: 667.9 ms
|
||||
VAE encode bandwidth: 0.1 GB/s (peak HBM: 864.0 GB/s)
|
||||
VAE decode bandwidth: 0.0 GB/s (peak HBM: 864.0 GB/s)
|
||||
GPU BF16 peak: 61.0 TFLOPS
|
||||
|
||||
Done.
|
||||
Reference in New Issue
Block a user