""" Profile one DDIM sampling iteration to capture all matmul/attention ops, their matrix sizes, wall time, and compute utilization. Uses torch.profiler for CUDA timing and FlopCounterMode for accurate FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS). Usage: TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \ python scripts/evaluation/profile_unet.py \ --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \ --config configs/inference/world_model_interaction.yaml """ import argparse import torch import torch.nn as nn from collections import OrderedDict, defaultdict from omegaconf import OmegaConf from torch.utils.flop_counter import FlopCounterMode from unifolm_wma.utils.utils import instantiate_from_config import torch.nn.functional as F def patch_norm_bypass_autocast(): """Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.""" def _group_norm_forward(self, x): with torch.amp.autocast('cuda', enabled=False): return F.group_norm( x, self.num_groups, self.weight.to(x.dtype) if self.weight is not None else None, self.bias.to(x.dtype) if self.bias is not None else None, self.eps) def _layer_norm_forward(self, x): with torch.amp.autocast('cuda', enabled=False): return F.layer_norm( x, self.normalized_shape, self.weight.to(x.dtype) if self.weight is not None else None, self.bias.to(x.dtype) if self.bias is not None else None, self.eps) torch.nn.GroupNorm.forward = _group_norm_forward torch.nn.LayerNorm.forward = _layer_norm_forward # --- W7900D theoretical peak (TFLOPS) --- PEAK_BF16_TFLOPS = 61.0 PEAK_FP32_TFLOPS = 30.5 def apply_torch_compile(model, hot_indices=(5, 8, 9)): """Compile ResBlock._forward in the hottest output_blocks for operator fusion.""" from unifolm_wma.modules.networks.wma_model import ResBlock unet = model.model.diffusion_model compiled = 0 for idx in hot_indices: block = unet.output_blocks[idx] for layer in block: if isinstance(layer, ResBlock): layer._forward = torch.compile(layer._forward, mode="default") compiled += 1 print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}") def load_model(args): config = OmegaConf.load(args.config) config['model']['params']['wma_config']['params']['use_checkpoint'] = False model = instantiate_from_config(config.model) state_dict = torch.load(args.ckpt_path, map_location="cpu") if "state_dict" in state_dict: state_dict = state_dict["state_dict"] model.load_state_dict(state_dict, strict=True) model.eval() model.model.to(torch.bfloat16) apply_torch_compile(model) model = model.cuda() return model def build_call_kwargs(model, noise_shape): """Build dummy inputs matching the hybrid conditioning forward signature.""" device = next(model.parameters()).device B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64] dtype = torch.bfloat16 x_action = torch.randn(B, 16, 16, device=device, dtype=dtype) x_state = torch.randn(B, 16, 16, device=device, dtype=dtype) timesteps = torch.tensor([500], device=device, dtype=torch.long) context = torch.randn(B, 351, 1024, device=device, dtype=dtype) obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype) obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype) context_action = [obs_images, obs_state, True, False] fps = torch.tensor([1], device=device, dtype=torch.long) x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype) c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)] return dict( x=x_raw, x_action=x_action, x_state=x_state, t=timesteps, c_concat=c_concat, c_crossattn=[context], c_crossattn_action=context_action, s=fps, ) def profile_one_step(model, noise_shape): """Run one UNet forward pass under torch.profiler for CUDA timing.""" diff_wrapper = model.model call_kwargs = build_call_kwargs(model, noise_shape) with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): # Warmup for _ in range(2): _ = diff_wrapper(**call_kwargs) torch.cuda.synchronize() with torch.profiler.profile( activities=[torch.profiler.ProfilerActivity.CUDA], record_shapes=True, with_flops=True, ) as prof: _ = diff_wrapper(**call_kwargs) torch.cuda.synchronize() return prof def count_flops(model, noise_shape): """Run one UNet forward pass under FlopCounterMode for accurate FLOPS.""" diff_wrapper = model.model call_kwargs = build_call_kwargs(model, noise_shape) with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16): flop_counter = FlopCounterMode(display=False) with flop_counter: _ = diff_wrapper(**call_kwargs) torch.cuda.synchronize() return flop_counter def print_report(prof, flop_counter): """Parse profiler results and print a structured report with accurate FLOPS.""" events = prof.key_averages() # --- Extract per-operator FLOPS from FlopCounterMode --- # flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting flop_by_op = {} flop_by_module = {} if hasattr(flop_counter, 'flop_counts'): # Per-op: only from top-level "Global" entry (no parent/child duplication) global_ops = flop_counter.flop_counts.get("Global", {}) for op_name, flop_count in global_ops.items(): key = str(op_name).split('.')[-1] flop_by_op[key] = flop_by_op.get(key, 0) + flop_count # Per-module: collect all, skip "Global" and top-level wrapper duplicates for module_name, op_dict in flop_counter.flop_counts.items(): module_total = sum(op_dict.values()) if module_total > 0: flop_by_module[module_name] = module_total total_counted_flops = flop_counter.get_total_flops() # Collect matmul-like ops matmul_ops = [] other_ops = [] for evt in events: if evt.device_time_total <= 0: continue name = evt.key is_matmul = any(k in name.lower() for k in ['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear']) entry = { 'name': name, 'input_shapes': str(evt.input_shapes) if evt.input_shapes else '', 'cuda_time_ms': evt.device_time_total / 1000.0, 'count': evt.count, 'flops': evt.flops if evt.flops else 0, } if is_matmul: matmul_ops.append(entry) else: other_ops.append(entry) # Sort by CUDA time matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True) other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True) total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops) total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops) # --- Print matmul ops --- print("=" * 130) print("MATMUL / LINEAR OPS (sorted by CUDA time)") print("=" * 130) print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes") print("-" * 130) for op in matmul_ops: shapes_str = op['input_shapes'][:60] if op['input_shapes'] else '' print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}") # --- Print top non-matmul ops --- print() print("=" * 130) print("TOP NON-MATMUL OPS (sorted by CUDA time)") print("=" * 130) print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes") print("-" * 130) for op in other_ops[:20]: shapes_str = op['input_shapes'][:60] if op['input_shapes'] else '' print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}") # --- FlopCounterMode per-operator breakdown --- print() print("=" * 130) print("FLOPS BY ATen OPERATOR (FlopCounterMode)") print("=" * 130) print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}") print("-" * 55) sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True) for op_name, flops in sorted_flop_ops: gflops = flops / 1e9 pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0 print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%") # --- FlopCounterMode per-module breakdown --- if flop_by_module: print() print("=" * 130) print("FLOPS BY MODULE (FlopCounterMode)") print("=" * 130) print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}") print("-" * 90) sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True) for mod_name, flops in sorted_modules[:30]: gflops = flops / 1e9 pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0 name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%") # --- Summary --- print() print("=" * 130) print("SUMMARY") print("=" * 130) print(f" Total CUDA time: {total_cuda_ms:.1f} ms") print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)") print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)") print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS") if total_matmul_ms > 0 and total_counted_flops > 0: avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12 avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100 overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12 overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100 print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)") print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)") print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS") if __name__ == '__main__': patch_norm_bypass_autocast() parser = argparse.ArgumentParser() parser.add_argument("--ckpt_path", type=str, required=True) parser.add_argument("--config", type=str, required=True) args = parser.parse_args() print("Loading model...") model = load_model(args) noise_shape = [1, 4, 16, 40, 64] print(f"Profiling UNet forward pass with shape {noise_shape}...") prof = profile_one_step(model, noise_shape) print("Counting FLOPS with FlopCounterMode...") flop_counter = count_flops(model, noise_shape) print_report(prof, flop_counter)