全链路 bf16 混合精度修正与 UNet FLOPS profiling
- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销 - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配 - alphas_cumprod 提升到 float64 保证数值精度 - SinusoidalPosEmb 输出 dtype跟随模型精度 - 新增 profile_unet.py 脚本及FLOPS 分析结果 - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
This commit is contained in:
@@ -7,7 +7,9 @@
|
||||
#
|
||||
# thanks!
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
@@ -78,7 +80,11 @@ def nonlinearity(type='silu'):
|
||||
class GroupNormSpecific(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
|
||||
def normalization(channels, num_groups=32):
|
||||
|
||||
Reference in New Issue
Block a user