import argparse, os, glob from contextlib import nullcontext import pandas as pd import random import torch import torchvision import h5py import numpy as np import logging import einops import warnings import imageio from pytorch_lightning import seed_everything from omegaconf import OmegaConf from tqdm import tqdm from einops import rearrange, repeat from collections import OrderedDict from torch import nn from eval_utils import populate_queues, log_to_tensorboard from collections import deque from torch import Tensor from torch.utils.tensorboard import SummaryWriter from PIL import Image from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.utils.utils import instantiate_from_config import torch.nn.functional as F def patch_norm_bypass_autocast(): """Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy. This eliminates bf16->fp32->bf16 dtype conversions during UNet forward.""" def _group_norm_forward(self, x): with torch.amp.autocast('cuda', enabled=False): return F.group_norm( x, self.num_groups, self.weight.to(x.dtype) if self.weight is not None else None, self.bias.to(x.dtype) if self.bias is not None else None, self.eps) def _layer_norm_forward(self, x): with torch.amp.autocast('cuda', enabled=False): return F.layer_norm( x, self.normalized_shape, self.weight.to(x.dtype) if self.weight is not None else None, self.bias.to(x.dtype) if self.bias is not None else None, self.eps) torch.nn.GroupNorm.forward = _group_norm_forward torch.nn.LayerNorm.forward = _layer_norm_forward def get_device_from_parameters(module: nn.Module) -> torch.device: """Get a module's device by checking one of its parameters. Args: module (nn.Module): The model whose device is to be inferred. Returns: torch.device: The device of the model's parameters. """ return next(iter(module.parameters())).device def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module: """Apply precision settings to model components based on command-line arguments. Args: model (nn.Module): The model to apply precision settings to. args (argparse.Namespace): Parsed command-line arguments containing precision settings. Returns: nn.Module: Model with precision settings applied. """ print(f">>> Applying precision settings:") print(f" - Diffusion dtype: {args.diffusion_dtype}") print(f" - Projector mode: {args.projector_mode}") print(f" - Encoder mode: {args.encoder_mode}") print(f" - VAE dtype: {args.vae_dtype}") # 1. Set Diffusion backbone precision if args.diffusion_dtype == "bf16": # Convert diffusion model weights to bf16 model.model.to(torch.bfloat16) model.diffusion_autocast_dtype = torch.bfloat16 print(" ✓ Diffusion model weights converted to bfloat16") else: model.diffusion_autocast_dtype = torch.bfloat16 print(" ✓ Diffusion model using fp32") # 2. Set Projector precision if args.projector_mode == "bf16_full": model.state_projector.to(torch.bfloat16) model.action_projector.to(torch.bfloat16) model.projector_autocast_dtype = None print(" ✓ Projectors converted to bfloat16") elif args.projector_mode == "autocast": model.projector_autocast_dtype = torch.bfloat16 print(" ✓ Projectors will use autocast (weights fp32, compute bf16)") else: model.projector_autocast_dtype = None # fp32 mode: do nothing, keep original precision # 3. Set Encoder precision if args.encoder_mode == "bf16_full": model.embedder.to(torch.bfloat16) model.image_proj_model.to(torch.bfloat16) model.encoder_autocast_dtype = None print(" ✓ Encoders converted to bfloat16") elif args.encoder_mode == "autocast": model.encoder_autocast_dtype = torch.bfloat16 print(" ✓ Encoders will use autocast (weights fp32, compute bf16)") else: model.encoder_autocast_dtype = None # fp32 mode: do nothing, keep original precision # 4. Set VAE precision if args.vae_dtype == "bf16": model.first_stage_model.to(torch.bfloat16) print(" ✓ VAE converted to bfloat16") else: print(" ✓ VAE kept in fp32 for best quality") # 5. Safety net: ensure no fp32 parameters remain when all components are bf16 if args.diffusion_dtype == "bf16": fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32] if fp32_params: print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16") for name, param in fp32_params: param.data = param.data.to(torch.bfloat16) print(" ✓ All parameters converted to bfloat16") return model def write_video(video_path: str, stacked_frames: list, fps: int) -> None: """Save a list of frames to a video file. Args: video_path (str): Output path for the video. stacked_frames (list): List of image frames. fps (int): Frames per second for the video. """ with warnings.catch_warnings(): warnings.filterwarnings("ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning) imageio.mimsave(video_path, stacked_frames, fps=fps) def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: """Return sorted list of files in a directory matching specified postfixes. Args: data_dir (str): Directory path to search in. postfixes (list[str]): List of file extensions to match. Returns: list[str]: Sorted list of file paths. """ patterns = [ os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes ] file_list = [] for pattern in patterns: file_list.extend(glob.glob(pattern)) file_list.sort() return file_list def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: """Load model weights from checkpoint file. Args: model (nn.Module): Model instance. ckpt (str): Path to the checkpoint file. Returns: nn.Module: Model with loaded weights. """ state_dict = torch.load(ckpt, map_location="cpu") if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] try: model.load_state_dict(state_dict, strict=True) except: new_pl_sd = OrderedDict() for k, v in state_dict.items(): new_pl_sd[k] = v for k in list(new_pl_sd.keys()): if "framestride_embed" in k: new_key = k.replace("framestride_embed", "fps_embedding") new_pl_sd[new_key] = new_pl_sd[k] del new_pl_sd[k] model.load_state_dict(new_pl_sd, strict=True) else: new_pl_sd = OrderedDict() for key in state_dict['module'].keys(): new_pl_sd[key[16:]] = state_dict['module'][key] model.load_state_dict(new_pl_sd) print('>>> model checkpoint loaded.') return model def is_inferenced(save_dir: str, filename: str) -> bool: """Check if a given filename has already been processed and saved. Args: save_dir (str): Directory where results are saved. filename (str): Name of the file to check. Returns: bool: True if processed file exists, False otherwise. """ video_file = os.path.join(save_dir, "samples_separate", f"{filename[:-4]}_sample0.mp4") return os.path.exists(video_file) def save_results(video: Tensor, filename: str, fps: int = 8) -> None: """Save video tensor to file using torchvision. Args: video (Tensor): Tensor of shape (B, C, T, H, W). filename (str): Output file path. fps (int, optional): Frames per second. Defaults to 8. """ video = video.detach().cpu() video = torch.clamp(video.float(), -1., 1.) n = video.shape[0] video = video.permute(2, 0, 1, 3, 4) frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video ] grid = torch.stack(frame_grids, dim=0) grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video(filename, grid, fps=fps, video_codec='h264', options={'crf': '10'}) def get_init_frame_path(data_dir: str, sample: dict) -> str: """Construct the init_frame path from directory and sample metadata. Args: data_dir (str): Base directory containing videos. sample (dict): Dictionary containing 'data_dir' and 'videoid'. Returns: str: Full path to the video file. """ rel_video_fp = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png') full_image_fp = os.path.join(data_dir, 'images', rel_video_fp) return full_image_fp def get_transition_path(data_dir: str, sample: dict) -> str: """Construct the full transition file path from directory and sample metadata. Args: data_dir (str): Base directory containing transition files. sample (dict): Dictionary containing 'data_dir' and 'videoid'. Returns: str: Full path to the HDF5 transition file. """ rel_transition_fp = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5') full_transition_fp = os.path.join(data_dir, 'transitions', rel_transition_fp) return full_transition_fp def prepare_init_input(start_idx: int, init_frame_path: str, transition_dict: dict[str, torch.Tensor], frame_stride: int, wma_data, video_length: int = 16, n_obs_steps: int = 2) -> dict[str, Tensor]: """ Extracts a structured sample from a video sequence including frames, states, and actions, along with properly padded observations and pre-processed tensors for model input. Args: start_idx (int): Starting frame index for the current clip. video: decord video instance. transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action', 'observation.state', 'action_type', 'state_type'. frame_stride (int): Temporal stride between sampled frames. wma_data: Object that holds configuration and utility functions like normalization, transformation, and resolution info. video_length (int, optional): Number of frames to sample from the video. Default is 16. n_obs_steps (int, optional): Number of historical steps for observations. Default is 2. """ indices = [start_idx + frame_stride * i for i in range(video_length)] init_frame = Image.open(init_frame_path).convert('RGB') init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute( 3, 0, 1, 2).float() if start_idx < n_obs_steps - 1: state_indices = list(range(0, start_idx + 1)) states = transition_dict['observation.state'][state_indices, :] num_padding = n_obs_steps - 1 - start_idx first_slice = states[0:1, :] # (t, d) padding = first_slice.repeat(num_padding, 1) states = torch.cat((padding, states), dim=0) else: state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1)) states = transition_dict['observation.state'][state_indices, :] actions = transition_dict['action'][indices, :] ori_state_dim = states.shape[-1] ori_action_dim = actions.shape[-1] frames_action_state_dict = { 'action': actions, 'observation.state': states, } frames_action_state_dict = wma_data.normalizer(frames_action_state_dict) frames_action_state_dict = wma_data.get_uni_vec( frames_action_state_dict, transition_dict['action_type'], transition_dict['state_type'], ) if wma_data.spatial_transform is not None: init_frame = wma_data.spatial_transform(init_frame) init_frame = (init_frame / 255 - 0.5) * 2 data = { 'observation.image': init_frame, } data.update(frames_action_state_dict) return data, ori_state_dim, ori_action_dim def get_latent_z(model, videos: Tensor) -> Tensor: """ Extracts latent features from a video batch using the model's first-stage encoder. Args: model: the world model. videos (Tensor): Input videos of shape [B, C, T, H, W]. Returns: Tensor: Latent video tensor of shape [B, C, T, H, W]. """ b, c, t, h, w = videos.shape x = rearrange(videos, 'b c t h w -> (b t) c h w') # Auto-detect VAE dtype and convert input vae_dtype = next(model.first_stage_model.parameters()).dtype x = x.to(dtype=vae_dtype) z = model.encode_first_stage(x) z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) return z def preprocess_observation( model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]: """Convert environment observation to LeRobot format observation. Args: observation: Dictionary of observation batches from a Gym vector environment. Returns: Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. """ # Map to expected inputs for the policy return_observations = {} if isinstance(observations["pixels"], dict): imgs = { f"observation.images.{key}": img for key, img in observations["pixels"].items() } else: imgs = {"observation.images.top": observations["pixels"]} for imgkey, img in imgs.items(): img = torch.from_numpy(img) # Sanity check that images are channel last _, h, w, c = img.shape assert c < h and c < w, f"expect channel first images, but instead {img.shape}" # Sanity check that images are uint8 assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" # Convert to channel first of type float32 in range [0,1] img = einops.rearrange(img, "b h w c -> b c h w").contiguous() img = img.type(torch.float32) return_observations[imgkey] = img return_observations["observation.state"] = torch.from_numpy( observations["agent_pos"]).float() return_observations['observation.state'] = model.normalize_inputs({ 'observation.state': return_observations['observation.state'].to(model.device) })['observation.state'] return return_observations def image_guided_synthesis_sim_mode( model: torch.nn.Module, prompts: list[str], observation: dict, noise_shape: tuple[int, int, int, int, int], action_cond_step: int = 16, n_samples: int = 1, ddim_steps: int = 50, ddim_eta: float = 1.0, unconditional_guidance_scale: float = 1.0, fs: int | None = None, text_input: bool = True, timestep_spacing: str = 'uniform', guidance_rescale: float = 0.0, sim_mode: bool = True, **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text). Args: model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning. prompts (list[str]): A list of textual prompts to guide the synthesis process. observation (dict): A dictionary containing observed inputs including: - 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images) - 'observation.state': Tensor of shape [B, O, D] (state vector) - 'action': Tensor of shape [B, T, D] (action sequence) noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate, typically (B, C, T, H, W). action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16. n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1. ddim_steps (int): Number of DDIM sampling steps. Default is 50. ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0. unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off. fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None. text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True. timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace". guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. **kwargs: Additional arguments passed to the DDIM sampler. Returns: batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W]. actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding. states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding. """ b, _, t, _, _ = noise_shape ddim_sampler = DDIMSampler(model) batch_size = noise_shape[0] fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) # Auto-detect model dtype and convert inputs accordingly model_dtype = next(model.embedder.parameters()).dtype img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype) # Encoder autocast: weights stay fp32, compute in bf16 enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None) if enc_ac_dtype is not None and model.device.type == 'cuda': enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype) else: enc_ctx = nullcontext() with enc_ctx: cond_img_emb = model.embedder(cond_img) cond_img_emb = model.image_proj_model(cond_img_emb) if model.model.conditioning_key == 'hybrid': z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) img_cat_cond = z[:, :, -1:, :, :] img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=noise_shape[2]) cond = {"c_concat": [img_cat_cond]} if not text_input: prompts = [""] * batch_size cond_ins_emb = model.get_learned_conditioning(prompts) # Auto-detect projector dtype and convert inputs projector_dtype = next(model.state_projector.parameters()).dtype # Projector autocast: weights stay fp32, compute in bf16 proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None) if proj_ac_dtype is not None and model.device.type == 'cuda': proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype) else: proj_ctx = nullcontext() with proj_ctx: cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype)) cond_state_emb = cond_state_emb + model.agent_state_pos_emb cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype)) cond_action_emb = cond_action_emb + model.agent_action_pos_emb if not sim_mode: cond_action_emb = torch.zeros_like(cond_action_emb) cond["c_crossattn"] = [ torch.cat( [cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1) ] cond["c_crossattn_action"] = [ observation['observation.images.top'][:, :, -model.n_obs_steps_acting:], observation['observation.state'][:, -model.n_obs_steps_acting:], sim_mode, False, ] uc = None kwargs.update({"unconditional_conditioning_img_nonetext": None}) cond_mask = None cond_z0 = None # Setup autocast context for diffusion sampling autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None) if autocast_dtype is not None and model.device.type == 'cuda': autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype) else: autocast_ctx = nullcontext() if ddim_sampler is not None: with autocast_ctx: samples, actions, states, intermedia = ddim_sampler.sample( S=ddim_steps, conditioning=cond, batch_size=batch_size, shape=noise_shape[1:], verbose=False, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=uc, eta=ddim_eta, cfg_img=None, mask=cond_mask, x0=cond_z0, fs=fs, timestep_spacing=timestep_spacing, guidance_rescale=guidance_rescale, **kwargs) # Reconstruct from latent to pixel space batch_images = model.decode_first_stage(samples) batch_variants = batch_images return batch_variants, actions, states def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: """ Run inference pipeline on prompts and image inputs. Args: args (argparse.Namespace): Parsed command-line arguments. gpu_num (int): Number of GPUs. gpu_no (int): Index of the current GPU. Returns: None """ # Create inference and tensorboard dirs os.makedirs(args.savedir + '/inference', exist_ok=True) log_dir = args.savedir + f"/tensorboard" os.makedirs(log_dir, exist_ok=True) writer = SummaryWriter(log_dir=log_dir) # Load prompt csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") df = pd.read_csv(csv_path) # Load config config = OmegaConf.load(args.config) config['model']['params']['wma_config']['params'][ 'use_checkpoint'] = False model = instantiate_from_config(config.model) model.perframe_ae = args.perframe_ae assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" model = load_model_checkpoint(model, args.ckpt_path) model.eval() print(f'>>> Load pre-trained model ...') # Apply precision settings before moving to GPU model = apply_precision_settings(model, args) # Export precision-converted checkpoint if requested if args.export_precision_ckpt: export_path = args.export_precision_ckpt os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True) torch.save({"state_dict": model.state_dict()}, export_path) print(f">>> Precision-converted checkpoint saved to: {export_path}") return # Build unnomalizer logging.info("***** Configing Data *****") data = instantiate_from_config(config.data) data.setup() print(">>> Dataset is successfully loaded ...") model = model.cuda(gpu_no) device = get_device_from_parameters(model) # Run over data assert (args.height % 16 == 0) and ( args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" assert args.bs == 1, "Current implementation only support [batch size = 1]!" # Get latent noise shape h, w = args.height // 8, args.width // 8 channels = model.model.diffusion_model.out_channels n_frames = args.video_length print(f'>>> Generate {n_frames} frames under each generation ...') noise_shape = [args.bs, channels, n_frames, h, w] # Start inference for idx in range(0, len(df)): sample = df.iloc[idx] # Got initial frame path init_frame_path = get_init_frame_path(args.prompt_dir, sample) ori_fps = float(sample['fps']) video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}" os.makedirs(video_save_dir, exist_ok=True) os.makedirs(video_save_dir + '/dm', exist_ok=True) os.makedirs(video_save_dir + '/wm', exist_ok=True) # Load transitions to get the initial state later transition_path = get_transition_path(args.prompt_dir, sample) with h5py.File(transition_path, 'r') as h5f: transition_dict = {} for key in h5f.keys(): transition_dict[key] = torch.tensor(h5f[key][()]) for key in h5f.attrs.keys(): transition_dict[key] = h5f.attrs[key] # If many, test various frequence control and world-model generation for fs in args.frame_stride: # For saving imagens in policy sample_save_dir = f'{video_save_dir}/dm/{fs}' os.makedirs(sample_save_dir, exist_ok=True) # For saving environmental changes in world-model sample_save_dir = f'{video_save_dir}/wm/{fs}' os.makedirs(sample_save_dir, exist_ok=True) # For collecting interaction videos wm_video = [] # Initialize observation queues cond_obs_queues = { "observation.images.top": deque(maxlen=model.n_obs_steps_imagen), "observation.state": deque(maxlen=model.n_obs_steps_imagen), "action": deque(maxlen=args.video_length), } # Obtain initial frame and state start_idx = 0 model_input_fs = ori_fps // fs batch, ori_state_dim, ori_action_dim = prepare_init_input( start_idx, init_frame_path, transition_dict, fs, data.test_datasets[args.dataset], n_obs_steps=model.n_obs_steps_imagen) observation = { 'observation.images.top': batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0), 'observation.state': batch['observation.state'][-1].unsqueeze(0), 'action': torch.zeros_like(batch['action'][-1]).unsqueeze(0) } observation = { key: observation[key].to(device, non_blocking=True) for key in observation } # Update observation queues cond_obs_queues = populate_queues(cond_obs_queues, observation) # Multi-round interaction with the world-model for itr in tqdm(range(args.n_iter)): # Get observation observation = { 'observation.images.top': torch.stack(list( cond_obs_queues['observation.images.top']), dim=1).permute(0, 2, 1, 3, 4), 'observation.state': torch.stack(list(cond_obs_queues['observation.state']), dim=1), 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } observation = { key: observation[key].to(device, non_blocking=True) for key in observation } # Use world-model in policy to generate action print(f'>>> Step {itr}: generating actions ...') pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( model, sample['instruction'], observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. unconditional_guidance_scale, fs=model_input_fs, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, sim_mode=False) # Update future actions in the observation queues for idx in range(len(pred_actions[0])): observation = {'action': pred_actions[0][idx:idx + 1]} observation['action'][:, ori_action_dim:] = 0.0 cond_obs_queues = populate_queues(cond_obs_queues, observation) # Collect data for interacting the world-model using the predicted actions observation = { 'observation.images.top': torch.stack(list( cond_obs_queues['observation.images.top']), dim=1).permute(0, 2, 1, 3, 4), 'observation.state': torch.stack(list(cond_obs_queues['observation.state']), dim=1), 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } observation = { key: observation[key].to(device, non_blocking=True) for key in observation } # Interaction with the world-model print(f'>>> Step {itr}: interacting with world model ...') pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( model, "", observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. unconditional_guidance_scale, fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale) for idx in range(args.exe_steps): observation = { 'observation.images.top': pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3), 'observation.state': torch.zeros_like(pred_states[0][idx:idx + 1]) if args.zero_pred_state else pred_states[0][idx:idx + 1], 'action': torch.zeros_like(pred_actions[0][-1:]) } observation['observation.state'][:, ori_state_dim:] = 0.0 cond_obs_queues = populate_queues(cond_obs_queues, observation) # Save the imagen videos for decision-making sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" log_to_tensorboard(writer, pred_videos_0, sample_tag, fps=args.save_fps) # Save videos environment changes via world-model interaction sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" log_to_tensorboard(writer, pred_videos_1, sample_tag, fps=args.save_fps) # Save the imagen videos for decision-making sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' save_results(pred_videos_0.cpu(), sample_video_file, fps=args.save_fps) # Save videos environment changes via world-model interaction sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' save_results(pred_videos_1.cpu(), sample_video_file, fps=args.save_fps) print('>' * 24) # Collect the result of world-model interactions wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) full_video = torch.cat(wm_video, dim=2) sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" log_to_tensorboard(writer, full_video, sample_tag, fps=args.save_fps) sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" save_results(full_video, sample_full_video_file, fps=args.save_fps) def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--savedir", type=str, default=None, help="Path to save the results.") parser.add_argument("--ckpt_path", type=str, default=None, help="Path to the model checkpoint.") parser.add_argument("--config", type=str, help="Path to the model checkpoint.") parser.add_argument( "--prompt_dir", type=str, default=None, help="Directory containing videos and corresponding prompts.") parser.add_argument("--dataset", type=str, default=None, help="the name of dataset to test") parser.add_argument( "--ddim_steps", type=int, default=50, help="Number of DDIM steps. If non-positive, DDPM is used instead.") parser.add_argument( "--ddim_eta", type=float, default=1.0, help="Eta for DDIM sampling. Set to 0.0 for deterministic results.") parser.add_argument("--bs", type=int, default=1, help="Batch size for inference. Must be 1.") parser.add_argument("--height", type=int, default=320, help="Height of the generated images in pixels.") parser.add_argument("--width", type=int, default=512, help="Width of the generated images in pixels.") parser.add_argument( "--frame_stride", type=int, nargs='+', required=True, help= "frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)" ) parser.add_argument( "--unconditional_guidance_scale", type=float, default=1.0, help="Scale for classifier-free guidance during sampling.") parser.add_argument("--seed", type=int, default=123, help="Random seed for reproducibility.") parser.add_argument("--video_length", type=int, default=16, help="Number of frames in the generated video.") parser.add_argument("--num_generation", type=int, default=1, help="seed for seed_everything") parser.add_argument( "--timestep_spacing", type=str, default="uniform", help= "Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." ) parser.add_argument( "--guidance_rescale", type=float, default=0.0, help= "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." ) parser.add_argument( "--perframe_ae", action='store_true', default=False, help= "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." ) parser.add_argument( "--n_action_steps", type=int, default=16, help="num of samples per prompt", ) parser.add_argument( "--exe_steps", type=int, default=16, help="num of samples to execute", ) parser.add_argument( "--n_iter", type=int, default=40, help="num of iteration to interact with the world model", ) parser.add_argument("--zero_pred_state", action='store_true', default=False, help="not using the predicted states as comparison") parser.add_argument("--save_fps", type=int, default=8, help="fps for the saving video") parser.add_argument( "--diffusion_dtype", type=str, choices=["fp32", "bf16"], default="bf16", help="Diffusion backbone precision (fp32/bf16)") parser.add_argument( "--projector_mode", type=str, choices=["fp32", "autocast", "bf16_full"], default="bf16_full", help="Projector precision mode (fp32/autocast/bf16_full)") parser.add_argument( "--encoder_mode", type=str, choices=["fp32", "autocast", "bf16_full"], default="bf16_full", help="Encoder precision mode (fp32/autocast/bf16_full)") parser.add_argument( "--vae_dtype", type=str, choices=["fp32", "bf16"], default="fp32", help="VAE precision (fp32/bf16, most affects image quality)") parser.add_argument( "--export_precision_ckpt", type=str, default=None, help="Export precision-converted checkpoint to this path, then exit.") return parser if __name__ == '__main__': patch_norm_bypass_autocast() parser = get_parser() args = parser.parse_args() seed = args.seed if seed < 0: seed = random.randint(0, 2**31) seed_everything(seed) rank, gpu_num = 0, 1 run_inference(args, gpu_num, rank)