32 lines
2.9 KiB
Bash
32 lines
2.9 KiB
Bash
python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4 --pred_video unitree_g1_pack_camera/case1/output/inference/0_full_fs6.mp4 --output_file unitree_g1_pack_camera/case1/psnr_result.json
|
||
采样阶段(synthesis/ddim_sampling)+ UNet 前向几乎吃掉全部时间,外加明显的 CPU 侧 aten::to/_to_copy 和 aten::copy_ 开销;整体优化优先级还是“减少采样步数 + 加速每步前向 + 降低无谓拷贝”。下
|
||
面是更针对这份 profile 的思路:
|
||
|
||
- 优先级1:减少采样步数/换更快采样器
|
||
- 把 DDIM 30 步降到 10–20 步,或改用 DPM-Solver++/UniPC;这往往是 1.5–3× 的最直接收益。采样逻辑在 src/unifolm_wma/models/samplers/ddim.py,入口在 scripts/evaluation/
|
||
world_model_interaction.py。
|
||
- 如允许训练侧投入,可做蒸馏(LCM/Consistency)让 4–8 步也可用。
|
||
- 优先级1:每步前向加速(编译 + AMP + TF32)
|
||
- torch.compile 只包 diffusion_model,见 scripts/evaluation/world_model_interaction.py(你文档里也已写)。
|
||
- 推理包一层 autocast(fp16/bf16)并开启 TF32:
|
||
torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision("high")。
|
||
- torch.backends.cudnn.benchmark = True 在固定 shape 下很有效。
|
||
- 优先级1:消除循环内的 to()/copy/clone
|
||
- profile 里 aten::_to_copy CPU 时间很高,建议逐步排查是否在 DDIM loop 或条件准备里重复 .to(device) / .float() / .half()。
|
||
- 把常量(timesteps/sigmas/ts 等)提前放 GPU,避免每步创建;避免不必要的 clone()。
|
||
- 重点排查 src/unifolm_wma/models/samplers/ddim.py 与 scripts/evaluation/world_model_interaction.py 的数据准备段。
|
||
- 优先级2:注意力实现确认
|
||
- attention 只占 3% 左右,但如果没启用 xformers/SDPA,仍有收益空间。检查 src/unifolm_wma/modules/attention.py 的 XFORMERS_IS_AVAILBLE。
|
||
- 无 xformers 时可改用 scaled_dot_product_attention(Flash Attention 路径)。
|
||
- 优先级2:VAE 解码 & 保存 I/O
|
||
- synthesis/decode_first_stage 仍是秒级,建议 autocast + 可能的 torch.compile。位置在 src/unifolm_wma/models/autoencoder.py。
|
||
- save_results 约 38s:如果只是评测,考虑降低保存频率/分辨率或异步写盘。
|
||
- 优先级3:结构性减负
|
||
- 降低 temporal_length、输入分辨率或 model_channels 会线性降低 compute(配置在 configs/inference/world_model_interaction.yaml)。
|
||
- 如果 action_generation 与 world_model_interaction 共享条件,可以缓存 CLIP/VAE 编码,避免重复计算(model_architecture_analysis.md 的条件编码流程已说明)。
|
||
|
||
如果你希望我直接落地改动,推荐顺序:
|
||
|
||
1. torch.compile + AMP + TF32 + cudnn.benchmark
|
||
2. 排查 .to()/copy/clone 的重复位置并移出循环
|
||
3. 若需要更大幅度,再换采样器/降步数 |