添加CrossAttention kv缓存,减少重复计算,提升性能,psnr=31.8022 dB
This commit is contained in:
@@ -97,6 +97,9 @@ class CrossAttention(nn.Module):
|
||||
self.text_context_len = text_context_len
|
||||
self.agent_state_context_len = agent_state_context_len
|
||||
self.agent_action_context_len = agent_action_context_len
|
||||
self._kv_cache = {}
|
||||
self._kv_cache_enabled = False
|
||||
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
if self.image_cross_attention:
|
||||
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
@@ -243,7 +246,22 @@ class CrossAttention(nn.Module):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
use_cache = self._kv_cache_enabled and not spatial_self_attn
|
||||
cache_hit = use_cache and len(self._kv_cache) > 0
|
||||
|
||||
if cache_hit:
|
||||
# Reuse cached K/V (already in (b*h, n, d) shape)
|
||||
k = self._kv_cache['k']
|
||||
v = self._kv_cache['v']
|
||||
if 'k_ip' in self._kv_cache:
|
||||
k_ip = self._kv_cache['k_ip']
|
||||
v_ip = self._kv_cache['v_ip']
|
||||
k_as = self._kv_cache['k_as']
|
||||
v_as = self._kv_cache['v_as']
|
||||
k_aa = self._kv_cache['k_aa']
|
||||
v_aa = self._kv_cache['v_aa']
|
||||
q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
|
||||
elif self.image_cross_attention and not spatial_self_attn:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:,
|
||||
self.agent_state_context_len:self.
|
||||
@@ -266,20 +284,39 @@ class CrossAttention(nn.Module):
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_ip, v_ip))
|
||||
k_as, v_as = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_as, v_as))
|
||||
k_aa, v_aa = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_aa, v_aa))
|
||||
|
||||
if use_cache:
|
||||
self._kv_cache = {
|
||||
'k': k, 'v': v,
|
||||
'k_ip': k_ip, 'v_ip': v_ip,
|
||||
'k_as': k_as, 'v_as': v_as,
|
||||
'k_aa': k_aa, 'v_aa': v_aa,
|
||||
}
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
if use_cache:
|
||||
self._kv_cache = {'k': k, 'v': v}
|
||||
|
||||
# baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul
|
||||
sim = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
del k
|
||||
|
||||
if exists(mask):
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
@@ -293,40 +330,28 @@ class CrossAttention(nn.Module):
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if k_ip is not None and k_as is not None and k_aa is not None:
|
||||
## image cross-attention
|
||||
k_ip, v_ip = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_ip, v_ip))
|
||||
## image cross-attention (k_ip/v_ip already in (b*h, n, d) shape)
|
||||
sim_ip = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
del k_ip
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
out_ip = torch.bmm(sim_ip, v_ip)
|
||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## agent state cross-attention
|
||||
k_as, v_as = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_as, v_as))
|
||||
## agent state cross-attention (k_as/v_as already in (b*h, n, d) shape)
|
||||
sim_as = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k_as.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
del k_as
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
out_as = torch.bmm(sim_as, v_as)
|
||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## agent action cross-attention
|
||||
k_aa, v_aa = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_aa, v_aa))
|
||||
## agent action cross-attention (k_aa/v_aa already in (b*h, n, d) shape)
|
||||
sim_aa = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
del k_aa
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
out_aa = torch.bmm(sim_aa, v_aa)
|
||||
@@ -526,6 +551,20 @@ class CrossAttention(nn.Module):
|
||||
return attn_mask
|
||||
|
||||
|
||||
def enable_cross_attn_kv_cache(module):
|
||||
for m in module.modules():
|
||||
if isinstance(m, CrossAttention):
|
||||
m._kv_cache_enabled = True
|
||||
m._kv_cache = {}
|
||||
|
||||
|
||||
def disable_cross_attn_kv_cache(module):
|
||||
for m in module.modules():
|
||||
if isinstance(m, CrossAttention):
|
||||
m._kv_cache_enabled = False
|
||||
m._kv_cache = {}
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
Reference in New Issue
Block a user