From b0ebb7006ee24ccac59a7f12bc36da7ec163e715 Mon Sep 17 00:00:00 2001 From: olivame Date: Tue, 10 Feb 2026 05:42:11 +0000 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=B8=89=E5=B1=82=E8=BF=AD?= =?UTF-8?q?=E4=BB=A3=E7=BA=A7=E6=80=A7=E8=83=BD=E5=88=86=E6=9E=90=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=20profile=5Fiteration.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Layer1: CUDA Events 精确测量每个itr内10个阶段耗时 Layer2: torch.profiler GPU timeline trace Layer3: CSV输出支持A/B对比 Co-Authored-By: Claude Opus 4.6 (1M context) --- scripts/evaluation/profile_iteration.py | 975 ++++++++++++++++++ .../case1/run_profile.sh | 5 + 2 files changed, 980 insertions(+) create mode 100644 scripts/evaluation/profile_iteration.py create mode 100644 unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh diff --git a/scripts/evaluation/profile_iteration.py b/scripts/evaluation/profile_iteration.py new file mode 100644 index 0000000..a213f97 --- /dev/null +++ b/scripts/evaluation/profile_iteration.py @@ -0,0 +1,975 @@ +""" +Profile the full iteration loop of world model interaction. + +Three layers of profiling: + Layer 1: Iteration-level wall-clock breakdown (CUDA events) + Layer 2: GPU timeline trace (torch.profiler → Chrome trace) + Layer 3: A/B comparison (standardized CSV output) + +Usage: + # Layer 1 only (fast, default): + TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \ + python scripts/evaluation/profile_iteration.py \ + --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \ + --config configs/inference/world_model_interaction.yaml \ + --prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \ + --dataset unitree_z1_dual_arm_cleanup_pencils \ + --frame_stride 4 --n_iter 5 + + # Layer 1 + Layer 2 (GPU trace): + ... --trace --trace_dir ./profile_traces + + # Layer 3 (A/B comparison): run twice, diff the CSVs + ... --csv baseline.csv + ... --csv optimized.csv + python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv +""" + +import argparse +import csv +import os +import sys +import time +from collections import defaultdict, deque +from contextlib import nullcontext + +import h5py +import numpy as np +import pandas as pd +import torch +import torchvision +from einops import rearrange, repeat +from omegaconf import OmegaConf +from PIL import Image +from pytorch_lightning import seed_everything +from torch import Tensor + +from unifolm_wma.models.samplers.ddim import DDIMSampler +from unifolm_wma.utils.utils import instantiate_from_config +import torch.nn.functional as F + + +# ────────────────────────────────────────────────────────────────────── +# Constants +# ────────────────────────────────────────────────────────────────────── +STAGE_NAMES = [ + "stack_to_device_1", + "synth_policy", + "update_action_queue", + "stack_to_device_2", + "synth_world_model", + "update_obs_queue", + "tensorboard_log", + "save_results", + "cpu_transfer", + "itr_total", +] + +# Sub-stages inside image_guided_synthesis_sim_mode +SYNTH_SUB_STAGES = [ + "ddim_sampler_init", + "image_embedding", + "vae_encode", + "text_conditioning", + "projectors", + "cond_assembly", + "ddim_sampling", + "vae_decode", +] + + +# ────────────────────────────────────────────────────────────────────── +# CudaTimer — GPU-precise timing via CUDA events +# ────────────────────────────────────────────────────────────────────── +class CudaTimer: + """Context manager that records GPU time between enter/exit using CUDA events.""" + + def __init__(self, name, records): + self.name = name + self.records = records + + def __enter__(self): + torch.cuda.synchronize() + self._start = torch.cuda.Event(enable_timing=True) + self._end = torch.cuda.Event(enable_timing=True) + self._start.record() + return self + + def __exit__(self, *args): + self._end.record() + torch.cuda.synchronize() + elapsed_ms = self._start.elapsed_time(self._end) + self.records[self.name].append(elapsed_ms) + + +class WallTimer: + """Context manager that records CPU wall-clock time (for pure-CPU stages).""" + + def __init__(self, name, records): + self.name = name + self.records = records + + def __enter__(self): + torch.cuda.synchronize() + self._t0 = time.perf_counter() + return self + + def __exit__(self, *args): + torch.cuda.synchronize() + elapsed_ms = (time.perf_counter() - self._t0) * 1000.0 + self.records[self.name].append(elapsed_ms) + + +# ────────────────────────────────────────────────────────────────────── +# Model loading (reused from world_model_interaction.py) +# ────────────────────────────────────────────────────────────────────── +def patch_norm_bypass_autocast(): + def _group_norm_forward(self, x): + with torch.amp.autocast('cuda', enabled=False): + return F.group_norm( + x, self.num_groups, + self.weight.to(x.dtype) if self.weight is not None else None, + self.bias.to(x.dtype) if self.bias is not None else None, + self.eps) + + def _layer_norm_forward(self, x): + with torch.amp.autocast('cuda', enabled=False): + return F.layer_norm( + x, self.normalized_shape, + self.weight.to(x.dtype) if self.weight is not None else None, + self.bias.to(x.dtype) if self.bias is not None else None, + self.eps) + + torch.nn.GroupNorm.forward = _group_norm_forward + torch.nn.LayerNorm.forward = _layer_norm_forward + + +def apply_torch_compile(model, hot_indices=(5, 8, 9)): + from unifolm_wma.modules.networks.wma_model import ResBlock + unet = model.model.diffusion_model + compiled = 0 + for idx in hot_indices: + block = unet.output_blocks[idx] + for layer in block: + if isinstance(layer, ResBlock): + layer._forward = torch.compile(layer._forward, mode="default") + compiled += 1 + print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}") + + +def load_model(args): + config = OmegaConf.load(args.config) + config['model']['params']['wma_config']['params']['use_checkpoint'] = False + model = instantiate_from_config(config.model) + model.perframe_ae = args.perframe_ae + + from collections import OrderedDict + state_dict = torch.load(args.ckpt_path, map_location="cpu") + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + try: + model.load_state_dict(state_dict, strict=True) + except Exception: + new_sd = OrderedDict() + for k, v in state_dict.items(): + new_sd[k] = v + for k in list(new_sd.keys()): + if "framestride_embed" in k: + new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k) + model.load_state_dict(new_sd, strict=True) + + model.eval() + + # Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE + model.model.to(torch.bfloat16) + model.diffusion_autocast_dtype = torch.bfloat16 + model.embedder.to(torch.bfloat16) + model.image_proj_model.to(torch.bfloat16) + model.encoder_autocast_dtype = None + model.state_projector.to(torch.bfloat16) + model.action_projector.to(torch.bfloat16) + model.projector_autocast_dtype = None + if args.vae_dtype == "bf16": + model.first_stage_model.to(torch.bfloat16) + + # Compile hot ResBlocks + apply_torch_compile(model) + model = model.cuda() + print(">>> Model loaded and ready.") + return model, config + + +# ────────────────────────────────────────────────────────────────────── +# Data preparation (reused from world_model_interaction.py) +# ────────────────────────────────────────────────────────────────────── +def get_init_frame_path(data_dir, sample): + rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png') + return os.path.join(data_dir, 'images', rel) + + +def get_transition_path(data_dir, sample): + rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5') + return os.path.join(data_dir, 'transitions', rel) + + +def prepare_init_input(start_idx, init_frame_path, transition_dict, + frame_stride, wma_data, video_length=16, n_obs_steps=2): + indices = [start_idx + frame_stride * i for i in range(video_length)] + init_frame = Image.open(init_frame_path).convert('RGB') + init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float() + + if start_idx < n_obs_steps - 1: + state_indices = list(range(0, start_idx + 1)) + states = transition_dict['observation.state'][state_indices, :] + num_padding = n_obs_steps - 1 - start_idx + padding = states[0:1, :].repeat(num_padding, 1) + states = torch.cat((padding, states), dim=0) + else: + state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1)) + states = transition_dict['observation.state'][state_indices, :] + + actions = transition_dict['action'][indices, :] + ori_state_dim = states.shape[-1] + ori_action_dim = actions.shape[-1] + + frames_action_state_dict = { + 'action': actions, + 'observation.state': states, + } + frames_action_state_dict = wma_data.normalizer(frames_action_state_dict) + frames_action_state_dict = wma_data.get_uni_vec( + frames_action_state_dict, + transition_dict['action_type'], + transition_dict['state_type'], + ) + + if wma_data.spatial_transform is not None: + init_frame = wma_data.spatial_transform(init_frame) + init_frame = (init_frame / 255 - 0.5) * 2 + + data = {'observation.image': init_frame} + data.update(frames_action_state_dict) + return data, ori_state_dim, ori_action_dim + + +def populate_queues(queues, batch): + for key in batch: + if key not in queues: + continue + if len(queues[key]) != queues[key].maxlen: + while len(queues[key]) != queues[key].maxlen: + queues[key].append(batch[key]) + else: + queues[key].append(batch[key]) + return queues + + +# ────────────────────────────────────────────────────────────────────── +# Instrumented image_guided_synthesis_sim_mode with sub-stage timing +# ────────────────────────────────────────────────────────────────────── +def get_latent_z(model, videos): + b, c, t, h, w = videos.shape + x = rearrange(videos, 'b c t h w -> (b t) c h w') + vae_dtype = next(model.first_stage_model.parameters()).dtype + x = x.to(dtype=vae_dtype) + z = model.encode_first_stage(x) + z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) + return z + + +def save_results(video, filename, fps=8): + video = video.detach().cpu() + video = torch.clamp(video.float(), -1., 1.) + n = video.shape[0] + video = video.permute(2, 0, 1, 3, 4) + frame_grids = [ + torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) + for framesheet in video + ] + grid = torch.stack(frame_grids, dim=0) + grid = (grid + 1.0) / 2.0 + grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1) + torchvision.io.write_video(filename, grid, fps=fps, + video_codec='h264', options={'crf': '10'}) + + +def profiled_synthesis(model, prompts, observation, noise_shape, + ddim_steps, ddim_eta, unconditional_guidance_scale, + fs, text_input, timestep_spacing, guidance_rescale, + sim_mode, decode_video, records, prefix): + """image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing. + + Args: + prefix: "policy" or "wm" — prepended to sub-stage names in records. + """ + b, _, t, _, _ = noise_shape + batch_size = noise_shape[0] + device = next(model.parameters()).device + + # --- sub-stage: ddim_sampler_init --- + with CudaTimer(f"{prefix}/ddim_sampler_init", records): + ddim_sampler = DDIMSampler(model) + fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device) + + # --- sub-stage: image_embedding --- + with CudaTimer(f"{prefix}/image_embedding", records): + model_dtype = next(model.embedder.parameters()).dtype + img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) + cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype) + cond_img_emb = model.embedder(cond_img) + cond_img_emb = model.image_proj_model(cond_img_emb) + + # --- sub-stage: vae_encode --- + with CudaTimer(f"{prefix}/vae_encode", records): + if model.model.conditioning_key == 'hybrid': + z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) + img_cat_cond = z[:, :, -1:, :, :] + img_cat_cond = repeat(img_cat_cond, + 'b c t h w -> b c (repeat t) h w', + repeat=noise_shape[2]) + cond = {"c_concat": [img_cat_cond]} + + # --- sub-stage: text_conditioning --- + with CudaTimer(f"{prefix}/text_conditioning", records): + if not text_input: + prompts_use = [""] * batch_size + else: + prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size + cond_ins_emb = model.get_learned_conditioning(prompts_use) + + # --- sub-stage: projectors --- + with CudaTimer(f"{prefix}/projectors", records): + projector_dtype = next(model.state_projector.parameters()).dtype + cond_state_emb = model.state_projector( + observation['observation.state'].to(dtype=projector_dtype)) + cond_state_emb = cond_state_emb + model.agent_state_pos_emb + cond_action_emb = model.action_projector( + observation['action'].to(dtype=projector_dtype)) + cond_action_emb = cond_action_emb + model.agent_action_pos_emb + if not sim_mode: + cond_action_emb = torch.zeros_like(cond_action_emb) + + # --- sub-stage: cond_assembly --- + with CudaTimer(f"{prefix}/cond_assembly", records): + cond["c_crossattn"] = [ + torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1) + ] + cond["c_crossattn_action"] = [ + observation['observation.images.top'][:, :, -model.n_obs_steps_acting:], + observation['observation.state'][:, -model.n_obs_steps_acting:], + sim_mode, + False, + ] + + # --- sub-stage: ddim_sampling --- + autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None) + if autocast_dtype is not None and device.type == 'cuda': + autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype) + else: + autocast_ctx = nullcontext() + + with CudaTimer(f"{prefix}/ddim_sampling", records): + with autocast_ctx: + samples, actions, states, _ = ddim_sampler.sample( + S=ddim_steps, + conditioning=cond, + batch_size=batch_size, + shape=noise_shape[1:], + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=None, + eta=ddim_eta, + cfg_img=None, + mask=None, + x0=None, + fs=fs_t, + timestep_spacing=timestep_spacing, + guidance_rescale=guidance_rescale, + unconditional_conditioning_img_nonetext=None, + ) + + # --- sub-stage: vae_decode --- + batch_variants = None + if decode_video: + with CudaTimer(f"{prefix}/vae_decode", records): + batch_variants = model.decode_first_stage(samples) + else: + records[f"{prefix}/vae_decode"].append(0.0) + + return batch_variants, actions, states + + +# ────────────────────────────────────────────────────────────────────── +# Instrumented iteration loop +# ────────────────────────────────────────────────────────────────────── +def run_profiled_iterations(model, args, config, noise_shape, device): + """Run the full iteration loop with per-stage timing. + + Returns: + all_records: list of dicts, one per itr, {stage_name: ms} + """ + # Load data + csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") + df = pd.read_csv(csv_path) + sample = df.iloc[0] + + data_module = instantiate_from_config(config.data) + data_module.setup() + + init_frame_path = get_init_frame_path(args.prompt_dir, sample) + ori_fps = float(sample['fps']) + fs = args.frame_stride + model_input_fs = ori_fps // fs + + transition_path = get_transition_path(args.prompt_dir, sample) + with h5py.File(transition_path, 'r') as h5f: + transition_dict = {} + for key in h5f.keys(): + transition_dict[key] = torch.tensor(h5f[key][()]) + for key in h5f.attrs.keys(): + transition_dict[key] = h5f.attrs[key] + + # Prepare initial observation + batch, ori_state_dim, ori_action_dim = prepare_init_input( + 0, init_frame_path, transition_dict, fs, + data_module.test_datasets[args.dataset], + n_obs_steps=model.n_obs_steps_imagen) + + observation = { + 'observation.images.top': + batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0), + 'observation.state': + batch['observation.state'][-1].unsqueeze(0), + 'action': + torch.zeros_like(batch['action'][-1]).unsqueeze(0), + } + observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()} + + cond_obs_queues = { + "observation.images.top": deque(maxlen=model.n_obs_steps_imagen), + "observation.state": deque(maxlen=model.n_obs_steps_imagen), + "action": deque(maxlen=args.video_length), + } + cond_obs_queues = populate_queues(cond_obs_queues, observation) + + # Temp dir for save_results profiling + tmp_dir = os.path.join(args.savedir, "profile_tmp") + os.makedirs(tmp_dir, exist_ok=True) + + prompt_text = sample['instruction'] + all_records = [] + + print(f">>> Running {args.n_iter} profiled iterations ...") + for itr in range(args.n_iter): + rec = defaultdict(list) + + # ── itr_total start ── + torch.cuda.synchronize() + itr_start = torch.cuda.Event(enable_timing=True) + itr_end = torch.cuda.Event(enable_timing=True) + itr_start.record() + + # ① stack_to_device_1 + with CudaTimer("stack_to_device_1", rec): + observation = { + 'observation.images.top': + torch.stack(list(cond_obs_queues['observation.images.top']), + dim=1).permute(0, 2, 1, 3, 4), + 'observation.state': + torch.stack(list(cond_obs_queues['observation.state']), dim=1), + 'action': + torch.stack(list(cond_obs_queues['action']), dim=1), + } + observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()} + + # ② synth_policy + with CudaTimer("synth_policy", rec): + pred_videos_0, pred_actions, _ = profiled_synthesis( + model, prompt_text, observation, noise_shape, + ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args.unconditional_guidance_scale, + fs=model_input_fs, text_input=True, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + sim_mode=False, + decode_video=not args.fast_policy_no_decode, + records=rec, prefix="policy") + + # ③ update_action_queue + with WallTimer("update_action_queue", rec): + for idx in range(len(pred_actions[0])): + obs_a = {'action': pred_actions[0][idx:idx + 1]} + obs_a['action'][:, ori_action_dim:] = 0.0 + cond_obs_queues = populate_queues(cond_obs_queues, obs_a) + + # ④ stack_to_device_2 + with CudaTimer("stack_to_device_2", rec): + observation = { + 'observation.images.top': + torch.stack(list(cond_obs_queues['observation.images.top']), + dim=1).permute(0, 2, 1, 3, 4), + 'observation.state': + torch.stack(list(cond_obs_queues['observation.state']), dim=1), + 'action': + torch.stack(list(cond_obs_queues['action']), dim=1), + } + observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()} + + # ⑤ synth_world_model + with CudaTimer("synth_world_model", rec): + pred_videos_1, _, pred_states = profiled_synthesis( + model, "", observation, noise_shape, + ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args.unconditional_guidance_scale, + fs=model_input_fs, text_input=False, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + sim_mode=True, decode_video=True, + records=rec, prefix="wm") + + # ⑥ update_obs_queue + with WallTimer("update_obs_queue", rec): + for idx in range(args.exe_steps): + obs_u = { + 'observation.images.top': + pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3), + 'observation.state': + pred_states[0][idx:idx + 1], + 'action': + torch.zeros_like(pred_actions[0][-1:]), + } + obs_u['observation.state'][:, ori_state_dim:] = 0.0 + cond_obs_queues = populate_queues(cond_obs_queues, obs_u) + + # ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost) + with WallTimer("tensorboard_log", rec): + for vid in [pred_videos_0, pred_videos_1]: + if vid is not None and vid.dim() == 5: + v = vid.permute(2, 0, 1, 3, 4) + grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v] + _ = torch.stack(grids, dim=0) + + # ⑧ save_results + with WallTimer("save_results", rec): + if pred_videos_0 is not None: + save_results(pred_videos_0.cpu(), + os.path.join(tmp_dir, f"dm_{itr}.mp4"), + fps=args.save_fps) + save_results(pred_videos_1.cpu(), + os.path.join(tmp_dir, f"wm_{itr}.mp4"), + fps=args.save_fps) + + # ⑨ cpu_transfer + with CudaTimer("cpu_transfer", rec): + _ = pred_videos_1[:, :, :args.exe_steps].cpu() + + # ── itr_total end ── + itr_end.record() + torch.cuda.synchronize() + itr_total_ms = itr_start.elapsed_time(itr_end) + rec["itr_total"].append(itr_total_ms) + + # Flatten: each stage has exactly one entry per itr + itr_rec = {k: v[0] for k, v in rec.items()} + all_records.append(itr_rec) + + # Print live progress + print(f" itr {itr}: {itr_total_ms:.0f} ms total | " + f"policy={itr_rec.get('synth_policy', 0):.0f} | " + f"wm={itr_rec.get('synth_world_model', 0):.0f} | " + f"save={itr_rec.get('save_results', 0):.0f} | " + f"tb={itr_rec.get('tensorboard_log', 0):.0f}") + + return all_records + + +# ────────────────────────────────────────────────────────────────────── +# Layer 1: Console report +# ────────────────────────────────────────────────────────────────────── +def print_iteration_report(all_records, warmup=1): + """Print a structured table of per-stage timing across iterations.""" + if len(all_records) <= warmup: + records = all_records + else: + records = all_records[warmup:] + print(f"\n(Skipping first {warmup} itr(s) as warmup)\n") + + # Collect all stage keys in a stable order + all_keys = [] + seen = set() + for rec in records: + for k in rec: + if k not in seen: + all_keys.append(k) + seen.add(k) + + # Separate top-level stages from sub-stages + top_keys = [k for k in all_keys if '/' not in k] + sub_keys = [k for k in all_keys if '/' in k] + + def _print_table(keys, title): + if not keys: + return + print("=" * 82) + print(title) + print("=" * 82) + print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}") + print("-" * 82) + + total_mean = np.mean([rec.get("itr_total", 0) for rec in records]) + for k in keys: + vals = [rec.get(k, 0) for rec in records] + mean = np.mean(vals) + std = np.std(vals) + mn = np.min(vals) + mx = np.max(vals) + pct = mean / total_mean * 100 if total_mean > 0 else 0 + print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%") + print("-" * 82) + print() + + _print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN") + _print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN") + + +# ────────────────────────────────────────────────────────────────────── +# Layer 3: CSV output for A/B comparison +# ────────────────────────────────────────────────────────────────────── +def write_csv(all_records, csv_path, warmup=1): + """Write per-iteration timing to CSV for later comparison.""" + records = all_records[warmup:] if len(all_records) > warmup else all_records + + # Collect all keys + all_keys = [] + seen = set() + for rec in records: + for k in rec: + if k not in seen: + all_keys.append(k) + seen.add(k) + + with open(csv_path, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys) + writer.writeheader() + for i, rec in enumerate(records): + row = {'itr': i} + row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys}) + writer.writerow(row) + + # Also write a summary row + summary_path = csv_path.replace('.csv', '_summary.csv') + with open(summary_path, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys) + writer.writeheader() + for stat_name, stat_fn in [('mean', np.mean), ('std', np.std), + ('min', np.min), ('max', np.max)]: + row = {'stat': stat_name} + row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}" + for k in all_keys}) + writer.writerow(row) + + print(f">>> CSV written to: {csv_path}") + print(f">>> Summary written to: {summary_path}") + + +def compare_csvs(path_a, path_b): + """Compare two summary CSVs and print a diff table.""" + df_a = pd.read_csv(path_a, index_col='stat') + df_b = pd.read_csv(path_b, index_col='stat') + + # Use mean row for comparison + mean_a = df_a.loc['mean'].astype(float) + mean_b = df_b.loc['mean'].astype(float) + + print("=" * 90) + print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}") + print("=" * 90) + print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}") + print("-" * 90) + + for col in mean_a.index: + if col not in mean_b.index: + continue + a_val = mean_a[col] + b_val = mean_b[col] + diff = b_val - a_val + speedup = a_val / b_val if b_val > 0 else float('inf') + marker = " <<<" if abs(diff) > 50 else "" + print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}") + + print("-" * 90) + total_a = mean_a.get('itr_total', 0) + total_b = mean_b.get('itr_total', 0) + print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} " + f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x") + print() + + +# ────────────────────────────────────────────────────────────────────── +# Layer 2: GPU timeline trace wrapper +# ────────────────────────────────────────────────────────────────────── +def run_with_trace(model, args, config, noise_shape, device): + """Run iterations under torch.profiler to generate Chrome/TensorBoard traces.""" + trace_dir = args.trace_dir + os.makedirs(trace_dir, exist_ok=True) + + # We need the same data setup as run_profiled_iterations + csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") + df = pd.read_csv(csv_path) + sample = df.iloc[0] + + data_module = instantiate_from_config(config.data) + data_module.setup() + + init_frame_path = get_init_frame_path(args.prompt_dir, sample) + ori_fps = float(sample['fps']) + fs = args.frame_stride + model_input_fs = ori_fps // fs + + transition_path = get_transition_path(args.prompt_dir, sample) + with h5py.File(transition_path, 'r') as h5f: + transition_dict = {} + for key in h5f.keys(): + transition_dict[key] = torch.tensor(h5f[key][()]) + for key in h5f.attrs.keys(): + transition_dict[key] = h5f.attrs[key] + + batch, ori_state_dim, ori_action_dim = prepare_init_input( + 0, init_frame_path, transition_dict, fs, + data_module.test_datasets[args.dataset], + n_obs_steps=model.n_obs_steps_imagen) + + observation = { + 'observation.images.top': + batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0), + 'observation.state': + batch['observation.state'][-1].unsqueeze(0), + 'action': + torch.zeros_like(batch['action'][-1]).unsqueeze(0), + } + observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()} + + cond_obs_queues = { + "observation.images.top": deque(maxlen=model.n_obs_steps_imagen), + "observation.state": deque(maxlen=model.n_obs_steps_imagen), + "action": deque(maxlen=args.video_length), + } + cond_obs_queues = populate_queues(cond_obs_queues, observation) + + tmp_dir = os.path.join(args.savedir, "profile_tmp") + os.makedirs(tmp_dir, exist_ok=True) + prompt_text = sample['instruction'] + + # Total iterations: warmup + active + n_warmup = 1 + n_active = min(args.n_iter, 2) # trace 2 active iterations max + n_total = n_warmup + n_active + + print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations") + print(f">>> Trace output: {trace_dir}") + + with torch.no_grad(), torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule( + wait=0, warmup=n_warmup, active=n_active, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir), + record_shapes=True, + with_stack=True, + ) as prof: + for itr_idx in range(n_total): + phase = "warmup" if itr_idx < n_warmup else "active" + print(f" trace itr {itr_idx} ({phase})...") + + # ── One full iteration (same logic as run_inference) ── + obs_loc = { + 'observation.images.top': + torch.stack(list(cond_obs_queues['observation.images.top']), + dim=1).permute(0, 2, 1, 3, 4), + 'observation.state': + torch.stack(list(cond_obs_queues['observation.state']), dim=1), + 'action': + torch.stack(list(cond_obs_queues['action']), dim=1), + } + obs_loc = {k: v.to(device) for k, v in obs_loc.items()} + + # Policy pass + dummy_rec = defaultdict(list) + pv0, pa, _ = profiled_synthesis( + model, prompt_text, obs_loc, noise_shape, + ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args.unconditional_guidance_scale, + fs=model_input_fs, text_input=True, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + sim_mode=False, + decode_video=not args.fast_policy_no_decode, + records=dummy_rec, prefix="policy") + + for idx in range(len(pa[0])): + oa = {'action': pa[0][idx:idx + 1]} + oa['action'][:, ori_action_dim:] = 0.0 + populate_queues(cond_obs_queues, oa) + + # Re-stack for world model + obs_loc2 = { + 'observation.images.top': + torch.stack(list(cond_obs_queues['observation.images.top']), + dim=1).permute(0, 2, 1, 3, 4), + 'observation.state': + torch.stack(list(cond_obs_queues['observation.state']), dim=1), + 'action': + torch.stack(list(cond_obs_queues['action']), dim=1), + } + obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()} + + # World model pass + pv1, _, ps = profiled_synthesis( + model, "", obs_loc2, noise_shape, + ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta, + unconditional_guidance_scale=args.unconditional_guidance_scale, + fs=model_input_fs, text_input=False, + timestep_spacing=args.timestep_spacing, + guidance_rescale=args.guidance_rescale, + sim_mode=True, decode_video=True, + records=dummy_rec, prefix="wm") + + # Update obs queue + for idx in range(args.exe_steps): + ou = { + 'observation.images.top': + pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3), + 'observation.state': ps[0][idx:idx + 1], + 'action': torch.zeros_like(pa[0][-1:]), + } + ou['observation.state'][:, ori_state_dim:] = 0.0 + populate_queues(cond_obs_queues, ou) + + # Save results (captures CPU stall in trace) + if pv0 is not None: + save_results(pv0.cpu(), + os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"), + fps=args.save_fps) + save_results(pv1.cpu(), + os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"), + fps=args.save_fps) + + prof.step() + + print(f">>> Trace saved to {trace_dir}") + print(" View with: tensorboard --logdir", trace_dir) + print(" Or open the .json file in chrome://tracing") + + +# ────────────────────────────────────────────────────────────────────── +# Argument parser +# ────────────────────────────────────────────────────────────────────── +def get_parser(): + p = argparse.ArgumentParser(description="Profile full iteration loop") + + # Compare mode (no model needed) + p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"), + help="Compare two summary CSVs and exit") + + # Model / data + p.add_argument("--ckpt_path", type=str, default=None) + p.add_argument("--config", type=str, default=None) + p.add_argument("--prompt_dir", type=str, default=None) + p.add_argument("--dataset", type=str, default=None) + p.add_argument("--savedir", type=str, default="profile_output") + + # Inference params (match world_model_interaction.py) + p.add_argument("--ddim_steps", type=int, default=50) + p.add_argument("--ddim_eta", type=float, default=1.0) + p.add_argument("--bs", type=int, default=1) + p.add_argument("--height", type=int, default=320) + p.add_argument("--width", type=int, default=512) + p.add_argument("--frame_stride", type=int, default=4) + p.add_argument("--unconditional_guidance_scale", type=float, default=1.0) + p.add_argument("--video_length", type=int, default=16) + p.add_argument("--timestep_spacing", type=str, default="uniform_trailing") + p.add_argument("--guidance_rescale", type=float, default=0.7) + p.add_argument("--exe_steps", type=int, default=16) + p.add_argument("--n_iter", type=int, default=5) + p.add_argument("--save_fps", type=int, default=8) + p.add_argument("--seed", type=int, default=123) + p.add_argument("--perframe_ae", action='store_true', default=False) + p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16") + p.add_argument("--fast_policy_no_decode", action='store_true', default=False) + + # Profiling control + p.add_argument("--warmup", type=int, default=1, + help="Number of warmup iterations to skip in statistics") + p.add_argument("--csv", type=str, default=None, + help="Write per-iteration timing to this CSV file") + p.add_argument("--trace", action='store_true', default=False, + help="Enable Layer 2: GPU timeline trace") + p.add_argument("--trace_dir", type=str, default="./profile_traces", + help="Directory for trace output") + + return p + + +# ────────────────────────────────────────────────────────────────────── +# Main +# ────────────────────────────────────────────────────────────────────── +def main(): + patch_norm_bypass_autocast() + parser = get_parser() + args = parser.parse_args() + + # ── Compare mode: no model needed ── + if args.compare: + compare_csvs(args.compare[0], args.compare[1]) + return + + # ── Validate required args ── + for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']: + if getattr(args, required) is None: + parser.error(f"--{required} is required for profiling mode") + + seed_everything(args.seed) + os.makedirs(args.savedir, exist_ok=True) + + # ── Load model ── + print("=" * 60) + print("PROFILE ITERATION — Loading model...") + print("=" * 60) + model, config = load_model(args) + device = next(model.parameters()).device + + h, w = args.height // 8, args.width // 8 + channels = model.model.diffusion_model.out_channels + noise_shape = [args.bs, channels, args.video_length, h, w] + print(f">>> Noise shape: {noise_shape}") + print(f">>> DDIM steps: {args.ddim_steps}") + print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}") + + # ── Layer 2: GPU trace (optional) ── + if args.trace: + with torch.no_grad(): + run_with_trace(model, args, config, noise_shape, device) + print() + + # ── Layer 1: Iteration-level breakdown ── + print("=" * 60) + print("LAYER 1: ITERATION-LEVEL PROFILING") + print("=" * 60) + with torch.no_grad(): + all_records = run_profiled_iterations( + model, args, config, noise_shape, device) + + # Print report + print_iteration_report(all_records, warmup=args.warmup) + + # ── Layer 3: CSV output for A/B comparison ── + if args.csv: + write_csv(all_records, args.csv, warmup=args.warmup) + + print("Done.") + + +if __name__ == '__main__': + main() diff --git a/unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh b/unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh new file mode 100644 index 0000000..18bef0a --- /dev/null +++ b/unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh @@ -0,0 +1,5 @@ +#\!/bin/bash +res_dir="unitree_z1_dual_arm_cleanup_pencils/case1" +dataset="unitree_z1_dual_arm_cleanup_pencils" + +TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/profile_iteration.py --seed 123 --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --savedir "${res_dir}/profile_output" --prompt_dir "${res_dir}/world_model_interaction_prompts" --dataset ${dataset} --bs 1 --height 320 --width 512 --unconditional_guidance_scale 1.0 --ddim_steps 50 --ddim_eta 1.0 --video_length 16 --frame_stride 4 --exe_steps 16 --n_iter 5 --warmup 1 --timestep_spacing uniform_trailing --guidance_rescale 0.7 --perframe_ae --vae_dtype bf16 --fast_policy_no_decode --csv "${res_dir}/profile_output/baseline.csv" 2>&1 | tee "${res_dir}/profile_output/profile.log"