""" Profile the full inference pipeline of the world model, covering all 7 stages: 1. Image Embedding 2. VAE Encode 3. Text Conditioning 4. State/Action Projectors 5. DDIM Loop 6. VAE Decode 7. Post-process Reports stage-level timing, UNet sub-module breakdown, memory summary, and throughput analysis. TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep Usage: TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 """ import argparse import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common from contextlib import nullcontext, contextmanager from collections import defaultdict from omegaconf import OmegaConf from einops import rearrange, repeat from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.modules.attention import ( SpatialTransformer, TemporalTransformer, BasicTransformerBlock, CrossAttention, FeedForward, ) from unifolm_wma.modules.networks.wma_model import ResBlock from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D # --- W7900D theoretical peak --- PEAK_BF16_TFLOPS = 61.0 MEM_BW_GBS = 864.0 # --------------------------------------------------------------------------- # Utility: patch norms to bypass autocast fp32 promotion # --------------------------------------------------------------------------- def patch_norm_bypass_autocast(): """Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.""" def _group_norm_forward(self, x): with torch.amp.autocast('cuda', enabled=False): return F.group_norm( x, self.num_groups, self.weight.to(x.dtype) if self.weight is not None else None, self.bias.to(x.dtype) if self.bias is not None else None, self.eps) def _layer_norm_forward(self, x): with torch.amp.autocast('cuda', enabled=False): return F.layer_norm( x, self.normalized_shape, self.weight.to(x.dtype) if self.weight is not None else None, self.bias.to(x.dtype) if self.bias is not None else None, self.eps) torch.nn.GroupNorm.forward = _group_norm_forward torch.nn.LayerNorm.forward = _layer_norm_forward # --------------------------------------------------------------------------- # Utility: torch.compile hot ResBlocks # --------------------------------------------------------------------------- def apply_torch_compile(model, hot_indices=(5, 8, 9)): """Compile ResBlock._forward in the hottest output_blocks for operator fusion.""" unet = model.model.diffusion_model compiled = 0 for idx in hot_indices: block = unet.output_blocks[idx] for layer in block: if isinstance(layer, ResBlock): layer._forward = torch.compile(layer._forward, mode="default") compiled += 1 print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}") # --------------------------------------------------------------------------- # Model loading # --------------------------------------------------------------------------- def load_model(args): config = OmegaConf.load(args.config) config['model']['params']['wma_config']['params']['use_checkpoint'] = False model = instantiate_from_config(config.model) state_dict = torch.load(args.ckpt_path, map_location="cpu") if "state_dict" in state_dict: state_dict = state_dict["state_dict"] model.load_state_dict(state_dict, strict=True) model.eval() model.model.to(torch.bfloat16) model.diffusion_autocast_dtype = torch.bfloat16 apply_torch_compile(model) model = model.cuda() return model # --------------------------------------------------------------------------- # CudaTimer — precise GPU timing via CUDA events # --------------------------------------------------------------------------- class CudaTimer: """Context manager for GPU-precise stage timing using CUDA events.""" def __init__(self, name, records): self.name = name self.records = records self.start = torch.cuda.Event(enable_timing=True) self.end = torch.cuda.Event(enable_timing=True) def __enter__(self): torch.cuda.synchronize() self.start.record() return self def __exit__(self, *args): self.end.record() torch.cuda.synchronize() elapsed = self.start.elapsed_time(self.end) self.records[self.name].append(elapsed) # --------------------------------------------------------------------------- # HookProfiler — sub-module level timing inside UNet via hooks # --------------------------------------------------------------------------- class HookProfiler: """Register forward hooks on UNet sub-modules to collect per-call timing.""" # Coarse-grained targets (original) COARSE_CLASSES = ( SpatialTransformer, TemporalTransformer, ResBlock, ConditionalUnet1D, ) # Fine-grained targets for deep DDIM analysis FINE_CLASSES = ( CrossAttention, FeedForward, ) def __init__(self, unet, deep=False): self.unet = unet self.deep = deep self.handles = [] # per-instance data: {instance_id: [(start_event, end_event), ...]} self._events = defaultdict(list) # tag mapping: {instance_id: (class_name, module_name)} self._tags = {} # block location: {instance_id: block_location_str} self._block_loc = {} @staticmethod def _get_block_location(name): """Derive UNet block location from module name, e.g. 'input_blocks.3.1'.""" parts = name.split('.') if len(parts) >= 2 and parts[0] == 'input_blocks': return f"input_blocks.{parts[1]}" elif len(parts) >= 1 and parts[0] == 'middle_block': return "middle_block" elif len(parts) >= 2 and parts[0] == 'output_blocks': return f"output_blocks.{parts[1]}" elif 'action_unet' in name: return "action_unet" elif 'state_unet' in name: return "state_unet" elif name == 'out' or name.startswith('out.'): return "out" return "other" def register(self): """Attach pre/post forward hooks to target sub-modules + unet.out.""" target_classes = self.COARSE_CLASSES if self.deep: target_classes = target_classes + self.FINE_CLASSES for name, mod in self.unet.named_modules(): if isinstance(mod, target_classes): tag = type(mod).__name__ inst_id = id(mod) self._tags[inst_id] = (tag, name) self._block_loc[inst_id] = self._get_block_location(name) self.handles.append( mod.register_forward_pre_hook(self._make_pre_hook(inst_id))) self.handles.append( mod.register_forward_hook(self._make_post_hook(inst_id))) # Also hook unet.out (nn.Sequential) out_mod = self.unet.out inst_id = id(out_mod) self._tags[inst_id] = ("UNet.out", "out") self._block_loc[inst_id] = "out" self.handles.append( out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id))) self.handles.append( out_mod.register_forward_hook(self._make_post_hook(inst_id))) def _make_pre_hook(self, inst_id): events = self._events def hook(module, input): start = torch.cuda.Event(enable_timing=True) start.record() events[inst_id].append([start, None]) return hook def _make_post_hook(self, inst_id): events = self._events def hook(module, input, output): end = torch.cuda.Event(enable_timing=True) end.record() events[inst_id][-1][1] = end return hook def reset(self): """Clear collected events for a fresh run.""" self._events.clear() def synchronize_and_collect(self): """Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block).""" torch.cuda.synchronize() by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []}) by_instance = {} # by_block: {block_loc: {tag: {"total_ms", "count"}}} by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0})) for inst_id, pairs in self._events.items(): tag, mod_name = self._tags[inst_id] block_loc = self._block_loc.get(inst_id, "other") inst_times = [] for start_evt, end_evt in pairs: if end_evt is not None: ms = start_evt.elapsed_time(end_evt) inst_times.append(ms) by_type[tag]["total_ms"] += ms by_type[tag]["count"] += 1 by_type[tag]["calls"].append(ms) by_block[block_loc][tag]["total_ms"] += ms by_block[block_loc][tag]["count"] += 1 by_instance[(tag, mod_name)] = inst_times return dict(by_type), by_instance, dict(by_block) def remove(self): """Remove all hooks.""" for h in self.handles: h.remove() self.handles.clear() # --------------------------------------------------------------------------- # Build dummy inputs matching the pipeline's expected shapes # --------------------------------------------------------------------------- def build_dummy_inputs(model, noise_shape): """Create synthetic observation dict and prompts for profiling.""" device = next(model.parameters()).device B, C, T, H, W = noise_shape dtype = torch.bfloat16 # observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline) O = 2 obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype) obs_state = torch.randn(B, O, 16, device=device, dtype=dtype) action = torch.randn(B, 16, 16, device=device, dtype=dtype) observation = { 'observation.images.top': obs_images, 'observation.state': obs_state, 'action': action, } prompts = ["a robot arm performing a task"] * B return observation, prompts # --------------------------------------------------------------------------- # Run one full pipeline pass with per-stage timing # --------------------------------------------------------------------------- def run_pipeline(model, observation, prompts, noise_shape, ddim_steps, cfg_scale, hook_profiler): """Execute the full 7-stage pipeline, returning per-stage timing dict.""" records = defaultdict(list) device = next(model.parameters()).device B, C, T, H, W = noise_shape dtype = torch.bfloat16 fs = torch.tensor([1] * B, dtype=torch.long, device=device) # --- Stage 1: Image Embedding --- with CudaTimer("1_Image_Embedding", records): img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype) with torch.autocast('cuda', dtype=torch.bfloat16): cond_img_emb = model.embedder(cond_img) cond_img_emb = model.image_proj_model(cond_img_emb) # --- Stage 2: VAE Encode --- with CudaTimer("2_VAE_Encode", records): videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W] b_v, c_v, t_v, h_v, w_v = videos.shape vae_dtype = next(model.first_stage_model.parameters()).dtype x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype) z = model.encode_first_stage(x_vae) z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v) img_cat_cond = z[:, :, -1:, :, :] img_cat_cond = repeat(img_cat_cond, 'b c t h w -> b c (repeat t) h w', repeat=T) cond = {"c_concat": [img_cat_cond]} vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size() vae_enc_output_bytes = z.nelement() * z.element_size() # --- Stage 3: Text Conditioning --- with CudaTimer("3_Text_Conditioning", records): cond_ins_emb = model.get_learned_conditioning(prompts) # --- Stage 4: State/Action Projectors --- with CudaTimer("4_Projectors", records): projector_dtype = next(model.state_projector.parameters()).dtype with torch.autocast('cuda', dtype=torch.bfloat16): cond_state_emb = model.state_projector( observation['observation.state'].to(dtype=projector_dtype)) cond_state_emb = cond_state_emb + model.agent_state_pos_emb cond_action_emb = model.action_projector( observation['action'].to(dtype=projector_dtype)) cond_action_emb = cond_action_emb + model.agent_action_pos_emb # Assemble cross-attention conditioning cond["c_crossattn"] = [ torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1) ] n_obs_acting = getattr(model, 'n_obs_steps_acting', 2) cond["c_crossattn_action"] = [ observation['observation.images.top'][:, :, -n_obs_acting:], observation['observation.state'][:, -n_obs_acting:], True, # sim_mode False, ] # CFG: build unconditional conditioning if needed uc = None if cfg_scale != 1.0: uc_crossattn = torch.zeros_like(cond["c_crossattn"][0]) uc = { "c_concat": cond["c_concat"], "c_crossattn": [uc_crossattn], "c_crossattn_action": cond["c_crossattn_action"], } # --- Stage 5: DDIM Loop --- ddim_sampler = DDIMSampler(model) hook_profiler.reset() with CudaTimer("5_DDIM_Loop", records): with torch.autocast('cuda', dtype=torch.bfloat16): samples, actions, states, _ = ddim_sampler.sample( S=ddim_steps, conditioning=cond, batch_size=B, shape=noise_shape[1:], verbose=False, unconditional_guidance_scale=cfg_scale, unconditional_conditioning=uc, eta=1.0, cfg_img=None, mask=None, x0=None, fs=fs, timestep_spacing='uniform', guidance_rescale=0.0, unconditional_conditioning_img_nonetext=None, ) hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect() # --- Stage 6: VAE Decode --- with CudaTimer("6_VAE_Decode", records): batch_images = model.decode_first_stage(samples) vae_dec_input_bytes = samples.nelement() * samples.element_size() vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size() # --- Stage 7: Post-process --- with CudaTimer("7_Post_Process", records): batch_images_cpu = batch_images.cpu() actions_cpu = actions.cpu() states_cpu = states.cpu() # Simulate video save overhead: clamp + uint8 conversion _ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8) # Flatten single-element lists stage_times = {k: v[0] for k, v in records.items()} bandwidth_info = { "vae_enc_input_bytes": vae_enc_input_bytes, "vae_enc_output_bytes": vae_enc_output_bytes, "vae_dec_input_bytes": vae_dec_input_bytes, "vae_dec_output_bytes": vae_dec_output_bytes, } return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info # --------------------------------------------------------------------------- # Reporting # --------------------------------------------------------------------------- def print_stage_timing(all_runs_stages): """Table 1: Stage Timing — name | mean(ms) | std | percent.""" import numpy as np stage_names = list(all_runs_stages[0].keys()) means = {} stds = {} for name in stage_names: vals = [run[name] for run in all_runs_stages] means[name] = np.mean(vals) stds[name] = np.std(vals) total = sum(means.values()) print() print("=" * 72) print("TABLE 1: STAGE TIMING") print("=" * 72) print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}") print("-" * 72) for name in stage_names: pct = means[name] / total * 100 if total > 0 else 0 print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%") print("-" * 72) print(f"{'TOTAL':<25} {total:>10.1f}") print() def print_unet_breakdown(all_runs_hooks): """Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent.""" import numpy as np # Aggregate across runs agg = defaultdict(lambda: {"totals": [], "counts": []}) for hook_by_type in all_runs_hooks: for tag, data in hook_by_type.items(): agg[tag]["totals"].append(data["total_ms"]) agg[tag]["counts"].append(data["count"]) print("=" * 80) print("TABLE 2: UNET SUB-MODULE BREAKDOWN") print("=" * 80) print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}") print("-" * 80) grand_total = 0 rows = [] for tag, d in agg.items(): mean_total = np.mean(d["totals"]) mean_count = np.mean(d["counts"]) per_call = mean_total / mean_count if mean_count > 0 else 0 grand_total += mean_total rows.append((tag, mean_total, mean_count, per_call)) rows.sort(key=lambda r: r[1], reverse=True) for tag, mean_total, mean_count, per_call in rows: pct = mean_total / grand_total * 100 if grand_total > 0 else 0 print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%") print("-" * 80) print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}") print() def print_block_timing(all_runs_blocks): """Table 2b: Per-UNet-block timing — which blocks are hottest.""" import numpy as np # Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}} agg = defaultdict(lambda: defaultdict(list)) for by_block in all_runs_blocks: for block_loc, tag_dict in by_block.items(): for tag, data in tag_dict.items(): agg[block_loc][tag].append(data["total_ms"]) # Compute per-block totals block_totals = {} for block_loc, tag_dict in agg.items(): block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values()) grand_total = sum(block_totals.values()) # Sort blocks in logical order def block_sort_key(name): if name.startswith("input_blocks."): return (0, int(name.split('.')[1])) elif name == "middle_block": return (1, 0) elif name.startswith("output_blocks."): return (2, int(name.split('.')[1])) elif name == "out": return (3, 0) elif name == "action_unet": return (4, 0) elif name == "state_unet": return (5, 0) return (9, 0) sorted_blocks = sorted(block_totals.keys(), key=block_sort_key) print("=" * 90) print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)") print("=" * 90) print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown") print("-" * 90) for block_loc in sorted_blocks: total = block_totals[block_loc] pct = total / grand_total * 100 if grand_total > 0 else 0 # Build breakdown string parts = [] for tag, vals in sorted(agg[block_loc].items(), key=lambda x: np.mean(x[1]), reverse=True): parts.append(f"{tag}={np.mean(vals):.0f}") breakdown = ", ".join(parts) print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}") print("-" * 90) print(f"{'TOTAL':<22} {grand_total:>10.1f}") print() def print_attn_ff_breakdown(all_runs_hooks): """Table 2c: CrossAttention vs FeedForward breakdown (--deep mode).""" import numpy as np agg = defaultdict(list) for hook_by_type in all_runs_hooks: for tag, data in hook_by_type.items(): if tag in ("CrossAttention", "FeedForward"): agg[tag].append(data["total_ms"]) if not agg: return print("=" * 70) print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)") print("=" * 70) print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}") print("-" * 70) grand = 0 rows = [] for tag in ("CrossAttention", "FeedForward"): if tag in agg: mean_t = np.mean(agg[tag]) grand += mean_t rows.append((tag, mean_t)) for tag, mean_t in rows: pct = mean_t / grand * 100 if grand > 0 else 0 print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%") print("-" * 70) print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}") print() def print_unet_detailed(all_runs_instances): """Print per-instance UNet sub-module detail (--detailed mode).""" import numpy as np # Use last run's data by_instance = all_runs_instances[-1] print("=" * 100) print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)") print("=" * 100) print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}") print("-" * 100) rows = [] for (tag, mod_name), times in by_instance.items(): if len(times) == 0: continue total = sum(times) mean = np.mean(times) rows.append((tag, mod_name, len(times), total, mean)) rows.sort(key=lambda r: r[3], reverse=True) for tag, mod_name, count, total, mean in rows: short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}") print() def print_memory_summary(mem_before, mem_peak): """Table 3: Memory Summary.""" delta = mem_peak - mem_before print("=" * 50) print("TABLE 3: MEMORY SUMMARY") print("=" * 50) print(f" Initial allocated: {mem_before / 1e9:.2f} GB") print(f" Peak allocated: {mem_peak / 1e9:.2f} GB") print(f" Delta (pipeline): {delta / 1e9:.2f} GB") print() def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale): """Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth.""" import numpy as np n_runs = len(all_runs_stages) # Total latency totals = [] for run in all_runs_stages: totals.append(sum(run.values())) mean_total = np.mean(totals) # DDIM loop time ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages] mean_ddim = np.mean(ddim_times) unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2 per_step = mean_ddim / ddim_steps per_unet = mean_ddim / unet_calls # VAE bandwidth mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages]) mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages]) bw = all_bw[-1] # use last run's byte counts enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"] dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"] enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0 dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0 print("=" * 60) print("TABLE 4: THROUGHPUT") print("=" * 60) print(f" Total pipeline latency: {mean_total:.1f} ms") print(f" DDIM loop latency: {mean_ddim:.1f} ms") print(f" DDIM steps: {ddim_steps}") print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})") print(f" UNet forward calls: {unet_calls}") print(f" Per DDIM step: {per_step:.1f} ms") print(f" Per UNet forward: {per_unet:.1f} ms") print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)") print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)") print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS") print() # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): patch_norm_bypass_autocast() parser = argparse.ArgumentParser( description="Profile the full inference pipeline") parser.add_argument("--ckpt_path", type=str, required=True) parser.add_argument("--config", type=str, required=True) parser.add_argument("--ddim_steps", type=int, default=50) parser.add_argument("--cfg_scale", type=float, default=1.0) parser.add_argument("--n_runs", type=int, default=3) parser.add_argument("--warmup", type=int, default=1) parser.add_argument("--detailed", action="store_true", help="Print per-instance UNet sub-module detail") parser.add_argument("--deep", action="store_true", help="Enable deep DDIM analysis: per-block, attn vs ff") args = parser.parse_args() noise_shape = [1, 4, 16, 40, 64] # --- Load model --- print("Loading model...") model = load_model(args) observation, prompts = build_dummy_inputs(model, noise_shape) # --- Setup hook profiler --- unet = model.model.diffusion_model hook_profiler = HookProfiler(unet, deep=args.deep) hook_profiler.register() print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules") # --- Warmup --- print(f"Warmup: {args.warmup} run(s)...") with torch.no_grad(): for i in range(args.warmup): run_pipeline(model, observation, prompts, noise_shape, args.ddim_steps, args.cfg_scale, hook_profiler) print(f" warmup {i+1}/{args.warmup} done") # --- Measurement runs --- print(f"Measuring: {args.n_runs} run(s)...") torch.cuda.reset_peak_memory_stats() mem_before = torch.cuda.memory_allocated() all_stages = [] all_hooks = [] all_instances = [] all_blocks = [] all_bw = [] with torch.no_grad(): for i in range(args.n_runs): stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline( model, observation, prompts, noise_shape, args.ddim_steps, args.cfg_scale, hook_profiler) all_stages.append(stage_times) all_hooks.append(hook_by_type) all_instances.append(hook_by_instance) all_blocks.append(hook_by_block) all_bw.append(bw) total = sum(stage_times.values()) print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total") mem_peak = torch.cuda.max_memory_allocated() # --- Reports --- print_stage_timing(all_stages) print_unet_breakdown(all_hooks) print_block_timing(all_blocks) if args.deep: print_attn_ff_breakdown(all_hooks) if args.detailed: print_unet_detailed(all_instances) print_memory_summary(mem_before, mem_peak) print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale) hook_profiler.remove() print("Done.") if __name__ == "__main__": main()