ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配
attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
This commit is contained in:
@@ -135,6 +135,21 @@ def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.M
|
||||
return model
|
||||
|
||||
|
||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||
unet = model.model.diffusion_model
|
||||
compiled = 0
|
||||
for idx in hot_indices:
|
||||
block = unet.output_blocks[idx]
|
||||
for layer in block:
|
||||
if isinstance(layer, ResBlock):
|
||||
layer._forward = torch.compile(layer._forward, mode="default")
|
||||
compiled += 1
|
||||
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||
return model
|
||||
|
||||
|
||||
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
||||
"""Save a list of frames to a video file.
|
||||
|
||||
@@ -601,6 +616,9 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Apply precision settings before moving to GPU
|
||||
model = apply_precision_settings(model, args)
|
||||
|
||||
# Compile hot ResBlocks for operator fusion
|
||||
apply_torch_compile(model)
|
||||
|
||||
# Export precision-converted checkpoint if requested
|
||||
if args.export_precision_ckpt:
|
||||
export_path = args.export_precision_ckpt
|
||||
|
||||
Reference in New Issue
Block a user