Files
unifolm-world-model-action/scripts/evaluation/profile_pipeline.py
olivame a2cd34dd51 1. einsum('b i d, b j d -> b i j') → torch.bmm(q, k.transpose(-1,-2)) — 直接映射 rocBLAS batched GEMM
2. baddbmm 把 scale 融合进 GEMM,少一次 kernel launch
3. 第二个 einsum 同理换torch.bm
每一轮加速1到两秒
2026-02-08 18:54:48 +00:00

734 lines
28 KiB
Python

"""
Profile the full inference pipeline of the world model, covering all 7 stages:
1. Image Embedding
2. VAE Encode
3. Text Conditioning
4. State/Action Projectors
5. DDIM Loop
6. VAE Decode
7. Post-process
Reports stage-level timing, UNet sub-module breakdown, memory summary,
and throughput analysis.
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep
Usage:
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3
"""
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common
from contextlib import nullcontext, contextmanager
from collections import defaultdict
from omegaconf import OmegaConf
from einops import rearrange, repeat
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.modules.attention import (
SpatialTransformer, TemporalTransformer,
BasicTransformerBlock, CrossAttention, FeedForward,
)
from unifolm_wma.modules.networks.wma_model import ResBlock
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
# --- W7900D theoretical peak ---
PEAK_BF16_TFLOPS = 61.0
MEM_BW_GBS = 864.0
# ---------------------------------------------------------------------------
# Utility: patch norms to bypass autocast fp32 promotion
# ---------------------------------------------------------------------------
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
# ---------------------------------------------------------------------------
# Utility: torch.compile hot ResBlocks
# ---------------------------------------------------------------------------
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=True)
model.eval()
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
apply_torch_compile(model)
model = model.cuda()
return model
# ---------------------------------------------------------------------------
# CudaTimer — precise GPU timing via CUDA events
# ---------------------------------------------------------------------------
class CudaTimer:
"""Context manager for GPU-precise stage timing using CUDA events."""
def __init__(self, name, records):
self.name = name
self.records = records
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
def __enter__(self):
torch.cuda.synchronize()
self.start.record()
return self
def __exit__(self, *args):
self.end.record()
torch.cuda.synchronize()
elapsed = self.start.elapsed_time(self.end)
self.records[self.name].append(elapsed)
# ---------------------------------------------------------------------------
# HookProfiler — sub-module level timing inside UNet via hooks
# ---------------------------------------------------------------------------
class HookProfiler:
"""Register forward hooks on UNet sub-modules to collect per-call timing."""
# Coarse-grained targets (original)
COARSE_CLASSES = (
SpatialTransformer,
TemporalTransformer,
ResBlock,
ConditionalUnet1D,
)
# Fine-grained targets for deep DDIM analysis
FINE_CLASSES = (
CrossAttention,
FeedForward,
)
def __init__(self, unet, deep=False):
self.unet = unet
self.deep = deep
self.handles = []
# per-instance data: {instance_id: [(start_event, end_event), ...]}
self._events = defaultdict(list)
# tag mapping: {instance_id: (class_name, module_name)}
self._tags = {}
# block location: {instance_id: block_location_str}
self._block_loc = {}
@staticmethod
def _get_block_location(name):
"""Derive UNet block location from module name, e.g. 'input_blocks.3.1'."""
parts = name.split('.')
if len(parts) >= 2 and parts[0] == 'input_blocks':
return f"input_blocks.{parts[1]}"
elif len(parts) >= 1 and parts[0] == 'middle_block':
return "middle_block"
elif len(parts) >= 2 and parts[0] == 'output_blocks':
return f"output_blocks.{parts[1]}"
elif 'action_unet' in name:
return "action_unet"
elif 'state_unet' in name:
return "state_unet"
elif name == 'out' or name.startswith('out.'):
return "out"
return "other"
def register(self):
"""Attach pre/post forward hooks to target sub-modules + unet.out."""
target_classes = self.COARSE_CLASSES
if self.deep:
target_classes = target_classes + self.FINE_CLASSES
for name, mod in self.unet.named_modules():
if isinstance(mod, target_classes):
tag = type(mod).__name__
inst_id = id(mod)
self._tags[inst_id] = (tag, name)
self._block_loc[inst_id] = self._get_block_location(name)
self.handles.append(
mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
self.handles.append(
mod.register_forward_hook(self._make_post_hook(inst_id)))
# Also hook unet.out (nn.Sequential)
out_mod = self.unet.out
inst_id = id(out_mod)
self._tags[inst_id] = ("UNet.out", "out")
self._block_loc[inst_id] = "out"
self.handles.append(
out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
self.handles.append(
out_mod.register_forward_hook(self._make_post_hook(inst_id)))
def _make_pre_hook(self, inst_id):
events = self._events
def hook(module, input):
start = torch.cuda.Event(enable_timing=True)
start.record()
events[inst_id].append([start, None])
return hook
def _make_post_hook(self, inst_id):
events = self._events
def hook(module, input, output):
end = torch.cuda.Event(enable_timing=True)
end.record()
events[inst_id][-1][1] = end
return hook
def reset(self):
"""Clear collected events for a fresh run."""
self._events.clear()
def synchronize_and_collect(self):
"""Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block)."""
torch.cuda.synchronize()
by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []})
by_instance = {}
# by_block: {block_loc: {tag: {"total_ms", "count"}}}
by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0}))
for inst_id, pairs in self._events.items():
tag, mod_name = self._tags[inst_id]
block_loc = self._block_loc.get(inst_id, "other")
inst_times = []
for start_evt, end_evt in pairs:
if end_evt is not None:
ms = start_evt.elapsed_time(end_evt)
inst_times.append(ms)
by_type[tag]["total_ms"] += ms
by_type[tag]["count"] += 1
by_type[tag]["calls"].append(ms)
by_block[block_loc][tag]["total_ms"] += ms
by_block[block_loc][tag]["count"] += 1
by_instance[(tag, mod_name)] = inst_times
return dict(by_type), by_instance, dict(by_block)
def remove(self):
"""Remove all hooks."""
for h in self.handles:
h.remove()
self.handles.clear()
# ---------------------------------------------------------------------------
# Build dummy inputs matching the pipeline's expected shapes
# ---------------------------------------------------------------------------
def build_dummy_inputs(model, noise_shape):
"""Create synthetic observation dict and prompts for profiling."""
device = next(model.parameters()).device
B, C, T, H, W = noise_shape
dtype = torch.bfloat16
# observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline)
O = 2
obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype)
obs_state = torch.randn(B, O, 16, device=device, dtype=dtype)
action = torch.randn(B, 16, 16, device=device, dtype=dtype)
observation = {
'observation.images.top': obs_images,
'observation.state': obs_state,
'action': action,
}
prompts = ["a robot arm performing a task"] * B
return observation, prompts
# ---------------------------------------------------------------------------
# Run one full pipeline pass with per-stage timing
# ---------------------------------------------------------------------------
def run_pipeline(model, observation, prompts, noise_shape, ddim_steps,
cfg_scale, hook_profiler):
"""Execute the full 7-stage pipeline, returning per-stage timing dict."""
records = defaultdict(list)
device = next(model.parameters()).device
B, C, T, H, W = noise_shape
dtype = torch.bfloat16
fs = torch.tensor([1] * B, dtype=torch.long, device=device)
# --- Stage 1: Image Embedding ---
with CudaTimer("1_Image_Embedding", records):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype)
with torch.autocast('cuda', dtype=torch.bfloat16):
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
# --- Stage 2: VAE Encode ---
with CudaTimer("2_VAE_Encode", records):
videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W]
b_v, c_v, t_v, h_v, w_v = videos.shape
vae_dtype = next(model.first_stage_model.parameters()).dtype
x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype)
z = model.encode_first_stage(x_vae)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v)
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w', repeat=T)
cond = {"c_concat": [img_cat_cond]}
vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size()
vae_enc_output_bytes = z.nelement() * z.element_size()
# --- Stage 3: Text Conditioning ---
with CudaTimer("3_Text_Conditioning", records):
cond_ins_emb = model.get_learned_conditioning(prompts)
# --- Stage 4: State/Action Projectors ---
with CudaTimer("4_Projectors", records):
projector_dtype = next(model.state_projector.parameters()).dtype
with torch.autocast('cuda', dtype=torch.bfloat16):
cond_state_emb = model.state_projector(
observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(
observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
# Assemble cross-attention conditioning
cond["c_crossattn"] = [
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
n_obs_acting = getattr(model, 'n_obs_steps_acting', 2)
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :, -n_obs_acting:],
observation['observation.state'][:, -n_obs_acting:],
True, # sim_mode
False,
]
# CFG: build unconditional conditioning if needed
uc = None
if cfg_scale != 1.0:
uc_crossattn = torch.zeros_like(cond["c_crossattn"][0])
uc = {
"c_concat": cond["c_concat"],
"c_crossattn": [uc_crossattn],
"c_crossattn_action": cond["c_crossattn_action"],
}
# --- Stage 5: DDIM Loop ---
ddim_sampler = DDIMSampler(model)
hook_profiler.reset()
with CudaTimer("5_DDIM_Loop", records):
with torch.autocast('cuda', dtype=torch.bfloat16):
samples, actions, states, _ = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=B,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=1.0,
cfg_img=None,
mask=None,
x0=None,
fs=fs,
timestep_spacing='uniform',
guidance_rescale=0.0,
unconditional_conditioning_img_nonetext=None,
)
hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect()
# --- Stage 6: VAE Decode ---
with CudaTimer("6_VAE_Decode", records):
batch_images = model.decode_first_stage(samples)
vae_dec_input_bytes = samples.nelement() * samples.element_size()
vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size()
# --- Stage 7: Post-process ---
with CudaTimer("7_Post_Process", records):
batch_images_cpu = batch_images.cpu()
actions_cpu = actions.cpu()
states_cpu = states.cpu()
# Simulate video save overhead: clamp + uint8 conversion
_ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8)
# Flatten single-element lists
stage_times = {k: v[0] for k, v in records.items()}
bandwidth_info = {
"vae_enc_input_bytes": vae_enc_input_bytes,
"vae_enc_output_bytes": vae_enc_output_bytes,
"vae_dec_input_bytes": vae_dec_input_bytes,
"vae_dec_output_bytes": vae_dec_output_bytes,
}
return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info
# ---------------------------------------------------------------------------
# Reporting
# ---------------------------------------------------------------------------
def print_stage_timing(all_runs_stages):
"""Table 1: Stage Timing — name | mean(ms) | std | percent."""
import numpy as np
stage_names = list(all_runs_stages[0].keys())
means = {}
stds = {}
for name in stage_names:
vals = [run[name] for run in all_runs_stages]
means[name] = np.mean(vals)
stds[name] = np.std(vals)
total = sum(means.values())
print()
print("=" * 72)
print("TABLE 1: STAGE TIMING")
print("=" * 72)
print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}")
print("-" * 72)
for name in stage_names:
pct = means[name] / total * 100 if total > 0 else 0
print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%")
print("-" * 72)
print(f"{'TOTAL':<25} {total:>10.1f}")
print()
def print_unet_breakdown(all_runs_hooks):
"""Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent."""
import numpy as np
# Aggregate across runs
agg = defaultdict(lambda: {"totals": [], "counts": []})
for hook_by_type in all_runs_hooks:
for tag, data in hook_by_type.items():
agg[tag]["totals"].append(data["total_ms"])
agg[tag]["counts"].append(data["count"])
print("=" * 80)
print("TABLE 2: UNET SUB-MODULE BREAKDOWN")
print("=" * 80)
print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}")
print("-" * 80)
grand_total = 0
rows = []
for tag, d in agg.items():
mean_total = np.mean(d["totals"])
mean_count = np.mean(d["counts"])
per_call = mean_total / mean_count if mean_count > 0 else 0
grand_total += mean_total
rows.append((tag, mean_total, mean_count, per_call))
rows.sort(key=lambda r: r[1], reverse=True)
for tag, mean_total, mean_count, per_call in rows:
pct = mean_total / grand_total * 100 if grand_total > 0 else 0
print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%")
print("-" * 80)
print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}")
print()
def print_block_timing(all_runs_blocks):
"""Table 2b: Per-UNet-block timing — which blocks are hottest."""
import numpy as np
# Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}}
agg = defaultdict(lambda: defaultdict(list))
for by_block in all_runs_blocks:
for block_loc, tag_dict in by_block.items():
for tag, data in tag_dict.items():
agg[block_loc][tag].append(data["total_ms"])
# Compute per-block totals
block_totals = {}
for block_loc, tag_dict in agg.items():
block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values())
grand_total = sum(block_totals.values())
# Sort blocks in logical order
def block_sort_key(name):
if name.startswith("input_blocks."):
return (0, int(name.split('.')[1]))
elif name == "middle_block":
return (1, 0)
elif name.startswith("output_blocks."):
return (2, int(name.split('.')[1]))
elif name == "out":
return (3, 0)
elif name == "action_unet":
return (4, 0)
elif name == "state_unet":
return (5, 0)
return (9, 0)
sorted_blocks = sorted(block_totals.keys(), key=block_sort_key)
print("=" * 90)
print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)")
print("=" * 90)
print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown")
print("-" * 90)
for block_loc in sorted_blocks:
total = block_totals[block_loc]
pct = total / grand_total * 100 if grand_total > 0 else 0
# Build breakdown string
parts = []
for tag, vals in sorted(agg[block_loc].items(),
key=lambda x: np.mean(x[1]), reverse=True):
parts.append(f"{tag}={np.mean(vals):.0f}")
breakdown = ", ".join(parts)
print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}")
print("-" * 90)
print(f"{'TOTAL':<22} {grand_total:>10.1f}")
print()
def print_attn_ff_breakdown(all_runs_hooks):
"""Table 2c: CrossAttention vs FeedForward breakdown (--deep mode)."""
import numpy as np
agg = defaultdict(list)
for hook_by_type in all_runs_hooks:
for tag, data in hook_by_type.items():
if tag in ("CrossAttention", "FeedForward"):
agg[tag].append(data["total_ms"])
if not agg:
return
print("=" * 70)
print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)")
print("=" * 70)
print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}")
print("-" * 70)
grand = 0
rows = []
for tag in ("CrossAttention", "FeedForward"):
if tag in agg:
mean_t = np.mean(agg[tag])
grand += mean_t
rows.append((tag, mean_t))
for tag, mean_t in rows:
pct = mean_t / grand * 100 if grand > 0 else 0
print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%")
print("-" * 70)
print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}")
print()
def print_unet_detailed(all_runs_instances):
"""Print per-instance UNet sub-module detail (--detailed mode)."""
import numpy as np
# Use last run's data
by_instance = all_runs_instances[-1]
print("=" * 100)
print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)")
print("=" * 100)
print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}")
print("-" * 100)
rows = []
for (tag, mod_name), times in by_instance.items():
if len(times) == 0:
continue
total = sum(times)
mean = np.mean(times)
rows.append((tag, mod_name, len(times), total, mean))
rows.sort(key=lambda r: r[3], reverse=True)
for tag, mod_name, count, total, mean in rows:
short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name
print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}")
print()
def print_memory_summary(mem_before, mem_peak):
"""Table 3: Memory Summary."""
delta = mem_peak - mem_before
print("=" * 50)
print("TABLE 3: MEMORY SUMMARY")
print("=" * 50)
print(f" Initial allocated: {mem_before / 1e9:.2f} GB")
print(f" Peak allocated: {mem_peak / 1e9:.2f} GB")
print(f" Delta (pipeline): {delta / 1e9:.2f} GB")
print()
def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale):
"""Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth."""
import numpy as np
n_runs = len(all_runs_stages)
# Total latency
totals = []
for run in all_runs_stages:
totals.append(sum(run.values()))
mean_total = np.mean(totals)
# DDIM loop time
ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages]
mean_ddim = np.mean(ddim_times)
unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2
per_step = mean_ddim / ddim_steps
per_unet = mean_ddim / unet_calls
# VAE bandwidth
mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages])
mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages])
bw = all_bw[-1] # use last run's byte counts
enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"]
dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"]
enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0
dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0
print("=" * 60)
print("TABLE 4: THROUGHPUT")
print("=" * 60)
print(f" Total pipeline latency: {mean_total:.1f} ms")
print(f" DDIM loop latency: {mean_ddim:.1f} ms")
print(f" DDIM steps: {ddim_steps}")
print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})")
print(f" UNet forward calls: {unet_calls}")
print(f" Per DDIM step: {per_step:.1f} ms")
print(f" Per UNet forward: {per_unet:.1f} ms")
print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS")
print()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
patch_norm_bypass_autocast()
parser = argparse.ArgumentParser(
description="Profile the full inference pipeline")
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--ddim_steps", type=int, default=50)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--n_runs", type=int, default=3)
parser.add_argument("--warmup", type=int, default=1)
parser.add_argument("--detailed", action="store_true",
help="Print per-instance UNet sub-module detail")
parser.add_argument("--deep", action="store_true",
help="Enable deep DDIM analysis: per-block, attn vs ff")
args = parser.parse_args()
noise_shape = [1, 4, 16, 40, 64]
# --- Load model ---
print("Loading model...")
model = load_model(args)
observation, prompts = build_dummy_inputs(model, noise_shape)
# --- Setup hook profiler ---
unet = model.model.diffusion_model
hook_profiler = HookProfiler(unet, deep=args.deep)
hook_profiler.register()
print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules")
# --- Warmup ---
print(f"Warmup: {args.warmup} run(s)...")
with torch.no_grad():
for i in range(args.warmup):
run_pipeline(model, observation, prompts, noise_shape,
args.ddim_steps, args.cfg_scale, hook_profiler)
print(f" warmup {i+1}/{args.warmup} done")
# --- Measurement runs ---
print(f"Measuring: {args.n_runs} run(s)...")
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
all_stages = []
all_hooks = []
all_instances = []
all_blocks = []
all_bw = []
with torch.no_grad():
for i in range(args.n_runs):
stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline(
model, observation, prompts, noise_shape,
args.ddim_steps, args.cfg_scale, hook_profiler)
all_stages.append(stage_times)
all_hooks.append(hook_by_type)
all_instances.append(hook_by_instance)
all_blocks.append(hook_by_block)
all_bw.append(bw)
total = sum(stage_times.values())
print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total")
mem_peak = torch.cuda.max_memory_allocated()
# --- Reports ---
print_stage_timing(all_stages)
print_unet_breakdown(all_hooks)
print_block_timing(all_blocks)
if args.deep:
print_attn_ff_breakdown(all_hooks)
if args.detailed:
print_unet_detailed(all_instances)
print_memory_summary(mem_before, mem_peak)
print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale)
hook_profiler.remove()
print("Done.")
if __name__ == "__main__":
main()