全链路 bf16 混合精度修正与 UNet FLOPS profiling
- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销 - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配 - alphas_cumprod 提升到 float64 保证数值精度 - SinusoidalPosEmb 输出 dtype跟随模型精度 - 新增 profile_unet.py 脚本及FLOPS 分析结果 - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
This commit is contained in:
272
scripts/evaluation/profile_unet.py
Normal file
272
scripts/evaluation/profile_unet.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Profile one DDIM sampling iteration to capture all matmul/attention ops,
|
||||
their matrix sizes, wall time, and compute utilization.
|
||||
|
||||
Uses torch.profiler for CUDA timing and FlopCounterMode for accurate
|
||||
FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS).
|
||||
|
||||
Usage:
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
||||
python scripts/evaluation/profile_unet.py \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict, defaultdict
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def patch_norm_bypass_autocast():
|
||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
||||
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
# --- W7900D theoretical peak (TFLOPS) ---
|
||||
PEAK_BF16_TFLOPS = 61.0
|
||||
PEAK_FP32_TFLOPS = 30.5
|
||||
|
||||
|
||||
def load_model(args):
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
model.eval()
|
||||
model.model.to(torch.bfloat16)
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
def build_call_kwargs(model, noise_shape):
|
||||
"""Build dummy inputs matching the hybrid conditioning forward signature."""
|
||||
device = next(model.parameters()).device
|
||||
B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
x_action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||
x_state = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||
timesteps = torch.tensor([500], device=device, dtype=torch.long)
|
||||
context = torch.randn(B, 351, 1024, device=device, dtype=dtype)
|
||||
|
||||
obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype)
|
||||
obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype)
|
||||
context_action = [obs_images, obs_state, True, False]
|
||||
fps = torch.tensor([1], device=device, dtype=torch.long)
|
||||
|
||||
x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype)
|
||||
c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)]
|
||||
|
||||
return dict(
|
||||
x=x_raw, x_action=x_action, x_state=x_state, t=timesteps,
|
||||
c_concat=c_concat, c_crossattn=[context],
|
||||
c_crossattn_action=context_action, s=fps,
|
||||
)
|
||||
|
||||
|
||||
def profile_one_step(model, noise_shape):
|
||||
"""Run one UNet forward pass under torch.profiler for CUDA timing."""
|
||||
diff_wrapper = model.model
|
||||
call_kwargs = build_call_kwargs(model, noise_shape)
|
||||
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
||||
# Warmup
|
||||
for _ in range(2):
|
||||
_ = diff_wrapper(**call_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
with_flops=True,
|
||||
) as prof:
|
||||
_ = diff_wrapper(**call_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return prof
|
||||
|
||||
|
||||
def count_flops(model, noise_shape):
|
||||
"""Run one UNet forward pass under FlopCounterMode for accurate FLOPS."""
|
||||
diff_wrapper = model.model
|
||||
call_kwargs = build_call_kwargs(model, noise_shape)
|
||||
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
||||
flop_counter = FlopCounterMode(display=False)
|
||||
with flop_counter:
|
||||
_ = diff_wrapper(**call_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return flop_counter
|
||||
|
||||
|
||||
def print_report(prof, flop_counter):
|
||||
"""Parse profiler results and print a structured report with accurate FLOPS."""
|
||||
events = prof.key_averages()
|
||||
|
||||
# --- Extract per-operator FLOPS from FlopCounterMode ---
|
||||
# flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting
|
||||
flop_by_op = {}
|
||||
flop_by_module = {}
|
||||
if hasattr(flop_counter, 'flop_counts'):
|
||||
# Per-op: only from top-level "Global" entry (no parent/child duplication)
|
||||
global_ops = flop_counter.flop_counts.get("Global", {})
|
||||
for op_name, flop_count in global_ops.items():
|
||||
key = str(op_name).split('.')[-1]
|
||||
flop_by_op[key] = flop_by_op.get(key, 0) + flop_count
|
||||
|
||||
# Per-module: collect all, skip "Global" and top-level wrapper duplicates
|
||||
for module_name, op_dict in flop_counter.flop_counts.items():
|
||||
module_total = sum(op_dict.values())
|
||||
if module_total > 0:
|
||||
flop_by_module[module_name] = module_total
|
||||
|
||||
total_counted_flops = flop_counter.get_total_flops()
|
||||
|
||||
# Collect matmul-like ops
|
||||
matmul_ops = []
|
||||
other_ops = []
|
||||
|
||||
for evt in events:
|
||||
if evt.device_time_total <= 0:
|
||||
continue
|
||||
name = evt.key
|
||||
is_matmul = any(k in name.lower() for k in
|
||||
['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear'])
|
||||
entry = {
|
||||
'name': name,
|
||||
'input_shapes': str(evt.input_shapes) if evt.input_shapes else '',
|
||||
'cuda_time_ms': evt.device_time_total / 1000.0,
|
||||
'count': evt.count,
|
||||
'flops': evt.flops if evt.flops else 0,
|
||||
}
|
||||
if is_matmul:
|
||||
matmul_ops.append(entry)
|
||||
else:
|
||||
other_ops.append(entry)
|
||||
|
||||
# Sort by CUDA time
|
||||
matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
||||
other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
||||
|
||||
total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops)
|
||||
total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops)
|
||||
# --- Print matmul ops ---
|
||||
print("=" * 130)
|
||||
print("MATMUL / LINEAR OPS (sorted by CUDA time)")
|
||||
print("=" * 130)
|
||||
print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
||||
print("-" * 130)
|
||||
|
||||
for op in matmul_ops:
|
||||
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
||||
print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
||||
|
||||
# --- Print top non-matmul ops ---
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("TOP NON-MATMUL OPS (sorted by CUDA time)")
|
||||
print("=" * 130)
|
||||
print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
||||
print("-" * 130)
|
||||
|
||||
for op in other_ops[:20]:
|
||||
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
||||
print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
||||
|
||||
# --- FlopCounterMode per-operator breakdown ---
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("FLOPS BY ATen OPERATOR (FlopCounterMode)")
|
||||
print("=" * 130)
|
||||
print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}")
|
||||
print("-" * 55)
|
||||
|
||||
sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True)
|
||||
for op_name, flops in sorted_flop_ops:
|
||||
gflops = flops / 1e9
|
||||
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
||||
print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%")
|
||||
|
||||
# --- FlopCounterMode per-module breakdown ---
|
||||
if flop_by_module:
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("FLOPS BY MODULE (FlopCounterMode)")
|
||||
print("=" * 130)
|
||||
print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}")
|
||||
print("-" * 90)
|
||||
|
||||
sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True)
|
||||
for mod_name, flops in sorted_modules[:30]:
|
||||
gflops = flops / 1e9
|
||||
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
||||
name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name
|
||||
print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%")
|
||||
|
||||
# --- Summary ---
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("SUMMARY")
|
||||
print("=" * 130)
|
||||
print(f" Total CUDA time: {total_cuda_ms:.1f} ms")
|
||||
print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)")
|
||||
print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)")
|
||||
print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS")
|
||||
if total_matmul_ms > 0 and total_counted_flops > 0:
|
||||
avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12
|
||||
avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100
|
||||
overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12
|
||||
overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100
|
||||
print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)")
|
||||
print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)")
|
||||
print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
patch_norm_bypass_autocast()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading model...")
|
||||
model = load_model(args)
|
||||
|
||||
noise_shape = [1, 4, 16, 40, 64]
|
||||
|
||||
print(f"Profiling UNet forward pass with shape {noise_shape}...")
|
||||
prof = profile_one_step(model, noise_shape)
|
||||
|
||||
print("Counting FLOPS with FlopCounterMode...")
|
||||
flop_counter = count_flops(model, noise_shape)
|
||||
|
||||
print_report(prof, flop_counter)
|
||||
@@ -25,6 +25,31 @@ from PIL import Image
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def patch_norm_bypass_autocast():
|
||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
|
||||
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
|
||||
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
@@ -62,7 +87,7 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model weights converted to bfloat16")
|
||||
else:
|
||||
model.diffusion_autocast_dtype = None
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model using fp32")
|
||||
|
||||
# 2. Set Projector precision
|
||||
@@ -98,6 +123,15 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
|
||||
else:
|
||||
print(" ✓ VAE kept in fp32 for best quality")
|
||||
|
||||
# 5. Safety net: ensure no fp32 parameters remain when all components are bf16
|
||||
if args.diffusion_dtype == "bf16":
|
||||
fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32]
|
||||
if fp32_params:
|
||||
print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16")
|
||||
for name, param in fp32_params:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
print(" ✓ All parameters converted to bfloat16")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -942,6 +976,7 @@ def get_parser():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
patch_norm_bypass_autocast()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
seed = args.seed
|
||||
|
||||
Reference in New Issue
Block a user