attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
288 lines
11 KiB
Python
288 lines
11 KiB
Python
"""
|
|
Profile one DDIM sampling iteration to capture all matmul/attention ops,
|
|
their matrix sizes, wall time, and compute utilization.
|
|
|
|
Uses torch.profiler for CUDA timing and FlopCounterMode for accurate
|
|
FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS).
|
|
|
|
Usage:
|
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
|
python scripts/evaluation/profile_unet.py \
|
|
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
|
--config configs/inference/world_model_interaction.yaml
|
|
"""
|
|
|
|
import argparse
|
|
import torch
|
|
import torch.nn as nn
|
|
from collections import OrderedDict, defaultdict
|
|
from omegaconf import OmegaConf
|
|
from torch.utils.flop_counter import FlopCounterMode
|
|
from unifolm_wma.utils.utils import instantiate_from_config
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def patch_norm_bypass_autocast():
|
|
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
|
|
|
def _group_norm_forward(self, x):
|
|
with torch.amp.autocast('cuda', enabled=False):
|
|
return F.group_norm(
|
|
x, self.num_groups,
|
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
self.eps)
|
|
|
|
def _layer_norm_forward(self, x):
|
|
with torch.amp.autocast('cuda', enabled=False):
|
|
return F.layer_norm(
|
|
x, self.normalized_shape,
|
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
|
self.eps)
|
|
|
|
torch.nn.GroupNorm.forward = _group_norm_forward
|
|
torch.nn.LayerNorm.forward = _layer_norm_forward
|
|
|
|
|
|
# --- W7900D theoretical peak (TFLOPS) ---
|
|
PEAK_BF16_TFLOPS = 61.0
|
|
PEAK_FP32_TFLOPS = 30.5
|
|
|
|
|
|
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
|
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
|
from unifolm_wma.modules.networks.wma_model import ResBlock
|
|
unet = model.model.diffusion_model
|
|
compiled = 0
|
|
for idx in hot_indices:
|
|
block = unet.output_blocks[idx]
|
|
for layer in block:
|
|
if isinstance(layer, ResBlock):
|
|
layer._forward = torch.compile(layer._forward, mode="default")
|
|
compiled += 1
|
|
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
|
|
|
|
|
def load_model(args):
|
|
config = OmegaConf.load(args.config)
|
|
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
|
model = instantiate_from_config(config.model)
|
|
|
|
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
|
if "state_dict" in state_dict:
|
|
state_dict = state_dict["state_dict"]
|
|
model.load_state_dict(state_dict, strict=True)
|
|
|
|
model.eval()
|
|
model.model.to(torch.bfloat16)
|
|
apply_torch_compile(model)
|
|
model = model.cuda()
|
|
return model
|
|
|
|
|
|
def build_call_kwargs(model, noise_shape):
|
|
"""Build dummy inputs matching the hybrid conditioning forward signature."""
|
|
device = next(model.parameters()).device
|
|
B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64]
|
|
dtype = torch.bfloat16
|
|
|
|
x_action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
|
x_state = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
|
timesteps = torch.tensor([500], device=device, dtype=torch.long)
|
|
context = torch.randn(B, 351, 1024, device=device, dtype=dtype)
|
|
|
|
obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype)
|
|
obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype)
|
|
context_action = [obs_images, obs_state, True, False]
|
|
fps = torch.tensor([1], device=device, dtype=torch.long)
|
|
|
|
x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype)
|
|
c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)]
|
|
|
|
return dict(
|
|
x=x_raw, x_action=x_action, x_state=x_state, t=timesteps,
|
|
c_concat=c_concat, c_crossattn=[context],
|
|
c_crossattn_action=context_action, s=fps,
|
|
)
|
|
|
|
|
|
def profile_one_step(model, noise_shape):
|
|
"""Run one UNet forward pass under torch.profiler for CUDA timing."""
|
|
diff_wrapper = model.model
|
|
call_kwargs = build_call_kwargs(model, noise_shape)
|
|
|
|
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
|
# Warmup
|
|
for _ in range(2):
|
|
_ = diff_wrapper(**call_kwargs)
|
|
torch.cuda.synchronize()
|
|
|
|
with torch.profiler.profile(
|
|
activities=[torch.profiler.ProfilerActivity.CUDA],
|
|
record_shapes=True,
|
|
with_flops=True,
|
|
) as prof:
|
|
_ = diff_wrapper(**call_kwargs)
|
|
torch.cuda.synchronize()
|
|
|
|
return prof
|
|
|
|
|
|
def count_flops(model, noise_shape):
|
|
"""Run one UNet forward pass under FlopCounterMode for accurate FLOPS."""
|
|
diff_wrapper = model.model
|
|
call_kwargs = build_call_kwargs(model, noise_shape)
|
|
|
|
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
|
flop_counter = FlopCounterMode(display=False)
|
|
with flop_counter:
|
|
_ = diff_wrapper(**call_kwargs)
|
|
torch.cuda.synchronize()
|
|
|
|
return flop_counter
|
|
|
|
|
|
def print_report(prof, flop_counter):
|
|
"""Parse profiler results and print a structured report with accurate FLOPS."""
|
|
events = prof.key_averages()
|
|
|
|
# --- Extract per-operator FLOPS from FlopCounterMode ---
|
|
# flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting
|
|
flop_by_op = {}
|
|
flop_by_module = {}
|
|
if hasattr(flop_counter, 'flop_counts'):
|
|
# Per-op: only from top-level "Global" entry (no parent/child duplication)
|
|
global_ops = flop_counter.flop_counts.get("Global", {})
|
|
for op_name, flop_count in global_ops.items():
|
|
key = str(op_name).split('.')[-1]
|
|
flop_by_op[key] = flop_by_op.get(key, 0) + flop_count
|
|
|
|
# Per-module: collect all, skip "Global" and top-level wrapper duplicates
|
|
for module_name, op_dict in flop_counter.flop_counts.items():
|
|
module_total = sum(op_dict.values())
|
|
if module_total > 0:
|
|
flop_by_module[module_name] = module_total
|
|
|
|
total_counted_flops = flop_counter.get_total_flops()
|
|
|
|
# Collect matmul-like ops
|
|
matmul_ops = []
|
|
other_ops = []
|
|
|
|
for evt in events:
|
|
if evt.device_time_total <= 0:
|
|
continue
|
|
name = evt.key
|
|
is_matmul = any(k in name.lower() for k in
|
|
['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear'])
|
|
entry = {
|
|
'name': name,
|
|
'input_shapes': str(evt.input_shapes) if evt.input_shapes else '',
|
|
'cuda_time_ms': evt.device_time_total / 1000.0,
|
|
'count': evt.count,
|
|
'flops': evt.flops if evt.flops else 0,
|
|
}
|
|
if is_matmul:
|
|
matmul_ops.append(entry)
|
|
else:
|
|
other_ops.append(entry)
|
|
|
|
# Sort by CUDA time
|
|
matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
|
other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
|
|
|
total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops)
|
|
total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops)
|
|
# --- Print matmul ops ---
|
|
print("=" * 130)
|
|
print("MATMUL / LINEAR OPS (sorted by CUDA time)")
|
|
print("=" * 130)
|
|
print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
|
print("-" * 130)
|
|
|
|
for op in matmul_ops:
|
|
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
|
print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
|
|
|
# --- Print top non-matmul ops ---
|
|
print()
|
|
print("=" * 130)
|
|
print("TOP NON-MATMUL OPS (sorted by CUDA time)")
|
|
print("=" * 130)
|
|
print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
|
print("-" * 130)
|
|
|
|
for op in other_ops[:20]:
|
|
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
|
print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
|
|
|
# --- FlopCounterMode per-operator breakdown ---
|
|
print()
|
|
print("=" * 130)
|
|
print("FLOPS BY ATen OPERATOR (FlopCounterMode)")
|
|
print("=" * 130)
|
|
print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}")
|
|
print("-" * 55)
|
|
|
|
sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True)
|
|
for op_name, flops in sorted_flop_ops:
|
|
gflops = flops / 1e9
|
|
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
|
print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%")
|
|
|
|
# --- FlopCounterMode per-module breakdown ---
|
|
if flop_by_module:
|
|
print()
|
|
print("=" * 130)
|
|
print("FLOPS BY MODULE (FlopCounterMode)")
|
|
print("=" * 130)
|
|
print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}")
|
|
print("-" * 90)
|
|
|
|
sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True)
|
|
for mod_name, flops in sorted_modules[:30]:
|
|
gflops = flops / 1e9
|
|
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
|
name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name
|
|
print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%")
|
|
|
|
# --- Summary ---
|
|
print()
|
|
print("=" * 130)
|
|
print("SUMMARY")
|
|
print("=" * 130)
|
|
print(f" Total CUDA time: {total_cuda_ms:.1f} ms")
|
|
print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)")
|
|
print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)")
|
|
print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS")
|
|
if total_matmul_ms > 0 and total_counted_flops > 0:
|
|
avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12
|
|
avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100
|
|
overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12
|
|
overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100
|
|
print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)")
|
|
print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)")
|
|
print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
patch_norm_bypass_autocast()
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--ckpt_path", type=str, required=True)
|
|
parser.add_argument("--config", type=str, required=True)
|
|
args = parser.parse_args()
|
|
|
|
print("Loading model...")
|
|
model = load_model(args)
|
|
|
|
noise_shape = [1, 4, 16, 40, 64]
|
|
|
|
print(f"Profiling UNet forward pass with shape {noise_shape}...")
|
|
prof = profile_one_step(model, noise_shape)
|
|
|
|
print("Counting FLOPS with FlopCounterMode...")
|
|
flop_counter = count_flops(model, noise_shape)
|
|
|
|
print_report(prof, flop_counter)
|