全链路 bf16 混合精度修正与 UNet FLOPS profiling
- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销 - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配 - alphas_cumprod 提升到 float64 保证数值精度 - SinusoidalPosEmb 输出 dtype跟随模型精度 - 新增 profile_unet.py 脚本及FLOPS 分析结果 - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
This commit is contained in:
121
profile_unet_flops.md
Normal file
121
profile_unet_flops.md
Normal file
@@ -0,0 +1,121 @@
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_unet.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml
|
||||
|
||||
|
||||
==================================================================================================================================
|
||||
FLOPS BY ATen OPERATOR (FlopCounterMode)
|
||||
==================================================================================================================================
|
||||
ATen Op | GFLOPS | % of Total
|
||||
-------------------------------------------------------
|
||||
convolution | 6185.17 | 46.4%
|
||||
addmm | 4411.17 | 33.1%
|
||||
mm | 1798.34 | 13.5%
|
||||
bmm | 949.54 | 7.1%
|
||||
|
||||
==================================================================================================================================
|
||||
FLOPS BY MODULE (FlopCounterMode)
|
||||
==================================================================================================================================
|
||||
Module | GFLOPS | % of Total
|
||||
------------------------------------------------------------------------------------------
|
||||
Global | 13344.23 | 100.0%
|
||||
DiffusionWrapper | 13344.23 | 100.0%
|
||||
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
|
||||
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
|
||||
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
|
||||
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
|
||||
|
||||
==================================================================================================================================
|
||||
SUMMARY
|
||||
==================================================================================================================================
|
||||
Total CUDA time: 761.4 ms
|
||||
Matmul CUDA time: 404.2 ms (53.1%)
|
||||
Non-matmul CUDA time: 357.1 ms (46.9%)
|
||||
Total FLOPS (FlopCounter): 13344.23 GFLOPS
|
||||
Matmul throughput: 33.01 TFLOPS/s (54.1% of BF16 peak)
|
||||
Overall throughput: 17.53 TFLOPS/s (28.7% of BF16 peak)
|
||||
GPU peak (BF16): 61.0 TFLOPS
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
==================================================================================================================================
|
||||
FLOPS BY ATen OPERATOR (FlopCounterMode)
|
||||
==================================================================================================================================
|
||||
ATen Op | GFLOPS | % of Total
|
||||
-------------------------------------------------------
|
||||
convolution | 6185.17 | 46.4%
|
||||
addmm | 4411.17 | 33.1%
|
||||
mm | 1798.34 | 13.5%
|
||||
bmm | 949.54 | 7.1%
|
||||
|
||||
==================================================================================================================================
|
||||
FLOPS BY MODULE (FlopCounterMode)
|
||||
==================================================================================================================================
|
||||
Module | GFLOPS | % of Total
|
||||
------------------------------------------------------------------------------------------
|
||||
DiffusionWrapper | 13344.23 | 100.0%
|
||||
Global | 13344.23 | 100.0%
|
||||
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
|
||||
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
|
||||
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
|
||||
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
|
||||
|
||||
==================================================================================================================================
|
||||
SUMMARY
|
||||
==================================================================================================================================
|
||||
Total CUDA time: 707.1 ms
|
||||
Matmul CUDA time: 403.1 ms (57.0%)
|
||||
Non-matmul CUDA time: 304.0 ms (43.0%)
|
||||
Total FLOPS (FlopCounter): 13344.23 GFLOPS
|
||||
Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak)
|
||||
Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak)
|
||||
GPU peak (BF16): 61.0 TFLOPS
|
||||
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
|
||||
272
scripts/evaluation/profile_unet.py
Normal file
272
scripts/evaluation/profile_unet.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
Profile one DDIM sampling iteration to capture all matmul/attention ops,
|
||||
their matrix sizes, wall time, and compute utilization.
|
||||
|
||||
Uses torch.profiler for CUDA timing and FlopCounterMode for accurate
|
||||
FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS).
|
||||
|
||||
Usage:
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
||||
python scripts/evaluation/profile_unet.py \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict, defaultdict
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def patch_norm_bypass_autocast():
|
||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
||||
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
# --- W7900D theoretical peak (TFLOPS) ---
|
||||
PEAK_BF16_TFLOPS = 61.0
|
||||
PEAK_FP32_TFLOPS = 30.5
|
||||
|
||||
|
||||
def load_model(args):
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
model.eval()
|
||||
model.model.to(torch.bfloat16)
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
def build_call_kwargs(model, noise_shape):
|
||||
"""Build dummy inputs matching the hybrid conditioning forward signature."""
|
||||
device = next(model.parameters()).device
|
||||
B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
x_action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||
x_state = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||
timesteps = torch.tensor([500], device=device, dtype=torch.long)
|
||||
context = torch.randn(B, 351, 1024, device=device, dtype=dtype)
|
||||
|
||||
obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype)
|
||||
obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype)
|
||||
context_action = [obs_images, obs_state, True, False]
|
||||
fps = torch.tensor([1], device=device, dtype=torch.long)
|
||||
|
||||
x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype)
|
||||
c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)]
|
||||
|
||||
return dict(
|
||||
x=x_raw, x_action=x_action, x_state=x_state, t=timesteps,
|
||||
c_concat=c_concat, c_crossattn=[context],
|
||||
c_crossattn_action=context_action, s=fps,
|
||||
)
|
||||
|
||||
|
||||
def profile_one_step(model, noise_shape):
|
||||
"""Run one UNet forward pass under torch.profiler for CUDA timing."""
|
||||
diff_wrapper = model.model
|
||||
call_kwargs = build_call_kwargs(model, noise_shape)
|
||||
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
||||
# Warmup
|
||||
for _ in range(2):
|
||||
_ = diff_wrapper(**call_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
with_flops=True,
|
||||
) as prof:
|
||||
_ = diff_wrapper(**call_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return prof
|
||||
|
||||
|
||||
def count_flops(model, noise_shape):
|
||||
"""Run one UNet forward pass under FlopCounterMode for accurate FLOPS."""
|
||||
diff_wrapper = model.model
|
||||
call_kwargs = build_call_kwargs(model, noise_shape)
|
||||
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
||||
flop_counter = FlopCounterMode(display=False)
|
||||
with flop_counter:
|
||||
_ = diff_wrapper(**call_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return flop_counter
|
||||
|
||||
|
||||
def print_report(prof, flop_counter):
|
||||
"""Parse profiler results and print a structured report with accurate FLOPS."""
|
||||
events = prof.key_averages()
|
||||
|
||||
# --- Extract per-operator FLOPS from FlopCounterMode ---
|
||||
# flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting
|
||||
flop_by_op = {}
|
||||
flop_by_module = {}
|
||||
if hasattr(flop_counter, 'flop_counts'):
|
||||
# Per-op: only from top-level "Global" entry (no parent/child duplication)
|
||||
global_ops = flop_counter.flop_counts.get("Global", {})
|
||||
for op_name, flop_count in global_ops.items():
|
||||
key = str(op_name).split('.')[-1]
|
||||
flop_by_op[key] = flop_by_op.get(key, 0) + flop_count
|
||||
|
||||
# Per-module: collect all, skip "Global" and top-level wrapper duplicates
|
||||
for module_name, op_dict in flop_counter.flop_counts.items():
|
||||
module_total = sum(op_dict.values())
|
||||
if module_total > 0:
|
||||
flop_by_module[module_name] = module_total
|
||||
|
||||
total_counted_flops = flop_counter.get_total_flops()
|
||||
|
||||
# Collect matmul-like ops
|
||||
matmul_ops = []
|
||||
other_ops = []
|
||||
|
||||
for evt in events:
|
||||
if evt.device_time_total <= 0:
|
||||
continue
|
||||
name = evt.key
|
||||
is_matmul = any(k in name.lower() for k in
|
||||
['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear'])
|
||||
entry = {
|
||||
'name': name,
|
||||
'input_shapes': str(evt.input_shapes) if evt.input_shapes else '',
|
||||
'cuda_time_ms': evt.device_time_total / 1000.0,
|
||||
'count': evt.count,
|
||||
'flops': evt.flops if evt.flops else 0,
|
||||
}
|
||||
if is_matmul:
|
||||
matmul_ops.append(entry)
|
||||
else:
|
||||
other_ops.append(entry)
|
||||
|
||||
# Sort by CUDA time
|
||||
matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
||||
other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
||||
|
||||
total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops)
|
||||
total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops)
|
||||
# --- Print matmul ops ---
|
||||
print("=" * 130)
|
||||
print("MATMUL / LINEAR OPS (sorted by CUDA time)")
|
||||
print("=" * 130)
|
||||
print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
||||
print("-" * 130)
|
||||
|
||||
for op in matmul_ops:
|
||||
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
||||
print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
||||
|
||||
# --- Print top non-matmul ops ---
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("TOP NON-MATMUL OPS (sorted by CUDA time)")
|
||||
print("=" * 130)
|
||||
print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
||||
print("-" * 130)
|
||||
|
||||
for op in other_ops[:20]:
|
||||
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
||||
print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
||||
|
||||
# --- FlopCounterMode per-operator breakdown ---
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("FLOPS BY ATen OPERATOR (FlopCounterMode)")
|
||||
print("=" * 130)
|
||||
print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}")
|
||||
print("-" * 55)
|
||||
|
||||
sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True)
|
||||
for op_name, flops in sorted_flop_ops:
|
||||
gflops = flops / 1e9
|
||||
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
||||
print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%")
|
||||
|
||||
# --- FlopCounterMode per-module breakdown ---
|
||||
if flop_by_module:
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("FLOPS BY MODULE (FlopCounterMode)")
|
||||
print("=" * 130)
|
||||
print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}")
|
||||
print("-" * 90)
|
||||
|
||||
sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True)
|
||||
for mod_name, flops in sorted_modules[:30]:
|
||||
gflops = flops / 1e9
|
||||
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
||||
name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name
|
||||
print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%")
|
||||
|
||||
# --- Summary ---
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("SUMMARY")
|
||||
print("=" * 130)
|
||||
print(f" Total CUDA time: {total_cuda_ms:.1f} ms")
|
||||
print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)")
|
||||
print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)")
|
||||
print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS")
|
||||
if total_matmul_ms > 0 and total_counted_flops > 0:
|
||||
avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12
|
||||
avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100
|
||||
overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12
|
||||
overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100
|
||||
print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)")
|
||||
print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)")
|
||||
print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
patch_norm_bypass_autocast()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading model...")
|
||||
model = load_model(args)
|
||||
|
||||
noise_shape = [1, 4, 16, 40, 64]
|
||||
|
||||
print(f"Profiling UNet forward pass with shape {noise_shape}...")
|
||||
prof = profile_one_step(model, noise_shape)
|
||||
|
||||
print("Counting FLOPS with FlopCounterMode...")
|
||||
flop_counter = count_flops(model, noise_shape)
|
||||
|
||||
print_report(prof, flop_counter)
|
||||
@@ -25,6 +25,31 @@ from PIL import Image
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def patch_norm_bypass_autocast():
|
||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
|
||||
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
|
||||
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
@@ -62,7 +87,7 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model weights converted to bfloat16")
|
||||
else:
|
||||
model.diffusion_autocast_dtype = None
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model using fp32")
|
||||
|
||||
# 2. Set Projector precision
|
||||
@@ -98,6 +123,15 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
|
||||
else:
|
||||
print(" ✓ VAE kept in fp32 for best quality")
|
||||
|
||||
# 5. Safety net: ensure no fp32 parameters remain when all components are bf16
|
||||
if args.diffusion_dtype == "bf16":
|
||||
fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32]
|
||||
if fp32_params:
|
||||
print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16")
|
||||
for name, param in fp32_params:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
print(" ✓ All parameters converted to bfloat16")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -942,6 +976,7 @@ def get_parser():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
patch_norm_bypass_autocast()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
seed = args.seed
|
||||
|
||||
@@ -8,12 +8,14 @@ class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
# Dummy buffer so .to(dtype) propagates to this module
|
||||
self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = x.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
return emb.to(self._dtype_buf.dtype)
|
||||
|
||||
@@ -36,7 +36,7 @@ class DDIMSampler(object):
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model
|
||||
.device)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
@@ -376,10 +376,10 @@ class DDIMSampler(object):
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
||||
a_t = alphas[index]
|
||||
a_prev = alphas_prev[index]
|
||||
sigma_t = sigmas[index]
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||
a_t = alphas[index].to(x.dtype)
|
||||
a_prev = alphas_prev[index].to(x.dtype)
|
||||
sigma_t = sigmas[index].to(x.dtype)
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
@@ -402,7 +402,7 @@ class CrossAttention(nn.Module):
|
||||
col_indices = torch.arange(l2, device=target_device)
|
||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
|
||||
attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device)
|
||||
attn_mask[mask] = float('-inf')
|
||||
|
||||
self._attn_mask_aa_cache_key = cache_key
|
||||
|
||||
@@ -422,7 +422,7 @@ class WMAModel(nn.Module):
|
||||
self.temporal_attention = temporal_attention
|
||||
time_embed_dim = model_channels * 4
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.dtype = torch.float16 if use_fp16 else torch.bfloat16
|
||||
temporal_self_att_only = True
|
||||
self.addition_attention = addition_attention
|
||||
self.temporal_length = temporal_length
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
#
|
||||
# thanks!
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
@@ -78,7 +80,11 @@ def nonlinearity(type='silu'):
|
||||
class GroupNormSpecific(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
|
||||
def normalization(channels, num_groups=32):
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
2026-02-08 13:59:02.578826: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 13:59:02.581891: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 13:59:02.613088: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 13:59:02.613125: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 13:59:02.614961: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 13:59:02.623180: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 13:59:02.623460: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
2026-02-08 15:47:30.035545: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 15:47:30.038628: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 15:47:30.069635: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 15:47:30.069671: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 15:47:30.071534: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 15:47:30.080021: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 15:47:30.080300: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 13:59:03.306638: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
2026-02-08 15:47:30.746161: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
@@ -23,7 +23,7 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:149: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:183: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
@@ -36,6 +36,8 @@ INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
✓ Projectors converted to bfloat16
|
||||
✓ Encoders converted to bfloat16
|
||||
✓ VAE converted to bfloat16
|
||||
⚠ Found 601 fp32 params, converting to bf16
|
||||
✓ All parameters converted to bfloat16
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
@@ -60,10 +62,6 @@ DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
proj = linear(q, w, b)
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
@@ -115,7 +113,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
12%|█▎ | 1/8 [01:19<09:17, 79.58s/it]
|
||||
25%|██▌ | 2/8 [02:38<07:54, 79.06s/it]
|
||||
@@ -139,6 +137,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||
"psnr": 30.44844270035179
|
||||
"psnr": 30.24435361473318
|
||||
}
|
||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
|
||||
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
|
||||
Reference in New Issue
Block a user