第一次完整测例跑完

This commit is contained in:
2026-01-18 00:30:10 +08:00
parent ca15cc593b
commit 25c6fc04db
180 changed files with 29305 additions and 0 deletions

View File

View File

@@ -0,0 +1,806 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from functools import partial
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
from unifolm_wma.utils.common import (
checkpoint,
exists,
default,
)
from unifolm_wma.utils.basics import zero_module
class RelativePosition(nn.Module):
""" https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(
torch.Tensor(max_relative_position * 2 + 1, num_units))
nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k):
device = self.embeddings_table.device
range_vec_q = torch.arange(length_q, device=device)
range_vec_k = torch.arange(length_k, device=device)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(distance_mat,
-self.max_relative_position,
self.max_relative_position)
final_mat = distance_mat_clipped + self.max_relative_position
final_mat = final_mat.long()
embeddings = self.embeddings_table[final_mat]
return embeddings
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.,
relative_position=False,
temporal_length=None,
video_length=None,
agent_state_context_len=2,
agent_action_context_len=16,
image_cross_attention=False,
image_cross_attention_scale=1.0,
agent_state_cross_attention_scale=1.0,
agent_action_cross_attention_scale=1.0,
cross_attention_scale_learnable=False,
text_context_len=77):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
self.relative_position = relative_position
if self.relative_position:
assert (temporal_length is not None)
self.relative_position_k = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length)
self.relative_position_v = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length)
else:
## only used for spatial attention, while NOT for temporal attention
if XFORMERS_IS_AVAILBLE and temporal_length is None:
self.forward = self.efficient_forward
self.video_length = video_length
self.image_cross_attention = image_cross_attention
self.image_cross_attention_scale = image_cross_attention_scale
self.agent_state_cross_attention_scale = agent_state_cross_attention_scale
self.agent_action_cross_attention_scale = agent_action_cross_attention_scale
self.text_context_len = text_context_len
self.agent_state_context_len = agent_state_context_len
self.agent_action_context_len = agent_action_context_len
self.cross_attention_scale_learnable = cross_attention_scale_learnable
if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_k_as = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_as = nn.Linear(context_dim, inner_dim, bias=False)
self.to_k_aa = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_aa = nn.Linear(context_dim, inner_dim, bias=False)
if cross_attention_scale_learnable:
self.register_parameter('alpha_ctx',
nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_cas',
nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_caa',
nn.Parameter(torch.tensor(0.)))
def forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None
h = self.heads
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:,
self.agent_state_context_len:self.
agent_state_context_len +
self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len +
self.agent_action_context_len:self.
agent_state_context_len +
self.agent_action_context_len +
self.text_context_len, :]
context_image = context[:, self.agent_state_context_len +
self.agent_action_context_len +
self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
else:
if not spatial_self_attn:
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
if self.relative_position:
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
k2 = self.relative_position_k(len_q, len_k)
sim2 = einsum('b t d, t s d -> b t s', q,
k2) * self.scale # TODO check
sim += sim2
del k
if exists(mask):
## feasible for causal attention mask only
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
sim.masked_fill_(~(mask > 0.5), max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v)
if self.relative_position:
v2 = self.relative_position_v(len_q, len_v)
out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
out += out2
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
if k_ip is not None and k_as is not None and k_aa is not None:
## for image cross-attention
k_ip, v_ip = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_ip, v_ip))
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
k_ip) * self.scale
del k_ip
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
## for agent state cross-attention
k_as, v_as = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_as, v_as))
sim_as = torch.einsum('b i d, b j d -> b i j', q,
k_as) * self.scale
del k_as
sim_as = sim_as.softmax(dim=-1)
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
## for agent action cross-attention
k_aa, v_aa = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_aa, v_aa))
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
k_aa) * self.scale
del k_aa
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
if out_ip is not None and out_as is not None and out_aa is not None:
if self.cross_attention_scale_learnable:
out = out + \
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
else:
out = out + \
self.image_cross_attention_scale * out_ip + \
self.agent_state_cross_attention_scale * out_as + \
self.agent_action_cross_attention_scale * out_aa
return self.to_out(out)
def efficient_forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k, v, out = None, None, None
k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
if context.shape[1] == self.text_context_len + self.video_length:
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
k = self.to_k(context)
v = self.to_v(context)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
else:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
q.shape[1],
k_aa.shape[1],
block_size=16).to(k_aa.device)
else:
if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..."
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
if k is not None:
k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(k, v),
)
out = xformers.ops.memory_efficient_attention(q,
k,
v,
attn_bias=None,
op=None)
out = (out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
if k_ip is not None:
# For image cross-attention
k_ip, v_ip = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_ip, v_ip),
)
out_ip = xformers.ops.memory_efficient_attention(q,
k_ip,
v_ip,
attn_bias=None,
op=None)
out_ip = (out_ip.unsqueeze(0).reshape(
b, self.heads, out_ip.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out_ip.shape[1],
self.heads * self.dim_head))
if k_as is not None:
# For agent state cross-attention
k_as, v_as = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_as, v_as),
)
out_as = xformers.ops.memory_efficient_attention(q,
k_as,
v_as,
attn_bias=None,
op=None)
out_as = (out_as.unsqueeze(0).reshape(
b, self.heads, out_as.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out_as.shape[1],
self.heads * self.dim_head))
if k_aa is not None:
# For agent action cross-attention
k_aa, v_aa = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_aa, v_aa),
)
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
attn_mask_aa = attn_mask_aa.to(q.dtype)
out_aa = xformers.ops.memory_efficient_attention(
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
out_aa = (out_aa.unsqueeze(0).reshape(
b, self.heads, out_aa.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out_aa.shape[1],
self.heads * self.dim_head))
if exists(mask):
raise NotImplementedError
out = 0.0 if out is None else out
out_ip = 0.0 if out_ip is None else out_ip
out_as = 0.0 if out_as is None else out_as
out_aa = 0.0 if out_aa is None else out_aa
if self.cross_attention_scale_learnable:
out = out + \
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
else:
out = out + \
self.image_cross_attention_scale * out_ip + \
self.agent_state_cross_attention_scale * out_as + \
self.agent_action_cross_attention_scale * out_aa
return self.to_out(out)
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
num_token = l2 // block_size
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
col_indices = torch.arange(l2)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros_like(mask, dtype=torch.float)
attn_mask[mask] = float('-inf')
return attn_mask
class BasicTransformerBlock(nn.Module):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attention_cls=None,
video_length=None,
agent_state_context_len=2,
agent_action_context_len=16,
image_cross_attention=False,
image_cross_attention_scale=1.0,
cross_attention_scale_learnable=False,
text_context_len=77):
super().__init__()
attn_cls = CrossAttention if attention_cls is None else attention_cls
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
video_length=video_length,
agent_state_context_len=agent_state_context_len,
agent_action_context_len=agent_action_context_len,
image_cross_attention=image_cross_attention,
image_cross_attention_scale=image_cross_attention_scale,
cross_attention_scale_learnable=cross_attention_scale_learnable,
text_context_len=text_context_len)
self.image_cross_attention = image_cross_attention
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None, mask=None, **kwargs):
# implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
input_tuple = (
x,
) # should not be (x), otherwise *input_tuple will decouple x into multiple arguments
if context is not None:
input_tuple = (x, context)
if mask is not None:
forward_mask = partial(self._forward, mask=mask)
return checkpoint(forward_mask, (x, ), self.parameters(),
self.checkpoint)
return checkpoint(self._forward, input_tuple, self.parameters(),
self.checkpoint)
def _forward(self, x, context=None, mask=None):
x = self.attn1(self.norm1(x),
context=context if self.disable_self_attn else None,
mask=mask) + x
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data in spatial axis.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
use_checkpoint=True,
disable_self_attn=False,
use_linear=False,
video_length=None,
agent_state_context_len=2,
agent_action_context_len=16,
image_cross_attention=False,
cross_attention_scale_learnable=False):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
attention_cls = None
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
attention_cls=attention_cls,
video_length=video_length,
agent_state_context_len=agent_state_context_len,
agent_action_context_len=agent_action_context_len,
image_cross_attention=image_cross_attention,
cross_attention_scale_learnable=cross_attention_scale_learnable,
) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None, **kwargs):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, **kwargs)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
class TemporalTransformer(nn.Module):
"""
Transformer block for image-like data in temporal axis.
First, reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
use_checkpoint=True,
use_linear=False,
only_self_att=True,
causal_attention=False,
causal_block_size=1,
relative_position=False,
temporal_length=None):
super().__init__()
self.only_self_att = only_self_att
self.relative_position = relative_position
self.causal_attention = causal_attention
self.causal_block_size = causal_block_size
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.proj_in = nn.Conv1d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
if not use_linear:
self.proj_in = nn.Conv1d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
if relative_position:
assert (temporal_length is not None)
attention_cls = partial(CrossAttention,
relative_position=True,
temporal_length=temporal_length)
else:
attention_cls = partial(CrossAttention,
temporal_length=temporal_length)
if self.causal_attention:
assert (temporal_length is not None)
self.mask = torch.tril(
torch.ones([1, temporal_length, temporal_length]))
if self.only_self_att:
context_dim = None
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
attention_cls=attention_cls,
checkpoint=use_checkpoint)
for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv1d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):
b, c, t, h, w = x.shape
x_in = x
x = self.norm(x)
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
if self.use_linear:
x = self.proj_in(x)
temp_mask = None
if self.causal_attention:
# Slice the from mask map
temp_mask = self.mask[:, :t, :t].to(x.device)
if temp_mask is not None:
mask = temp_mask.to(x.device)
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b * h * w)
else:
mask = None
if self.only_self_att:
# NOTE: if no context is given, cross-attention defaults to self-attention
for i, block in enumerate(self.transformer_blocks):
x = block(x, mask=mask)
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
context = rearrange(context, '(b t) l con -> b t l con',
t=t).contiguous()
for i, block in enumerate(self.transformer_blocks):
# Calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
context_j = repeat(context[j],
't l con -> (t r) l con',
r=(h * w) // t,
t=t).contiguous()
# Note: causal mask will not applied in cross-attention case
x[j] = block(x[j], context=context_j)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
if not self.use_linear:
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
x = self.proj_out(x)
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h,
w=w).contiguous()
return x + x_in
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# Compute attention
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# Attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x + h_

View File

@@ -0,0 +1,630 @@
import torch
import torch.nn as nn
import kornia
import open_clip
import math
from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
from unifolm_wma.utils.common import autocast
from unifolm_wma.utils.utils import count_params
from unifolm_wma.modules.encoders.resampler import reshape_tensor
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes -
1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs, ), device=device) * uc_class
uc = {self.key: uc}
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self,
version="google/t5-v1_1-xxl",
device="cuda",
max_length=77,
freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens,
output_hidden_states=self.layer == "hidden")
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class ClipImageEmbedder(nn.Module):
def __init__(self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=True,
ucg_rate=0.):
super().__init__()
from clip import load as load_clip
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer('std',
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# re-normalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli(
(1. - self.ucg_rate) *
torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
# "pooled",
"last",
"penultimate"
]
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(
text) ## all clip models use 77 as context length
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="pooled",
antialias=True,
ucg_rate=0.):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device('cpu'),
pretrained=version,
)
del model.transformer
self.model = model
# self.mapper = torch.nn.Linear(1280, 1024)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer('std',
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
@autocast
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli(
(1. - self.ucg_rate) *
torch.ones(z.shape[0], device=z.device))[:, None] * z
return z
def encode_with_vision_transformer(self, img):
img = self.preprocess(img)
x = self.model.visual(img)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
freeze=True,
layer="pooled",
antialias=True):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device('cpu'),
pretrained=version,
)
del model.transformer
self.model = model
self.device = device
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer('std',
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def forward(self, image, no_dropout=False):
## image: b c h w
z = self.encode_with_vision_transformer(image)
return z
def encode_with_vision_transformer(self, x):
x = self.preprocess(x)
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.model.visual.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(x.shape[0], x.shape[1],
self.model.visual.grid_size[0],
self.model.visual.patch_size[0],
self.model.visual.grid_size[1],
self.model.visual.patch_size[1])
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(
x.shape[0], self.model.visual.grid_size[0] *
self.model.visual.grid_size[1], -1)
x = self.model.visual.patchnorm_pre_ln(x)
x = self.model.visual.conv1(x)
else:
x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat([
self.model.visual.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.model.visual.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.model.visual.patch_dropout(x)
x = self.model.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.model.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
return x
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version,
device,
max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version,
device,
max_length=t5_max_length)
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
class LinearProjector(nn.Module):
def __init__(self, input_dim: int, output_dim: int) -> None:
super().__init__()
self.projector = nn.Linear(input_dim, output_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class MLPProjector(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int,
mlp_type: str = "gelu-mlp") -> None:
super().__init__()
if mlp_type == "gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(input_dim, output_dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(output_dim, output_dim, bias=True),
)
elif mlp_type == "silu-mlp":
self.projector = nn.Sequential(
nn.Linear(input_dim, output_dim, bias=True),
nn.SiLU(),
nn.Linear(output_dim, output_dim, bias=True),
)
else:
raise ValueError(
f"Projector with `{mlp_type = }` is not supported!")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(
-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
inner_dim = int(dim * mult)
if ffd_type == "gelu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(approximate='tanh'),
nn.Linear(inner_dim, dim, bias=False),
)
elif ffd_type == "silu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.SiLU(),
nn.Linear(inner_dim, dim, bias=False),
)
else:
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
class SATokenProjector(nn.Module):
def __init__(self,
dim=1024,
depth=1,
dim_head=64,
heads=16,
num_queries=16,
output_dim=1024,
ff_mult=4,
chunk_size=None):
super().__init__()
self.num_queries = num_queries
self.chunk_size = chunk_size
if chunk_size is not None:
num_queries = num_queries * chunk_size
self.latents = nn.Parameter(
torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head,
heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)
return latents

View File

@@ -0,0 +1,153 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
import math
import torch
import torch.nn as nn
class ImageProjModel(nn.Module):
"""Projection Model"""
def __init__(self,
cross_attention_dim=1024,
clip_embeddings_dim=1024,
clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = nn.Linear(
clip_embeddings_dim,
self.clip_extra_context_tokens * cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
#embeds = image_embeds
embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(
-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
video_length=None, # using frame-wise version or not
):
super().__init__()
## queries for a single frame / image
self.num_queries = num_queries
self.video_length = video_length
## <num_queries> queries for each frame
if video_length is not None:
num_queries = num_queries * video_length
self.latents = nn.Parameter(
torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head,
heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)
return latents

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,848 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from functools import partial
from abc import abstractmethod
from einops import rearrange
from omegaconf import OmegaConf
from typing import Optional, Sequence, Any, Tuple, Union, List, Dict
from collections.abc import Mapping, Iterable, Callable
from unifolm_wma.utils.diffusion import timestep_embedding
from unifolm_wma.utils.common import checkpoint
from unifolm_wma.utils.basics import (zero_module, conv_nd, linear,
avg_pool_nd, normalization)
from unifolm_wma.modules.attention import SpatialTransformer, TemporalTransformer
from unifolm_wma.utils.utils import instantiate_from_config
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None, batch_size=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb, batch_size=batch_size)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
elif isinstance(layer, TemporalTransformer):
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
x = layer(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
else:
x = layer(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims,
self.channels,
self.out_channels,
3,
padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode='nearest')
else:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
:param use_temporal_conv: if True, use the temporal convolution.
:param use_image_dataset: if True, the temporal parameters will not be optimized.
"""
def __init__(self,
channels,
emb_channels,
dropout,
out_channels=None,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
use_conv=False,
up=False,
down=False,
use_temporal_conv=False,
tempspatial_aware=False):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.use_temporal_conv = use_temporal_conv
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels
if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims,
channels,
self.out_channels,
3,
padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels,
1)
if self.use_temporal_conv:
self.temopral_conv = TemporalConvBlock(
self.out_channels,
self.out_channels,
dropout=0.1,
spatial_aware=tempspatial_aware)
def forward(self, x, emb, batch_size=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
input_tuple = (x, emb)
if batch_size:
forward_batchsize = partial(self._forward, batch_size=batch_size)
return checkpoint(forward_batchsize, input_tuple,
self.parameters(), self.use_checkpoint)
return checkpoint(self._forward, input_tuple, self.parameters(),
self.use_checkpoint)
def _forward(self, x, emb, batch_size=None):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
h = self.skip_connection(x) + h
if self.use_temporal_conv and batch_size:
h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
h = self.temopral_conv(h)
h = rearrange(h, 'b c t h w -> (b t) c h w')
return h
class TemporalConvBlock(nn.Module):
"""
Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
"""
def __init__(self,
in_channels,
out_channels=None,
dropout=0.0,
spatial_aware=False):
super(TemporalConvBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
self.out_channels = out_channels
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_channels), nn.SiLU(),
nn.Conv3d(in_channels,
out_channels,
th_kernel_shape,
padding=th_padding_shape))
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels,
in_channels,
tw_kernel_shape,
padding=tw_padding_shape))
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels,
in_channels,
th_kernel_shape,
padding=th_padding_shape))
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels,
in_channels,
tw_kernel_shape,
padding=tw_padding_shape))
# Zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return identity + x
class WMAModel(nn.Module):
"""
The full World-Model-Action model.
"""
def __init__(self,
in_channels: int,
model_channels: int,
out_channels: int,
num_res_blocks: int,
attention_resolutions: Sequence[int],
dropout: float = 0.0,
channel_mult: Sequence[int] = (1, 2, 4, 8),
conv_resample: bool = True,
dims: int = 2,
context_dim: int | None = None,
use_scale_shift_norm: bool = False,
resblock_updown: bool = False,
num_heads: int = -1,
num_head_channels: int = -1,
transformer_depth: int = 1,
use_linear: bool = False,
use_checkpoint: bool = False,
temporal_conv: bool = False,
tempspatial_aware: bool = False,
temporal_attention: bool = True,
use_relative_position: bool = True,
use_causal_attention: bool = False,
temporal_length: int | None = None,
use_fp16: bool = False,
addition_attention: bool = False,
temporal_selfatt_only: bool = True,
image_cross_attention: bool = False,
cross_attention_scale_learnable: bool = False,
default_fs: int = 4,
fs_condition: bool = False,
n_obs_steps: int = 1,
num_stem_token: int = 1,
unet_head_config: OmegaConf | None = None,
stem_process_config: OmegaConf | None = None,
base_model_gen_only: bool = False):
"""
Initialize the World-Model-Action network.
Args:
in_channels: Number of input channels to the backbone.
model_channels: Base channel width for the UNet/backbone.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks per resolution stage.
attention_resolutions: Resolutions at which to enable attention.
dropout: Dropout probability used inside residual/attention blocks.
channel_mult: Multipliers for channels at each resolution level.
conv_resample: If True, use convolutional resampling for up/down sampling.
dims: Spatial dimensionality of the backbone (1/2/3).
context_dim: Optional context embedding dimension (for cross-attention).
use_scale_shift_norm: Enable scale-shift (FiLM-style) normalization in blocks.
resblock_updown: Use residual blocks for up/down sampling (instead of plain conv).
num_heads: Number of attention heads (if >= 0). If -1, derive from num_head_channels.
num_head_channels: Channels per attention head (if >= 0). If -1, derive from num_heads.
transformer_depth: Number of transformer/attention blocks per stage.
use_linear: Use linear attention variants where applicable.
use_checkpoint: Enable gradient checkpointing in blocks to save memory.
temporal_conv: Include temporal convolution along the time dimension.
tempspatial_aware: If True, use timespace aware blocks.
temporal_attention: Enable temporal self-attention.
use_relative_position: Use relative position encodings in attention.
use_causal_attention: Use causal (uni-directional) attention along time.
temporal_length: Optional maximum temporal length expected by the model.
use_fp16: Enable half-precision layers/normalization where supported.
addition_attention: Add auxiliary attention modules.
temporal_selfatt_only: Restrict attention to temporal-only (no spatial) if True.
image_cross_attention: Enable cross-attention with image embeddings.
cross_attention_scale_learnable: Make cross-attention scaling a learnable parameter.
default_fs: Default frame-stride / fps.
fs_condition: If True, condition on frame-stride/fps features.
n_obs_steps: Number of observed steps used in conditioning heads.
num_stem_token: Number of stem tokens for action tokenization.
unet_head_config: OmegaConf for UNet heads (e.g., action/state heads).
stem_process_config: OmegaConf for stem/preprocessor module.
base_model_gen_only: Perform the generation using the base model with out action and state outputs.
"""
super(WMAModel, self).__init__()
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.temporal_attention = temporal_attention
time_embed_dim = model_channels * 4
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
temporal_self_att_only = True
self.addition_attention = addition_attention
self.temporal_length = temporal_length
self.image_cross_attention = image_cross_attention
self.cross_attention_scale_learnable = cross_attention_scale_learnable
self.default_fs = default_fs
self.fs_condition = fs_condition
self.n_obs_steps = n_obs_steps
self.num_stem_token = num_stem_token
self.base_model_gen_only = base_model_gen_only
# Time embedding blocks
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if fs_condition:
self.fps_embedding = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
# Input Block
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1))
])
if self.addition_attention:
self.init_attn = TimestepEmbedSequential(
TemporalTransformer(model_channels,
n_heads=8,
d_head=num_head_channels,
depth=transformer_depth,
context_dim=context_dim,
use_checkpoint=use_checkpoint,
only_self_att=temporal_selfatt_only,
causal_attention=False,
relative_position=use_relative_position,
temporal_length=temporal_length))
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
agent_state_context_len=self.n_obs_steps,
agent_action_context_len=self.temporal_length *
num_stem_token,
image_cross_attention=self.image_cross_attention,
cross_attention_scale_learnable=self.
cross_attention_scale_learnable,
))
if self.temporal_attention:
layers.append(
TemporalTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length))
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True)
if resblock_updown else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers = [
ResBlock(ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv),
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
agent_state_context_len=self.n_obs_steps,
agent_action_context_len=self.temporal_length * num_stem_token,
image_cross_attention=self.image_cross_attention,
cross_attention_scale_learnable=self.
cross_attention_scale_learnable)
]
if self.temporal_attention:
layers.append(
TemporalTransformer(ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length))
layers.append(
ResBlock(ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv))
# Middle Block
self.middle_block = TimestepEmbedSequential(*layers)
# Output Block
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(ch + ich,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
agent_state_context_len=self.n_obs_steps,
image_cross_attention=self.image_cross_attention,
cross_attention_scale_learnable=self.
cross_attention_scale_learnable))
if self.temporal_attention:
layers.append(
TemporalTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length))
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True)
if resblock_updown else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
# Action and state prediction unet
unet_head_config['params']['context_dims'] = [
mult * model_channels for mult in channel_mult
]
self.action_unet = instantiate_from_config(unet_head_config)
self.state_unet = instantiate_from_config(unet_head_config)
# Initialize action token_projector
self.action_token_projector = instantiate_from_config(
stem_process_config)
def forward(self,
x: Tensor,
x_action: Tensor,
x_state: Tensor,
timesteps: Tensor,
context: Tensor | None = None,
context_action: Tensor | None = None,
features_adapter: Any = None,
fs: Tensor | None = None,
**kwargs) -> Tensor | tuple[Tensor, ...]:
"""
Forward pass of the World-Model-Action backbone.
Args:
x: Input tensor (latent video), shape (B, C,...).
x_action: action stream input.
x_state: state stream input.
timesteps: Diffusion timesteps, shape (B,) or scalar Tensor.
context: conditioning context for cross-attention.
context_action: conditioning context specific to action/state (implementation-specific).
features_adapter: module or dict to adapt intermediate features.
fs: frame-stride / fps conditioning.
Returns:
Tuple of Tensors for predictions:
"""
b, _, t, _, _ = x.shape
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb)
bt, l_context, _ = context.shape
if self.base_model_gen_only:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
else:
if l_context == self.n_obs_steps + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
77, :]
context_img = context[:, self.n_obs_steps + 77:, :]
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context = torch.cat(
[context_agent_state, context_text, context_img], dim=1)
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_agent_action = context[:, self.
n_obs_steps:self.n_obs_steps +
16, :]
context_agent_action = rearrange(
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
context_agent_action = self.action_token_projector(
context_agent_action)
context_agent_action = rearrange(context_agent_action,
'(b o) l d -> b o l d',
o=t)
context_agent_action = rearrange(context_agent_action,
'b o (t l) d -> b o t l d',
t=t)
context_agent_action = context_agent_action.permute(
0, 2, 1, 3, 4)
context_agent_action = rearrange(context_agent_action,
'b t o l d -> (b t) (o l) d')
context_text = context[:, self.n_obs_steps +
16:self.n_obs_steps + 16 + 77, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context = torch.cat([
context_agent_state, context_agent_action, context_text,
context_img
],
dim=1)
emb = emb.repeat_interleave(repeats=t, dim=0)
x = rearrange(x, 'b c t h w -> (b t) c h w')
# Combine emb
if self.fs_condition:
if fs is None:
fs = torch.tensor([self.default_fs] * b,
dtype=torch.long,
device=x.device)
fs_emb = timestep_embedding(fs,
self.model_channels,
repeat_only=False).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + fs_embed
h = x.type(self.dtype)
adapter_idx = 0
hs = []
hs_a = []
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b)
if id == 0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b)
# plug-in adapter features
if ((id + 1) % 3 == 0) and features_adapter is not None:
h = h + features_adapter[adapter_idx]
adapter_idx += 1
if id != 0:
if isinstance(module[0], Downsample):
hs_a.append(
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
hs.append(h)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
if features_adapter is not None:
assert len(
features_adapter) == adapter_idx, 'Wrong features_adapter'
h = self.middle_block(h, emb, context=context, batch_size=b)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
hs_out = []
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b)
if isinstance(module[-1], Upsample):
hs_a.append(
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
hs_out.append(h)
h = h.type(x.dtype)
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
y = self.out(h)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
if not self.base_model_gen_only:
ba, _, _ = x_action.shape
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
# Predict state
if b > 1:
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
else:
s_y = self.state_unet(x_state, timesteps, hs_a,
context_action[:2], **kwargs)
else:
a_y = torch.zeros_like(x_action)
s_y = torch.zeros_like(x_state)
return y, a_y, s_y

View File

@@ -0,0 +1,244 @@
"""
base_vision.py
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
functions, and initialization logic.
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
Transformer model for feature extraction.
"""
import timm
import torch
import torch.nn as nn
import torchvision.transforms.functional as TVF
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
from PIL.Image import Image
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
from torchvision.transforms import Compose, Resize
# === Utility Functions for Monkey-Patching ===
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
result = fn(*args, **kwargs)
return result[0] if isinstance(result, tuple) else result
return wrapper
# === Interface for an Image Transform ===
class ImageTransform(Protocol):
def __call__(
self, img: Image,
**kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
...
# === Custom Torchvision Image Transforms ===
@dataclass
class LetterboxPad:
padding_fill_value: Tuple[int, int, int]
def __call__(self, image: Image) -> Image:
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
(w, h), max_wh = image.size, max(image.size)
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int(
(max_wh - h) / 2)
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
return TVF.pad(image,
padding,
fill=self.padding_fill_value,
padding_mode="constant")
# === Abstract Base Class for arbitrary Vision Backbones ===
class VisionBackbone(nn.Module, ABC):
def __init__(self,
vision_backbone_id: str,
image_resize_strategy: str,
default_image_size: int = 224) -> None:
super().__init__()
self.identifier: str = vision_backbone_id
self.image_resize_strategy: str = image_resize_strategy
self.default_image_size: int = default_image_size
# Instance attributes for a Vision Backbone
self.featurizer: nn.Module = None
self.image_transform: ImageTransform = None
def get_image_transform(self) -> ImageTransform:
return self.image_transform
@abstractmethod
def get_fsdp_wrapping_policy(self) -> Callable:
...
@abstractmethod
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
raise NotImplementedError
@property
@abstractmethod
def default_image_resolution(self) -> Tuple[int, int, int]:
...
@property
@abstractmethod
def embed_dim(self) -> int:
...
@property
@abstractmethod
def num_patches(self) -> int:
...
@property
@abstractmethod
def half_precision_dtype(self) -> torch.dtype:
...
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
class TimmViTBackbone(VisionBackbone, ABC):
def __init__(
self,
vision_backbone_id: str,
timm_path_or_url: str,
image_resize_strategy: str,
default_image_size: int = 224,
override_act_layer: Optional[str] = None,
) -> None:
super().__init__(vision_backbone_id,
image_resize_strategy,
default_image_size=default_image_size)
self.timm_path_or_url = timm_path_or_url
self.override_act_layer = override_act_layer
self.dtype = torch.bfloat16
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
if self.override_act_layer is None:
self.featurizer: VisionTransformer = timm.create_model(
self.timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
else:
self.featurizer: VisionTransformer = timm.create_model(
self.timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size,
act_layer=self.override_act_layer,
)
self.featurizer.eval()
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
self.featurizer.forward = unpack_tuple(
partial(self.featurizer.get_intermediate_layers,
n={len(self.featurizer.blocks) - 2}))
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
assert isinstance(self.featurizer, VisionTransformer), (
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
"file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
)
# Get Config =>> Note :: Override default image size to ensure correct image transform
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
self.data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
default_image_transform = timm.data.create_transform(**self.data_cfg,
is_training=False)
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(default_image_transform.transforms[0], Resize)
default_image_transform = Compose([
Resize(self.default_image_size,
interpolation=default_image_transform.transforms[0].
interpolation),
*default_image_transform.transforms[1:],
])
# Switch on `image_resize_strategy`
if self.image_resize_strategy == "resize-naive":
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(default_image_transform.transforms[0], Resize)
target_size = (self.default_image_size, self.default_image_size)
self.image_transform = Compose([
Resize(target_size,
interpolation=default_image_transform.transforms[0].
interpolation),
*default_image_transform.transforms[1:],
])
elif self.image_resize_strategy == "resize-crop":
self.image_transform = default_image_transform
elif self.image_resize_strategy == "letterbox":
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
# Compute Padding Fill Value (rescaled normalization mean if applicable)
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
# Build New Transform
self.image_transform = Compose(
[LetterboxPad(fill), *default_image_transform.transforms])
else:
raise ValueError(
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
)
def get_fsdp_wrapping_policy(self) -> Callable:
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
vit_wrap_policy = partial(_module_wrap_policy,
module_classes={VisionTransformer})
transformer_block_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={Block})
return partial(_or_policy,
policies=[vit_wrap_policy, transformer_block_policy])
def forward(
self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> torch.Tensor:
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
return self.featurizer(pixel_values)
@property
def default_image_resolution(self) -> Tuple[int, int, int]:
return self.data_cfg["input_size"]
@property
def embed_dim(self) -> int:
return self.featurizer.embed_dim
@property
def num_patches(self) -> int:
return self.featurizer.patch_embed.num_patches
@property
def half_precision_dtype(self) -> torch.dtype:
return self.dtype

View File

@@ -0,0 +1,273 @@
"""
dinosiglip_vit.py
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
"""
import timm
import torch
import torchvision.transforms as transforms
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Tuple
from PIL import Image
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
from torchvision.transforms import Compose, Resize, Normalize
from unifolm_wma.modules.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
from unifolm_wma.utils.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
DINOSigLIP_VISION_BACKBONES = {
"dinosiglip-vit-so-224px": {
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
"siglip": "vit_so400m_patch14_siglip_224",
},
"dinosiglip-vit-so-384px": {
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
"siglip": "vit_so400m_patch14_siglip_384",
},
}
@dataclass
class DinoSigLIPImageTransform:
dino_image_transform: ImageTransform
siglip_image_transform: ImageTransform
is_prismatic: bool = True
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
return {
"dino": self.dino_image_transform(img, **kwargs),
"siglip": self.siglip_image_transform(img, **kwargs)
}
class DinoSigLIPViTBackbone(VisionBackbone):
def __init__(self,
vision_backbone_id: str,
image_resize_strategy: str,
arch_specifier: str,
output_dim: int,
pretrained_checkpoint=None,
freeze=True,
default_image_size: int = 224) -> None:
super().__init__(vision_backbone_id,
image_resize_strategy,
default_image_size=default_image_size)
self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
vision_backbone_id]["dino"]
self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
vision_backbone_id]["siglip"]
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
self.dino_featurizer: VisionTransformer = timm.create_model(
self.dino_timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
if pretrained_checkpoint:
ckpt = pretrained_checkpoint + '/openvla_dino.pt'
self.dino_featurizer.load_state_dict(
torch.load(ckpt, weights_only=True))
print('>>> load dino weights')
if freeze:
self.dino_featurizer.eval()
for param in self.dino_featurizer.parameters():
param.requires_grad = False
self.siglip_featurizer: VisionTransformer = timm.create_model(
self.siglip_timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
if pretrained_checkpoint:
ckpt = pretrained_checkpoint + '/openvla_siglip.pt'
self.siglip_featurizer.load_state_dict(
torch.load(ckpt, weights_only=True))
print('>>> load siglip weights')
if freeze:
self.siglip_featurizer.eval()
for param in self.siglip_featurizer.parameters():
param.requires_grad = False
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
self.dino_featurizer.forward = unpack_tuple(
partial(self.dino_featurizer.get_intermediate_layers,
n={len(self.dino_featurizer.blocks) - 2}))
self.siglip_featurizer.forward = unpack_tuple(
partial(self.siglip_featurizer.get_intermediate_layers,
n={len(self.siglip_featurizer.blocks) - 2}))
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
self.dino_data_cfg = timm.data.resolve_model_data_config(
self.dino_featurizer)
self.dino_data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
self.siglip_data_cfg = timm.data.resolve_model_data_config(
self.siglip_featurizer)
self.siglip_data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
# Initialize *both* Transforms
self.default_dino_transform = timm.data.create_transform(
**self.dino_data_cfg, is_training=False)
self.default_siglip_transform = timm.data.create_transform(
**self.siglip_data_cfg, is_training=False)
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
assert isinstance(self.default_siglip_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(self.default_siglip_transform.transforms[0], Resize)
self.default_siglip_transform = Compose([
Resize(self.default_image_size,
interpolation=self.default_siglip_transform.transforms[0].
interpolation),
*self.default_siglip_transform.transforms[1:],
])
if self.image_resize_strategy == "resize-naive":
assert isinstance(
self.default_dino_transform,
Compose), "Unexpected `default_dino_image_transform`!"
assert isinstance(
self.default_siglip_transform,
Compose), "Unexpected `default_siglip_image_transform`!"
assert isinstance(self.default_dino_transform.transforms[0],
Resize)
assert isinstance(self.default_siglip_transform.transforms[0],
Resize)
self.target_size = (self.default_image_size,
self.default_image_size)
dino_transform = Compose([
Resize(self.target_size,
interpolation=self.default_dino_transform.transforms[0].
interpolation),
*self.default_dino_transform.transforms[1:],
])
siglip_transform = Compose([
Resize(self.target_size,
interpolation=self.default_siglip_transform.
transforms[0].interpolation),
*self.default_siglip_transform.transforms[1:],
])
self.image_transform = DinoSigLIPImageTransform(
dino_transform, siglip_transform)
elif self.image_resize_strategy == "resize-crop":
self.image_transform = DinoSigLIPImageTransform(
self.default_dino_transform, self.default_siglip_transform)
elif self.image_resize_strategy == "letterbox":
assert isinstance(self.default_dino_transform,
Compose), "Unexpected `default_dino_transform`!"
assert isinstance(
self.default_siglip_transform,
Compose), "Unexpected `default_siglip_transform`!"
assert ("mean" in self.dino_data_cfg
and "mean" in self.siglip_data_cfg
), "DinoSigLIP `data_cfg` missing `mean`!"
# Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
dino_fill = tuple(
[int(x * 255) for x in self.dino_data_cfg["mean"]])
siglip_fill = tuple(
[int(x * 255) for x in self.siglip_data_cfg["mean"]])
# Build New Transform
self.image_transform = DinoSigLIPImageTransform(
Compose([
LetterboxPad(dino_fill),
*self.default_dino_transform.transforms
]),
Compose([
LetterboxPad(siglip_fill),
*self.default_siglip_transform.transforms
]),
)
else:
raise ValueError(
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
)
self.arch_specifier = arch_specifier
if arch_specifier == "linear":
self.projector = LinearProjector(self.embed_dim, output_dim)
elif arch_specifier.endswith("fused-gelu-mlp"):
self.projector = FusedMLPProjector(self.embed_dim, output_dim)
elif arch_specifier.endswith("gelu-mlp"):
self.projector = MLPProjector(self.embed_dim, output_dim)
else:
raise ValueError(
f"PrismaticVLM with `{arch_specifier = }` is not supported!")
self.on_gpu = False
def get_fsdp_wrapping_policy(self) -> Callable:
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
vit_wrap_policy = partial(_module_wrap_policy,
module_classes={VisionTransformer})
transformer_block_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={Block})
return partial(_or_policy,
policies=[vit_wrap_policy, transformer_block_policy])
def forward(self, img) -> torch.Tensor:
img = torch.clamp(img.float(), -1., 1.)
img = (img + 1.0) / 2.0
img = img * 255
resize = transforms.Resize(min(self.target_size),
interpolation=self.default_dino_transform.
transforms[0].interpolation,
max_size=None,
antialias=True)
center_crop = transforms.CenterCrop(self.target_size)
img = center_crop(resize(img))
dino_normalizer = Normalize(mean=torch.tensor([0.4850, 0.4560,
0.4060]),
std=torch.tensor([0.2290, 0.2240, 0.2250]))
siglip_normalizer = Normalize(
mean=torch.tensor([0.5000, 0.5000, 0.5000]),
std=torch.tensor([0.5000, 0.5000, 0.5000]))
pixel_values = {
'dino': dino_normalizer(img),
'siglip': siglip_normalizer(img)
}
if self.on_gpu:
pixel_values = {k: v.cuda() for k, v in pixel_values.items()}
elif next(self.dino_featurizer.parameters()).device.type != 'cpu':
self.on_gpu = True
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
dino_patches = self.dino_featurizer(pixel_values["dino"])
siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
return self.projector(torch.cat([dino_patches, siglip_patches], dim=2))
@property
def default_image_resolution(self) -> Tuple[int, int, int]:
return self.dino_data_cfg["input_size"]
@property
def embed_dim(self) -> int:
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
@property
def num_patches(self) -> int:
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
return self.dino_featurizer.patch_embed.num_patches
@property
def half_precision_dtype(self) -> torch.dtype:
return torch.bfloat16