ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配
attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
This commit is contained in:
@@ -173,7 +173,8 @@ class CrossAttention(nn.Module):
|
||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
||||
if self.relative_position:
|
||||
@@ -190,7 +191,8 @@ class CrossAttention(nn.Module):
|
||||
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_ip) * self.scale
|
||||
del k_ip
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
@@ -201,7 +203,8 @@ class CrossAttention(nn.Module):
|
||||
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_as) * self.scale
|
||||
del k_as
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
@@ -212,7 +215,8 @@ class CrossAttention(nn.Module):
|
||||
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_aa) * self.scale
|
||||
del k_aa
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user