ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配

attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
This commit is contained in:
2026-02-08 17:02:05 +00:00
parent f86ab51a04
commit 7338cc384a
6 changed files with 59 additions and 21 deletions

View File

@@ -50,6 +50,20 @@ PEAK_BF16_TFLOPS = 61.0
PEAK_FP32_TFLOPS = 30.5
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
@@ -62,6 +76,7 @@ def load_model(args):
model.eval()
model.model.to(torch.bfloat16)
apply_torch_compile(model)
model = model.cuda()
return model

View File

@@ -135,6 +135,21 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
return model
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
return model
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file.
@@ -601,6 +616,9 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Apply precision settings before moving to GPU
model = apply_precision_settings(model, args)
# Compile hot ResBlocks for operator fusion
apply_torch_compile(model)
# Export precision-converted checkpoint if requested
if args.export_precision_ckpt:
export_path = args.export_precision_ckpt