diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 281693c..2f740ba 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -752,13 +752,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: if hasattr(model, "first_stage_model") and model.first_stage_model is not None: vae = model.first_stage_model - # Channels-last memory format: cuDNN uses faster NHWC kernels - if args.vae_channels_last: - vae = vae.to(memory_format=torch.channels_last) - vae._channels_last = True - model.first_stage_model = vae - print(">>> VAE converted to channels_last (NHWC) memory format") - # torch.compile: fuses GroupNorm+SiLU, conv chains, etc. if args.vae_compile: vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead") @@ -1173,12 +1166,6 @@ def get_parser(): default=False, help="Apply torch.compile to VAE decoder for kernel fusion." ) - parser.add_argument( - "--vae_channels_last", - action='store_true', - default=False, - help="Convert VAE to channels-last (NHWC) memory format for faster cuDNN convolutions." - ) parser.add_argument( "--vae_decode_bs", type=int, diff --git a/src/unifolm_wma/models/autoencoder.py b/src/unifolm_wma/models/autoencoder.py index 1a79699..94b2d8c 100644 --- a/src/unifolm_wma/models/autoencoder.py +++ b/src/unifolm_wma/models/autoencoder.py @@ -99,16 +99,12 @@ class AutoencoderKL(pl.LightningModule): print(f"Restored from {path}") def encode(self, x, **kwargs): - if getattr(self, '_channels_last', False): - x = x.to(memory_format=torch.channels_last) h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) return posterior def decode(self, z, **kwargs): - if getattr(self, '_channels_last', False): - z = z.to(memory_format=torch.channels_last) z = self.post_quant_conv(z) dec = self.decoder(z) return dec diff --git a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh index 36b0458..ace451b 100644 --- a/unitree_g1_pack_camera/case1/run_world_model_interaction.sh +++ b/unitree_g1_pack_camera/case1/run_world_model_interaction.sh @@ -24,6 +24,5 @@ dataset="unitree_g1_pack_camera" --diffusion_dtype bf16 \ --projector_mode bf16_full \ --encoder_mode bf16_full \ - --vae_dtype bf16 \ - --vae_channels_last + --vae_dtype bf16 } 2>&1 | tee "${res_dir}/output.log"