diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 2f2d690..75f9c98 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -912,6 +912,10 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: diffusion_autocast_dtype = torch.bfloat16 print(">>> diffusion backbone set to bfloat16") + if hasattr(model, "projector_bf16"): + model.projector_bf16 = args.projector_dtype == "bf16" + print(f">>> projector dtype set to {args.projector_dtype}") + log_inference_precision(model) profiler.record_memory("after_model_load") @@ -1255,6 +1259,13 @@ def get_parser(): default="fp32", help="Dtype for diffusion backbone weights and sampling autocast." ) + parser.add_argument( + "--projector_dtype", + type=str, + choices=["fp32", "bf16"], + default="fp32", + help="Dtype for image/state/action projectors (autocast in forward)." + ) parser.add_argument( "--n_action_steps", type=int, diff --git a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh index 3054dd3..a6dd008 100644 --- a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh +++ b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh @@ -21,5 +21,6 @@ dataset="unitree_g1_pack_camera" --timestep_spacing 'uniform_trailing' \ --guidance_rescale 0.7 \ --perframe_ae \ - --diffusion_dtype bf16 + --diffusion_dtype bf16 \ + --projector_dtype bf16 } 2>&1 | tee "${res_dir}/output.log" diff --git a/useful.sh b/useful.sh index aec7681..bcbf234 100644 --- a/useful.sh +++ b/useful.sh @@ -71,3 +71,11 @@ python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unit - 优先 BF16:稳定性好于 FP16 - 只做半精度,不做 INT 量化:保持 PSNR - VAE 尽量 FP32:最影响画质的模块 + +BF16 projector比FP32 projector更准的可能原因: + - 数值路径更一致:主干在 BF16 下做 attention/MLP,projector 若是 FP32 会在进入主干前被 downcast,导致“先高精度非线性→再截断”的分布偏移;直接 BF16 算 projector 反而让输出 + 分布更贴近主干的计算习惯。 + - 训练分布匹配:训练时你是 precision:16,projector 长期在低精度环境下被优化;推理用 FP32 反而可能偏离训练时的统计特性。 + - LayerNorm/Softmax 敏感:Resampler/MLP 里 LN/Softmax 对精度很敏感,FP32 计算后再降精度,数值边界更容易“硬截断”;BF16 全程计算可能更平滑。 + + 这也解释了为什么你看到 BF16 projector 反而更准。 \ No newline at end of file