diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 9c1565c..ac5ebde 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -649,7 +649,7 @@ def prepare_init_input(start_idx: int, return data, ori_state_dim, ori_action_dim -def get_latent_z(model, videos: Tensor) -> Tensor: +def get_latent_z(model, videos: Tensor) -> Tensor: """ Extracts latent features from a video batch using the model's first-stage encoder. @@ -661,11 +661,15 @@ def get_latent_z(model, videos: Tensor) -> Tensor: Tensor: Latent video tensor of shape [B, C, T, H, W]. """ profiler = get_profiler() - with profiler.profile_section("get_latent_z/encode"): - b, c, t, h, w = videos.shape - x = rearrange(videos, 'b c t h w -> (b t) c h w') - z = model.encode_first_stage(x) - z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) + with profiler.profile_section("get_latent_z/encode"): + b, c, t, h, w = videos.shape + x = rearrange(videos, 'b c t h w -> (b t) c h w') + vae_ctx = nullcontext() + if getattr(model, "vae_bf16", False) and model.device.type == "cuda": + vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + with vae_ctx: + z = model.encode_first_stage(x) + z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) return z @@ -879,9 +883,18 @@ def image_guided_synthesis_sim_mode( # Reconstruct from latent to pixel space with profiler.profile_section("synthesis/decode_first_stage"): - if samples.dtype != torch.float32: - samples = samples.float() - batch_images = model.decode_first_stage(samples) + if getattr(model, "vae_bf16", False): + if samples.dtype != torch.bfloat16: + samples = samples.to(dtype=torch.bfloat16) + vae_ctx = nullcontext() + if model.device.type == "cuda": + vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16) + with vae_ctx: + batch_images = model.decode_first_stage(samples) + else: + if samples.dtype != torch.float32: + samples = samples.float() + batch_images = model.decode_first_stage(samples) batch_variants = batch_images return batch_variants, actions, states @@ -944,6 +957,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: diffusion_autocast_dtype = torch.bfloat16 print(">>> diffusion backbone set to bfloat16") + if hasattr(model, "first_stage_model") and model.first_stage_model is not None: + if args.vae_dtype == "bf16": + model.first_stage_model.to(dtype=torch.bfloat16) + else: + model.first_stage_model.to(dtype=torch.float32) + model.vae_bf16 = args.vae_dtype == "bf16" + print(f">>> VAE dtype set to {args.vae_dtype}") + encoder_mode = args.encoder_mode encoder_bf16 = encoder_mode in ("autocast", "bf16_full") encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32 @@ -957,9 +978,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})" ) + projector_mode = args.projector_mode + projector_bf16 = projector_mode in ("autocast", "bf16_full") + projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32 + if hasattr(model, "image_proj_model") and model.image_proj_model is not None: + model.image_proj_model.to(dtype=projector_weight_dtype) + if hasattr(model, "state_projector") and model.state_projector is not None: + model.state_projector.to(dtype=projector_weight_dtype) + if hasattr(model, "action_projector") and model.action_projector is not None: + model.action_projector.to(dtype=projector_weight_dtype) if hasattr(model, "projector_bf16"): - model.projector_bf16 = args.projector_dtype == "bf16" - print(f">>> projector dtype set to {args.projector_dtype}") + model.projector_bf16 = projector_bf16 + model.projector_mode = projector_mode + print( + f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})" + ) log_inference_precision(model) @@ -1305,11 +1338,14 @@ def get_parser(): help="Dtype for diffusion backbone weights and sampling autocast." ) parser.add_argument( - "--projector_dtype", + "--projector_mode", type=str, - choices=["fp32", "bf16"], + choices=["fp32", "autocast", "bf16_full"], default="fp32", - help="Dtype for image/state/action projectors (autocast in forward)." + help= + "Projector precision mode for image/state/action projectors: " + "fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, " + "bf16_full=bf16 weights + bf16 forward." ) parser.add_argument( "--encoder_mode", @@ -1321,6 +1357,13 @@ def get_parser(): "fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, " "bf16_full=bf16 weights + bf16 forward." ) + parser.add_argument( + "--vae_dtype", + type=str, + choices=["fp32", "bf16"], + default="fp32", + help="Dtype for VAE/first_stage_model weights and forward autocast." + ) parser.add_argument( "--n_action_steps", type=int, diff --git a/src/unifolm_wma/models/ddpms.py b/src/unifolm_wma/models/ddpms.py index ceca543..2f6c3ca 100644 --- a/src/unifolm_wma/models/ddpms.py +++ b/src/unifolm_wma/models/ddpms.py @@ -2032,6 +2032,13 @@ class LatentVisualDiffusion(LatentDiffusion): target_dtype: torch.dtype | None) -> Tensor: use_bf16 = (self.projector_bf16 and x.device.type == "cuda" and torch.cuda.is_bf16_supported()) + if not use_bf16: + weight_dtype = None + for param in projector.parameters(): + weight_dtype = param.dtype + break + if weight_dtype is not None and x.dtype != weight_dtype: + x = x.to(dtype=weight_dtype) if use_bf16: with torch.autocast(device_type="cuda", dtype=torch.bfloat16): out = projector(x) diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768739615.node-0.468941.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768739615.node-0.468941.0 new file mode 100644 index 0000000..8ca4ba7 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768739615.node-0.468941.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768740558.node-0.479988.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768740558.node-0.479988.0 new file mode 100644 index 0000000..34b8a36 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768740558.node-0.479988.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768741386.node-0.487044.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768741386.node-0.487044.0 new file mode 100644 index 0000000..5e94e6c Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768741386.node-0.487044.0 differ diff --git a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh index a937845..62784f3 100644 --- a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh +++ b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh @@ -22,6 +22,7 @@ dataset="unitree_g1_pack_camera" --guidance_rescale 0.7 \ --perframe_ae \ --diffusion_dtype bf16 \ - --projector_dtype bf16 \ - --encoder_mode autocast #fp32/autocast/bf16_full + --projector_mode autocast \ + --encoder_mode bf16_full \ + --vae_dtype bf16 } 2>&1 | tee "${res_dir}/output.log"