全链路 bf16 混合精度修正与 UNet FLOPS profiling
- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销 - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配 - alphas_cumprod 提升到 float64 保证数值精度 - SinusoidalPosEmb 输出 dtype跟随模型精度 - 新增 profile_unet.py 脚本及FLOPS 分析结果 - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
This commit is contained in:
@@ -8,12 +8,14 @@ class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
# Dummy buffer so .to(dtype) propagates to this module
|
||||
self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = x.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
return emb.to(self._dtype_buf.dtype)
|
||||
|
||||
Reference in New Issue
Block a user