161 lines
8.5 KiB
Bash
161 lines
8.5 KiB
Bash
python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4 --pred_video unitree_g1_pack_camera/case1/output/inference/0_full_fs6.mp4 --output_file unitree_g1_pack_camera/case1/psnr_result.json
|
||
采样阶段(synthesis/ddim_sampling)+ UNet 前向几乎吃掉全部时间,外加明显的 CPU 侧 aten::to/_to_copy 和 aten::copy_ 开销;整体优化优先级还是“减少采样步数 + 加速每步前向 + 降低无谓拷贝”。下
|
||
面是更针对这份 profile 的思路:
|
||
|
||
- 优先级1:减少采样步数/换更快采样器
|
||
- 把 DDIM 30 步降到 10–20 步,或改用 DPM-Solver++/UniPC;这往往是 1.5–3× 的最直接收益。采样逻辑在 src/unifolm_wma/models/samplers/ddim.py,入口在 scripts/evaluation/
|
||
world_model_interaction.py。
|
||
- 如允许训练侧投入,可做蒸馏(LCM/Consistency)让 4–8 步也可用。
|
||
- 优先级1:每步前向加速(编译 + AMP + TF32)
|
||
- torch.compile 只包 diffusion_model,见 scripts/evaluation/world_model_interaction.py(你文档里也已写)。
|
||
- 推理包一层 autocast(fp16/bf16)并开启 TF32:
|
||
torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision("high")。
|
||
- torch.backends.cudnn.benchmark = True 在固定 shape 下很有效。
|
||
- 优先级1:消除循环内的 to()/copy/clone
|
||
- profile 里 aten::_to_copy CPU 时间很高,建议逐步排查是否在 DDIM loop 或条件准备里重复 .to(device) / .float() / .half()。
|
||
- 把常量(timesteps/sigmas/ts 等)提前放 GPU,避免每步创建;避免不必要的 clone()。
|
||
- 重点排查 src/unifolm_wma/models/samplers/ddim.py 与 scripts/evaluation/world_model_interaction.py 的数据准备段。
|
||
- 优先级2:注意力实现确认
|
||
- attention 只占 3% 左右,但如果没启用 xformers/SDPA,仍有收益空间。检查 src/unifolm_wma/modules/attention.py 的 XFORMERS_IS_AVAILBLE。
|
||
- 无 xformers 时可改用 scaled_dot_product_attention(Flash Attention 路径)。
|
||
- 优先级2:VAE 解码 & 保存 I/O
|
||
- synthesis/decode_first_stage 仍是秒级,建议 autocast + 可能的 torch.compile。位置在 src/unifolm_wma/models/autoencoder.py。
|
||
- save_results 约 38s:如果只是评测,考虑降低保存频率/分辨率或异步写盘。
|
||
- 优先级3:结构性减负
|
||
- 降低 temporal_length、输入分辨率或 model_channels 会线性降低 compute(配置在 configs/inference/world_model_interaction.yaml)。
|
||
- 如果 action_generation 与 world_model_interaction 共享条件,可以缓存 CLIP/VAE 编码,避免重复计算(model_architecture_analysis.md 的条件编码流程已说明)。
|
||
|
||
如果你希望我直接落地改动,推荐顺序:
|
||
|
||
1. torch.compile + AMP + TF32 + cudnn.benchmark
|
||
2. 排查 .to()/copy/clone 的重复位置并移出循环
|
||
3. 若需要更大幅度,再换采样器/降步数
|
||
|
||
|
||
|
||
|
||
|
||
A100 上我推荐 BF16 优先(稳定性更好、PSNR 更稳),FP16 作为速度优先方案。
|
||
|
||
下面是“分模块”的 消融方案(从稳到激进):
|
||
|
||
0)基线
|
||
|
||
- 全 FP32(你现在就是这个)
|
||
|
||
1)只对扩散主干做 BF16(最推荐)
|
||
|
||
- 量化对象:model.model(扩散 UNet/WMAModel 主体)
|
||
- 保持 FP32:first_stage_model(VAE 编/解码)、cond_stage_model(文本)、embedder(图像)、image_proj_model
|
||
- 预期:PSNR 基本不掉 or 极小波动
|
||
|
||
2)+ 轻量投影/MLP 做 BF16
|
||
|
||
- 增加:image_proj_model、state_projector、action_projector
|
||
- 预期:几乎不影响 PSNR
|
||
|
||
3)+ 文本/图像编码做 BF16
|
||
|
||
- 增加:cond_stage_model、embedder
|
||
- 预期:可能有轻微波动,通常仍可接受
|
||
|
||
4)VAE 也做 BF16(最容易伤 PSNR)
|
||
|
||
- 增加:first_stage_model
|
||
- 预期:画质/PSNR 最敏感,建议最后做消融
|
||
|
||
———
|
||
|
||
具体建议(A100)
|
||
|
||
- 优先 BF16:稳定性好于 FP16
|
||
- 只做半精度,不做 INT 量化:保持 PSNR
|
||
- VAE 尽量 FP32:最影响画质的模块
|
||
|
||
BF16 projector比FP32 projector更准的可能原因:
|
||
- 数值路径更一致:主干在 BF16 下做 attention/MLP,projector 若是 FP32 会在进入主干前被 downcast,导致“先高精度非线性→再截断”的分布偏移;直接 BF16 算 projector 反而让输出
|
||
分布更贴近主干的计算习惯。
|
||
- 训练分布匹配:训练时你是 precision:16,projector 长期在低精度环境下被优化;推理用 FP32 反而可能偏离训练时的统计特性。
|
||
- LayerNorm/Softmax 敏感:Resampler/MLP 里 LN/Softmax 对精度很敏感,FP32 计算后再降精度,数值边界更容易“硬截断”;BF16 全程计算可能更平滑。
|
||
|
||
这也解释了为什么你看到 BF16 projector 反而更准。
|
||
|
||
embedder:
|
||
改成 autocast only(权重 FP32,预处理 FP32,仅主干 BF16)
|
||
|
||
- 效果差的那次:encoder autocast set to bf16 (weights=fp32)
|
||
也就是“权重 FP32 + autocast BF16(我现在的 autocast only 方案)”
|
||
- 效果好的那次:cond_stage_model / embedder param dtype: torch.bfloat16
|
||
也就是“权重 BF16(全 BF16)”
|
||
|
||
为什么 FP32 反而掉分大?主要是“混精度路径不一致”:
|
||
|
||
1. autocast only 不是等价于 full BF16
|
||
autocast 会把部分算子留在 FP32(比如 LayerNorm / Softmax),另一部分用 BF16;
|
||
这会让 encoder 输出分布和“全 BF16”不一样,尤其当扩散主干也是 BF16 时,条件向量分布一偏就会影响 PSNR。
|
||
2. 先 FP32 计算再转 BF16
|
||
encoder 的输出最后还是要拼到 BF16 主干里,等于“先高精度 → 再截断”。
|
||
训练时如果你用的是 BF16,模型更适应“低精度直接计算”的分布。
|
||
3. 你的现象和 projector 一样
|
||
你已经观察到 projector BF16 更准,这说明模型很可能“更偏好一致的 BF16 路径”。
|
||
|
||
结论:
|
||
对你这个模型,full BF16(权重 BF16)更匹配训练分布,反而 PSNR 更高。
|
||
|
||
如果你要我改回“bf16=全 BF16(权重 BF16)”,并且保留现在的 autocast only 作为可选模式,我可以这样做:
|
||
|
||
1. 新增 --encoder_mode {fp32, autocast, bf16_full}
|
||
2. bf16_full = 权重 BF16 + 前向 BF16
|
||
3. autocast = 权重 FP32 + 仅主干 autocast(现在的实现)
|
||
|
||
|
||
|
||
1. DDIM loop 内小张量分配优化(已完成)
|
||
|
||
- 每步 torch.full(...) 改成预先构造/广播,减少 loop 内分配
|
||
- 位置:src/unifolm_wma/models/samplers/ddim.py
|
||
|
||
2. attention mask 缓存到 GPU(已完成)
|
||
|
||
- _get_attn_mask_aa 现在直接在目标 device 构造并缓存,避免每步 CPU→GPU 拷贝
|
||
- 位置:src/unifolm_wma/modules/attention.py
|
||
|
||
|
||
|
||
|
||
|
||
--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
|
||
--export_only
|
||
|
||
|
||
checkpoint or previous run; diffusion backbone and condition embedder are bfloat16 while
|
||
projector is fp32, so casts are skipped as expected. The "inference precision" line reflects the
|
||
top-level param dtype (possibly fp32 embeddings), explaining the apparent inconsistency with
|
||
submodule param dtypes.
|
||
|
||
Explaining skip cast due to preserved dtypes
|
||
|
||
The skip cast happens because checkpoint loading preserves stored dtypes, so weights already
|
||
match the target (bf16 or fp32), making casting unnecessary. The log shows first param dtype for
|
||
checks, which can differ from submodules but is expected. For forcing recast, exporting a new
|
||
casted checkpoint is needed since repeated casts to the same dtype no-op.
|
||
|
||
• “skip cast” 的意思是该模块的权重已经是目标 dtype,所以 maybe_cast_module 直接跳过转换。你这段日志
|
||
里:
|
||
|
||
- diffusion backbone / cond_stage_model / embedder 已经是 torch.bfloat16
|
||
- projector 仍是 torch.float32(因为你设的是 projector_mode=autocast,权重保持 fp32)
|
||
- VAE 你设置为 fp32,也已经是 fp32
|
||
|
||
所以全部都匹配目标精度了,自然会 “skip cast”。这通常说明你加载的 ckpt 本身就是混合精度权重(比如
|
||
之前用 --export_casted_ckpt 导出的)。
|
||
|
||
另外这行:
|
||
|
||
>>> inference precision: model=torch.float32
|
||
|
||
只是取了模型第一个参数的 dtype(log_inference_precision 的实现如此),模型里可能还有一些主模块外
|
||
的 fp32 参数(比如位置编码等),所以看到整体是 fp32 很正常,不代表 diffusion 不在 bf16。后面的子
|
||
模块打印才更准确。
|
||
|
||
如果你想确认是否从“原始 fp32 ckpt”重新 cast,只要把 --ckpt_path 指回原始 ckpt 跑一次即可。 |