2. baddbmm 把 scale 融合进 GEMM,少一次 kernel launch 3. 第二个 einsum 同理换torch.bm 每一轮加速1到两秒
734 lines
28 KiB
Python
734 lines
28 KiB
Python
"""
|
|
Profile the full inference pipeline of the world model, covering all 7 stages:
|
|
1. Image Embedding
|
|
2. VAE Encode
|
|
3. Text Conditioning
|
|
4. State/Action Projectors
|
|
5. DDIM Loop
|
|
6. VAE Decode
|
|
7. Post-process
|
|
|
|
Reports stage-level timing, UNet sub-module breakdown, memory summary,
|
|
and throughput analysis.
|
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep
|
|
Usage:
|
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common
|
|
from contextlib import nullcontext, contextmanager
|
|
from collections import defaultdict
|
|
from omegaconf import OmegaConf
|
|
from einops import rearrange, repeat
|
|
|
|
from unifolm_wma.utils.utils import instantiate_from_config
|
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
|
from unifolm_wma.modules.attention import (
|
|
SpatialTransformer, TemporalTransformer,
|
|
BasicTransformerBlock, CrossAttention, FeedForward,
|
|
)
|
|
from unifolm_wma.modules.networks.wma_model import ResBlock
|
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
|
|
|
# --- W7900D theoretical peak ---
|
|
PEAK_BF16_TFLOPS = 61.0
|
|
MEM_BW_GBS = 864.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Utility: patch norms to bypass autocast fp32 promotion
|
|
# ---------------------------------------------------------------------------
|
|
def patch_norm_bypass_autocast():
|
|
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
|
|
|
def _group_norm_forward(self, x):
|
|
with torch.amp.autocast('cuda', enabled=False):
|
|
return F.group_norm(
|
|
x, self.num_groups,
|
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
self.eps)
|
|
|
|
def _layer_norm_forward(self, x):
|
|
with torch.amp.autocast('cuda', enabled=False):
|
|
return F.layer_norm(
|
|
x, self.normalized_shape,
|
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
self.eps)
|
|
|
|
torch.nn.GroupNorm.forward = _group_norm_forward
|
|
torch.nn.LayerNorm.forward = _layer_norm_forward
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Utility: torch.compile hot ResBlocks
|
|
# ---------------------------------------------------------------------------
|
|
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
|
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
|
unet = model.model.diffusion_model
|
|
compiled = 0
|
|
for idx in hot_indices:
|
|
block = unet.output_blocks[idx]
|
|
for layer in block:
|
|
if isinstance(layer, ResBlock):
|
|
layer._forward = torch.compile(layer._forward, mode="default")
|
|
compiled += 1
|
|
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model loading
|
|
# ---------------------------------------------------------------------------
|
|
def load_model(args):
|
|
config = OmegaConf.load(args.config)
|
|
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
|
model = instantiate_from_config(config.model)
|
|
|
|
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
|
if "state_dict" in state_dict:
|
|
state_dict = state_dict["state_dict"]
|
|
model.load_state_dict(state_dict, strict=True)
|
|
|
|
model.eval()
|
|
model.model.to(torch.bfloat16)
|
|
model.diffusion_autocast_dtype = torch.bfloat16
|
|
apply_torch_compile(model)
|
|
model = model.cuda()
|
|
return model
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CudaTimer — precise GPU timing via CUDA events
|
|
# ---------------------------------------------------------------------------
|
|
class CudaTimer:
|
|
"""Context manager for GPU-precise stage timing using CUDA events."""
|
|
|
|
def __init__(self, name, records):
|
|
self.name = name
|
|
self.records = records
|
|
self.start = torch.cuda.Event(enable_timing=True)
|
|
self.end = torch.cuda.Event(enable_timing=True)
|
|
|
|
def __enter__(self):
|
|
torch.cuda.synchronize()
|
|
self.start.record()
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
self.end.record()
|
|
torch.cuda.synchronize()
|
|
elapsed = self.start.elapsed_time(self.end)
|
|
self.records[self.name].append(elapsed)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# HookProfiler — sub-module level timing inside UNet via hooks
|
|
# ---------------------------------------------------------------------------
|
|
class HookProfiler:
|
|
"""Register forward hooks on UNet sub-modules to collect per-call timing."""
|
|
|
|
# Coarse-grained targets (original)
|
|
COARSE_CLASSES = (
|
|
SpatialTransformer,
|
|
TemporalTransformer,
|
|
ResBlock,
|
|
ConditionalUnet1D,
|
|
)
|
|
|
|
# Fine-grained targets for deep DDIM analysis
|
|
FINE_CLASSES = (
|
|
CrossAttention,
|
|
FeedForward,
|
|
)
|
|
|
|
def __init__(self, unet, deep=False):
|
|
self.unet = unet
|
|
self.deep = deep
|
|
self.handles = []
|
|
# per-instance data: {instance_id: [(start_event, end_event), ...]}
|
|
self._events = defaultdict(list)
|
|
# tag mapping: {instance_id: (class_name, module_name)}
|
|
self._tags = {}
|
|
# block location: {instance_id: block_location_str}
|
|
self._block_loc = {}
|
|
|
|
@staticmethod
|
|
def _get_block_location(name):
|
|
"""Derive UNet block location from module name, e.g. 'input_blocks.3.1'."""
|
|
parts = name.split('.')
|
|
if len(parts) >= 2 and parts[0] == 'input_blocks':
|
|
return f"input_blocks.{parts[1]}"
|
|
elif len(parts) >= 1 and parts[0] == 'middle_block':
|
|
return "middle_block"
|
|
elif len(parts) >= 2 and parts[0] == 'output_blocks':
|
|
return f"output_blocks.{parts[1]}"
|
|
elif 'action_unet' in name:
|
|
return "action_unet"
|
|
elif 'state_unet' in name:
|
|
return "state_unet"
|
|
elif name == 'out' or name.startswith('out.'):
|
|
return "out"
|
|
return "other"
|
|
|
|
def register(self):
|
|
"""Attach pre/post forward hooks to target sub-modules + unet.out."""
|
|
target_classes = self.COARSE_CLASSES
|
|
if self.deep:
|
|
target_classes = target_classes + self.FINE_CLASSES
|
|
|
|
for name, mod in self.unet.named_modules():
|
|
if isinstance(mod, target_classes):
|
|
tag = type(mod).__name__
|
|
inst_id = id(mod)
|
|
self._tags[inst_id] = (tag, name)
|
|
self._block_loc[inst_id] = self._get_block_location(name)
|
|
self.handles.append(
|
|
mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
|
|
self.handles.append(
|
|
mod.register_forward_hook(self._make_post_hook(inst_id)))
|
|
|
|
# Also hook unet.out (nn.Sequential)
|
|
out_mod = self.unet.out
|
|
inst_id = id(out_mod)
|
|
self._tags[inst_id] = ("UNet.out", "out")
|
|
self._block_loc[inst_id] = "out"
|
|
self.handles.append(
|
|
out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
|
|
self.handles.append(
|
|
out_mod.register_forward_hook(self._make_post_hook(inst_id)))
|
|
|
|
def _make_pre_hook(self, inst_id):
|
|
events = self._events
|
|
|
|
def hook(module, input):
|
|
start = torch.cuda.Event(enable_timing=True)
|
|
start.record()
|
|
events[inst_id].append([start, None])
|
|
return hook
|
|
|
|
def _make_post_hook(self, inst_id):
|
|
events = self._events
|
|
|
|
def hook(module, input, output):
|
|
end = torch.cuda.Event(enable_timing=True)
|
|
end.record()
|
|
events[inst_id][-1][1] = end
|
|
return hook
|
|
|
|
def reset(self):
|
|
"""Clear collected events for a fresh run."""
|
|
self._events.clear()
|
|
|
|
def synchronize_and_collect(self):
|
|
"""Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block)."""
|
|
torch.cuda.synchronize()
|
|
by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []})
|
|
by_instance = {}
|
|
# by_block: {block_loc: {tag: {"total_ms", "count"}}}
|
|
by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0}))
|
|
|
|
for inst_id, pairs in self._events.items():
|
|
tag, mod_name = self._tags[inst_id]
|
|
block_loc = self._block_loc.get(inst_id, "other")
|
|
inst_times = []
|
|
for start_evt, end_evt in pairs:
|
|
if end_evt is not None:
|
|
ms = start_evt.elapsed_time(end_evt)
|
|
inst_times.append(ms)
|
|
by_type[tag]["total_ms"] += ms
|
|
by_type[tag]["count"] += 1
|
|
by_type[tag]["calls"].append(ms)
|
|
by_block[block_loc][tag]["total_ms"] += ms
|
|
by_block[block_loc][tag]["count"] += 1
|
|
by_instance[(tag, mod_name)] = inst_times
|
|
|
|
return dict(by_type), by_instance, dict(by_block)
|
|
|
|
def remove(self):
|
|
"""Remove all hooks."""
|
|
for h in self.handles:
|
|
h.remove()
|
|
self.handles.clear()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Build dummy inputs matching the pipeline's expected shapes
|
|
# ---------------------------------------------------------------------------
|
|
def build_dummy_inputs(model, noise_shape):
|
|
"""Create synthetic observation dict and prompts for profiling."""
|
|
device = next(model.parameters()).device
|
|
B, C, T, H, W = noise_shape
|
|
dtype = torch.bfloat16
|
|
|
|
# observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline)
|
|
O = 2
|
|
obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype)
|
|
obs_state = torch.randn(B, O, 16, device=device, dtype=dtype)
|
|
action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
|
|
|
observation = {
|
|
'observation.images.top': obs_images,
|
|
'observation.state': obs_state,
|
|
'action': action,
|
|
}
|
|
prompts = ["a robot arm performing a task"] * B
|
|
return observation, prompts
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Run one full pipeline pass with per-stage timing
|
|
# ---------------------------------------------------------------------------
|
|
def run_pipeline(model, observation, prompts, noise_shape, ddim_steps,
|
|
cfg_scale, hook_profiler):
|
|
"""Execute the full 7-stage pipeline, returning per-stage timing dict."""
|
|
records = defaultdict(list)
|
|
device = next(model.parameters()).device
|
|
B, C, T, H, W = noise_shape
|
|
dtype = torch.bfloat16
|
|
|
|
fs = torch.tensor([1] * B, dtype=torch.long, device=device)
|
|
|
|
# --- Stage 1: Image Embedding ---
|
|
with CudaTimer("1_Image_Embedding", records):
|
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype)
|
|
with torch.autocast('cuda', dtype=torch.bfloat16):
|
|
cond_img_emb = model.embedder(cond_img)
|
|
cond_img_emb = model.image_proj_model(cond_img_emb)
|
|
|
|
# --- Stage 2: VAE Encode ---
|
|
with CudaTimer("2_VAE_Encode", records):
|
|
videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W]
|
|
b_v, c_v, t_v, h_v, w_v = videos.shape
|
|
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
|
x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype)
|
|
z = model.encode_first_stage(x_vae)
|
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v)
|
|
img_cat_cond = z[:, :, -1:, :, :]
|
|
img_cat_cond = repeat(img_cat_cond,
|
|
'b c t h w -> b c (repeat t) h w', repeat=T)
|
|
cond = {"c_concat": [img_cat_cond]}
|
|
|
|
vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size()
|
|
vae_enc_output_bytes = z.nelement() * z.element_size()
|
|
|
|
# --- Stage 3: Text Conditioning ---
|
|
with CudaTimer("3_Text_Conditioning", records):
|
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
|
|
|
# --- Stage 4: State/Action Projectors ---
|
|
with CudaTimer("4_Projectors", records):
|
|
projector_dtype = next(model.state_projector.parameters()).dtype
|
|
with torch.autocast('cuda', dtype=torch.bfloat16):
|
|
cond_state_emb = model.state_projector(
|
|
observation['observation.state'].to(dtype=projector_dtype))
|
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
|
|
|
cond_action_emb = model.action_projector(
|
|
observation['action'].to(dtype=projector_dtype))
|
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
|
|
|
# Assemble cross-attention conditioning
|
|
cond["c_crossattn"] = [
|
|
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
|
|
dim=1)
|
|
]
|
|
n_obs_acting = getattr(model, 'n_obs_steps_acting', 2)
|
|
cond["c_crossattn_action"] = [
|
|
observation['observation.images.top'][:, :, -n_obs_acting:],
|
|
observation['observation.state'][:, -n_obs_acting:],
|
|
True, # sim_mode
|
|
False,
|
|
]
|
|
|
|
# CFG: build unconditional conditioning if needed
|
|
uc = None
|
|
if cfg_scale != 1.0:
|
|
uc_crossattn = torch.zeros_like(cond["c_crossattn"][0])
|
|
uc = {
|
|
"c_concat": cond["c_concat"],
|
|
"c_crossattn": [uc_crossattn],
|
|
"c_crossattn_action": cond["c_crossattn_action"],
|
|
}
|
|
|
|
# --- Stage 5: DDIM Loop ---
|
|
ddim_sampler = DDIMSampler(model)
|
|
hook_profiler.reset()
|
|
|
|
with CudaTimer("5_DDIM_Loop", records):
|
|
with torch.autocast('cuda', dtype=torch.bfloat16):
|
|
samples, actions, states, _ = ddim_sampler.sample(
|
|
S=ddim_steps,
|
|
conditioning=cond,
|
|
batch_size=B,
|
|
shape=noise_shape[1:],
|
|
verbose=False,
|
|
unconditional_guidance_scale=cfg_scale,
|
|
unconditional_conditioning=uc,
|
|
eta=1.0,
|
|
cfg_img=None,
|
|
mask=None,
|
|
x0=None,
|
|
fs=fs,
|
|
timestep_spacing='uniform',
|
|
guidance_rescale=0.0,
|
|
unconditional_conditioning_img_nonetext=None,
|
|
)
|
|
|
|
hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect()
|
|
|
|
# --- Stage 6: VAE Decode ---
|
|
with CudaTimer("6_VAE_Decode", records):
|
|
batch_images = model.decode_first_stage(samples)
|
|
|
|
vae_dec_input_bytes = samples.nelement() * samples.element_size()
|
|
vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size()
|
|
|
|
# --- Stage 7: Post-process ---
|
|
with CudaTimer("7_Post_Process", records):
|
|
batch_images_cpu = batch_images.cpu()
|
|
actions_cpu = actions.cpu()
|
|
states_cpu = states.cpu()
|
|
# Simulate video save overhead: clamp + uint8 conversion
|
|
_ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8)
|
|
|
|
# Flatten single-element lists
|
|
stage_times = {k: v[0] for k, v in records.items()}
|
|
|
|
bandwidth_info = {
|
|
"vae_enc_input_bytes": vae_enc_input_bytes,
|
|
"vae_enc_output_bytes": vae_enc_output_bytes,
|
|
"vae_dec_input_bytes": vae_dec_input_bytes,
|
|
"vae_dec_output_bytes": vae_dec_output_bytes,
|
|
}
|
|
|
|
return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Reporting
|
|
# ---------------------------------------------------------------------------
|
|
def print_stage_timing(all_runs_stages):
|
|
"""Table 1: Stage Timing — name | mean(ms) | std | percent."""
|
|
import numpy as np
|
|
stage_names = list(all_runs_stages[0].keys())
|
|
means = {}
|
|
stds = {}
|
|
for name in stage_names:
|
|
vals = [run[name] for run in all_runs_stages]
|
|
means[name] = np.mean(vals)
|
|
stds[name] = np.std(vals)
|
|
total = sum(means.values())
|
|
|
|
print()
|
|
print("=" * 72)
|
|
print("TABLE 1: STAGE TIMING")
|
|
print("=" * 72)
|
|
print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}")
|
|
print("-" * 72)
|
|
for name in stage_names:
|
|
pct = means[name] / total * 100 if total > 0 else 0
|
|
print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%")
|
|
print("-" * 72)
|
|
print(f"{'TOTAL':<25} {total:>10.1f}")
|
|
print()
|
|
|
|
|
|
def print_unet_breakdown(all_runs_hooks):
|
|
"""Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent."""
|
|
import numpy as np
|
|
# Aggregate across runs
|
|
agg = defaultdict(lambda: {"totals": [], "counts": []})
|
|
for hook_by_type in all_runs_hooks:
|
|
for tag, data in hook_by_type.items():
|
|
agg[tag]["totals"].append(data["total_ms"])
|
|
agg[tag]["counts"].append(data["count"])
|
|
|
|
print("=" * 80)
|
|
print("TABLE 2: UNET SUB-MODULE BREAKDOWN")
|
|
print("=" * 80)
|
|
print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}")
|
|
print("-" * 80)
|
|
|
|
grand_total = 0
|
|
rows = []
|
|
for tag, d in agg.items():
|
|
mean_total = np.mean(d["totals"])
|
|
mean_count = np.mean(d["counts"])
|
|
per_call = mean_total / mean_count if mean_count > 0 else 0
|
|
grand_total += mean_total
|
|
rows.append((tag, mean_total, mean_count, per_call))
|
|
|
|
rows.sort(key=lambda r: r[1], reverse=True)
|
|
for tag, mean_total, mean_count, per_call in rows:
|
|
pct = mean_total / grand_total * 100 if grand_total > 0 else 0
|
|
print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%")
|
|
print("-" * 80)
|
|
print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}")
|
|
print()
|
|
|
|
|
|
def print_block_timing(all_runs_blocks):
|
|
"""Table 2b: Per-UNet-block timing — which blocks are hottest."""
|
|
import numpy as np
|
|
# Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}}
|
|
agg = defaultdict(lambda: defaultdict(list))
|
|
for by_block in all_runs_blocks:
|
|
for block_loc, tag_dict in by_block.items():
|
|
for tag, data in tag_dict.items():
|
|
agg[block_loc][tag].append(data["total_ms"])
|
|
|
|
# Compute per-block totals
|
|
block_totals = {}
|
|
for block_loc, tag_dict in agg.items():
|
|
block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values())
|
|
|
|
grand_total = sum(block_totals.values())
|
|
|
|
# Sort blocks in logical order
|
|
def block_sort_key(name):
|
|
if name.startswith("input_blocks."):
|
|
return (0, int(name.split('.')[1]))
|
|
elif name == "middle_block":
|
|
return (1, 0)
|
|
elif name.startswith("output_blocks."):
|
|
return (2, int(name.split('.')[1]))
|
|
elif name == "out":
|
|
return (3, 0)
|
|
elif name == "action_unet":
|
|
return (4, 0)
|
|
elif name == "state_unet":
|
|
return (5, 0)
|
|
return (9, 0)
|
|
|
|
sorted_blocks = sorted(block_totals.keys(), key=block_sort_key)
|
|
|
|
print("=" * 90)
|
|
print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)")
|
|
print("=" * 90)
|
|
print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown")
|
|
print("-" * 90)
|
|
|
|
for block_loc in sorted_blocks:
|
|
total = block_totals[block_loc]
|
|
pct = total / grand_total * 100 if grand_total > 0 else 0
|
|
# Build breakdown string
|
|
parts = []
|
|
for tag, vals in sorted(agg[block_loc].items(),
|
|
key=lambda x: np.mean(x[1]), reverse=True):
|
|
parts.append(f"{tag}={np.mean(vals):.0f}")
|
|
breakdown = ", ".join(parts)
|
|
print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}")
|
|
|
|
print("-" * 90)
|
|
print(f"{'TOTAL':<22} {grand_total:>10.1f}")
|
|
print()
|
|
|
|
|
|
def print_attn_ff_breakdown(all_runs_hooks):
|
|
"""Table 2c: CrossAttention vs FeedForward breakdown (--deep mode)."""
|
|
import numpy as np
|
|
agg = defaultdict(list)
|
|
for hook_by_type in all_runs_hooks:
|
|
for tag, data in hook_by_type.items():
|
|
if tag in ("CrossAttention", "FeedForward"):
|
|
agg[tag].append(data["total_ms"])
|
|
|
|
if not agg:
|
|
return
|
|
|
|
print("=" * 70)
|
|
print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)")
|
|
print("=" * 70)
|
|
print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}")
|
|
print("-" * 70)
|
|
|
|
grand = 0
|
|
rows = []
|
|
for tag in ("CrossAttention", "FeedForward"):
|
|
if tag in agg:
|
|
mean_t = np.mean(agg[tag])
|
|
grand += mean_t
|
|
rows.append((tag, mean_t))
|
|
|
|
for tag, mean_t in rows:
|
|
pct = mean_t / grand * 100 if grand > 0 else 0
|
|
print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%")
|
|
print("-" * 70)
|
|
print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}")
|
|
print()
|
|
|
|
|
|
def print_unet_detailed(all_runs_instances):
|
|
"""Print per-instance UNet sub-module detail (--detailed mode)."""
|
|
import numpy as np
|
|
# Use last run's data
|
|
by_instance = all_runs_instances[-1]
|
|
print("=" * 100)
|
|
print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)")
|
|
print("=" * 100)
|
|
print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}")
|
|
print("-" * 100)
|
|
|
|
rows = []
|
|
for (tag, mod_name), times in by_instance.items():
|
|
if len(times) == 0:
|
|
continue
|
|
total = sum(times)
|
|
mean = np.mean(times)
|
|
rows.append((tag, mod_name, len(times), total, mean))
|
|
rows.sort(key=lambda r: r[3], reverse=True)
|
|
|
|
for tag, mod_name, count, total, mean in rows:
|
|
short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name
|
|
print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}")
|
|
print()
|
|
|
|
|
|
def print_memory_summary(mem_before, mem_peak):
|
|
"""Table 3: Memory Summary."""
|
|
delta = mem_peak - mem_before
|
|
print("=" * 50)
|
|
print("TABLE 3: MEMORY SUMMARY")
|
|
print("=" * 50)
|
|
print(f" Initial allocated: {mem_before / 1e9:.2f} GB")
|
|
print(f" Peak allocated: {mem_peak / 1e9:.2f} GB")
|
|
print(f" Delta (pipeline): {delta / 1e9:.2f} GB")
|
|
print()
|
|
|
|
|
|
def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale):
|
|
"""Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth."""
|
|
import numpy as np
|
|
n_runs = len(all_runs_stages)
|
|
|
|
# Total latency
|
|
totals = []
|
|
for run in all_runs_stages:
|
|
totals.append(sum(run.values()))
|
|
mean_total = np.mean(totals)
|
|
|
|
# DDIM loop time
|
|
ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages]
|
|
mean_ddim = np.mean(ddim_times)
|
|
|
|
unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2
|
|
per_step = mean_ddim / ddim_steps
|
|
per_unet = mean_ddim / unet_calls
|
|
|
|
# VAE bandwidth
|
|
mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages])
|
|
mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages])
|
|
|
|
bw = all_bw[-1] # use last run's byte counts
|
|
enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"]
|
|
dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"]
|
|
enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0
|
|
dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0
|
|
|
|
print("=" * 60)
|
|
print("TABLE 4: THROUGHPUT")
|
|
print("=" * 60)
|
|
print(f" Total pipeline latency: {mean_total:.1f} ms")
|
|
print(f" DDIM loop latency: {mean_ddim:.1f} ms")
|
|
print(f" DDIM steps: {ddim_steps}")
|
|
print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})")
|
|
print(f" UNet forward calls: {unet_calls}")
|
|
print(f" Per DDIM step: {per_step:.1f} ms")
|
|
print(f" Per UNet forward: {per_unet:.1f} ms")
|
|
print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
|
|
print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
|
|
print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS")
|
|
print()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Main
|
|
# ---------------------------------------------------------------------------
|
|
def main():
|
|
patch_norm_bypass_autocast()
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Profile the full inference pipeline")
|
|
parser.add_argument("--ckpt_path", type=str, required=True)
|
|
parser.add_argument("--config", type=str, required=True)
|
|
parser.add_argument("--ddim_steps", type=int, default=50)
|
|
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
|
parser.add_argument("--n_runs", type=int, default=3)
|
|
parser.add_argument("--warmup", type=int, default=1)
|
|
parser.add_argument("--detailed", action="store_true",
|
|
help="Print per-instance UNet sub-module detail")
|
|
parser.add_argument("--deep", action="store_true",
|
|
help="Enable deep DDIM analysis: per-block, attn vs ff")
|
|
args = parser.parse_args()
|
|
|
|
noise_shape = [1, 4, 16, 40, 64]
|
|
|
|
# --- Load model ---
|
|
print("Loading model...")
|
|
model = load_model(args)
|
|
observation, prompts = build_dummy_inputs(model, noise_shape)
|
|
|
|
# --- Setup hook profiler ---
|
|
unet = model.model.diffusion_model
|
|
hook_profiler = HookProfiler(unet, deep=args.deep)
|
|
hook_profiler.register()
|
|
print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules")
|
|
|
|
# --- Warmup ---
|
|
print(f"Warmup: {args.warmup} run(s)...")
|
|
with torch.no_grad():
|
|
for i in range(args.warmup):
|
|
run_pipeline(model, observation, prompts, noise_shape,
|
|
args.ddim_steps, args.cfg_scale, hook_profiler)
|
|
print(f" warmup {i+1}/{args.warmup} done")
|
|
|
|
# --- Measurement runs ---
|
|
print(f"Measuring: {args.n_runs} run(s)...")
|
|
torch.cuda.reset_peak_memory_stats()
|
|
mem_before = torch.cuda.memory_allocated()
|
|
|
|
all_stages = []
|
|
all_hooks = []
|
|
all_instances = []
|
|
all_blocks = []
|
|
all_bw = []
|
|
|
|
with torch.no_grad():
|
|
for i in range(args.n_runs):
|
|
stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline(
|
|
model, observation, prompts, noise_shape,
|
|
args.ddim_steps, args.cfg_scale, hook_profiler)
|
|
all_stages.append(stage_times)
|
|
all_hooks.append(hook_by_type)
|
|
all_instances.append(hook_by_instance)
|
|
all_blocks.append(hook_by_block)
|
|
all_bw.append(bw)
|
|
total = sum(stage_times.values())
|
|
print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total")
|
|
|
|
mem_peak = torch.cuda.max_memory_allocated()
|
|
|
|
# --- Reports ---
|
|
print_stage_timing(all_stages)
|
|
print_unet_breakdown(all_hooks)
|
|
print_block_timing(all_blocks)
|
|
if args.deep:
|
|
print_attn_ff_breakdown(all_hooks)
|
|
if args.detailed:
|
|
print_unet_detailed(all_instances)
|
|
print_memory_summary(mem_before, mem_peak)
|
|
print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale)
|
|
|
|
hook_profiler.remove()
|
|
print("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|