ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配

attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
This commit is contained in:
2026-02-08 17:02:05 +00:00
parent f86ab51a04
commit 7338cc384a
6 changed files with 59 additions and 21 deletions

View File

@@ -135,6 +135,21 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
return model
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
return model
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file.
@@ -601,6 +616,9 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Apply precision settings before moving to GPU
model = apply_precision_settings(model, args)
# Compile hot ResBlocks for operator fusion
apply_torch_compile(model)
# Export precision-converted checkpoint if requested
if args.export_precision_ckpt:
export_path = args.export_precision_ckpt