""" Profile the full iteration loop of world model interaction. Three layers of profiling: Layer 1: Iteration-level wall-clock breakdown (CUDA events) Layer 2: GPU timeline trace (torch.profiler → Chrome trace) Layer 3: A/B comparison (standardized CSV output) Usage: # Layer 1 only (fast, default): TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \ python scripts/evaluation/profile_iteration.py \ --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \ --config configs/inference/world_model_interaction.yaml \ --prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \ --dataset unitree_z1_dual_arm_cleanup_pencils \ --frame_stride 4 --n_iter 5 # Layer 1 + Layer 2 (GPU trace): ... --trace --trace_dir ./profile_traces # Layer 3 (A/B comparison): run twice, diff the CSVs ... --csv baseline.csv ... --csv optimized.csv python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv """ import argparse import csv import os import sys import time from collections import defaultdict, deque from contextlib import nullcontext import h5py import numpy as np import pandas as pd import torch import torchvision from einops import rearrange, repeat from omegaconf import OmegaConf from PIL import Image from pytorch_lightning import seed_everything from torch import Tensor from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.utils.utils import instantiate_from_config import torch.nn.functional as F # ────────────────────────────────────────────────────────────────────── # Constants # ────────────────────────────────────────────────────────────────────── STAGE_NAMES = [ "stack_to_device_1", "synth_policy", "update_action_queue", "stack_to_device_2", "synth_world_model", "update_obs_queue", "tensorboard_log", "save_results", "cpu_transfer", "itr_total", ] # Sub-stages inside image_guided_synthesis_sim_mode SYNTH_SUB_STAGES = [ "ddim_sampler_init", "image_embedding", "vae_encode", "text_conditioning", "projectors", "cond_assembly", "ddim_sampling", "vae_decode", ] # ────────────────────────────────────────────────────────────────────── # CudaTimer — GPU-precise timing via CUDA events # ────────────────────────────────────────────────────────────────────── class CudaTimer: """Context manager that records GPU time between enter/exit using CUDA events.""" def __init__(self, name, records): self.name = name self.records = records def __enter__(self): torch.cuda.synchronize() self._start = torch.cuda.Event(enable_timing=True) self._end = torch.cuda.Event(enable_timing=True) self._start.record() return self def __exit__(self, *args): self._end.record() torch.cuda.synchronize() elapsed_ms = self._start.elapsed_time(self._end) self.records[self.name].append(elapsed_ms) class WallTimer: """Context manager that records CPU wall-clock time (for pure-CPU stages).""" def __init__(self, name, records): self.name = name self.records = records def __enter__(self): torch.cuda.synchronize() self._t0 = time.perf_counter() return self def __exit__(self, *args): torch.cuda.synchronize() elapsed_ms = (time.perf_counter() - self._t0) * 1000.0 self.records[self.name].append(elapsed_ms) # ────────────────────────────────────────────────────────────────────── # Model loading (reused from world_model_interaction.py) # ────────────────────────────────────────────────────────────────────── def patch_norm_bypass_autocast(): def _group_norm_forward(self, x): with torch.amp.autocast('cuda', enabled=False): return F.group_norm( x, self.num_groups, self.weight.to(x.dtype) if self.weight is not None else None, self.bias.to(x.dtype) if self.bias is not None else None, self.eps) def _layer_norm_forward(self, x): with torch.amp.autocast('cuda', enabled=False): return F.layer_norm( x, self.normalized_shape, self.weight.to(x.dtype) if self.weight is not None else None, self.bias.to(x.dtype) if self.bias is not None else None, self.eps) torch.nn.GroupNorm.forward = _group_norm_forward torch.nn.LayerNorm.forward = _layer_norm_forward def apply_torch_compile(model, hot_indices=(5, 8, 9)): from unifolm_wma.modules.networks.wma_model import ResBlock unet = model.model.diffusion_model compiled = 0 for idx in hot_indices: block = unet.output_blocks[idx] for layer in block: if isinstance(layer, ResBlock): layer._forward = torch.compile(layer._forward, mode="default") compiled += 1 print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}") def load_model(args): config = OmegaConf.load(args.config) config['model']['params']['wma_config']['params']['use_checkpoint'] = False model = instantiate_from_config(config.model) model.perframe_ae = args.perframe_ae from collections import OrderedDict state_dict = torch.load(args.ckpt_path, map_location="cpu") if "state_dict" in state_dict: state_dict = state_dict["state_dict"] try: model.load_state_dict(state_dict, strict=True) except Exception: new_sd = OrderedDict() for k, v in state_dict.items(): new_sd[k] = v for k in list(new_sd.keys()): if "framestride_embed" in k: new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k) model.load_state_dict(new_sd, strict=True) model.eval() # Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE model.model.to(torch.bfloat16) model.diffusion_autocast_dtype = torch.bfloat16 model.embedder.to(torch.bfloat16) model.image_proj_model.to(torch.bfloat16) model.encoder_autocast_dtype = None model.state_projector.to(torch.bfloat16) model.action_projector.to(torch.bfloat16) model.projector_autocast_dtype = None if args.vae_dtype == "bf16": model.first_stage_model.to(torch.bfloat16) # Compile hot ResBlocks apply_torch_compile(model) model = model.cuda() print(">>> Model loaded and ready.") return model, config # ────────────────────────────────────────────────────────────────────── # Data preparation (reused from world_model_interaction.py) # ────────────────────────────────────────────────────────────────────── def get_init_frame_path(data_dir, sample): rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png') return os.path.join(data_dir, 'images', rel) def get_transition_path(data_dir, sample): rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5') return os.path.join(data_dir, 'transitions', rel) def prepare_init_input(start_idx, init_frame_path, transition_dict, frame_stride, wma_data, video_length=16, n_obs_steps=2): indices = [start_idx + frame_stride * i for i in range(video_length)] init_frame = Image.open(init_frame_path).convert('RGB') init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float() if start_idx < n_obs_steps - 1: state_indices = list(range(0, start_idx + 1)) states = transition_dict['observation.state'][state_indices, :] num_padding = n_obs_steps - 1 - start_idx padding = states[0:1, :].repeat(num_padding, 1) states = torch.cat((padding, states), dim=0) else: state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1)) states = transition_dict['observation.state'][state_indices, :] actions = transition_dict['action'][indices, :] ori_state_dim = states.shape[-1] ori_action_dim = actions.shape[-1] frames_action_state_dict = { 'action': actions, 'observation.state': states, } frames_action_state_dict = wma_data.normalizer(frames_action_state_dict) frames_action_state_dict = wma_data.get_uni_vec( frames_action_state_dict, transition_dict['action_type'], transition_dict['state_type'], ) if wma_data.spatial_transform is not None: init_frame = wma_data.spatial_transform(init_frame) init_frame = (init_frame / 255 - 0.5) * 2 data = {'observation.image': init_frame} data.update(frames_action_state_dict) return data, ori_state_dim, ori_action_dim def populate_queues(queues, batch): for key in batch: if key not in queues: continue if len(queues[key]) != queues[key].maxlen: while len(queues[key]) != queues[key].maxlen: queues[key].append(batch[key]) else: queues[key].append(batch[key]) return queues # ────────────────────────────────────────────────────────────────────── # Instrumented image_guided_synthesis_sim_mode with sub-stage timing # ────────────────────────────────────────────────────────────────────── def get_latent_z(model, videos): b, c, t, h, w = videos.shape x = rearrange(videos, 'b c t h w -> (b t) c h w') vae_dtype = next(model.first_stage_model.parameters()).dtype x = x.to(dtype=vae_dtype) z = model.encode_first_stage(x) z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) return z def save_results(video, filename, fps=8): video = video.detach().cpu() video = torch.clamp(video.float(), -1., 1.) n = video.shape[0] video = video.permute(2, 0, 1, 3, 4) frame_grids = [ torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video ] grid = torch.stack(frame_grids, dim=0) grid = (grid + 1.0) / 2.0 grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) torchvision.io.write_video(filename, grid, fps=fps, video_codec='h264', options={'crf': '10'}) def profiled_synthesis(model, prompts, observation, noise_shape, ddim_steps, ddim_eta, unconditional_guidance_scale, fs, text_input, timestep_spacing, guidance_rescale, sim_mode, decode_video, records, prefix): """image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing. Args: prefix: "policy" or "wm" — prepended to sub-stage names in records. """ b, _, t, _, _ = noise_shape batch_size = noise_shape[0] device = next(model.parameters()).device # --- sub-stage: ddim_sampler_init --- with CudaTimer(f"{prefix}/ddim_sampler_init", records): ddim_sampler = DDIMSampler(model) fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) # --- sub-stage: image_embedding --- with CudaTimer(f"{prefix}/image_embedding", records): model_dtype = next(model.embedder.parameters()).dtype img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype) cond_img_emb = model.embedder(cond_img) cond_img_emb = model.image_proj_model(cond_img_emb) # --- sub-stage: vae_encode --- with CudaTimer(f"{prefix}/vae_encode", records): if model.model.conditioning_key == 'hybrid': z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) img_cat_cond = z[:, :, -1:, :, :] img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=noise_shape[2]) cond = {"c_concat": [img_cat_cond]} # --- sub-stage: text_conditioning --- with CudaTimer(f"{prefix}/text_conditioning", records): if not text_input: prompts_use = [""] * batch_size else: prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size cond_ins_emb = model.get_learned_conditioning(prompts_use) # --- sub-stage: projectors --- with CudaTimer(f"{prefix}/projectors", records): projector_dtype = next(model.state_projector.parameters()).dtype cond_state_emb = model.state_projector( observation['observation.state'].to(dtype=projector_dtype)) cond_state_emb = cond_state_emb + model.agent_state_pos_emb cond_action_emb = model.action_projector( observation['action'].to(dtype=projector_dtype)) cond_action_emb = cond_action_emb + model.agent_action_pos_emb if not sim_mode: cond_action_emb = torch.zeros_like(cond_action_emb) # --- sub-stage: cond_assembly --- with CudaTimer(f"{prefix}/cond_assembly", records): cond["c_crossattn"] = [ torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1) ] cond["c_crossattn_action"] = [ observation['observation.images.top'][:, :, -model.n_obs_steps_acting:], observation['observation.state'][:, -model.n_obs_steps_acting:], sim_mode, False, ] # --- sub-stage: ddim_sampling --- autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None) if autocast_dtype is not None and device.type == 'cuda': autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype) else: autocast_ctx = nullcontext() with CudaTimer(f"{prefix}/ddim_sampling", records): with autocast_ctx: samples, actions, states, _ = ddim_sampler.sample( S=ddim_steps, conditioning=cond, batch_size=batch_size, shape=noise_shape[1:], verbose=False, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=None, eta=ddim_eta, cfg_img=None, mask=None, x0=None, fs=fs_t, timestep_spacing=timestep_spacing, guidance_rescale=guidance_rescale, unconditional_conditioning_img_nonetext=None, ) # --- sub-stage: vae_decode --- batch_variants = None if decode_video: with CudaTimer(f"{prefix}/vae_decode", records): batch_variants = model.decode_first_stage(samples) else: records[f"{prefix}/vae_decode"].append(0.0) return batch_variants, actions, states # ────────────────────────────────────────────────────────────────────── # Instrumented iteration loop # ────────────────────────────────────────────────────────────────────── def run_profiled_iterations(model, args, config, noise_shape, device): """Run the full iteration loop with per-stage timing. Returns: all_records: list of dicts, one per itr, {stage_name: ms} """ # Load data csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") df = pd.read_csv(csv_path) sample = df.iloc[0] data_module = instantiate_from_config(config.data) data_module.setup() init_frame_path = get_init_frame_path(args.prompt_dir, sample) ori_fps = float(sample['fps']) fs = args.frame_stride model_input_fs = ori_fps // fs transition_path = get_transition_path(args.prompt_dir, sample) with h5py.File(transition_path, 'r') as h5f: transition_dict = {} for key in h5f.keys(): transition_dict[key] = torch.tensor(h5f[key][()]) for key in h5f.attrs.keys(): transition_dict[key] = h5f.attrs[key] # Prepare initial observation batch, ori_state_dim, ori_action_dim = prepare_init_input( 0, init_frame_path, transition_dict, fs, data_module.test_datasets[args.dataset], n_obs_steps=model.n_obs_steps_imagen) observation = { 'observation.images.top': batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0), 'observation.state': batch['observation.state'][-1].unsqueeze(0), 'action': torch.zeros_like(batch['action'][-1]).unsqueeze(0), } observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()} cond_obs_queues = { "observation.images.top": deque(maxlen=model.n_obs_steps_imagen), "observation.state": deque(maxlen=model.n_obs_steps_imagen), "action": deque(maxlen=args.video_length), } cond_obs_queues = populate_queues(cond_obs_queues, observation) # Temp dir for save_results profiling tmp_dir = os.path.join(args.savedir, "profile_tmp") os.makedirs(tmp_dir, exist_ok=True) prompt_text = sample['instruction'] all_records = [] print(f">>> Running {args.n_iter} profiled iterations ...") for itr in range(args.n_iter): rec = defaultdict(list) # ── itr_total start ── torch.cuda.synchronize() itr_start = torch.cuda.Event(enable_timing=True) itr_end = torch.cuda.Event(enable_timing=True) itr_start.record() # ① stack_to_device_1 with CudaTimer("stack_to_device_1", rec): observation = { 'observation.images.top': torch.stack(list(cond_obs_queues['observation.images.top']), dim=1).permute(0, 2, 1, 3, 4), 'observation.state': torch.stack(list(cond_obs_queues['observation.state']), dim=1), 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()} # ② synth_policy with CudaTimer("synth_policy", rec): pred_videos_0, pred_actions, _ = profiled_synthesis( model, prompt_text, observation, noise_shape, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args.unconditional_guidance_scale, fs=model_input_fs, text_input=True, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, sim_mode=False, decode_video=not args.fast_policy_no_decode, records=rec, prefix="policy") # ③ update_action_queue with WallTimer("update_action_queue", rec): for idx in range(len(pred_actions[0])): obs_a = {'action': pred_actions[0][idx:idx + 1]} obs_a['action'][:, ori_action_dim:] = 0.0 cond_obs_queues = populate_queues(cond_obs_queues, obs_a) # ④ stack_to_device_2 with CudaTimer("stack_to_device_2", rec): observation = { 'observation.images.top': torch.stack(list(cond_obs_queues['observation.images.top']), dim=1).permute(0, 2, 1, 3, 4), 'observation.state': torch.stack(list(cond_obs_queues['observation.state']), dim=1), 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()} # ⑤ synth_world_model with CudaTimer("synth_world_model", rec): pred_videos_1, _, pred_states = profiled_synthesis( model, "", observation, noise_shape, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args.unconditional_guidance_scale, fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, sim_mode=True, decode_video=True, records=rec, prefix="wm") # ⑥ update_obs_queue with WallTimer("update_obs_queue", rec): for idx in range(args.exe_steps): obs_u = { 'observation.images.top': pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3), 'observation.state': pred_states[0][idx:idx + 1], 'action': torch.zeros_like(pred_actions[0][-1:]), } obs_u['observation.state'][:, ori_state_dim:] = 0.0 cond_obs_queues = populate_queues(cond_obs_queues, obs_u) # ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost) with WallTimer("tensorboard_log", rec): for vid in [pred_videos_0, pred_videos_1]: if vid is not None and vid.dim() == 5: v = vid.permute(2, 0, 1, 3, 4) grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v] _ = torch.stack(grids, dim=0) # ⑧ save_results with WallTimer("save_results", rec): if pred_videos_0 is not None: save_results(pred_videos_0.cpu(), os.path.join(tmp_dir, f"dm_{itr}.mp4"), fps=args.save_fps) save_results(pred_videos_1.cpu(), os.path.join(tmp_dir, f"wm_{itr}.mp4"), fps=args.save_fps) # ⑨ cpu_transfer with CudaTimer("cpu_transfer", rec): _ = pred_videos_1[:, :, :args.exe_steps].cpu() # ── itr_total end ── itr_end.record() torch.cuda.synchronize() itr_total_ms = itr_start.elapsed_time(itr_end) rec["itr_total"].append(itr_total_ms) # Flatten: each stage has exactly one entry per itr itr_rec = {k: v[0] for k, v in rec.items()} all_records.append(itr_rec) # Print live progress print(f" itr {itr}: {itr_total_ms:.0f} ms total | " f"policy={itr_rec.get('synth_policy', 0):.0f} | " f"wm={itr_rec.get('synth_world_model', 0):.0f} | " f"save={itr_rec.get('save_results', 0):.0f} | " f"tb={itr_rec.get('tensorboard_log', 0):.0f}") return all_records # ────────────────────────────────────────────────────────────────────── # Layer 1: Console report # ────────────────────────────────────────────────────────────────────── def print_iteration_report(all_records, warmup=1): """Print a structured table of per-stage timing across iterations.""" if len(all_records) <= warmup: records = all_records else: records = all_records[warmup:] print(f"\n(Skipping first {warmup} itr(s) as warmup)\n") # Collect all stage keys in a stable order all_keys = [] seen = set() for rec in records: for k in rec: if k not in seen: all_keys.append(k) seen.add(k) # Separate top-level stages from sub-stages top_keys = [k for k in all_keys if '/' not in k] sub_keys = [k for k in all_keys if '/' in k] def _print_table(keys, title): if not keys: return print("=" * 82) print(title) print("=" * 82) print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}") print("-" * 82) total_mean = np.mean([rec.get("itr_total", 0) for rec in records]) for k in keys: vals = [rec.get(k, 0) for rec in records] mean = np.mean(vals) std = np.std(vals) mn = np.min(vals) mx = np.max(vals) pct = mean / total_mean * 100 if total_mean > 0 else 0 print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%") print("-" * 82) print() _print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN") _print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN") # ────────────────────────────────────────────────────────────────────── # Layer 3: CSV output for A/B comparison # ────────────────────────────────────────────────────────────────────── def write_csv(all_records, csv_path, warmup=1): """Write per-iteration timing to CSV for later comparison.""" records = all_records[warmup:] if len(all_records) > warmup else all_records # Collect all keys all_keys = [] seen = set() for rec in records: for k in rec: if k not in seen: all_keys.append(k) seen.add(k) with open(csv_path, 'w', newline='') as f: writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys) writer.writeheader() for i, rec in enumerate(records): row = {'itr': i} row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys}) writer.writerow(row) # Also write a summary row summary_path = csv_path.replace('.csv', '_summary.csv') with open(summary_path, 'w', newline='') as f: writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys) writer.writeheader() for stat_name, stat_fn in [('mean', np.mean), ('std', np.std), ('min', np.min), ('max', np.max)]: row = {'stat': stat_name} row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}" for k in all_keys}) writer.writerow(row) print(f">>> CSV written to: {csv_path}") print(f">>> Summary written to: {summary_path}") def compare_csvs(path_a, path_b): """Compare two summary CSVs and print a diff table.""" df_a = pd.read_csv(path_a, index_col='stat') df_b = pd.read_csv(path_b, index_col='stat') # Use mean row for comparison mean_a = df_a.loc['mean'].astype(float) mean_b = df_b.loc['mean'].astype(float) print("=" * 90) print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}") print("=" * 90) print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}") print("-" * 90) for col in mean_a.index: if col not in mean_b.index: continue a_val = mean_a[col] b_val = mean_b[col] diff = b_val - a_val speedup = a_val / b_val if b_val > 0 else float('inf') marker = " <<<" if abs(diff) > 50 else "" print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}") print("-" * 90) total_a = mean_a.get('itr_total', 0) total_b = mean_b.get('itr_total', 0) print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} " f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x") print() # ────────────────────────────────────────────────────────────────────── # Layer 2: GPU timeline trace wrapper # ────────────────────────────────────────────────────────────────────── def run_with_trace(model, args, config, noise_shape, device): """Run iterations under torch.profiler to generate Chrome/TensorBoard traces.""" trace_dir = args.trace_dir os.makedirs(trace_dir, exist_ok=True) # We need the same data setup as run_profiled_iterations csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") df = pd.read_csv(csv_path) sample = df.iloc[0] data_module = instantiate_from_config(config.data) data_module.setup() init_frame_path = get_init_frame_path(args.prompt_dir, sample) ori_fps = float(sample['fps']) fs = args.frame_stride model_input_fs = ori_fps // fs transition_path = get_transition_path(args.prompt_dir, sample) with h5py.File(transition_path, 'r') as h5f: transition_dict = {} for key in h5f.keys(): transition_dict[key] = torch.tensor(h5f[key][()]) for key in h5f.attrs.keys(): transition_dict[key] = h5f.attrs[key] batch, ori_state_dim, ori_action_dim = prepare_init_input( 0, init_frame_path, transition_dict, fs, data_module.test_datasets[args.dataset], n_obs_steps=model.n_obs_steps_imagen) observation = { 'observation.images.top': batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0), 'observation.state': batch['observation.state'][-1].unsqueeze(0), 'action': torch.zeros_like(batch['action'][-1]).unsqueeze(0), } observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()} cond_obs_queues = { "observation.images.top": deque(maxlen=model.n_obs_steps_imagen), "observation.state": deque(maxlen=model.n_obs_steps_imagen), "action": deque(maxlen=args.video_length), } cond_obs_queues = populate_queues(cond_obs_queues, observation) tmp_dir = os.path.join(args.savedir, "profile_tmp") os.makedirs(tmp_dir, exist_ok=True) prompt_text = sample['instruction'] # Total iterations: warmup + active n_warmup = 1 n_active = min(args.n_iter, 2) # trace 2 active iterations max n_total = n_warmup + n_active print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations") print(f">>> Trace output: {trace_dir}") with torch.no_grad(), torch.profiler.profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule( wait=0, warmup=n_warmup, active=n_active, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir), record_shapes=True, with_stack=True, ) as prof: for itr_idx in range(n_total): phase = "warmup" if itr_idx < n_warmup else "active" print(f" trace itr {itr_idx} ({phase})...") # ── One full iteration (same logic as run_inference) ── obs_loc = { 'observation.images.top': torch.stack(list(cond_obs_queues['observation.images.top']), dim=1).permute(0, 2, 1, 3, 4), 'observation.state': torch.stack(list(cond_obs_queues['observation.state']), dim=1), 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } obs_loc = {k: v.to(device) for k, v in obs_loc.items()} # Policy pass dummy_rec = defaultdict(list) pv0, pa, _ = profiled_synthesis( model, prompt_text, obs_loc, noise_shape, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args.unconditional_guidance_scale, fs=model_input_fs, text_input=True, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, sim_mode=False, decode_video=not args.fast_policy_no_decode, records=dummy_rec, prefix="policy") for idx in range(len(pa[0])): oa = {'action': pa[0][idx:idx + 1]} oa['action'][:, ori_action_dim:] = 0.0 populate_queues(cond_obs_queues, oa) # Re-stack for world model obs_loc2 = { 'observation.images.top': torch.stack(list(cond_obs_queues['observation.images.top']), dim=1).permute(0, 2, 1, 3, 4), 'observation.state': torch.stack(list(cond_obs_queues['observation.state']), dim=1), 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()} # World model pass pv1, _, ps = profiled_synthesis( model, "", obs_loc2, noise_shape, ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, unconditional_guidance_scale=args.unconditional_guidance_scale, fs=model_input_fs, text_input=False, timestep_spacing=args.timestep_spacing, guidance_rescale=args.guidance_rescale, sim_mode=True, decode_video=True, records=dummy_rec, prefix="wm") # Update obs queue for idx in range(args.exe_steps): ou = { 'observation.images.top': pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3), 'observation.state': ps[0][idx:idx + 1], 'action': torch.zeros_like(pa[0][-1:]), } ou['observation.state'][:, ori_state_dim:] = 0.0 populate_queues(cond_obs_queues, ou) # Save results (captures CPU stall in trace) if pv0 is not None: save_results(pv0.cpu(), os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"), fps=args.save_fps) save_results(pv1.cpu(), os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"), fps=args.save_fps) prof.step() print(f">>> Trace saved to {trace_dir}") print(" View with: tensorboard --logdir", trace_dir) print(" Or open the .json file in chrome://tracing") # ────────────────────────────────────────────────────────────────────── # Argument parser # ────────────────────────────────────────────────────────────────────── def get_parser(): p = argparse.ArgumentParser(description="Profile full iteration loop") # Compare mode (no model needed) p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"), help="Compare two summary CSVs and exit") # Model / data p.add_argument("--ckpt_path", type=str, default=None) p.add_argument("--config", type=str, default=None) p.add_argument("--prompt_dir", type=str, default=None) p.add_argument("--dataset", type=str, default=None) p.add_argument("--savedir", type=str, default="profile_output") # Inference params (match world_model_interaction.py) p.add_argument("--ddim_steps", type=int, default=50) p.add_argument("--ddim_eta", type=float, default=1.0) p.add_argument("--bs", type=int, default=1) p.add_argument("--height", type=int, default=320) p.add_argument("--width", type=int, default=512) p.add_argument("--frame_stride", type=int, default=4) p.add_argument("--unconditional_guidance_scale", type=float, default=1.0) p.add_argument("--video_length", type=int, default=16) p.add_argument("--timestep_spacing", type=str, default="uniform_trailing") p.add_argument("--guidance_rescale", type=float, default=0.7) p.add_argument("--exe_steps", type=int, default=16) p.add_argument("--n_iter", type=int, default=5) p.add_argument("--save_fps", type=int, default=8) p.add_argument("--seed", type=int, default=123) p.add_argument("--perframe_ae", action='store_true', default=False) p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16") p.add_argument("--fast_policy_no_decode", action='store_true', default=False) # Profiling control p.add_argument("--warmup", type=int, default=1, help="Number of warmup iterations to skip in statistics") p.add_argument("--csv", type=str, default=None, help="Write per-iteration timing to this CSV file") p.add_argument("--trace", action='store_true', default=False, help="Enable Layer 2: GPU timeline trace") p.add_argument("--trace_dir", type=str, default="./profile_traces", help="Directory for trace output") return p # ────────────────────────────────────────────────────────────────────── # Main # ────────────────────────────────────────────────────────────────────── def main(): patch_norm_bypass_autocast() parser = get_parser() args = parser.parse_args() # ── Compare mode: no model needed ── if args.compare: compare_csvs(args.compare[0], args.compare[1]) return # ── Validate required args ── for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']: if getattr(args, required) is None: parser.error(f"--{required} is required for profiling mode") seed_everything(args.seed) os.makedirs(args.savedir, exist_ok=True) # ── Load model ── print("=" * 60) print("PROFILE ITERATION — Loading model...") print("=" * 60) model, config = load_model(args) device = next(model.parameters()).device h, w = args.height // 8, args.width // 8 channels = model.model.diffusion_model.out_channels noise_shape = [args.bs, channels, args.video_length, h, w] print(f">>> Noise shape: {noise_shape}") print(f">>> DDIM steps: {args.ddim_steps}") print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}") # ── Layer 2: GPU trace (optional) ── if args.trace: with torch.no_grad(): run_with_trace(model, args, config, noise_shape, device) print() # ── Layer 1: Iteration-level breakdown ── print("=" * 60) print("LAYER 1: ITERATION-LEVEL PROFILING") print("=" * 60) with torch.no_grad(): all_records = run_profiled_iterations( model, args, config, noise_shape, device) # Print report print_iteration_report(all_records, warmup=args.warmup) # ── Layer 3: CSV output for A/B comparison ── if args.csv: write_csv(all_records, args.csv, warmup=args.warmup) print("Done.") if __name__ == '__main__': main()