ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配
attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
This commit is contained in:
@@ -209,9 +209,9 @@ class DDIMSampler(object):
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
img = img.to(dtype=torch.float16)
|
||||
action = action.to(dtype=torch.float16)
|
||||
state = state.to(dtype=torch.float16)
|
||||
img = img.to(dtype=torch.bfloat16)
|
||||
action = action.to(dtype=torch.bfloat16)
|
||||
state = state.to(dtype=torch.bfloat16)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
|
||||
Reference in New Issue
Block a user