DDIM loop 内小张量分配优化,attention mask 缓存到 GPU,加速30s左右
This commit is contained in:
@@ -99,6 +99,7 @@ class CrossAttention(nn.Module):
|
||||
self.agent_state_context_len = agent_state_context_len
|
||||
self.agent_action_context_len = agent_action_context_len
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
self._attn_mask_cache = {}
|
||||
if self.image_cross_attention:
|
||||
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
@@ -275,7 +276,8 @@ class CrossAttention(nn.Module):
|
||||
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
||||
q.shape[1],
|
||||
k_aa.shape[1],
|
||||
block_size=16).to(k_aa.device)
|
||||
block_size=16,
|
||||
device=k_aa.device)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||
@@ -386,14 +388,26 @@ class CrossAttention(nn.Module):
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
|
||||
def _get_attn_mask_aa(self,
|
||||
b,
|
||||
l1,
|
||||
l2,
|
||||
block_size=16,
|
||||
device=None):
|
||||
if device is None:
|
||||
device = self.to_q.weight.device
|
||||
cache_key = (b, l1, l2, block_size, str(device))
|
||||
if cache_key in self._attn_mask_cache:
|
||||
return self._attn_mask_cache[cache_key]
|
||||
num_token = l2 // block_size
|
||||
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2)
|
||||
start_positions = ((torch.arange(b, device=device) % block_size) +
|
||||
1) * num_token
|
||||
col_indices = torch.arange(l2, device=device)
|
||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||
attn_mask = torch.zeros_like(mask, dtype=torch.float)
|
||||
attn_mask = torch.zeros_like(mask, dtype=torch.float32)
|
||||
attn_mask[mask] = float('-inf')
|
||||
self._attn_mask_cache[cache_key] = attn_mask
|
||||
return attn_mask
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user