diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 73e2baa..5bac2b2 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -714,22 +714,23 @@ def preprocess_observation( return return_observations -def image_guided_synthesis_sim_mode( - model: torch.nn.Module, - prompts: list[str], - observation: dict, - noise_shape: tuple[int, int, int, int, int], - action_cond_step: int = 16, - n_samples: int = 1, - ddim_steps: int = 50, - ddim_eta: float = 1.0, - unconditional_guidance_scale: float = 1.0, - fs: int | None = None, - text_input: bool = True, - timestep_spacing: str = 'uniform', - guidance_rescale: float = 0.0, - sim_mode: bool = True, - **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def image_guided_synthesis_sim_mode( + model: torch.nn.Module, + prompts: list[str], + observation: dict, + noise_shape: tuple[int, int, int, int, int], + action_cond_step: int = 16, + n_samples: int = 1, + ddim_steps: int = 50, + ddim_eta: float = 1.0, + unconditional_guidance_scale: float = 1.0, + fs: int | None = None, + text_input: bool = True, + timestep_spacing: str = 'uniform', + guidance_rescale: float = 0.0, + sim_mode: bool = True, + diffusion_autocast_dtype: Optional[torch.dtype] = None, + **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text). @@ -750,9 +751,10 @@ def image_guided_synthesis_sim_mode( fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None. text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True. timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace". - guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. - sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. - **kwargs: Additional arguments passed to the DDIM sampler. + guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. + sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. + diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16). + **kwargs: Additional arguments passed to the DDIM sampler. Returns: batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W]. @@ -810,31 +812,37 @@ def image_guided_synthesis_sim_mode( uc = None kwargs.update({"unconditional_conditioning_img_nonetext": None}) cond_mask = None - cond_z0 = None - - if ddim_sampler is not None: - with profiler.profile_section("synthesis/ddim_sampling"): - samples, actions, states, intermedia = ddim_sampler.sample( - S=ddim_steps, - conditioning=cond, - batch_size=batch_size, - shape=noise_shape[1:], - verbose=False, - unconditional_guidance_scale=unconditional_guidance_scale, - unconditional_conditioning=uc, - eta=ddim_eta, - cfg_img=None, - mask=cond_mask, - x0=cond_z0, - fs=fs, - timestep_spacing=timestep_spacing, - guidance_rescale=guidance_rescale, - **kwargs) - - # Reconstruct from latent to pixel space - with profiler.profile_section("synthesis/decode_first_stage"): - batch_images = model.decode_first_stage(samples) - batch_variants = batch_images + cond_z0 = None + + if ddim_sampler is not None: + with profiler.profile_section("synthesis/ddim_sampling"): + autocast_ctx = nullcontext() + if diffusion_autocast_dtype is not None and model.device.type == "cuda": + autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype) + with autocast_ctx: + samples, actions, states, intermedia = ddim_sampler.sample( + S=ddim_steps, + conditioning=cond, + batch_size=batch_size, + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=uc, + eta=ddim_eta, + cfg_img=None, + mask=cond_mask, + x0=cond_z0, + fs=fs, + timestep_spacing=timestep_spacing, + guidance_rescale=guidance_rescale, + **kwargs) + + # Reconstruct from latent to pixel space + with profiler.profile_section("synthesis/decode_first_stage"): + if samples.dtype != torch.float32: + samples = samples.float() + batch_images = model.decode_first_stage(samples) + batch_variants = batch_images return batch_variants, actions, states @@ -889,6 +897,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: model = model.cuda(gpu_no) device = get_device_from_parameters(model) + diffusion_autocast_dtype = None + if args.diffusion_dtype == "bf16": + with profiler.profile_section("model_loading/diffusion_bf16"): + model.model.to(dtype=torch.bfloat16) + diffusion_autocast_dtype = torch.bfloat16 + print(">>> diffusion backbone set to bfloat16") + log_inference_precision(model) profiler.record_memory("after_model_load") @@ -1014,20 +1029,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Use world-model in policy to generate action print(f'>>> Step {itr}: generating actions ...') with profiler.profile_section("action_generation"): - pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( - model, - sample['instruction'], - observation, - noise_shape, + pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( + model, + sample['instruction'], + observation, + noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, - unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale, - sim_mode=False) + unconditional_guidance_scale=args. + unconditional_guidance_scale, + fs=model_input_fs, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + sim_mode=False, + diffusion_autocast_dtype=diffusion_autocast_dtype) # Update future actions in the observation queues with profiler.profile_section("update_action_queues"): @@ -1058,20 +1074,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Interaction with the world-model print(f'>>> Step {itr}: interacting with world model ...') with profiler.profile_section("world_model_interaction"): - pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( - model, - "", - observation, - noise_shape, + pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( + model, + "", + observation, + noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. - unconditional_guidance_scale, - fs=model_input_fs, - text_input=False, - timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale) + unconditional_guidance_scale, + fs=model_input_fs, + text_input=False, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + diffusion_autocast_dtype=diffusion_autocast_dtype) with profiler.profile_section("update_state_queues"): for step_idx in range(args.exe_steps): @@ -1216,13 +1233,20 @@ def get_parser(): help= "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." ) - parser.add_argument( - "--perframe_ae", - action='store_true', - default=False, - help= - "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." - ) + parser.add_argument( + "--perframe_ae", + action='store_true', + default=False, + help= + "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." + ) + parser.add_argument( + "--diffusion_dtype", + type=str, + choices=["fp32", "bf16"], + default="fp32", + help="Dtype for diffusion backbone weights and sampling autocast." + ) parser.add_argument( "--n_action_steps", type=int, diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768707839.node-0.217134.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768707839.node-0.217134.0 new file mode 100644 index 0000000..d9d6f61 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768707839.node-0.217134.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768707906.node-0.218123.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768707906.node-0.218123.0 new file mode 100644 index 0000000..d9df644 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768707906.node-0.218123.0 differ diff --git a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh index e0e900f..3054dd3 100644 --- a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh +++ b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh @@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1" dataset="unitree_g1_pack_camera" { - time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ + time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \ --seed 123 \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \ --config configs/inference/world_model_interaction.yaml \ @@ -20,5 +20,6 @@ dataset="unitree_g1_pack_camera" --n_iter 11 \ --timestep_spacing 'uniform_trailing' \ --guidance_rescale 0.7 \ - --perframe_ae + --perframe_ae \ + --diffusion_dtype bf16 } 2>&1 | tee "${res_dir}/output.log" diff --git a/useful.sh b/useful.sh index a466ad5..aec7681 100644 --- a/useful.sh +++ b/useful.sh @@ -29,4 +29,45 @@ python3 psnr_score_for_challenge.py --gt_video unitree_g1_pack_camera/case1/unit 1. torch.compile + AMP + TF32 + cudnn.benchmark 2. 排查 .to()/copy/clone 的重复位置并移出循环 - 3. 若需要更大幅度,再换采样器/降步数 \ No newline at end of file + 3. 若需要更大幅度,再换采样器/降步数 + + + + + + A100 上我推荐 BF16 优先(稳定性更好、PSNR 更稳),FP16 作为速度优先方案。 + + 下面是“分模块”的 消融方案(从稳到激进): + + 0)基线 + + - 全 FP32(你现在就是这个) + + 1)只对扩散主干做 BF16(最推荐) + + - 量化对象:model.model(扩散 UNet/WMAModel 主体) + - 保持 FP32:first_stage_model(VAE 编/解码)、cond_stage_model(文本)、embedder(图像)、image_proj_model + - 预期:PSNR 基本不掉 or 极小波动 + + 2)+ 轻量投影/MLP 做 BF16 + + - 增加:image_proj_model、state_projector、action_projector + - 预期:几乎不影响 PSNR + + 3)+ 文本/图像编码做 BF16 + + - 增加:cond_stage_model、embedder + - 预期:可能有轻微波动,通常仍可接受 + + 4)VAE 也做 BF16(最容易伤 PSNR) + + - 增加:first_stage_model + - 预期:画质/PSNR 最敏感,建议最后做消融 + + ——— + + 具体建议(A100) + + - 优先 BF16:稳定性好于 FP16 + - 只做半精度,不做 INT 量化:保持 PSNR + - VAE 尽量 FP32:最影响画质的模块