diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index cb25f2e..50182df 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -559,6 +559,7 @@ def image_guided_synthesis_sim_mode( autocast_ctx = nullcontext() batch_variants = None + samples = None if ddim_sampler is not None: with autocast_ctx: samples, actions, states, intermedia = ddim_sampler.sample( @@ -583,7 +584,7 @@ def image_guided_synthesis_sim_mode( batch_images = model.decode_first_stage(samples) batch_variants = batch_images - return batch_variants, actions, states + return batch_variants, actions, states, samples def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: @@ -693,7 +694,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: sample_save_dir = f'{video_save_dir}/wm/{fs}' os.makedirs(sample_save_dir, exist_ok=True) # For collecting interaction videos - wm_video = [] + wm_latent = [] # Initialize observation queues cond_obs_queues = { "observation.images.top": @@ -749,7 +750,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Use world-model in policy to generate action print(f'>>> Step {itr}: generating actions ...') - pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( + pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode( model, sample['instruction'], observation, @@ -791,7 +792,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Interaction with the world-model print(f'>>> Step {itr}: interacting with world model ...') - pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( + pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode( model, "", observation, @@ -804,12 +805,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, - guidance_rescale=args.guidance_rescale) + guidance_rescale=args.guidance_rescale, + decode_video=False) + + # Decode only the last frame for CLIP embedding in next iteration + last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :]) for idx in range(args.exe_steps): observation = { 'observation.images.top': - pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3), + last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3), 'observation.state': torch.zeros_like(pred_states[0][idx:idx + 1]) if args.zero_pred_state else pred_states[0][idx:idx + 1], @@ -827,30 +832,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: pred_videos_0, sample_tag, fps=args.save_fps) - # Save videos environment changes via world-model interaction - sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" - log_to_tensorboard(writer, - pred_videos_1, - sample_tag, - fps=args.save_fps) - - # Save the imagen videos for decision-making - if pred_videos_0 is not None: - sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_0.cpu(), - sample_video_file, - fps=args.save_fps) - # Save videos environment changes via world-model interaction - sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' - save_results(pred_videos_1.cpu(), - sample_video_file, - fps=args.save_fps) print('>' * 24) - # Collect the result of world-model interactions - wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) + # Store raw latent for deferred decode + wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu()) - full_video = torch.cat(wm_video, dim=2) + # Deferred decode: batch decode all stored latents + full_latent = torch.cat(wm_latent, dim=2).to(device) + full_video = model.decode_first_stage(full_latent).cpu() sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" log_to_tensorboard(writer, full_video,