ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配

attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
This commit is contained in:
2026-02-08 17:02:05 +00:00
parent f86ab51a04
commit 7338cc384a
6 changed files with 59 additions and 21 deletions

View File

@@ -209,9 +209,9 @@ class DDIMSampler(object):
if precision is not None:
if precision == 16:
img = img.to(dtype=torch.float16)
action = action.to(dtype=torch.float16)
state = state.to(dtype=torch.float16)
img = img.to(dtype=torch.bfloat16)
action = action.to(dtype=torch.bfloat16)
state = state.to(dtype=torch.bfloat16)
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps

View File

@@ -173,7 +173,8 @@ class CrossAttention(nn.Module):
sim.masked_fill_(~(mask > 0.5), max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
with torch.amp.autocast('cuda', enabled=False):
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v)
if self.relative_position:
@@ -190,7 +191,8 @@ class CrossAttention(nn.Module):
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
k_ip) * self.scale
del k_ip
sim_ip = sim_ip.softmax(dim=-1)
with torch.amp.autocast('cuda', enabled=False):
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
@@ -201,7 +203,8 @@ class CrossAttention(nn.Module):
sim_as = torch.einsum('b i d, b j d -> b i j', q,
k_as) * self.scale
del k_as
sim_as = sim_as.softmax(dim=-1)
with torch.amp.autocast('cuda', enabled=False):
sim_as = sim_as.softmax(dim=-1)
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
@@ -212,7 +215,8 @@ class CrossAttention(nn.Module):
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
k_aa) * self.scale
del k_aa
sim_aa = sim_aa.softmax(dim=-1)
with torch.amp.autocast('cuda', enabled=False):
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)