import argparse, os, glob import pandas as pd import random import torch import torchvision import h5py import numpy as np import logging import einops import warnings import imageio import atexit import signal import multiprocessing as mp import time from concurrent.futures import ThreadPoolExecutor from queue import Empty, Queue from pytorch_lightning import seed_everything from omegaconf import OmegaConf from tqdm import tqdm from einops import rearrange, repeat from collections import OrderedDict from torch import nn from eval_utils import populate_queues from collections import deque from typing import Optional, List, Any from types import SimpleNamespace torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True from torch import Tensor from PIL import Image from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.utils.utils import instantiate_from_config def get_device_from_parameters(module: nn.Module) -> torch.device: """Get a module's device by checking one of its parameters. Args: module (nn.Module): The model whose device is to be inferred. Returns: torch.device: The device of the model's parameters. """ return next(iter(module.parameters())).device def clone_observation_queues( queues: dict[str, deque]) -> dict[str, deque]: """Deep-clone queue tensors so pipeline branches can diverge safely.""" cloned = {} for key, queue in queues.items(): cloned[key] = deque( ((item.clone() if torch.is_tensor(item) else item) for item in queue), maxlen=queue.maxlen) return cloned def clone_observation_queues_to_cpu( queues: dict[str, deque]) -> dict[str, deque]: cpu_queues = {} for key, queue in queues.items(): cpu_queues[key] = deque( (item.detach().cpu().clone() if torch.is_tensor(item) else item for item in queue), maxlen=queue.maxlen) return cpu_queues def move_observation_queues_to_device( queues: dict[str, deque], device: torch.device) -> dict[str, deque]: moved = {} for key, queue in queues.items(): moved[key] = deque( ((item.to(device, non_blocking=True) if torch.is_tensor(item) else item) for item in queue), maxlen=queue.maxlen) return moved def sync_module_device_attributes(module: nn.Module, device: torch.device) -> None: """Align cached `.device` attributes with the actual target device.""" for submodule in module.modules(): if hasattr(submodule, 'device'): try: setattr(submodule, 'device', device) except Exception: pass def pipeline_print(message: str) -> None: print(message, flush=True) def build_observation_from_queues( queues: dict[str, deque], device: torch.device) -> dict[str, torch.Tensor]: observation = { 'observation.images.top': torch.stack(list(queues['observation.images.top']), dim=1).permute( 0, 2, 1, 3, 4), 'observation.state': torch.stack(list(queues['observation.state']), dim=1), 'action': torch.stack(list(queues['action']), dim=1), } return { key: value.to(device, non_blocking=True) for key, value in observation.items() } def append_action_sequence( queues: dict[str, deque], action_seq: torch.Tensor, ori_action_dim: int) -> dict[str, deque]: for idx in range(action_seq.shape[1]): action_frame = action_seq[0][idx:idx + 1].clone() action_frame[:, ori_action_dim:] = 0.0 queues = populate_queues(queues, {'action': action_frame}) return queues def rollout_execution_segment( queues: dict[str, deque], seg_video: torch.Tensor, pred_states: torch.Tensor, zero_action_template: torch.Tensor, exe_steps: int, ori_state_dim: int, zero_pred_state: bool) -> dict[str, deque]: for idx in range(exe_steps): state_frame = (torch.zeros_like(pred_states[0][idx:idx + 1]) if zero_pred_state else pred_states[0][idx:idx + 1].clone()) state_frame[:, ori_state_dim:] = 0.0 observation = { 'observation.images.top': seg_video[0][:, idx:idx + 1].permute(1, 0, 2, 3), 'observation.state': state_frame, 'action': torch.zeros_like(zero_action_template), } queues = populate_queues(queues, observation) return queues def write_video(video_path: str, stacked_frames: list, fps: int) -> None: """Save a list of frames to a video file. Args: video_path (str): Output path for the video. stacked_frames (list): List of image frames. fps (int): Frames per second for the video. """ with warnings.catch_warnings(): warnings.filterwarnings("ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning) imageio.mimsave(video_path, stacked_frames, fps=fps) def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: """Return sorted list of files in a directory matching specified postfixes. Args: data_dir (str): Directory path to search in. postfixes (list[str]): List of file extensions to match. Returns: list[str]: Sorted list of file paths. """ patterns = [ os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes ] file_list = [] for pattern in patterns: file_list.extend(glob.glob(pattern)) file_list.sort() return file_list def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: """Load model weights from checkpoint file. Args: model (nn.Module): Model instance. ckpt (str): Path to the checkpoint file. Returns: nn.Module: Model with loaded weights. """ state_dict = torch.load(ckpt, map_location="cpu") if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] try: model.load_state_dict(state_dict, strict=True) except: new_pl_sd = OrderedDict() for k, v in state_dict.items(): new_pl_sd[k] = v for k in list(new_pl_sd.keys()): if "framestride_embed" in k: new_key = k.replace("framestride_embed", "fps_embedding") new_pl_sd[new_key] = new_pl_sd[k] del new_pl_sd[k] model.load_state_dict(new_pl_sd, strict=True) else: new_pl_sd = OrderedDict() for key in state_dict['module'].keys(): new_pl_sd[key[16:]] = state_dict['module'][key] model.load_state_dict(new_pl_sd) print('>>> model checkpoint loaded.') return model def is_inferenced(save_dir: str, filename: str) -> bool: """Check if a given filename has already been processed and saved. Args: save_dir (str): Directory where results are saved. filename (str): Name of the file to check. Returns: bool: True if processed file exists, False otherwise. """ video_file = os.path.join(save_dir, "samples_separate", f"{filename[:-4]}_sample0.mp4") return os.path.exists(video_file) def save_results(video: Tensor, filename: str, fps: int = 8) -> None: """Save video tensor to file using torchvision. Args: video (Tensor): Tensor of shape (B, C, T, H, W). filename (str): Output file path. fps (int, optional): Frames per second. Defaults to 8. """ video = video.detach().cpu() video = torch.clamp(video.float(), -1., 1.) n = video.shape[0] video = video.permute(2, 0, 1, 3, 4) frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video ] grid = torch.stack(frame_grids, dim=0) grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video(filename, grid, fps=fps, video_codec='h264', options={'crf': '10'}) # ========== Async I/O ========== _io_executor: Optional[ThreadPoolExecutor] = None _io_futures: List[Any] = [] _child_processes: List[mp.Process] = [] def _get_io_executor() -> ThreadPoolExecutor: global _io_executor if _io_executor is None: _io_executor = ThreadPoolExecutor(max_workers=2) return _io_executor def _flush_io(): """Wait for all pending async I/O to finish.""" global _io_futures for fut in _io_futures: try: fut.result() except Exception as e: print(f">>> [async I/O] error: {e}") _io_futures.clear() atexit.register(_flush_io) def _register_child_process(proc: mp.Process) -> None: _child_processes.append(proc) def _unregister_child_process(proc: mp.Process) -> None: try: _child_processes.remove(proc) except ValueError: pass def _terminate_process(proc: mp.Process, join_timeout: float = 3.0) -> None: if proc is None: return try: alive = proc.is_alive() except Exception: alive = False if not alive: try: proc.join(timeout=0.1) except Exception: pass return try: proc.terminate() except Exception: pass try: proc.join(timeout=join_timeout) except Exception: pass try: if proc.is_alive(): proc.kill() proc.join(timeout=join_timeout) except Exception: pass def _terminate_all_child_processes() -> None: for proc in list(_child_processes): _terminate_process(proc) _child_processes.clear() def _handle_termination_signal(signum, _frame) -> None: signame = signal.Signals(signum).name print(f">>> Received {signame}, terminating child processes ...") _terminate_all_child_processes() raise SystemExit(128 + signum) signal.signal(signal.SIGINT, _handle_termination_signal) signal.signal(signal.SIGTERM, _handle_termination_signal) atexit.register(_terminate_all_child_processes) def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None: """Synchronous save on CPU tensor (runs in background thread).""" video = torch.clamp(video_cpu.float(), -1., 1.) n = video.shape[0] video = video.permute(2, 0, 1, 3, 4) frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video ] grid = torch.stack(frame_grids, dim=0) grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video(filename, grid, fps=fps, video_codec='h264', options={'crf': '10'}) def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None: """Submit video saving to background thread pool.""" video_cpu = video.detach().cpu() fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps) _io_futures.append(fut) def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None: """Synchronous TensorBoard log on CPU tensor (runs in background thread).""" if video_cpu.dim() == 5: n = video_cpu.shape[0] video = video_cpu.permute(2, 0, 1, 3, 4) frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video ] grid = torch.stack(frame_grids, dim=0) grid = (grid + 1.0) / 2.0 grid = grid.unsqueeze(dim=0) writer.add_video(tag, grid, fps=fps) def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None: """Submit TensorBoard logging to background thread pool.""" if isinstance(data, torch.Tensor) and data.dim() == 5: data_cpu = data.detach().cpu() fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps) _io_futures.append(fut) def _video_tensor_to_frames(video: Tensor) -> np.ndarray: video = torch.clamp(video.float(), -1., 1.) n = video.shape[0] if n == 1: # Fast path for bs=1: skip make_grid and convert directly. frames = video[0].permute(1, 2, 3, 0).contiguous() frames = ((frames + 1.0) / 2.0 * 255).to(torch.uint8) return frames.numpy() video = video.permute(2, 0, 1, 3, 4) frame_grids = [ torchvision.utils.make_grid(f, nrow=int(n), padding=0) for f in video ] grid = torch.stack(frame_grids, dim=0) grid = ((grid + 1.0) / 2.0 * 255).to(torch.uint8).permute(0, 2, 3, 1) return grid.numpy() def _tensor_stats(name: str, tensor: Tensor) -> str: t = tensor.detach().float() finite_mask = torch.isfinite(t) finite_ratio = finite_mask.float().mean().item() if finite_mask.any(): tf = t[finite_mask] mean = tf.mean().item() std = tf.std(unbiased=False).item() min_v = tf.min().item() max_v = tf.max().item() abs_mean = tf.abs().mean().item() else: mean = std = min_v = max_v = abs_mean = float('nan') return (f"{name}: shape={tuple(t.shape)} dtype={tensor.dtype} " f"mean={mean:.6f} std={std:.6f} min={min_v:.6f} max={max_v:.6f} " f"abs_mean={abs_mean:.6f} finite={finite_ratio:.6f}") def _debug_world_model_stats(wm_samples: Tensor, wm_video: Tensor, prefix: str) -> None: print(f">>> [debug_wm_stats] {prefix} latent {_tensor_stats('wm_samples', wm_samples)}") print(f">>> [debug_wm_stats] {prefix} decoded {_tensor_stats('wm_video', wm_video)}") def _video_writer_process(q: mp.Queue, filename: str, fps: int): writer = None while True: item = q.get() if item is None: break frames = _video_tensor_to_frames(item) if writer is None: writer = imageio.get_writer( filename, fps=fps, codec='libx264', ffmpeg_params=['-crf', '10', '-pix_fmt', 'yuv420p']) for frame in frames: writer.append_data(frame) if writer is not None: writer.close() def _stop_writer_process(writer_proc: mp.Process, write_q: mp.Queue) -> None: try: write_q.put(None, timeout=1.0) except Exception: pass try: writer_proc.join(timeout=5.0) except Exception: pass try: if writer_proc.is_alive(): _terminate_process(writer_proc) except Exception: pass try: write_q.close() write_q.join_thread() except Exception: pass _unregister_child_process(writer_proc) def get_init_frame_path(data_dir: str, sample: dict) -> str: """Construct the init_frame path from directory and sample metadata. Args: data_dir (str): Base directory containing videos. sample (dict): Dictionary containing 'data_dir' and 'videoid'. Returns: str: Full path to the video file. """ rel_video_fp = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png') full_image_fp = os.path.join(data_dir, 'images', rel_video_fp) return full_image_fp def get_transition_path(data_dir: str, sample: dict) -> str: """Construct the full transition file path from directory and sample metadata. Args: data_dir (str): Base directory containing transition files. sample (dict): Dictionary containing 'data_dir' and 'videoid'. Returns: str: Full path to the HDF5 transition file. """ rel_transition_fp = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5') full_transition_fp = os.path.join(data_dir, 'transitions', rel_transition_fp) return full_transition_fp def prepare_init_input(start_idx: int, init_frame_path: str, transition_dict: dict[str, torch.Tensor], frame_stride: int, wma_data, video_length: int = 16, n_obs_steps: int = 2) -> dict[str, Tensor]: """ Extracts a structured sample from a video sequence including frames, states, and actions, along with properly padded observations and pre-processed tensors for model input. Args: start_idx (int): Starting frame index for the current clip. video: decord video instance. transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action', 'observation.state', 'action_type', 'state_type'. frame_stride (int): Temporal stride between sampled frames. wma_data: Object that holds configuration and utility functions like normalization, transformation, and resolution info. video_length (int, optional): Number of frames to sample from the video. Default is 16. n_obs_steps (int, optional): Number of historical steps for observations. Default is 2. """ indices = [start_idx + frame_stride * i for i in range(video_length)] init_frame = Image.open(init_frame_path).convert('RGB') init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute( 3, 0, 1, 2).float() if start_idx < n_obs_steps - 1: state_indices = list(range(0, start_idx + 1)) states = transition_dict['observation.state'][state_indices, :] num_padding = n_obs_steps - 1 - start_idx first_slice = states[0:1, :] # (t, d) padding = first_slice.repeat(num_padding, 1) states = torch.cat((padding, states), dim=0) else: state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1)) states = transition_dict['observation.state'][state_indices, :] actions = transition_dict['action'][indices, :] ori_state_dim = states.shape[-1] ori_action_dim = actions.shape[-1] frames_action_state_dict = { 'action': actions, 'observation.state': states, } frames_action_state_dict = wma_data.normalizer(frames_action_state_dict) frames_action_state_dict = wma_data.get_uni_vec( frames_action_state_dict, transition_dict['action_type'], transition_dict['state_type'], ) if wma_data.spatial_transform is not None: init_frame = wma_data.spatial_transform(init_frame) init_frame = (init_frame / 255 - 0.5) * 2 data = { 'observation.image': init_frame, } data.update(frames_action_state_dict) return data, ori_state_dim, ori_action_dim def get_latent_z(model, videos: Tensor) -> Tensor: """ Extracts latent features from a video batch using the model's first-stage encoder. Args: model: the world model. videos (Tensor): Input videos of shape [B, C, T, H, W]. Returns: Tensor: Latent video tensor of shape [B, C, T, H, W]. """ b, c, t, h, w = videos.shape x = rearrange(videos, 'b c t h w -> (b t) c h w') z = model.encode_first_stage(x) z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) return z def preprocess_observation( model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]: """Convert environment observation to LeRobot format observation. Args: observation: Dictionary of observation batches from a Gym vector environment. Returns: Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. """ # Map to expected inputs for the policy return_observations = {} if isinstance(observations["pixels"], dict): imgs = { f"observation.images.{key}": img for key, img in observations["pixels"].items() } else: imgs = {"observation.images.top": observations["pixels"]} for imgkey, img in imgs.items(): img = torch.from_numpy(img) # Sanity check that images are channel last _, h, w, c = img.shape assert c < h and c < w, f"expect channel first images, but instead {img.shape}" # Sanity check that images are uint8 assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" # Convert to channel first of type float32 in range [0,1] img = einops.rearrange(img, "b h w c -> b c h w").contiguous() img = img.type(torch.float32) return_observations[imgkey] = img return_observations["observation.state"] = torch.from_numpy( observations["agent_pos"]).float() return_observations['observation.state'] = model.normalize_inputs({ 'observation.state': return_observations['observation.state'].to(model.device) })['observation.state'] return return_observations def image_guided_synthesis_sim_mode( model: torch.nn.Module, prompts: list[str], observation: dict, noise_shape: tuple[int, int, int, int, int], action_cond_step: int = 16, n_samples: int = 1, ddim_steps: int = 50, ddim_eta: float = 1.0, unconditional_guidance_scale: float = 1.0, precision: int | None = 16, fs: int | None = None, text_input: bool = True, timestep_spacing: str = 'uniform', guidance_rescale: float = 0.0, sim_mode: bool = True, decode_video: bool = True, pipeline_split_step: int = 0, pipeline_compare_full: bool = False, handoff_callback=None, stop_at_handoff: bool = False, **kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]: """ Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text). Args: model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning. prompts (list[str]): A list of textual prompts to guide the synthesis process. observation (dict): A dictionary containing observed inputs including: - 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images) - 'observation.state': Tensor of shape [B, O, D] (state vector) - 'action': Tensor of shape [B, T, D] (action sequence) noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate, typically (B, C, T, H, W). action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16. n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1. ddim_steps (int): Number of DDIM sampling steps. Default is 50. ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0. unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off. fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None. text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True. timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace". guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. decode_video (bool): Whether to decode latent samples to pixel-space video. Set to False to skip VAE decode for speed when only actions/states are needed. **kwargs: Additional arguments passed to the DDIM sampler. Returns: batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W], or None when decode_video=False. actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding. states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding. """ b, _, t, _, _ = noise_shape ddim_sampler = DDIMSampler(model) batch_size = noise_shape[0] fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] cond_img_emb = model.embedder(cond_img) cond_img_emb = model.image_proj_model(cond_img_emb) if model.model.conditioning_key == 'hybrid': z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) img_cat_cond = z[:, :, -1:, :, :] img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=noise_shape[2]) cond = {"c_concat": [img_cat_cond]} if not text_input: prompts = [""] * batch_size cond_ins_emb = model.get_learned_conditioning(prompts) cond_state_emb = model.state_projector(observation['observation.state']) cond_state_emb = cond_state_emb + model.agent_state_pos_emb cond_action_emb = model.action_projector(observation['action']) cond_action_emb = cond_action_emb + model.agent_action_pos_emb if not sim_mode: cond_action_emb = torch.zeros_like(cond_action_emb) cond["c_crossattn"] = [ torch.cat( [cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1) ] cond["c_crossattn_action"] = [ observation['observation.images.top'][:, :, -model.n_obs_steps_acting:], observation['observation.state'][:, -model.n_obs_steps_acting:], sim_mode, False, ] uc = None kwargs.update({"unconditional_conditioning_img_nonetext": None}) cond_mask = None cond_z0 = None batch_variants = None if ddim_sampler is not None: sample_kwargs = dict( S=ddim_steps, conditioning=cond, batch_size=batch_size, shape=noise_shape[1:], verbose=False, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=uc, eta=ddim_eta, cfg_img=None, mask=cond_mask, x0=cond_z0, precision=precision, fs=fs, timestep_spacing=timestep_spacing, guidance_rescale=guidance_rescale, **kwargs) samples, actions, states, intermedia = ddim_sampler.sample( **sample_kwargs, handoff_step=pipeline_split_step, handoff_callback=handoff_callback, stop_at_handoff=stop_at_handoff) if decode_video: # Reconstruct from latent to pixel space batch_images = model.decode_first_stage(samples) batch_variants = batch_images return batch_variants, actions, states, samples, intermedia def load_inference_runtime(args: argparse.Namespace, gpu_no: int) -> SimpleNamespace: """Load the model once and keep all GPU-side runtime state together.""" config = OmegaConf.load(args.config) prepared_path = args.ckpt_path + ".prepared.pt" if os.path.exists(prepared_path): print(f">>> Loading prepared model from {prepared_path} ...") model = torch.load(prepared_path, map_location=f"cuda:{gpu_no}", weights_only=False, mmap=True) model.eval() model = model.cuda(gpu_no) diffusion_model = model.model.diffusion_model if not hasattr(diffusion_model, '_ctx_cache_enabled'): diffusion_model._ctx_cache_enabled = False if not hasattr(diffusion_model, '_ctx_cache'): diffusion_model._ctx_cache = {} if not hasattr(diffusion_model, '_trt_backbone'): diffusion_model._trt_backbone = None if not hasattr(diffusion_model, '_state_stream'): diffusion_model._state_stream = torch.cuda.Stream( device=torch.device(f"cuda:{gpu_no}")) print(f">>> Prepared model loaded.") else: config['model']['params']['wma_config']['params'][ 'use_checkpoint'] = False model = instantiate_from_config(config.model) model.perframe_ae = args.perframe_ae assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" model = load_model_checkpoint(model, args.ckpt_path) model.eval() model = model.cuda(gpu_no) print(f'>>> Load pre-trained model ...') print(f">>> Saving prepared model to {prepared_path} ...") torch.save(model, prepared_path) print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).") device = get_device_from_parameters(model) sync_module_device_attributes(model, device) if hasattr(model, 'cond_stage_model') and model.cond_stage_model is not None: sync_module_device_attributes(model.cond_stage_model, device) if hasattr(model.model.diffusion_model, '_state_stream'): model.model.diffusion_model._state_stream = torch.cuda.Stream( device=device) # Fuse KV projections in attention layers (to_k + to_v -> to_kv). from unifolm_wma.modules.attention import CrossAttention kv_count = sum(1 for m in model.modules() if isinstance(m, CrossAttention) and m.fuse_kv()) print(f" ✓ KV fused: {kv_count} attention layers") # Load TRT backbone if engine exists and the user did not request Torch fallback. trt_engine_path = args.trt_engine_path if trt_engine_path is None: trt_engine_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'trt_engines', 'video_backbone.engine') if args.disable_trt: print(">>> TRT disabled by --disable_trt; using PyTorch video backbone.") elif os.path.exists(trt_engine_path): model.model.diffusion_model.load_trt_backbone(trt_engine_path) else: print(f">>> TRT engine not found at {trt_engine_path}; using PyTorch video backbone.") if torch.cuda.is_available(): torch.cuda.synchronize(device) return SimpleNamespace(model=model, config=config, device=device, gpu_no=gpu_no) def load_inference_data(config: OmegaConf, args: argparse.Namespace) -> SimpleNamespace: logging.info("***** Configing Data *****") test_cfg = OmegaConf.create( OmegaConf.to_container(config.data.params.test, resolve=True)) stats_root = OmegaConf.select(config, "data.params.test.params.data_dir") if stats_root is None: stats_root = args.prompt_dir test_cfg.params["data_dir"] = stats_root test_cfg.params["meta_path"] = os.path.join(args.prompt_dir, f"{args.dataset}.csv") test_cfg.params["transition_dir"] = os.path.join(stats_root, "transitions") test_cfg.params["dataset_name"] = args.dataset target_dataset = instantiate_from_config(test_cfg) data = SimpleNamespace(test_datasets={args.dataset: target_dataset}) print(">>> Dataset is successfully loaded ...") return data def build_pipeline_iteration_input( runtime: SimpleNamespace, args: argparse.Namespace, base_queues: dict[str, deque], action_seq: torch.Tensor, noise_shape: list[int], model_input_fs: int, ori_action_dim: int, ori_state_dim: int) -> dict[str, deque]: model = runtime.model device = runtime.device policy_observation = build_observation_from_queues(base_queues, device) pipeline_action_queues = clone_observation_queues(base_queues) pipeline_action_queues = append_action_sequence(pipeline_action_queues, action_seq, ori_action_dim) wm_handoff_observation = { 'observation.images.top': policy_observation['observation.images.top'], 'observation.state': policy_observation['observation.state'], 'action': torch.stack(list(pipeline_action_queues['action']), dim=1).to(device, non_blocking=True), } _, _, pred_states, wm_samples, wm_intermedia = image_guided_synthesis_sim_mode( model, "", wm_handoff_observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args.unconditional_guidance_scale, precision=args.precision, fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, decode_video=False, pipeline_split_step=args.pipeline_split_step, pipeline_compare_full=False, stop_at_handoff=True) wm_handoff = wm_intermedia.get('handoff', {}) handoff_video_latent = wm_handoff.get('pred_x0', wm_handoff.get( 'samples', None)) if handoff_video_latent is None: handoff_video_latent = wm_samples handoff_video = model.decode_first_stage(handoff_video_latent) handoff_state_seq = wm_handoff.get('states', pred_states) cond_obs_queues = clone_observation_queues(base_queues) cond_obs_queues = append_action_sequence(cond_obs_queues, action_seq, ori_action_dim) cond_obs_queues = rollout_execution_segment( queues=cond_obs_queues, seg_video=handoff_video[:, :, :args.exe_steps], pred_states=handoff_state_seq, zero_action_template=action_seq[0][-1:], exe_steps=args.exe_steps, ori_state_dim=ori_state_dim, zero_pred_state=args.zero_pred_state) return cond_obs_queues def run_pipeline_iteration_task( runtime: SimpleNamespace, args: argparse.Namespace, iter_idx: int, instruction: str, input_payload: dict[str, Any], noise_shape: list[int], model_input_fs: int, ori_action_dim: int, ori_state_dim: int, handoff_queue: Queue) -> dict[str, Any]: model = runtime.model device = runtime.device if input_payload['mode'] == 'initial': cond_obs_queues = move_observation_queues_to_device( input_payload['queues_cpu'], device) elif input_payload['mode'] == 'handoff': base_queues = move_observation_queues_to_device( input_payload['base_queues_cpu'], device) action_seq = input_payload['action_seq_cpu'].to( device, non_blocking=True) source_iter = input_payload['source_iter'] pipeline_print( f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: building pipeline input ' f'from step {source_iter} handoff ...') cond_obs_queues = build_pipeline_iteration_input( runtime=runtime, args=args, base_queues=base_queues, action_seq=action_seq, noise_shape=noise_shape, model_input_fs=model_input_fs, ori_action_dim=ori_action_dim, ori_state_dim=ori_state_dim) else: raise ValueError(f"Unsupported input mode: {input_payload['mode']}") iter_input_queues = clone_observation_queues(cond_obs_queues) iter_input_queues_cpu = clone_observation_queues_to_cpu(iter_input_queues) handoff_sent = False def _on_policy_handoff(handoff: dict[str, torch.Tensor]) -> None: nonlocal handoff_sent if handoff_sent or (iter_idx + 1) >= args.n_iter: return handoff_queue.put({ 'source_iter': iter_idx, 'next_iter': iter_idx + 1, 'base_queues_cpu': iter_input_queues_cpu, 'action_seq_cpu': handoff['actions'].detach().cpu().clone(), 'handoff_step': handoff.get('step', args.pipeline_split_step), }) handoff_sent = True handoff_step = handoff.get('step', args.pipeline_split_step) pipeline_print( f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: emitted policy handoff ' f'for step {iter_idx + 1} at ddim step ' f'{handoff_step} ...') policy_observation = build_observation_from_queues(cond_obs_queues, device) pipeline_print( f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: generating actions ...') _, pred_actions, _, _, _ = image_guided_synthesis_sim_mode( model, instruction, policy_observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args.unconditional_guidance_scale, precision=args.precision, fs=model_input_fs, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, sim_mode=False, decode_video=False, pipeline_split_step=args.pipeline_split_step, pipeline_compare_full=False, handoff_callback=_on_policy_handoff if args.pipeline_split_step > 0 else None) cond_obs_queues = append_action_sequence(cond_obs_queues, pred_actions, ori_action_dim) wm_observation = { 'observation.images.top': policy_observation['observation.images.top'], 'observation.state': policy_observation['observation.state'], 'action': torch.stack(list(cond_obs_queues['action']), dim=1).to( device, non_blocking=True), } pipeline_print( f'>>> Step {iter_idx}@gpu{runtime.gpu_no}: interacting with world model ...') _, _, pred_states, wm_samples, _ = image_guided_synthesis_sim_mode( model, "", wm_observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args.unconditional_guidance_scale, precision=args.precision, fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, decode_video=False, pipeline_split_step=0, pipeline_compare_full=False) wm_video = model.decode_first_stage(wm_samples) if args.debug_wm_stats: _debug_world_model_stats(wm_samples, wm_video, prefix=f"step={iter_idx}@gpu{runtime.gpu_no}") seg_video = wm_video[:, :, :args.exe_steps] pipeline_print('>' * 24) return { 'iter_idx': iter_idx, 'gpu_no': runtime.gpu_no, 'seg_video_cpu': seg_video.detach().cpu(), } def run_inference_multi_gpu_pipeline( args: argparse.Namespace, runtimes: list[SimpleNamespace]) -> None: os.makedirs(args.savedir + '/inference', exist_ok=True) config = runtimes[0].config data = load_inference_data(config, args) df = pd.read_csv(os.path.join(args.prompt_dir, f"{args.dataset}.csv")) model0 = runtimes[0].model assert (args.height % 16 == 0) and ( args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" assert args.bs == 1, "Current implementation only support [batch size = 1]!" h, w = args.height // 8, args.width // 8 channels = model0.model.diffusion_model.out_channels n_frames = args.video_length pipeline_print(f'>>> Generate {n_frames} frames under each generation ...') noise_shape = [args.bs, channels, n_frames, h, w] gpu_ids = [runtime.gpu_no for runtime in runtimes] pipeline_print(f'>>> Multi-GPU pipeline enabled on GPUs {gpu_ids}.') for idx in range(0, len(df)): sample = df.iloc[idx] init_frame_path = get_init_frame_path(args.prompt_dir, sample) ori_fps = float(sample['fps']) video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}" os.makedirs(video_save_dir, exist_ok=True) transition_path = get_transition_path(args.prompt_dir, sample) with h5py.File(transition_path, 'r') as h5f: transition_dict = {} for key in h5f.keys(): transition_dict[key] = torch.tensor(h5f[key][()]) for key in h5f.attrs.keys(): transition_dict[key] = h5f.attrs[key] for fs in args.frame_stride: sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" write_q = mp.Queue(maxsize=4) writer_proc = mp.Process(target=_video_writer_process, args=(write_q, sample_full_video_file, args.save_fps)) writer_proc.daemon = True writer_proc.start() _register_child_process(writer_proc) executors = { runtime.gpu_no: ThreadPoolExecutor(max_workers=1) for runtime in runtimes } handoff_events = Queue() future_meta = {} submitted_iters = set() completed_results = {} next_write_idx = 0 model_input_fs = ori_fps // fs progress = tqdm(total=args.n_iter, ascii=False, desc=f'fs={fs} pipeline', leave=True) try: batch, ori_state_dim, ori_action_dim = prepare_init_input( 0, init_frame_path, transition_dict, fs, data.test_datasets[args.dataset], n_obs_steps=model0.n_obs_steps_imagen) initial_observation = { 'observation.images.top': batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0), 'observation.state': batch['observation.state'][-1].unsqueeze(0), 'action': torch.zeros_like(batch['action'][-1]).unsqueeze(0), } initial_queues = { "observation.images.top": deque(maxlen=model0.n_obs_steps_imagen), "observation.state": deque(maxlen=model0.n_obs_steps_imagen), "action": deque(maxlen=args.video_length), } initial_queues = populate_queues(initial_queues, initial_observation) initial_payload = { 'mode': 'initial', 'queues_cpu': clone_observation_queues_to_cpu(initial_queues), } first_gpu = runtimes[0].gpu_no future = executors[first_gpu].submit( run_pipeline_iteration_task, runtimes[0], args, 0, sample['instruction'], initial_payload, noise_shape, model_input_fs, ori_action_dim, ori_state_dim, handoff_events) future_meta[future] = {'iter_idx': 0, 'gpu_no': first_gpu} submitted_iters.add(0) while next_write_idx < args.n_iter: done_futures = [fut for fut in list(future_meta.keys()) if fut.done()] for fut in done_futures: meta = future_meta.pop(fut) result = fut.result() completed_results[result['iter_idx']] = result pipeline_print( f'>>> Step {result["iter_idx"]}@gpu{meta["gpu_no"]}: ' f'final segment ready.') while True: try: event = handoff_events.get_nowait() except Empty: break next_iter = event['next_iter'] if next_iter >= args.n_iter or next_iter in submitted_iters: continue target_runtime = runtimes[next_iter % len(runtimes)] payload = { 'mode': 'handoff', 'base_queues_cpu': event['base_queues_cpu'], 'action_seq_cpu': event['action_seq_cpu'], 'source_iter': event['source_iter'], 'handoff_step': event['handoff_step'], } pipeline_print( f'>>> Step {event["source_iter"]}: queueing step ' f'{next_iter} on gpu{target_runtime.gpu_no} from ' f'ddim step {event["handoff_step"]} ...') future = executors[target_runtime.gpu_no].submit( run_pipeline_iteration_task, target_runtime, args, next_iter, sample['instruction'], payload, noise_shape, model_input_fs, ori_action_dim, ori_state_dim, handoff_events) future_meta[future] = { 'iter_idx': next_iter, 'gpu_no': target_runtime.gpu_no } submitted_iters.add(next_iter) while next_write_idx in completed_results: result = completed_results.pop(next_write_idx) write_q.put(result['seg_video_cpu']) next_write_idx += 1 progress.update(1) progress.set_postfix_str( f'written={next_write_idx}/{args.n_iter}', refresh=False) if next_write_idx < args.n_iter: if not future_meta and len(submitted_iters) >= args.n_iter: raise RuntimeError( "Pipeline scheduler is idle before all iterations finished.") time.sleep(0.05) finally: progress.close() for executor in executors.values(): executor.shutdown(wait=True, cancel_futures=False) _stop_writer_process(writer_proc, write_q) _flush_io() def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int, runtime: Optional[SimpleNamespace] = None) -> None: """ Run inference pipeline on prompts and image inputs. Args: args (argparse.Namespace): Parsed command-line arguments. gpu_num (int): Number of GPUs. gpu_no (int): Index of the current GPU. Returns: None """ # Create inference dir os.makedirs(args.savedir + '/inference', exist_ok=True) # Load prompt csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") df = pd.read_csv(csv_path) # Load config (always needed for data setup) config = runtime.config if runtime is not None else OmegaConf.load(args.config) # Parallel loading: model and data are independent def _load_model(): return load_inference_runtime(args, gpu_no).model def _load_data(): return load_inference_data(config, args) if runtime is None: with ThreadPoolExecutor(max_workers=2) as executor: model_future = executor.submit(_load_model) data_future = executor.submit(_load_data) model = model_future.result() data = data_future.result() device = get_device_from_parameters(model) else: model = runtime.model device = runtime.device data = _load_data() # Run over data assert (args.height % 16 == 0) and ( args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" assert args.bs == 1, "Current implementation only support [batch size = 1]!" # Get latent noise shape h, w = args.height // 8, args.width // 8 channels = model.model.diffusion_model.out_channels n_frames = args.video_length print(f'>>> Generate {n_frames} frames under each generation ...') noise_shape = [args.bs, channels, n_frames, h, w] # Start inference for idx in range(0, len(df)): sample = df.iloc[idx] # Got initial frame path init_frame_path = get_init_frame_path(args.prompt_dir, sample) ori_fps = float(sample['fps']) video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}" os.makedirs(video_save_dir, exist_ok=True) # Load transitions to get the initial state later transition_path = get_transition_path(args.prompt_dir, sample) with h5py.File(transition_path, 'r') as h5f: transition_dict = {} for key in h5f.keys(): transition_dict[key] = torch.tensor(h5f[key][()]) for key in h5f.attrs.keys(): transition_dict[key] = h5f.attrs[key] # If many, test various frequence control and world-model generation for fs in args.frame_stride: # Writer process for incremental video saving sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" write_q = mp.Queue(maxsize=4) writer_proc = mp.Process( target=_video_writer_process, args=(write_q, sample_full_video_file, args.save_fps)) writer_proc.daemon = True writer_proc.start() _register_child_process(writer_proc) try: # Initialize observation queues cond_obs_queues = { "observation.images.top": deque(maxlen=model.n_obs_steps_imagen), "observation.state": deque(maxlen=model.n_obs_steps_imagen), "action": deque(maxlen=args.video_length), } # Obtain initial frame and state start_idx = 0 model_input_fs = ori_fps // fs batch, ori_state_dim, ori_action_dim = prepare_init_input( start_idx, init_frame_path, transition_dict, fs, data.test_datasets[args.dataset], n_obs_steps=model.n_obs_steps_imagen) observation = { 'observation.images.top': batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0), 'observation.state': batch['observation.state'][-1].unsqueeze(0), 'action': torch.zeros_like(batch['action'][-1]).unsqueeze(0) } observation = { key: observation[key].to(device, non_blocking=True) for key in observation } # Update observation queues cond_obs_queues = populate_queues(cond_obs_queues, observation) # Multi-round interaction with the world-model pending_handoff = None pending_handoff_source_itr = None pending_handoff_ddim_step = None for itr in tqdm(range(args.n_iter), ascii=False): # Build observation for policy pass. if pending_handoff is not None: print( f'>>> Step {itr}: consuming pipeline handoff ' f'from step {pending_handoff_source_itr} ' f'at ddim step {pending_handoff_ddim_step} ...') cond_obs_queues = pending_handoff pending_handoff = None pending_handoff_source_itr = None pending_handoff_ddim_step = None iter_input_queues = clone_observation_queues(cond_obs_queues) policy_observation = build_observation_from_queues( cond_obs_queues, device) # Use world-model in policy to generate action print(f'>>> Step {itr}: generating actions ...') _, pred_actions, _, _, policy_intermedia = image_guided_synthesis_sim_mode( model, sample['instruction'], policy_observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. unconditional_guidance_scale, precision=args.precision, fs=model_input_fs, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, sim_mode=False, decode_video=False, pipeline_split_step=args.pipeline_split_step, pipeline_compare_full=False) pipeline_action_seq = None pipeline_pred_states = None pipeline_wm_samples = None pipeline_wm_intermedia = {} if args.pipeline_split_step > 0 and (itr + 1) < args.n_iter: pipeline_action_seq = policy_intermedia.get( 'handoff', {}).get('actions', pred_actions) pipeline_action_queues = clone_observation_queues( iter_input_queues) pipeline_action_queues = append_action_sequence( pipeline_action_queues, pipeline_action_seq, ori_action_dim) wm_handoff_observation = { 'observation.images.top': policy_observation['observation.images.top'], 'observation.state': policy_observation['observation.state'], 'action': torch.stack(list(pipeline_action_queues['action']), dim=1).to(device, non_blocking=True), } print( f'>>> Step {itr}: preparing pipeline handoff branch ...' ) _, _, pipeline_pred_states, pipeline_wm_samples, pipeline_wm_intermedia = image_guided_synthesis_sim_mode( model, "", wm_handoff_observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. unconditional_guidance_scale, precision=args.precision, fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, decode_video=False, pipeline_split_step=args.pipeline_split_step, pipeline_compare_full=False) # Update future actions in the observation queues cond_obs_queues = append_action_sequence( cond_obs_queues, pred_actions, ori_action_dim) # Reuse images/state and only rebuild action for WM pass. wm_observation = { 'observation.images.top': policy_observation['observation.images.top'], 'observation.state': policy_observation['observation.state'], 'action': torch.stack(list(cond_obs_queues['action']), dim=1).to( device, non_blocking=True), } # Interaction with the world-model print(f'>>> Step {itr}: interacting with world model ...') _, _, pred_states, wm_samples, wm_intermedia = image_guided_synthesis_sim_mode( model, "", wm_observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. unconditional_guidance_scale, precision=args.precision, fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, decode_video=False, pipeline_split_step=args.pipeline_split_step, pipeline_compare_full=False) # Decode full WM clip, then take executable segment to keep behavior closer to previous path. wm_video = model.decode_first_stage(wm_samples) if args.debug_wm_stats: _debug_world_model_stats( wm_samples, wm_video, prefix=f"step={itr}") seg_video = wm_video[:, :, :args.exe_steps] cond_obs_queues = rollout_execution_segment( queues=cond_obs_queues, seg_video=seg_video, pred_states=pred_states, zero_action_template=pred_actions[0][-1:], exe_steps=args.exe_steps, ori_state_dim=ori_state_dim, zero_pred_state=args.zero_pred_state) if args.pipeline_split_step > 0 and (itr + 1) < args.n_iter: policy_handoff = policy_intermedia.get('handoff', {}) wm_handoff = pipeline_wm_intermedia.get('handoff', {}) handoff_ddim_step = wm_handoff.get( 'step', policy_handoff.get('step', args.pipeline_split_step)) next_action_seq = pipeline_action_seq handoff_video_latent = wm_handoff.get( 'pred_x0', wm_handoff.get('samples', None)) if handoff_video_latent is None: handoff_video_latent = pipeline_wm_samples handoff_video = model.decode_first_stage( handoff_video_latent) handoff_state_seq = wm_handoff.get('states', pipeline_pred_states) pending_handoff = clone_observation_queues( iter_input_queues) pending_handoff = append_action_sequence( pending_handoff, next_action_seq, ori_action_dim) pending_handoff = rollout_execution_segment( queues=pending_handoff, seg_video=handoff_video[:, :, :args.exe_steps], pred_states=handoff_state_seq, zero_action_template=next_action_seq[0][-1:], exe_steps=args.exe_steps, ori_state_dim=ori_state_dim, zero_pred_state=args.zero_pred_state) pending_handoff_source_itr = itr pending_handoff_ddim_step = handoff_ddim_step print( f'>>> Step {itr}: prepared pipeline handoff ' f'for step {itr + 1} at ddim step ' f'{handoff_ddim_step} ...') print('>' * 24) # Send decoded segment to writer process write_q.put(seg_video.detach().cpu()) finally: _stop_writer_process(writer_proc, write_q) # Wait for all async I/O to complete _flush_io() def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--savedir", type=str, default=None, help="Path to save the results.") parser.add_argument("--ckpt_path", type=str, default=None, help="Path to the model checkpoint.") parser.add_argument("--config", type=str, help="Path to the model checkpoint.") parser.add_argument( "--prompt_dir", type=str, default=None, help="Directory containing videos and corresponding prompts.") parser.add_argument("--dataset", type=str, default=None, help="the name of dataset to test") parser.add_argument( "--ddim_steps", type=int, default=50, help="Number of DDIM steps. If non-positive, DDPM is used instead.") parser.add_argument( "--ddim_eta", type=float, default=1.0, help="Eta for DDIM sampling. Set to 0.0 for deterministic results.") parser.add_argument("--bs", type=int, default=1, help="Batch size for inference. Must be 1.") parser.add_argument("--height", type=int, default=320, help="Height of the generated images in pixels.") parser.add_argument("--width", type=int, default=512, help="Width of the generated images in pixels.") parser.add_argument( "--frame_stride", type=int, nargs='+', required=True, help= "frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)" ) parser.add_argument( "--unconditional_guidance_scale", type=float, default=1.0, help="Scale for classifier-free guidance during sampling.") parser.add_argument("--seed", type=int, default=123, help="Random seed for reproducibility.") parser.add_argument("--video_length", type=int, default=16, help="Number of frames in the generated video.") parser.add_argument("--num_generation", type=int, default=1, help="seed for seed_everything") parser.add_argument( "--timestep_spacing", type=str, default="uniform", help= "Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." ) parser.add_argument( "--guidance_rescale", type=float, default=0.0, help= "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." ) parser.add_argument( "--perframe_ae", action='store_true', default=False, help= "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." ) parser.add_argument( "--precision", type=int, default=16, choices=[16, 32], help="Sampling precision for latent/action/state noise initialization. Default is 16.") parser.add_argument( "--n_action_steps", type=int, default=16, help="num of samples per prompt", ) parser.add_argument( "--exe_steps", type=int, default=16, help="num of samples to execute", ) parser.add_argument( "--n_iter", type=int, default=40, help="num of iteration to interact with the world model", ) parser.add_argument("--zero_pred_state", action='store_true', default=False, help="not using the predicted states as comparison") parser.add_argument( "--fast_policy_no_decode", action='store_true', default=False, help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.") parser.add_argument( "--debug_wm_stats", action='store_true', default=False, help="Print latent/decode statistics for world-model samples before writing video.") parser.add_argument( "--disable_trt", action='store_true', default=False, help="Disable TensorRT backbone loading and force the PyTorch video backbone path.") parser.add_argument( "--trt_engine_path", type=str, default=None, help="Optional explicit TensorRT engine path. Defaults to trt_engines/video_backbone.engine.") parser.add_argument("--save_fps", type=int, default=8, help="fps for the saving video") parser.add_argument( "--pipeline_split_step", type=int, default=0, help="Run DDIM sampling in two segments on a single GPU for pipeline experiments. " "Set to a value like 25 to test 25+25 splitting; 0 disables it.") parser.add_argument( "--pipeline_compare_full", action='store_true', default=False, help="When pipeline_split_step is enabled, also run the original full DDIM pass " "with the same initial noise and print max-abs diffs for validation.") parser.add_argument( "--pipeline_multi_gpu", action='store_true', default=False, help="Enable true asynchronous pipeline execution across multiple GPUs. " "When disabled, keep the original single-GPU/serial inference path.") parser.add_argument( "--pipeline_gpu_ids", type=int, nargs='+', default=[0, 1], help="Logical CUDA device ids used by the multi-GPU pipeline scheduler.") return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() seed = args.seed if seed < 0: seed = random.randint(0, 2**31) seed_everything(seed) if args.pipeline_multi_gpu: if args.pipeline_split_step <= 0: raise ValueError( "--pipeline_multi_gpu requires --pipeline_split_step > 0.") if len(args.pipeline_gpu_ids) < 2: raise ValueError( "--pipeline_multi_gpu requires at least two gpu ids.") runtimes = [ load_inference_runtime(args, gpu_id) for gpu_id in args.pipeline_gpu_ids ] run_inference_multi_gpu_pipeline(args, runtimes[:2]) else: rank, gpu_num = 0, 1 run_inference(args, gpu_num, rank)