打印推理权重精度信息

This commit is contained in:
2026-01-18 11:19:10 +08:00
parent c86c2be5ff
commit 7b499284bf
9 changed files with 256 additions and 143 deletions

32
useful.sh Normal file
View File

@@ -0,0 +1,32 @@
python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4 --pred_video unitree_g1_pack_camera/case1/output/inference/0_full_fs6.mp4 --output_file unitree_g1_pack_camera/case1/psnr_result.json
采样阶段synthesis/ddim_sampling+ UNet 前向几乎吃掉全部时间,外加明显的 CPU 侧 aten::to/_to_copy 和 aten::copy_ 开销;整体优化优先级还是“减少采样步数 + 加速每步前向 + 降低无谓拷贝”。下
面是更针对这份 profile 的思路:
- 优先级1减少采样步数/换更快采样器
- 把 DDIM 30 步降到 1020 步,或改用 DPM-Solver++/UniPC这往往是 1.53× 的最直接收益。采样逻辑在 src/unifolm_wma/models/samplers/ddim.py入口在 scripts/evaluation/
world_model_interaction.py。
- 如允许训练侧投入可做蒸馏LCM/Consistency让 48 步也可用。
- 优先级1每步前向加速编译 + AMP + TF32
- torch.compile 只包 diffusion_model见 scripts/evaluation/world_model_interaction.py你文档里也已写
- 推理包一层 autocastfp16/bf16并开启 TF32
torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision("high")
- torch.backends.cudnn.benchmark = True 在固定 shape 下很有效。
- 优先级1消除循环内的 to()/copy/clone
- profile 里 aten::_to_copy CPU 时间很高,建议逐步排查是否在 DDIM loop 或条件准备里重复 .to(device) / .float() / .half()
- 把常量timesteps/sigmas/ts 等)提前放 GPU避免每步创建避免不必要的 clone()
- 重点排查 src/unifolm_wma/models/samplers/ddim.py 与 scripts/evaluation/world_model_interaction.py 的数据准备段。
- 优先级2注意力实现确认
- attention 只占 3% 左右,但如果没启用 xformers/SDPA仍有收益空间。检查 src/unifolm_wma/modules/attention.py 的 XFORMERS_IS_AVAILBLE。
- 无 xformers 时可改用 scaled_dot_product_attentionFlash Attention 路径)。
- 优先级2VAE 解码 & 保存 I/O
- synthesis/decode_first_stage 仍是秒级,建议 autocast + 可能的 torch.compile。位置在 src/unifolm_wma/models/autoencoder.py。
- save_results 约 38s如果只是评测考虑降低保存频率/分辨率或异步写盘。
- 优先级3结构性减负
- 降低 temporal_length、输入分辨率或 model_channels 会线性降低 compute配置在 configs/inference/world_model_interaction.yaml
- 如果 action_generation 与 world_model_interaction 共享条件,可以缓存 CLIP/VAE 编码避免重复计算model_architecture_analysis.md 的条件编码流程已说明)。
如果你希望我直接落地改动,推荐顺序:
1. torch.compile + AMP + TF32 + cudnn.benchmark
2. 排查 .to()/copy/clone 的重复位置并移出循环
3. 若需要更大幅度,再换采样器/降步数