全链路 bf16 混合精度修正与 UNet FLOPS profiling
- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销 - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配 - alphas_cumprod 提升到 float64 保证数值精度 - SinusoidalPosEmb 输出 dtype跟随模型精度 - 新增 profile_unet.py 脚本及FLOPS 分析结果 - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
This commit is contained in:
@@ -25,6 +25,31 @@ from PIL import Image
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def patch_norm_bypass_autocast():
|
||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
|
||||
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
|
||||
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
@@ -62,7 +87,7 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model weights converted to bfloat16")
|
||||
else:
|
||||
model.diffusion_autocast_dtype = None
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model using fp32")
|
||||
|
||||
# 2. Set Projector precision
|
||||
@@ -98,6 +123,15 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
|
||||
else:
|
||||
print(" ✓ VAE kept in fp32 for best quality")
|
||||
|
||||
# 5. Safety net: ensure no fp32 parameters remain when all components are bf16
|
||||
if args.diffusion_dtype == "bf16":
|
||||
fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32]
|
||||
if fp32_params:
|
||||
print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16")
|
||||
for name, param in fp32_params:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
print(" ✓ All parameters converted to bfloat16")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
@@ -942,6 +976,7 @@ def get_parser():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
patch_norm_bypass_autocast()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
seed = args.seed
|
||||
|
||||
Reference in New Issue
Block a user