diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 71ed2fc..9784f62 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -1026,24 +1026,171 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") df = pd.read_csv(csv_path) - # Load config - with profiler.profile_section("model_loading/config"): - config = OmegaConf.load(args.config) - config['model']['params']['wma_config']['params'][ - 'use_checkpoint'] = False - model = instantiate_from_config(config.model) - model.perframe_ae = args.perframe_ae + # Load config (always needed for data setup) + config = OmegaConf.load(args.config) - assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" + prepared_path = args.ckpt_path + ".prepared.pt" + if os.path.exists(prepared_path): + # ---- Fast path: load the fully-prepared model ---- + print(f">>> Loading prepared model from {prepared_path} ...") + with profiler.profile_section("model_loading/prepared"): + model = torch.load(prepared_path, + map_location=f"cuda:{gpu_no}", + weights_only=False, + mmap=True) + model.eval() + diffusion_autocast_dtype = (torch.bfloat16 + if args.diffusion_dtype == "bf16" + else None) + print(f">>> Prepared model loaded.") + else: + # ---- Normal path: construct + checkpoint + casting ---- + with profiler.profile_section("model_loading/config"): + config['model']['params']['wma_config']['params'][ + 'use_checkpoint'] = False + model = instantiate_from_config(config.model) + model.perframe_ae = args.perframe_ae - with profiler.profile_section("model_loading/checkpoint"): - model = load_model_checkpoint(model, args.ckpt_path, - device=f"cuda:{gpu_no}") - model.eval() - model = model.cuda(gpu_no) # move residual buffers not in state_dict - print(f'>>> Load pre-trained model ...') + assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" - # Build unnomalizer + with profiler.profile_section("model_loading/checkpoint"): + model = load_model_checkpoint(model, args.ckpt_path, + device=f"cuda:{gpu_no}") + model.eval() + model = model.cuda(gpu_no) # move residual buffers not in state_dict + print(f'>>> Load pre-trained model ...') + + diffusion_autocast_dtype = None + if args.diffusion_dtype == "bf16": + maybe_cast_module( + model.model, + torch.bfloat16, + "diffusion backbone", + profiler=profiler, + profile_name="model_loading/diffusion_bf16", + ) + diffusion_autocast_dtype = torch.bfloat16 + print(">>> diffusion backbone set to bfloat16") + + if hasattr(model, "first_stage_model") and model.first_stage_model is not None: + vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32 + maybe_cast_module( + model.first_stage_model, + vae_weight_dtype, + "VAE", + profiler=profiler, + profile_name="model_loading/vae_cast", + ) + model.vae_bf16 = args.vae_dtype == "bf16" + print(f">>> VAE dtype set to {args.vae_dtype}") + + # --- VAE performance optimizations --- + if hasattr(model, "first_stage_model") and model.first_stage_model is not None: + vae = model.first_stage_model + + # Channels-last memory format: cuDNN uses faster NHWC kernels + if args.vae_channels_last: + vae = vae.to(memory_format=torch.channels_last) + vae._channels_last = True + model.first_stage_model = vae + print(">>> VAE converted to channels_last (NHWC) memory format") + + # torch.compile: fuses GroupNorm+SiLU, conv chains, etc. + if args.vae_compile: + vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead") + vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead") + print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)") + + # Batch decode size + vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999 + model.vae_decode_bs = vae_decode_bs + model.vae_encode_bs = vae_decode_bs + if args.vae_decode_bs > 0: + print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}") + else: + print(">>> VAE encode/decode batch size: all frames at once") + + encoder_mode = args.encoder_mode + encoder_bf16 = encoder_mode in ("autocast", "bf16_full") + encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32 + if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None: + maybe_cast_module( + model.cond_stage_model, + encoder_weight_dtype, + "cond_stage_model", + profiler=profiler, + profile_name="model_loading/encoder_cond_cast", + ) + if hasattr(model, "embedder") and model.embedder is not None: + maybe_cast_module( + model.embedder, + encoder_weight_dtype, + "embedder", + profiler=profiler, + profile_name="model_loading/encoder_embedder_cast", + ) + model.encoder_bf16 = encoder_bf16 + model.encoder_mode = encoder_mode + print( + f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})" + ) + + projector_mode = args.projector_mode + projector_bf16 = projector_mode in ("autocast", "bf16_full") + projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32 + if hasattr(model, "image_proj_model") and model.image_proj_model is not None: + maybe_cast_module( + model.image_proj_model, + projector_weight_dtype, + "image_proj_model", + profiler=profiler, + profile_name="model_loading/projector_image_cast", + ) + if hasattr(model, "state_projector") and model.state_projector is not None: + maybe_cast_module( + model.state_projector, + projector_weight_dtype, + "state_projector", + profiler=profiler, + profile_name="model_loading/projector_state_cast", + ) + if hasattr(model, "action_projector") and model.action_projector is not None: + maybe_cast_module( + model.action_projector, + projector_weight_dtype, + "action_projector", + profiler=profiler, + profile_name="model_loading/projector_action_cast", + ) + if hasattr(model, "projector_bf16"): + model.projector_bf16 = projector_bf16 + model.projector_mode = projector_mode + print( + f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})" + ) + + log_inference_precision(model) + + if args.export_casted_ckpt: + metadata = { + "diffusion_dtype": args.diffusion_dtype, + "vae_dtype": args.vae_dtype, + "encoder_mode": args.encoder_mode, + "projector_mode": args.projector_mode, + "perframe_ae": args.perframe_ae, + } + save_casted_checkpoint(model, args.export_casted_ckpt, metadata) + if args.export_only: + print(">>> export_only set; skipping inference.") + return + + # Save prepared model for fast loading next time + if prepared_path: + print(f">>> Saving prepared model to {prepared_path} ...") + torch.save(model, prepared_path) + print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).") + + # Build normalizer (always needed, independent of model loading path) logging.info("***** Configing Data *****") with profiler.profile_section("data_loading"): data = instantiate_from_config(config.data) @@ -1051,130 +1198,6 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: print(">>> Dataset is successfully loaded ...") device = get_device_from_parameters(model) - diffusion_autocast_dtype = None - if args.diffusion_dtype == "bf16": - maybe_cast_module( - model.model, - torch.bfloat16, - "diffusion backbone", - profiler=profiler, - profile_name="model_loading/diffusion_bf16", - ) - diffusion_autocast_dtype = torch.bfloat16 - print(">>> diffusion backbone set to bfloat16") - - if hasattr(model, "first_stage_model") and model.first_stage_model is not None: - vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32 - maybe_cast_module( - model.first_stage_model, - vae_weight_dtype, - "VAE", - profiler=profiler, - profile_name="model_loading/vae_cast", - ) - model.vae_bf16 = args.vae_dtype == "bf16" - print(f">>> VAE dtype set to {args.vae_dtype}") - - # --- VAE performance optimizations --- - if hasattr(model, "first_stage_model") and model.first_stage_model is not None: - vae = model.first_stage_model - - # Channels-last memory format: cuDNN uses faster NHWC kernels - if args.vae_channels_last: - vae = vae.to(memory_format=torch.channels_last) - vae._channels_last = True - model.first_stage_model = vae - print(">>> VAE converted to channels_last (NHWC) memory format") - - # torch.compile: fuses GroupNorm+SiLU, conv chains, etc. - if args.vae_compile: - vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead") - vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead") - print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)") - - # Batch decode size - vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999 - model.vae_decode_bs = vae_decode_bs - model.vae_encode_bs = vae_decode_bs - if args.vae_decode_bs > 0: - print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}") - else: - print(">>> VAE encode/decode batch size: all frames at once") - - encoder_mode = args.encoder_mode - encoder_bf16 = encoder_mode in ("autocast", "bf16_full") - encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32 - if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None: - maybe_cast_module( - model.cond_stage_model, - encoder_weight_dtype, - "cond_stage_model", - profiler=profiler, - profile_name="model_loading/encoder_cond_cast", - ) - if hasattr(model, "embedder") and model.embedder is not None: - maybe_cast_module( - model.embedder, - encoder_weight_dtype, - "embedder", - profiler=profiler, - profile_name="model_loading/encoder_embedder_cast", - ) - model.encoder_bf16 = encoder_bf16 - model.encoder_mode = encoder_mode - print( - f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})" - ) - - projector_mode = args.projector_mode - projector_bf16 = projector_mode in ("autocast", "bf16_full") - projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32 - if hasattr(model, "image_proj_model") and model.image_proj_model is not None: - maybe_cast_module( - model.image_proj_model, - projector_weight_dtype, - "image_proj_model", - profiler=profiler, - profile_name="model_loading/projector_image_cast", - ) - if hasattr(model, "state_projector") and model.state_projector is not None: - maybe_cast_module( - model.state_projector, - projector_weight_dtype, - "state_projector", - profiler=profiler, - profile_name="model_loading/projector_state_cast", - ) - if hasattr(model, "action_projector") and model.action_projector is not None: - maybe_cast_module( - model.action_projector, - projector_weight_dtype, - "action_projector", - profiler=profiler, - profile_name="model_loading/projector_action_cast", - ) - if hasattr(model, "projector_bf16"): - model.projector_bf16 = projector_bf16 - model.projector_mode = projector_mode - print( - f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})" - ) - - log_inference_precision(model) - - if args.export_casted_ckpt: - metadata = { - "diffusion_dtype": args.diffusion_dtype, - "vae_dtype": args.vae_dtype, - "encoder_mode": args.encoder_mode, - "projector_mode": args.projector_mode, - "perframe_ae": args.perframe_ae, - } - save_casted_checkpoint(model, args.export_casted_ckpt, metadata) - if args.export_only: - print(">>> export_only set; skipping inference.") - return - profiler.record_memory("after_model_load") # Run over data