ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配
attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
This commit is contained in:
@@ -50,6 +50,20 @@ PEAK_BF16_TFLOPS = 61.0
|
||||
PEAK_FP32_TFLOPS = 30.5
|
||||
|
||||
|
||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||
unet = model.model.diffusion_model
|
||||
compiled = 0
|
||||
for idx in hot_indices:
|
||||
block = unet.output_blocks[idx]
|
||||
for layer in block:
|
||||
if isinstance(layer, ResBlock):
|
||||
layer._forward = torch.compile(layer._forward, mode="default")
|
||||
compiled += 1
|
||||
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||
|
||||
|
||||
def load_model(args):
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||
@@ -62,6 +76,7 @@ def load_model(args):
|
||||
|
||||
model.eval()
|
||||
model.model.to(torch.bfloat16)
|
||||
apply_torch_compile(model)
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user