Files
unifolm-world-model-action/scripts/evaluation/profile_unet.py
olivame f86ab51a04 全链路 bf16 混合精度修正与 UNet FLOPS profiling
- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销
  - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配
  - alphas_cumprod 提升到 float64 保证数值精度
  - SinusoidalPosEmb 输出 dtype跟随模型精度
  - 新增 profile_unet.py 脚本及FLOPS 分析结果
  - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL
  - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
2026-02-08 16:01:30 +00:00

273 lines
10 KiB
Python

"""
Profile one DDIM sampling iteration to capture all matmul/attention ops,
their matrix sizes, wall time, and compute utilization.
Uses torch.profiler for CUDA timing and FlopCounterMode for accurate
FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS).
Usage:
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
python scripts/evaluation/profile_unet.py \
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
--config configs/inference/world_model_interaction.yaml
"""
import argparse
import torch
import torch.nn as nn
from collections import OrderedDict, defaultdict
from omegaconf import OmegaConf
from torch.utils.flop_counter import FlopCounterMode
from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
# --- W7900D theoretical peak (TFLOPS) ---
PEAK_BF16_TFLOPS = 61.0
PEAK_FP32_TFLOPS = 30.5
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=True)
model.eval()
model.model.to(torch.bfloat16)
model = model.cuda()
return model
def build_call_kwargs(model, noise_shape):
"""Build dummy inputs matching the hybrid conditioning forward signature."""
device = next(model.parameters()).device
B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64]
dtype = torch.bfloat16
x_action = torch.randn(B, 16, 16, device=device, dtype=dtype)
x_state = torch.randn(B, 16, 16, device=device, dtype=dtype)
timesteps = torch.tensor([500], device=device, dtype=torch.long)
context = torch.randn(B, 351, 1024, device=device, dtype=dtype)
obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype)
obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype)
context_action = [obs_images, obs_state, True, False]
fps = torch.tensor([1], device=device, dtype=torch.long)
x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype)
c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)]
return dict(
x=x_raw, x_action=x_action, x_state=x_state, t=timesteps,
c_concat=c_concat, c_crossattn=[context],
c_crossattn_action=context_action, s=fps,
)
def profile_one_step(model, noise_shape):
"""Run one UNet forward pass under torch.profiler for CUDA timing."""
diff_wrapper = model.model
call_kwargs = build_call_kwargs(model, noise_shape)
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
# Warmup
for _ in range(2):
_ = diff_wrapper(**call_kwargs)
torch.cuda.synchronize()
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_flops=True,
) as prof:
_ = diff_wrapper(**call_kwargs)
torch.cuda.synchronize()
return prof
def count_flops(model, noise_shape):
"""Run one UNet forward pass under FlopCounterMode for accurate FLOPS."""
diff_wrapper = model.model
call_kwargs = build_call_kwargs(model, noise_shape)
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
_ = diff_wrapper(**call_kwargs)
torch.cuda.synchronize()
return flop_counter
def print_report(prof, flop_counter):
"""Parse profiler results and print a structured report with accurate FLOPS."""
events = prof.key_averages()
# --- Extract per-operator FLOPS from FlopCounterMode ---
# flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting
flop_by_op = {}
flop_by_module = {}
if hasattr(flop_counter, 'flop_counts'):
# Per-op: only from top-level "Global" entry (no parent/child duplication)
global_ops = flop_counter.flop_counts.get("Global", {})
for op_name, flop_count in global_ops.items():
key = str(op_name).split('.')[-1]
flop_by_op[key] = flop_by_op.get(key, 0) + flop_count
# Per-module: collect all, skip "Global" and top-level wrapper duplicates
for module_name, op_dict in flop_counter.flop_counts.items():
module_total = sum(op_dict.values())
if module_total > 0:
flop_by_module[module_name] = module_total
total_counted_flops = flop_counter.get_total_flops()
# Collect matmul-like ops
matmul_ops = []
other_ops = []
for evt in events:
if evt.device_time_total <= 0:
continue
name = evt.key
is_matmul = any(k in name.lower() for k in
['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear'])
entry = {
'name': name,
'input_shapes': str(evt.input_shapes) if evt.input_shapes else '',
'cuda_time_ms': evt.device_time_total / 1000.0,
'count': evt.count,
'flops': evt.flops if evt.flops else 0,
}
if is_matmul:
matmul_ops.append(entry)
else:
other_ops.append(entry)
# Sort by CUDA time
matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops)
total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops)
# --- Print matmul ops ---
print("=" * 130)
print("MATMUL / LINEAR OPS (sorted by CUDA time)")
print("=" * 130)
print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
print("-" * 130)
for op in matmul_ops:
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
# --- Print top non-matmul ops ---
print()
print("=" * 130)
print("TOP NON-MATMUL OPS (sorted by CUDA time)")
print("=" * 130)
print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
print("-" * 130)
for op in other_ops[:20]:
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
# --- FlopCounterMode per-operator breakdown ---
print()
print("=" * 130)
print("FLOPS BY ATen OPERATOR (FlopCounterMode)")
print("=" * 130)
print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}")
print("-" * 55)
sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True)
for op_name, flops in sorted_flop_ops:
gflops = flops / 1e9
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%")
# --- FlopCounterMode per-module breakdown ---
if flop_by_module:
print()
print("=" * 130)
print("FLOPS BY MODULE (FlopCounterMode)")
print("=" * 130)
print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}")
print("-" * 90)
sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True)
for mod_name, flops in sorted_modules[:30]:
gflops = flops / 1e9
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name
print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%")
# --- Summary ---
print()
print("=" * 130)
print("SUMMARY")
print("=" * 130)
print(f" Total CUDA time: {total_cuda_ms:.1f} ms")
print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)")
print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)")
print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS")
if total_matmul_ms > 0 and total_counted_flops > 0:
avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12
avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100
overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12
overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100
print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)")
print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)")
print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS")
if __name__ == '__main__':
patch_norm_bypass_autocast()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
print("Loading model...")
model = load_model(args)
noise_shape = [1, 4, 16, 40, 64]
print(f"Profiling UNet forward pass with shape {noise_shape}...")
prof = profile_one_step(model, noise_shape)
print("Counting FLOPS with FlopCounterMode...")
flop_counter = count_flops(model, noise_shape)
print_report(prof, flop_counter)