diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 75f9c98..e2bfbc9 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -772,7 +772,11 @@ def image_guided_synthesis_sim_mode( with profiler.profile_section("synthesis/conditioning_prep"): img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] - cond_img_emb = model.embedder(cond_img) + embedder_ctx = nullcontext() + if getattr(model, "encoder_bf16", False) and model.device.type == "cuda": + embedder_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + with embedder_ctx: + cond_img_emb = model.embedder(cond_img) if model.model.conditioning_key == 'hybrid': z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) @@ -912,6 +916,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: diffusion_autocast_dtype = torch.bfloat16 print(">>> diffusion backbone set to bfloat16") + encoder_dtype = torch.float32 + if args.encoder_dtype == "bf16": + encoder_dtype = torch.bfloat16 + if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None: + model.cond_stage_model.to(dtype=encoder_dtype) + if hasattr(model, "embedder") and model.embedder is not None: + model.embedder.to(dtype=encoder_dtype) + model.encoder_bf16 = args.encoder_dtype == "bf16" + print(f">>> encoder dtype set to {args.encoder_dtype}") + if hasattr(model, "projector_bf16"): model.projector_bf16 = args.projector_dtype == "bf16" print(f">>> projector dtype set to {args.projector_dtype}") @@ -1266,6 +1280,13 @@ def get_parser(): default="fp32", help="Dtype for image/state/action projectors (autocast in forward)." ) + parser.add_argument( + "--encoder_dtype", + type=str, + choices=["fp32", "bf16"], + default="fp32", + help="Dtype for text/image encoders (cond_stage_model/embedder)." + ) parser.add_argument( "--n_action_steps", type=int, diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768733111.node-0.392376.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768733111.node-0.392376.0 new file mode 100644 index 0000000..245f59f Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768733111.node-0.392376.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768733357.node-0.394471.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768733357.node-0.394471.0 new file mode 100644 index 0000000..9b57dbe Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768733357.node-0.394471.0 differ diff --git a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh index a6dd008..a1068da 100644 --- a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh +++ b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh @@ -22,5 +22,6 @@ dataset="unitree_g1_pack_camera" --guidance_rescale 0.7 \ --perframe_ae \ --diffusion_dtype bf16 \ - --projector_dtype bf16 + --projector_dtype bf16 \ + --encoder_dtype bf16 } 2>&1 | tee "${res_dir}/output.log" diff --git a/useful.sh b/useful.sh index bcbf234..64de661 100644 --- a/useful.sh +++ b/useful.sh @@ -78,4 +78,7 @@ BF16 projector比FP32 projector更准的可能原因: - 训练分布匹配:训练时你是 precision:16,projector 长期在低精度环境下被优化;推理用 FP32 反而可能偏离训练时的统计特性。 - LayerNorm/Softmax 敏感:Resampler/MLP 里 LN/Softmax 对精度很敏感,FP32 计算后再降精度,数值边界更容易“硬截断”;BF16 全程计算可能更平滑。 - 这也解释了为什么你看到 BF16 projector 反而更准。 \ No newline at end of file + 这也解释了为什么你看到 BF16 projector 反而更准。 + +embedder: + 改成 autocast only(权重 FP32,预处理 FP32,仅主干 BF16) \ No newline at end of file