import argparse, os, glob import pandas as pd import random import torch import torchvision import h5py import numpy as np import logging import einops import warnings import imageio import time import json from contextlib import contextmanager, nullcontext from dataclasses import dataclass, field, asdict from typing import Optional, Dict, List, Any from pytorch_lightning import seed_everything from omegaconf import OmegaConf from tqdm import tqdm from einops import rearrange, repeat from collections import OrderedDict from torch import nn from eval_utils import populate_queues, log_to_tensorboard from collections import deque from torch import Tensor from torch.utils.tensorboard import SummaryWriter from PIL import Image from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.utils.utils import instantiate_from_config # ========== Profiling Infrastructure ========== @dataclass class TimingRecord: """Record for a single timing measurement.""" name: str start_time: float = 0.0 end_time: float = 0.0 cuda_time_ms: float = 0.0 count: int = 0 children: List['TimingRecord'] = field(default_factory=list) @property def cpu_time_ms(self) -> float: return (self.end_time - self.start_time) * 1000 def to_dict(self) -> dict: return { 'name': self.name, 'cpu_time_ms': self.cpu_time_ms, 'cuda_time_ms': self.cuda_time_ms, 'count': self.count, 'children': [c.to_dict() for c in self.children] } class ProfilerManager: """Manages macro and micro-level profiling.""" def __init__( self, enabled: bool = False, output_dir: str = "./profile_output", profile_detail: str = "light", ): self.enabled = enabled self.output_dir = output_dir self.profile_detail = profile_detail self.macro_timings: Dict[str, List[float]] = {} self.cuda_events: Dict[str, List[tuple]] = {} self.memory_snapshots: List[Dict] = [] self.pytorch_profiler = None self.current_iteration = 0 self.operator_stats: Dict[str, Dict] = {} self.profiler_config = self._build_profiler_config(profile_detail) if enabled: os.makedirs(output_dir, exist_ok=True) def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]: """Return profiler settings based on the requested detail level.""" if profile_detail not in ("light", "full"): raise ValueError(f"Unsupported profile_detail: {profile_detail}") if profile_detail == "full": return { "record_shapes": True, "profile_memory": True, "with_stack": True, "with_flops": True, "with_modules": True, "group_by_input_shape": True, } return { "record_shapes": False, "profile_memory": False, "with_stack": False, "with_flops": False, "with_modules": False, "group_by_input_shape": False, } @contextmanager def profile_section(self, name: str, sync_cuda: bool = True): """Context manager for profiling a code section.""" if not self.enabled: yield return if sync_cuda and torch.cuda.is_available(): torch.cuda.synchronize() start_event = None end_event = None if torch.cuda.is_available(): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() start_time = time.perf_counter() try: yield finally: if sync_cuda and torch.cuda.is_available(): torch.cuda.synchronize() end_time = time.perf_counter() cpu_time_ms = (end_time - start_time) * 1000 cuda_time_ms = 0.0 if start_event is not None and end_event is not None: end_event.record() torch.cuda.synchronize() cuda_time_ms = start_event.elapsed_time(end_event) if name not in self.macro_timings: self.macro_timings[name] = [] self.macro_timings[name].append(cpu_time_ms) if name not in self.cuda_events: self.cuda_events[name] = [] self.cuda_events[name].append((cpu_time_ms, cuda_time_ms)) def record_memory(self, tag: str = ""): """Record current GPU memory state.""" if not self.enabled or not torch.cuda.is_available(): return snapshot = { 'tag': tag, 'iteration': self.current_iteration, 'allocated_mb': torch.cuda.memory_allocated() / 1024**2, 'reserved_mb': torch.cuda.memory_reserved() / 1024**2, 'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2, } self.memory_snapshots.append(snapshot) def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3): """Start PyTorch profiler for operator-level analysis.""" if not self.enabled: return nullcontext() self.pytorch_profiler = torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule( wait=wait, warmup=warmup, active=active, repeat=1 ), on_trace_ready=self._trace_handler, record_shapes=self.profiler_config["record_shapes"], profile_memory=self.profiler_config["profile_memory"], with_stack=self.profiler_config["with_stack"], with_flops=self.profiler_config["with_flops"], with_modules=self.profiler_config["with_modules"], ) return self.pytorch_profiler def _trace_handler(self, prof): """Handle profiler trace output.""" trace_path = os.path.join( self.output_dir, f"trace_iter_{self.current_iteration}.json" ) prof.export_chrome_trace(trace_path) # Extract operator statistics key_averages = prof.key_averages( group_by_input_shape=self.profiler_config["group_by_input_shape"] ) for evt in key_averages: op_name = evt.key if op_name not in self.operator_stats: self.operator_stats[op_name] = { 'count': 0, 'cpu_time_total_us': 0, 'cuda_time_total_us': 0, 'self_cpu_time_total_us': 0, 'self_cuda_time_total_us': 0, 'cpu_memory_usage': 0, 'cuda_memory_usage': 0, 'flops': 0, } stats = self.operator_stats[op_name] stats['count'] += evt.count stats['cpu_time_total_us'] += evt.cpu_time_total stats['cuda_time_total_us'] += evt.cuda_time_total stats['self_cpu_time_total_us'] += evt.self_cpu_time_total stats['self_cuda_time_total_us'] += evt.self_cuda_time_total if hasattr(evt, 'cpu_memory_usage'): stats['cpu_memory_usage'] += evt.cpu_memory_usage if hasattr(evt, 'cuda_memory_usage'): stats['cuda_memory_usage'] += evt.cuda_memory_usage if hasattr(evt, 'flops') and evt.flops: stats['flops'] += evt.flops def step_profiler(self): """Step the PyTorch profiler.""" if self.pytorch_profiler is not None: self.pytorch_profiler.step() def generate_report(self) -> str: """Generate comprehensive profiling report.""" if not self.enabled: return "Profiling disabled." report_lines = [] report_lines.append("=" * 80) report_lines.append("PERFORMANCE PROFILING REPORT") report_lines.append("=" * 80) report_lines.append("") # Macro-level timing summary report_lines.append("-" * 40) report_lines.append("MACRO-LEVEL TIMING SUMMARY") report_lines.append("-" * 40) report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}") report_lines.append("-" * 86) total_time = 0 timing_data = [] for name, times in sorted(self.macro_timings.items()): cuda_times = [ct for _, ct in self.cuda_events.get(name, [])] avg_time = np.mean(times) avg_cuda = np.mean(cuda_times) if cuda_times else 0 total = sum(times) total_time += total timing_data.append({ 'name': name, 'count': len(times), 'total_ms': total, 'avg_ms': avg_time, 'cuda_avg_ms': avg_cuda, 'times': times, 'cuda_times': cuda_times, }) report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}") report_lines.append("-" * 86) report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}") report_lines.append("") # Memory summary if self.memory_snapshots: report_lines.append("-" * 40) report_lines.append("GPU MEMORY SUMMARY") report_lines.append("-" * 40) max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots) avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots]) report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB") report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB") report_lines.append("") # Top operators by CUDA time if self.operator_stats: report_lines.append("-" * 40) report_lines.append("TOP 30 OPERATORS BY CUDA TIME") report_lines.append("-" * 40) sorted_ops = sorted( self.operator_stats.items(), key=lambda x: x[1]['cuda_time_total_us'], reverse=True )[:30] report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}") report_lines.append("-" * 96) for op_name, stats in sorted_ops: # Truncate long operator names display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name report_lines.append( f"{display_name:<50} {stats['count']:>8} " f"{stats['cuda_time_total_us']/1000:>12.2f} " f"{stats['cpu_time_total_us']/1000:>12.2f} " f"{stats['self_cuda_time_total_us']/1000:>14.2f}" ) report_lines.append("") # Compute category breakdown report_lines.append("-" * 40) report_lines.append("OPERATOR CATEGORY BREAKDOWN") report_lines.append("-" * 40) categories = { 'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'], 'Convolution': ['conv', 'cudnn'], 'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'], 'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'], 'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'], 'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'], 'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'], } category_times = {cat: 0.0 for cat in categories} category_times['Other'] = 0.0 for op_name, stats in self.operator_stats.items(): op_lower = op_name.lower() categorized = False for cat, keywords in categories.items(): if any(kw in op_lower for kw in keywords): category_times[cat] += stats['cuda_time_total_us'] categorized = True break if not categorized: category_times['Other'] += stats['cuda_time_total_us'] total_op_time = sum(category_times.values()) report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}") report_lines.append("-" * 57) for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]): pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0 report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%") report_lines.append("") report = "\n".join(report_lines) return report def save_results(self): """Save all profiling results to files.""" if not self.enabled: return # Save report report = self.generate_report() report_path = os.path.join(self.output_dir, "profiling_report.txt") with open(report_path, 'w') as f: f.write(report) print(f">>> Profiling report saved to: {report_path}") # Save detailed JSON data data = { 'macro_timings': { name: { 'times': times, 'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])] } for name, times in self.macro_timings.items() }, 'memory_snapshots': self.memory_snapshots, 'operator_stats': self.operator_stats, } json_path = os.path.join(self.output_dir, "profiling_data.json") with open(json_path, 'w') as f: json.dump(data, f, indent=2) print(f">>> Detailed profiling data saved to: {json_path}") # Print summary to console print("\n" + report) # Global profiler instance _profiler: Optional[ProfilerManager] = None def get_profiler() -> ProfilerManager: """Get the global profiler instance.""" global _profiler if _profiler is None: _profiler = ProfilerManager(enabled=False) return _profiler def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager: """Initialize the global profiler.""" global _profiler _profiler = ProfilerManager( enabled=enabled, output_dir=output_dir, profile_detail=profile_detail, ) return _profiler # ========== Original Functions ========== def get_device_from_parameters(module: nn.Module) -> torch.device: """Get a module's device by checking one of its parameters. Args: module (nn.Module): The model whose device is to be inferred. Returns: torch.device: The device of the model's parameters. """ return next(iter(module.parameters())).device def write_video(video_path: str, stacked_frames: list, fps: int) -> None: """Save a list of frames to a video file. Args: video_path (str): Output path for the video. stacked_frames (list): List of image frames. fps (int): Frames per second for the video. """ with warnings.catch_warnings(): warnings.filterwarnings("ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning) imageio.mimsave(video_path, stacked_frames, fps=fps) def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: """Return sorted list of files in a directory matching specified postfixes. Args: data_dir (str): Directory path to search in. postfixes (list[str]): List of file extensions to match. Returns: list[str]: Sorted list of file paths. """ patterns = [ os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes ] file_list = [] for pattern in patterns: file_list.extend(glob.glob(pattern)) file_list.sort() return file_list def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: """Load model weights from checkpoint file. Args: model (nn.Module): Model instance. ckpt (str): Path to the checkpoint file. Returns: nn.Module: Model with loaded weights. """ state_dict = torch.load(ckpt, map_location="cpu") if "state_dict" in list(state_dict.keys()): state_dict = state_dict["state_dict"] try: model.load_state_dict(state_dict, strict=True) except: new_pl_sd = OrderedDict() for k, v in state_dict.items(): new_pl_sd[k] = v for k in list(new_pl_sd.keys()): if "framestride_embed" in k: new_key = k.replace("framestride_embed", "fps_embedding") new_pl_sd[new_key] = new_pl_sd[k] del new_pl_sd[k] model.load_state_dict(new_pl_sd, strict=True) else: new_pl_sd = OrderedDict() for key in state_dict['module'].keys(): new_pl_sd[key[16:]] = state_dict['module'][key] model.load_state_dict(new_pl_sd) print('>>> model checkpoint loaded.') return model def _module_param_dtype(module: nn.Module | None) -> str: if module is None: return "None" for param in module.parameters(): return str(param.dtype) return "no_params" def log_inference_precision(model: nn.Module) -> None: try: param = next(model.parameters()) device = str(param.device) model_dtype = str(param.dtype) except StopIteration: device = "unknown" model_dtype = "no_params" print(f">>> inference precision: model={model_dtype}, device={device}") for attr in [ "model", "first_stage_model", "cond_stage_model", "embedder", "image_proj_model" ]: if hasattr(model, attr): submodule = getattr(model, attr) print(f">>> {attr} param dtype: {_module_param_dtype(submodule)}") print( ">>> autocast gpu dtype default: " f"{torch.get_autocast_gpu_dtype()} " f"(enabled={torch.is_autocast_enabled()})") def is_inferenced(save_dir: str, filename: str) -> bool: """Check if a given filename has already been processed and saved. Args: save_dir (str): Directory where results are saved. filename (str): Name of the file to check. Returns: bool: True if processed file exists, False otherwise. """ video_file = os.path.join(save_dir, "samples_separate", f"{filename[:-4]}_sample0.mp4") return os.path.exists(video_file) def save_results(video: Tensor, filename: str, fps: int = 8) -> None: """Save video tensor to file using torchvision. Args: video (Tensor): Tensor of shape (B, C, T, H, W). filename (str): Output file path. fps (int, optional): Frames per second. Defaults to 8. """ video = video.detach().cpu() video = torch.clamp(video.float(), -1., 1.) n = video.shape[0] video = video.permute(2, 0, 1, 3, 4) frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video ] grid = torch.stack(frame_grids, dim=0) grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video(filename, grid, fps=fps, video_codec='h264', options={'crf': '10'}) def get_init_frame_path(data_dir: str, sample: dict) -> str: """Construct the init_frame path from directory and sample metadata. Args: data_dir (str): Base directory containing videos. sample (dict): Dictionary containing 'data_dir' and 'videoid'. Returns: str: Full path to the video file. """ rel_video_fp = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png') full_image_fp = os.path.join(data_dir, 'images', rel_video_fp) return full_image_fp def get_transition_path(data_dir: str, sample: dict) -> str: """Construct the full transition file path from directory and sample metadata. Args: data_dir (str): Base directory containing transition files. sample (dict): Dictionary containing 'data_dir' and 'videoid'. Returns: str: Full path to the HDF5 transition file. """ rel_transition_fp = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5') full_transition_fp = os.path.join(data_dir, 'transitions', rel_transition_fp) return full_transition_fp def prepare_init_input(start_idx: int, init_frame_path: str, transition_dict: dict[str, torch.Tensor], frame_stride: int, wma_data, video_length: int = 16, n_obs_steps: int = 2) -> dict[str, Tensor]: """ Extracts a structured sample from a video sequence including frames, states, and actions, along with properly padded observations and pre-processed tensors for model input. Args: start_idx (int): Starting frame index for the current clip. video: decord video instance. transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action', 'observation.state', 'action_type', 'state_type'. frame_stride (int): Temporal stride between sampled frames. wma_data: Object that holds configuration and utility functions like normalization, transformation, and resolution info. video_length (int, optional): Number of frames to sample from the video. Default is 16. n_obs_steps (int, optional): Number of historical steps for observations. Default is 2. """ indices = [start_idx + frame_stride * i for i in range(video_length)] init_frame = Image.open(init_frame_path).convert('RGB') init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute( 3, 0, 1, 2).float() if start_idx < n_obs_steps - 1: state_indices = list(range(0, start_idx + 1)) states = transition_dict['observation.state'][state_indices, :] num_padding = n_obs_steps - 1 - start_idx first_slice = states[0:1, :] # (t, d) padding = first_slice.repeat(num_padding, 1) states = torch.cat((padding, states), dim=0) else: state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1)) states = transition_dict['observation.state'][state_indices, :] actions = transition_dict['action'][indices, :] ori_state_dim = states.shape[-1] ori_action_dim = actions.shape[-1] frames_action_state_dict = { 'action': actions, 'observation.state': states, } frames_action_state_dict = wma_data.normalizer(frames_action_state_dict) frames_action_state_dict = wma_data.get_uni_vec( frames_action_state_dict, transition_dict['action_type'], transition_dict['state_type'], ) if wma_data.spatial_transform is not None: init_frame = wma_data.spatial_transform(init_frame) init_frame = (init_frame / 255 - 0.5) * 2 data = { 'observation.image': init_frame, } data.update(frames_action_state_dict) return data, ori_state_dim, ori_action_dim def get_latent_z(model, videos: Tensor) -> Tensor: """ Extracts latent features from a video batch using the model's first-stage encoder. Args: model: the world model. videos (Tensor): Input videos of shape [B, C, T, H, W]. Returns: Tensor: Latent video tensor of shape [B, C, T, H, W]. """ profiler = get_profiler() with profiler.profile_section("get_latent_z/encode"): b, c, t, h, w = videos.shape x = rearrange(videos, 'b c t h w -> (b t) c h w') z = model.encode_first_stage(x) z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) return z def preprocess_observation( model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]: """Convert environment observation to LeRobot format observation. Args: observation: Dictionary of observation batches from a Gym vector environment. Returns: Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. """ # Map to expected inputs for the policy return_observations = {} if isinstance(observations["pixels"], dict): imgs = { f"observation.images.{key}": img for key, img in observations["pixels"].items() } else: imgs = {"observation.images.top": observations["pixels"]} for imgkey, img in imgs.items(): img = torch.from_numpy(img) # Sanity check that images are channel last _, h, w, c = img.shape assert c < h and c < w, f"expect channel first images, but instead {img.shape}" # Sanity check that images are uint8 assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" # Convert to channel first of type float32 in range [0,1] img = einops.rearrange(img, "b h w c -> b c h w").contiguous() img = img.type(torch.float32) return_observations[imgkey] = img return_observations["observation.state"] = torch.from_numpy( observations["agent_pos"]).float() return_observations['observation.state'] = model.normalize_inputs({ 'observation.state': return_observations['observation.state'].to(model.device) })['observation.state'] return return_observations def image_guided_synthesis_sim_mode( model: torch.nn.Module, prompts: list[str], observation: dict, noise_shape: tuple[int, int, int, int, int], action_cond_step: int = 16, n_samples: int = 1, ddim_steps: int = 50, ddim_eta: float = 1.0, unconditional_guidance_scale: float = 1.0, fs: int | None = None, text_input: bool = True, timestep_spacing: str = 'uniform', guidance_rescale: float = 0.0, sim_mode: bool = True, **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text). Args: model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning. prompts (list[str]): A list of textual prompts to guide the synthesis process. observation (dict): A dictionary containing observed inputs including: - 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images) - 'observation.state': Tensor of shape [B, O, D] (state vector) - 'action': Tensor of shape [B, T, D] (action sequence) noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate, typically (B, C, T, H, W). action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16. n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1. ddim_steps (int): Number of DDIM sampling steps. Default is 50. ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0. unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off. fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None. text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True. timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace". guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance. sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model. **kwargs: Additional arguments passed to the DDIM sampler. Returns: batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W]. actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding. states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding. """ profiler = get_profiler() b, _, t, _, _ = noise_shape ddim_sampler = DDIMSampler(model) batch_size = noise_shape[0] fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) with profiler.profile_section("synthesis/conditioning_prep"): img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] cond_img_emb = model.embedder(cond_img) cond_img_emb = model.image_proj_model(cond_img_emb) if model.model.conditioning_key == 'hybrid': z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) img_cat_cond = z[:, :, -1:, :, :] img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=noise_shape[2]) cond = {"c_concat": [img_cat_cond]} if not text_input: prompts = [""] * batch_size cond_ins_emb = model.get_learned_conditioning(prompts) cond_state_emb = model.state_projector(observation['observation.state']) cond_state_emb = cond_state_emb + model.agent_state_pos_emb cond_action_emb = model.action_projector(observation['action']) cond_action_emb = cond_action_emb + model.agent_action_pos_emb if not sim_mode: cond_action_emb = torch.zeros_like(cond_action_emb) cond["c_crossattn"] = [ torch.cat( [cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1) ] cond["c_crossattn_action"] = [ observation['observation.images.top'][:, :, -model.n_obs_steps_acting:], observation['observation.state'][:, -model.n_obs_steps_acting:], sim_mode, False, ] uc = None kwargs.update({"unconditional_conditioning_img_nonetext": None}) cond_mask = None cond_z0 = None if ddim_sampler is not None: with profiler.profile_section("synthesis/ddim_sampling"): samples, actions, states, intermedia = ddim_sampler.sample( S=ddim_steps, conditioning=cond, batch_size=batch_size, shape=noise_shape[1:], verbose=False, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=uc, eta=ddim_eta, cfg_img=None, mask=cond_mask, x0=cond_z0, fs=fs, timestep_spacing=timestep_spacing, guidance_rescale=guidance_rescale, **kwargs) # Reconstruct from latent to pixel space with profiler.profile_section("synthesis/decode_first_stage"): batch_images = model.decode_first_stage(samples) batch_variants = batch_images return batch_variants, actions, states def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: """ Run inference pipeline on prompts and image inputs. Args: args (argparse.Namespace): Parsed command-line arguments. gpu_num (int): Number of GPUs. gpu_no (int): Index of the current GPU. Returns: None """ profiler = get_profiler() # Create inference and tensorboard dirs os.makedirs(args.savedir + '/inference', exist_ok=True) log_dir = args.savedir + f"/tensorboard" os.makedirs(log_dir, exist_ok=True) writer = SummaryWriter(log_dir=log_dir) # Load prompt csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") df = pd.read_csv(csv_path) # Load config with profiler.profile_section("model_loading/config"): config = OmegaConf.load(args.config) config['model']['params']['wma_config']['params'][ 'use_checkpoint'] = False model = instantiate_from_config(config.model) model.perframe_ae = args.perframe_ae assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" with profiler.profile_section("model_loading/checkpoint"): model = load_model_checkpoint(model, args.ckpt_path) model.eval() print(f'>>> Load pre-trained model ...') # Build unnomalizer logging.info("***** Configing Data *****") with profiler.profile_section("data_loading"): data = instantiate_from_config(config.data) data.setup() print(">>> Dataset is successfully loaded ...") with profiler.profile_section("model_to_cuda"): model = model.cuda(gpu_no) device = get_device_from_parameters(model) log_inference_precision(model) profiler.record_memory("after_model_load") # Run over data assert (args.height % 16 == 0) and ( args.width % 16 == 0), "Error: image size [h,w] should be multiples of 16!" assert args.bs == 1, "Current implementation only support [batch size = 1]!" # Get latent noise shape h, w = args.height // 8, args.width // 8 channels = model.model.diffusion_model.out_channels n_frames = args.video_length print(f'>>> Generate {n_frames} frames under each generation ...') noise_shape = [args.bs, channels, n_frames, h, w] # Determine profiler iterations profile_active_iters = getattr(args, 'profile_iterations', 3) use_pytorch_profiler = profiler.enabled and profile_active_iters > 0 # Start inference for idx in range(0, len(df)): sample = df.iloc[idx] # Got initial frame path init_frame_path = get_init_frame_path(args.prompt_dir, sample) ori_fps = float(sample['fps']) video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}" os.makedirs(video_save_dir, exist_ok=True) os.makedirs(video_save_dir + '/dm', exist_ok=True) os.makedirs(video_save_dir + '/wm', exist_ok=True) # Load transitions to get the initial state later transition_path = get_transition_path(args.prompt_dir, sample) with profiler.profile_section("load_transitions"): with h5py.File(transition_path, 'r') as h5f: transition_dict = {} for key in h5f.keys(): transition_dict[key] = torch.tensor(h5f[key][()]) for key in h5f.attrs.keys(): transition_dict[key] = h5f.attrs[key] # If many, test various frequence control and world-model generation for fs in args.frame_stride: # For saving imagens in policy sample_save_dir = f'{video_save_dir}/dm/{fs}' os.makedirs(sample_save_dir, exist_ok=True) # For saving environmental changes in world-model sample_save_dir = f'{video_save_dir}/wm/{fs}' os.makedirs(sample_save_dir, exist_ok=True) # For collecting interaction videos wm_video = [] # Initialize observation queues cond_obs_queues = { "observation.images.top": deque(maxlen=model.n_obs_steps_imagen), "observation.state": deque(maxlen=model.n_obs_steps_imagen), "action": deque(maxlen=args.video_length), } # Obtain initial frame and state with profiler.profile_section("prepare_init_input"): start_idx = 0 model_input_fs = ori_fps // fs batch, ori_state_dim, ori_action_dim = prepare_init_input( start_idx, init_frame_path, transition_dict, fs, data.test_datasets[args.dataset], n_obs_steps=model.n_obs_steps_imagen) observation = { 'observation.images.top': batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0), 'observation.state': batch['observation.state'][-1].unsqueeze(0), 'action': torch.zeros_like(batch['action'][-1]).unsqueeze(0) } observation = { key: observation[key].to(device, non_blocking=True) for key in observation } # Update observation queues cond_obs_queues = populate_queues(cond_obs_queues, observation) # Setup PyTorch profiler context if enabled pytorch_prof_ctx = nullcontext() if use_pytorch_profiler: pytorch_prof_ctx = profiler.start_pytorch_profiler( wait=1, warmup=1, active=profile_active_iters ) # Multi-round interaction with the world-model with pytorch_prof_ctx: for itr in tqdm(range(args.n_iter)): profiler.current_iteration = itr profiler.record_memory(f"iter_{itr}_start") with profiler.profile_section("iteration_total"): # Get observation with profiler.profile_section("prepare_observation"): observation = { 'observation.images.top': torch.stack(list( cond_obs_queues['observation.images.top']), dim=1).permute(0, 2, 1, 3, 4), 'observation.state': torch.stack(list(cond_obs_queues['observation.state']), dim=1), 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } observation = { key: observation[key].to(device, non_blocking=True) for key in observation } # Use world-model in policy to generate action print(f'>>> Step {itr}: generating actions ...') with profiler.profile_section("action_generation"): pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( model, sample['instruction'], observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. unconditional_guidance_scale, fs=model_input_fs, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, sim_mode=False) # Update future actions in the observation queues with profiler.profile_section("update_action_queues"): for act_idx in range(len(pred_actions[0])): obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]} obs_update['action'][:, ori_action_dim:] = 0.0 cond_obs_queues = populate_queues(cond_obs_queues, obs_update) # Collect data for interacting the world-model using the predicted actions with profiler.profile_section("prepare_wm_observation"): observation = { 'observation.images.top': torch.stack(list( cond_obs_queues['observation.images.top']), dim=1).permute(0, 2, 1, 3, 4), 'observation.state': torch.stack(list(cond_obs_queues['observation.state']), dim=1), 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } observation = { key: observation[key].to(device, non_blocking=True) for key in observation } # Interaction with the world-model print(f'>>> Step {itr}: interacting with world model ...') with profiler.profile_section("world_model_interaction"): pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( model, "", observation, noise_shape, action_cond_step=args.exe_steps, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args. unconditional_guidance_scale, fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale) with profiler.profile_section("update_state_queues"): for step_idx in range(args.exe_steps): obs_update = { 'observation.images.top': pred_videos_1[0][:, step_idx:step_idx + 1].permute(1, 0, 2, 3), 'observation.state': torch.zeros_like(pred_states[0][step_idx:step_idx + 1]) if args.zero_pred_state else pred_states[0][step_idx:step_idx + 1], 'action': torch.zeros_like(pred_actions[0][-1:]) } obs_update['observation.state'][:, ori_state_dim:] = 0.0 cond_obs_queues = populate_queues(cond_obs_queues, obs_update) # Save the imagen videos for decision-making with profiler.profile_section("save_results"): sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" log_to_tensorboard(writer, pred_videos_0, sample_tag, fps=args.save_fps) # Save videos environment changes via world-model interaction sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}" log_to_tensorboard(writer, pred_videos_1, sample_tag, fps=args.save_fps) # Save the imagen videos for decision-making sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4' save_results(pred_videos_0.cpu(), sample_video_file, fps=args.save_fps) # Save videos environment changes via world-model interaction sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4' save_results(pred_videos_1.cpu(), sample_video_file, fps=args.save_fps) print('>' * 24) # Collect the result of world-model interactions wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) profiler.record_memory(f"iter_{itr}_end") profiler.step_profiler() full_video = torch.cat(wm_video, dim=2) sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" log_to_tensorboard(writer, full_video, sample_tag, fps=args.save_fps) sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4" save_results(full_video, sample_full_video_file, fps=args.save_fps) # Save profiling results profiler.save_results() def get_parser(): parser = argparse.ArgumentParser() parser.add_argument("--savedir", type=str, default=None, help="Path to save the results.") parser.add_argument("--ckpt_path", type=str, default=None, help="Path to the model checkpoint.") parser.add_argument("--config", type=str, help="Path to the model checkpoint.") parser.add_argument( "--prompt_dir", type=str, default=None, help="Directory containing videos and corresponding prompts.") parser.add_argument("--dataset", type=str, default=None, help="the name of dataset to test") parser.add_argument( "--ddim_steps", type=int, default=50, help="Number of DDIM steps. If non-positive, DDPM is used instead.") parser.add_argument( "--ddim_eta", type=float, default=1.0, help="Eta for DDIM sampling. Set to 0.0 for deterministic results.") parser.add_argument("--bs", type=int, default=1, help="Batch size for inference. Must be 1.") parser.add_argument("--height", type=int, default=320, help="Height of the generated images in pixels.") parser.add_argument("--width", type=int, default=512, help="Width of the generated images in pixels.") parser.add_argument( "--frame_stride", type=int, nargs='+', required=True, help= "frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)" ) parser.add_argument( "--unconditional_guidance_scale", type=float, default=1.0, help="Scale for classifier-free guidance during sampling.") parser.add_argument("--seed", type=int, default=123, help="Random seed for reproducibility.") parser.add_argument("--video_length", type=int, default=16, help="Number of frames in the generated video.") parser.add_argument("--num_generation", type=int, default=1, help="seed for seed_everything") parser.add_argument( "--timestep_spacing", type=str, default="uniform", help= "Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." ) parser.add_argument( "--guidance_rescale", type=float, default=0.0, help= "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)." ) parser.add_argument( "--perframe_ae", action='store_true', default=False, help= "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024." ) parser.add_argument( "--n_action_steps", type=int, default=16, help="num of samples per prompt", ) parser.add_argument( "--exe_steps", type=int, default=16, help="num of samples to execute", ) parser.add_argument( "--n_iter", type=int, default=40, help="num of iteration to interact with the world model", ) parser.add_argument("--zero_pred_state", action='store_true', default=False, help="not using the predicted states as comparison") parser.add_argument("--save_fps", type=int, default=8, help="fps for the saving video") # Profiling arguments parser.add_argument( "--profile", action='store_true', default=False, help="Enable performance profiling (macro and operator-level analysis)." ) parser.add_argument( "--profile_output_dir", type=str, default=None, help="Directory to save profiling results. Defaults to {savedir}/profile_output." ) parser.add_argument( "--profile_iterations", type=int, default=3, help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis." ) parser.add_argument( "--profile_detail", type=str, choices=["light", "full"], default="light", help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops." ) return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() seed = args.seed if seed < 0: seed = random.randint(0, 2**31) seed_everything(seed) # Initialize profiler profile_output_dir = args.profile_output_dir if profile_output_dir is None: profile_output_dir = os.path.join(args.savedir, "profile_output") init_profiler( enabled=args.profile, output_dir=profile_output_dir, profile_detail=args.profile_detail, ) rank, gpu_num = 0, 1 run_inference(args, gpu_num, rank)