1. einsum('b i d, b j d -> b i j') → torch.bmm(q, k.transpose(-1,-2)) — 直接映射 rocBLAS batched GEMM
2. baddbmm 把 scale 融合进 GEMM,少一次 kernel launch 3. 第二个 einsum 同理换torch.bm 每一轮加速1到两秒
This commit is contained in:
@@ -86,9 +86,8 @@ class CrossAttention(nn.Module):
|
||||
self.relative_position_v = RelativePosition(
|
||||
num_units=dim_head, max_relative_position=temporal_length)
|
||||
else:
|
||||
## only used for spatial attention, while NOT for temporal attention
|
||||
if XFORMERS_IS_AVAILBLE and temporal_length is None:
|
||||
self.forward = self.efficient_forward
|
||||
## bmm fused-scale attention for all non-relative-position cases
|
||||
self.forward = self.bmm_forward
|
||||
|
||||
self.video_length = video_length
|
||||
self.image_cross_attention = image_cross_attention
|
||||
@@ -234,6 +233,119 @@ class CrossAttention(nn.Module):
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def bmm_forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
k_as, v_as, out_as = None, None, None
|
||||
k_aa, v_aa, out_aa = None, None, None
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:,
|
||||
self.agent_state_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
# baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul
|
||||
sim = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
del k
|
||||
|
||||
if exists(mask):
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
|
||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.bmm(sim, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if k_ip is not None and k_as is not None and k_aa is not None:
|
||||
## image cross-attention
|
||||
k_ip, v_ip = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_ip, v_ip))
|
||||
sim_ip = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
del k_ip
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
out_ip = torch.bmm(sim_ip, v_ip)
|
||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## agent state cross-attention
|
||||
k_as, v_as = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_as, v_as))
|
||||
sim_as = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k_as.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
del k_as
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
out_as = torch.bmm(sim_as, v_as)
|
||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## agent action cross-attention
|
||||
k_aa, v_aa = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_aa, v_aa))
|
||||
sim_aa = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
del k_aa
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
out_aa = torch.bmm(sim_aa, v_aa)
|
||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if out_ip is not None and out_as is not None and out_aa is not None:
|
||||
if self.cross_attention_scale_learnable:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
||||
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
||||
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
||||
else:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip + \
|
||||
self.agent_state_cross_attention_scale * out_as + \
|
||||
self.agent_action_cross_attention_scale * out_aa
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def efficient_forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k, v, out = None, None, None
|
||||
|
||||
Reference in New Issue
Block a user