全链路 bf16 混合精度修正与 UNet FLOPS profiling

- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销
  - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配
  - alphas_cumprod 提升到 float64 保证数值精度
  - SinusoidalPosEmb 输出 dtype跟随模型精度
  - 新增 profile_unet.py 脚本及FLOPS 分析结果
  - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL
  - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
This commit is contained in:
2026-02-08 16:01:30 +00:00
parent 75c798ded0
commit f86ab51a04
11 changed files with 464 additions and 30 deletions

View File

@@ -402,7 +402,7 @@ class CrossAttention(nn.Module):
col_indices = torch.arange(l2, device=target_device)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device)
attn_mask[mask] = float('-inf')
self._attn_mask_aa_cache_key = cache_key