- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销 - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配 - alphas_cumprod 提升到 float64 保证数值精度 - SinusoidalPosEmb 输出 dtype跟随模型精度 - 新增 profile_unet.py 脚本及FLOPS 分析结果 - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
8.9 KiB
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_unet.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml
================================================================================================================================== FLOPS BY ATen OPERATOR (FlopCounterMode)
ATen Op | GFLOPS | % of Total
convolution | 6185.17 | 46.4%
addmm | 4411.17 | 33.1%
mm | 1798.34 | 13.5%
bmm | 949.54 | 7.1%
================================================================================================================================== FLOPS BY MODULE (FlopCounterMode)
Module | GFLOPS | % of Total
Global | 13344.23 | 100.0%
DiffusionWrapper | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6% DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5% DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
================================================================================================================================== SUMMARY
Total CUDA time: 761.4 ms Matmul CUDA time: 404.2 ms (53.1%) Non-matmul CUDA time: 357.1 ms (46.9%) Total FLOPS (FlopCounter): 13344.23 GFLOPS Matmul throughput: 33.01 TFLOPS/s (54.1% of BF16 peak) Overall throughput: 17.53 TFLOPS/s (28.7% of BF16 peak) GPU peak (BF16): 61.0 TFLOPS
================================================================================================================================== FLOPS BY ATen OPERATOR (FlopCounterMode)
ATen Op | GFLOPS | % of Total
convolution | 6185.17 | 46.4%
addmm | 4411.17 | 33.1%
mm | 1798.34 | 13.5%
bmm | 949.54 | 7.1%
================================================================================================================================== FLOPS BY MODULE (FlopCounterMode)
Module | GFLOPS | % of Total
DiffusionWrapper | 13344.23 | 100.0%
Global | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6% DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5% DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
================================================================================================================================== SUMMARY
Total CUDA time: 707.1 ms Matmul CUDA time: 403.1 ms (57.0%) Non-matmul CUDA time: 304.0 ms (43.0%) Total FLOPS (FlopCounter): 13344.23 GFLOPS Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak) Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak) GPU peak (BF16): 61.0 TFLOPS (unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$