ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配

attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
This commit is contained in:
2026-02-08 17:02:05 +00:00
parent f86ab51a04
commit 7338cc384a
6 changed files with 59 additions and 21 deletions

View File

@@ -209,9 +209,9 @@ class DDIMSampler(object):
if precision is not None:
if precision == 16:
img = img.to(dtype=torch.float16)
action = action.to(dtype=torch.float16)
state = state.to(dtype=torch.float16)
img = img.to(dtype=torch.bfloat16)
action = action.to(dtype=torch.bfloat16)
state = state.to(dtype=torch.bfloat16)
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps