全链路 bf16 混合精度修正与 UNet FLOPS profiling

- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销
  - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配
  - alphas_cumprod 提升到 float64 保证数值精度
  - SinusoidalPosEmb 输出 dtype跟随模型精度
  - 新增 profile_unet.py 脚本及FLOPS 分析结果
  - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL
  - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
This commit is contained in:
2026-02-08 16:01:30 +00:00
parent 75c798ded0
commit f86ab51a04
11 changed files with 464 additions and 30 deletions

View File

@@ -36,7 +36,7 @@ class DDIMSampler(object):
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model
.device)
if self.model.use_dynamic_rescale:
@@ -376,10 +376,10 @@ class DDIMSampler(object):
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# Use 0-d tensors directly (already on device); broadcasting handles shape
a_t = alphas[index]
a_prev = alphas_prev[index]
sigma_t = sigmas[index]
sqrt_one_minus_at = sqrt_one_minus_alphas[index]
a_t = alphas[index].to(x.dtype)
a_prev = alphas_prev[index].to(x.dtype)
sigma_t = sigmas[index].to(x.dtype)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()