全链路 bf16 混合精度修正与 UNet FLOPS profiling

- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销
  - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配
  - alphas_cumprod 提升到 float64 保证数值精度
  - SinusoidalPosEmb 输出 dtype跟随模型精度
  - 新增 profile_unet.py 脚本及FLOPS 分析结果
  - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL
  - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
This commit is contained in:
2026-02-08 16:01:30 +00:00
parent 75c798ded0
commit f86ab51a04
11 changed files with 464 additions and 30 deletions

View File

@@ -8,12 +8,14 @@ class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
# Dummy buffer so .to(dtype) propagates to this module
self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False)
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = x.float()[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
return emb.to(self._dtype_buf.dtype)

View File

@@ -36,7 +36,7 @@ class DDIMSampler(object):
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model
.device)
if self.model.use_dynamic_rescale:
@@ -376,10 +376,10 @@ class DDIMSampler(object):
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# Use 0-d tensors directly (already on device); broadcasting handles shape
a_t = alphas[index]
a_prev = alphas_prev[index]
sigma_t = sigmas[index]
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
a_t = alphas[index].to(x.dtype)
a_prev = alphas_prev[index].to(x.dtype)
sigma_t = sigmas[index].to(x.dtype)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()

View File

@@ -402,7 +402,7 @@ class CrossAttention(nn.Module):
col_indices = torch.arange(l2, device=target_device)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device)
attn_mask[mask] = float('-inf')
self._attn_mask_aa_cache_key = cache_key

View File

@@ -422,7 +422,7 @@ class WMAModel(nn.Module):
self.temporal_attention = temporal_attention
time_embed_dim = model_channels * 4
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
self.dtype = torch.float16 if use_fp16 else torch.bfloat16
temporal_self_att_only = True
self.addition_attention = addition_attention
self.temporal_length = temporal_length

View File

@@ -7,7 +7,9 @@
#
# thanks!
import torch
import torch.nn as nn
import torch.nn.functional as F
from unifolm_wma.utils.utils import instantiate_from_config
@@ -78,7 +80,11 @@ def nonlinearity(type='silu'):
class GroupNormSpecific(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def normalization(channels, num_groups=32):