From bf4d66c87445f05f333259c2392f4d9ee493e68e Mon Sep 17 00:00:00 2001 From: qhy <2728290997@qq.com> Date: Tue, 10 Feb 2026 19:36:17 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B7=B3=E8=BF=87=E6=A8=A1=E5=9E=8B=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/evaluation/world_model_interaction.py | 41 ++++++++++++++----- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index 1b95a13..ad103a7 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -462,24 +462,43 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") df = pd.read_csv(csv_path) - # Load config + # Load config (always needed for data setup) config = OmegaConf.load(args.config) - config['model']['params']['wma_config']['params'][ - 'use_checkpoint'] = False - model = instantiate_from_config(config.model) - model.perframe_ae = args.perframe_ae - assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" - model = load_model_checkpoint(model, args.ckpt_path) - model.eval() - print(f'>>> Load pre-trained model ...') - # Build unnomalizer + prepared_path = args.ckpt_path + ".prepared.pt" + if os.path.exists(prepared_path): + # ---- Fast path: load the fully-prepared model ---- + print(f">>> Loading prepared model from {prepared_path} ...") + model = torch.load(prepared_path, + map_location=f"cuda:{gpu_no}", + weights_only=False, + mmap=True) + model.eval() + print(f">>> Prepared model loaded.") + else: + # ---- Normal path: construct + load checkpoint ---- + config['model']['params']['wma_config']['params'][ + 'use_checkpoint'] = False + model = instantiate_from_config(config.model) + model.perframe_ae = args.perframe_ae + + assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" + model = load_model_checkpoint(model, args.ckpt_path) + model.eval() + model = model.cuda(gpu_no) + print(f'>>> Load pre-trained model ...') + + # Save prepared model for fast loading next time + print(f">>> Saving prepared model to {prepared_path} ...") + torch.save(model, prepared_path) + print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).") + + # Build normalizer (always needed, independent of model loading path) logging.info("***** Configing Data *****") data = instantiate_from_config(config.data) data.setup() print(">>> Dataset is successfully loaded ...") - model = model.cuda(gpu_no) device = get_device_from_parameters(model) # Run over data