全链路 bf16 混合精度修正与 UNet FLOPS profiling
- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销 - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配 - alphas_cumprod 提升到 float64 保证数值精度 - SinusoidalPosEmb 输出 dtype跟随模型精度 - 新增 profile_unet.py 脚本及FLOPS 分析结果 - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
This commit is contained in:
@@ -8,12 +8,14 @@ class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
# Dummy buffer so .to(dtype) propagates to this module
|
||||
self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = x.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
return emb.to(self._dtype_buf.dtype)
|
||||
|
||||
@@ -36,7 +36,7 @@ class DDIMSampler(object):
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model
|
||||
.device)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
@@ -376,10 +376,10 @@ class DDIMSampler(object):
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
||||
a_t = alphas[index]
|
||||
a_prev = alphas_prev[index]
|
||||
sigma_t = sigmas[index]
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
|
||||
a_t = alphas[index].to(x.dtype)
|
||||
a_prev = alphas_prev[index].to(x.dtype)
|
||||
sigma_t = sigmas[index].to(x.dtype)
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
|
||||
@@ -402,7 +402,7 @@ class CrossAttention(nn.Module):
|
||||
col_indices = torch.arange(l2, device=target_device)
|
||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
|
||||
attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device)
|
||||
attn_mask[mask] = float('-inf')
|
||||
|
||||
self._attn_mask_aa_cache_key = cache_key
|
||||
|
||||
@@ -422,7 +422,7 @@ class WMAModel(nn.Module):
|
||||
self.temporal_attention = temporal_attention
|
||||
time_embed_dim = model_channels * 4
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.dtype = torch.float16 if use_fp16 else torch.bfloat16
|
||||
temporal_self_att_only = True
|
||||
self.addition_attention = addition_attention
|
||||
self.temporal_length = temporal_length
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
#
|
||||
# thanks!
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
@@ -78,7 +80,11 @@ def nonlinearity(type='silu'):
|
||||
class GroupNormSpecific(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
|
||||
def normalization(channels, num_groups=32):
|
||||
|
||||
Reference in New Issue
Block a user