第一次完整测例跑完
This commit is contained in:
0
src/unifolm_wma/modules/__init__.py
Normal file
0
src/unifolm_wma/modules/__init__.py
Normal file
806
src/unifolm_wma/modules/attention.py
Normal file
806
src/unifolm_wma/modules/attention.py
Normal file
@@ -0,0 +1,806 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from functools import partial
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
from unifolm_wma.utils.common import (
|
||||
checkpoint,
|
||||
exists,
|
||||
default,
|
||||
)
|
||||
from unifolm_wma.utils.basics import zero_module
|
||||
|
||||
|
||||
class RelativePosition(nn.Module):
|
||||
""" https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
|
||||
|
||||
def __init__(self, num_units, max_relative_position):
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.max_relative_position = max_relative_position
|
||||
self.embeddings_table = nn.Parameter(
|
||||
torch.Tensor(max_relative_position * 2 + 1, num_units))
|
||||
nn.init.xavier_uniform_(self.embeddings_table)
|
||||
|
||||
def forward(self, length_q, length_k):
|
||||
device = self.embeddings_table.device
|
||||
range_vec_q = torch.arange(length_q, device=device)
|
||||
range_vec_k = torch.arange(length_k, device=device)
|
||||
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
|
||||
distance_mat_clipped = torch.clamp(distance_mat,
|
||||
-self.max_relative_position,
|
||||
self.max_relative_position)
|
||||
final_mat = distance_mat_clipped + self.max_relative_position
|
||||
final_mat = final_mat.long()
|
||||
embeddings = self.embeddings_table[final_mat]
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
relative_position=False,
|
||||
temporal_length=None,
|
||||
video_length=None,
|
||||
agent_state_context_len=2,
|
||||
agent_action_context_len=16,
|
||||
image_cross_attention=False,
|
||||
image_cross_attention_scale=1.0,
|
||||
agent_state_cross_attention_scale=1.0,
|
||||
agent_action_cross_attention_scale=1.0,
|
||||
cross_attention_scale_learnable=False,
|
||||
text_context_len=77):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout))
|
||||
|
||||
self.relative_position = relative_position
|
||||
if self.relative_position:
|
||||
assert (temporal_length is not None)
|
||||
self.relative_position_k = RelativePosition(
|
||||
num_units=dim_head, max_relative_position=temporal_length)
|
||||
self.relative_position_v = RelativePosition(
|
||||
num_units=dim_head, max_relative_position=temporal_length)
|
||||
else:
|
||||
## only used for spatial attention, while NOT for temporal attention
|
||||
if XFORMERS_IS_AVAILBLE and temporal_length is None:
|
||||
self.forward = self.efficient_forward
|
||||
|
||||
self.video_length = video_length
|
||||
self.image_cross_attention = image_cross_attention
|
||||
self.image_cross_attention_scale = image_cross_attention_scale
|
||||
self.agent_state_cross_attention_scale = agent_state_cross_attention_scale
|
||||
self.agent_action_cross_attention_scale = agent_action_cross_attention_scale
|
||||
self.text_context_len = text_context_len
|
||||
self.agent_state_context_len = agent_state_context_len
|
||||
self.agent_action_context_len = agent_action_context_len
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
if self.image_cross_attention:
|
||||
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_k_as = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_as = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_k_aa = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_aa = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
if cross_attention_scale_learnable:
|
||||
self.register_parameter('alpha_ctx',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_cas',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_caa',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
k_as, v_as, out_as = None, None, None
|
||||
k_aa, v_aa, out_aa = None, None, None
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:,
|
||||
self.agent_state_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
if self.relative_position:
|
||||
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
|
||||
k2 = self.relative_position_k(len_q, len_k)
|
||||
sim2 = einsum('b t d, t s d -> b t s', q,
|
||||
k2) * self.scale # TODO check
|
||||
sim += sim2
|
||||
del k
|
||||
|
||||
if exists(mask):
|
||||
## feasible for causal attention mask only
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
|
||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
||||
if self.relative_position:
|
||||
v2 = self.relative_position_v(len_q, len_v)
|
||||
out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
|
||||
out += out2
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if k_ip is not None and k_as is not None and k_aa is not None:
|
||||
## for image cross-attention
|
||||
k_ip, v_ip = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_ip, v_ip))
|
||||
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_ip) * self.scale
|
||||
del k_ip
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## for agent state cross-attention
|
||||
k_as, v_as = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_as, v_as))
|
||||
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_as) * self.scale
|
||||
del k_as
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## for agent action cross-attention
|
||||
k_aa, v_aa = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_aa, v_aa))
|
||||
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_aa) * self.scale
|
||||
del k_aa
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if out_ip is not None and out_as is not None and out_aa is not None:
|
||||
if self.cross_attention_scale_learnable:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
||||
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
||||
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
||||
else:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip + \
|
||||
self.agent_state_cross_attention_scale * out_as + \
|
||||
self.agent_action_cross_attention_scale * out_aa
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def efficient_forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k, v, out = None, None, None
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
k_as, v_as, out_as = None, None, None
|
||||
k_aa, v_aa, out_aa = None, None, None
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
if context.shape[1] == self.text_context_len + self.video_length:
|
||||
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
else:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
|
||||
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
||||
q.shape[1],
|
||||
k_aa.shape[1],
|
||||
block_size=16).to(k_aa.device)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
|
||||
if k is not None:
|
||||
k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||
(k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out = (out.unsqueeze(0).reshape(
|
||||
b, self.heads, out.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
|
||||
if k_ip is not None:
|
||||
# For image cross-attention
|
||||
k_ip, v_ip = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_ip, v_ip),
|
||||
)
|
||||
out_ip = xformers.ops.memory_efficient_attention(q,
|
||||
k_ip,
|
||||
v_ip,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out_ip = (out_ip.unsqueeze(0).reshape(
|
||||
b, self.heads, out_ip.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_ip.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
|
||||
if k_as is not None:
|
||||
# For agent state cross-attention
|
||||
k_as, v_as = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_as, v_as),
|
||||
)
|
||||
out_as = xformers.ops.memory_efficient_attention(q,
|
||||
k_as,
|
||||
v_as,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out_as = (out_as.unsqueeze(0).reshape(
|
||||
b, self.heads, out_as.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_as.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
if k_aa is not None:
|
||||
# For agent action cross-attention
|
||||
k_aa, v_aa = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_aa, v_aa),
|
||||
)
|
||||
|
||||
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
|
||||
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
|
||||
attn_mask_aa = attn_mask_aa.to(q.dtype)
|
||||
|
||||
out_aa = xformers.ops.memory_efficient_attention(
|
||||
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
|
||||
|
||||
out_aa = (out_aa.unsqueeze(0).reshape(
|
||||
b, self.heads, out_aa.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_aa.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
|
||||
out = 0.0 if out is None else out
|
||||
out_ip = 0.0 if out_ip is None else out_ip
|
||||
out_as = 0.0 if out_as is None else out_as
|
||||
out_aa = 0.0 if out_aa is None else out_aa
|
||||
|
||||
if self.cross_attention_scale_learnable:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
||||
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
||||
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
||||
|
||||
else:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip + \
|
||||
self.agent_state_cross_attention_scale * out_as + \
|
||||
self.agent_action_cross_attention_scale * out_aa
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
|
||||
num_token = l2 // block_size
|
||||
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2)
|
||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||
attn_mask = torch.zeros_like(mask, dtype=torch.float)
|
||||
attn_mask[mask] = float('-inf')
|
||||
return attn_mask
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
attention_cls=None,
|
||||
video_length=None,
|
||||
agent_state_context_len=2,
|
||||
agent_action_context_len=16,
|
||||
image_cross_attention=False,
|
||||
image_cross_attention_scale=1.0,
|
||||
cross_attention_scale_learnable=False,
|
||||
text_context_len=77):
|
||||
super().__init__()
|
||||
attn_cls = CrossAttention if attention_cls is None else attention_cls
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None)
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
video_length=video_length,
|
||||
agent_state_context_len=agent_state_context_len,
|
||||
agent_action_context_len=agent_action_context_len,
|
||||
image_cross_attention=image_cross_attention,
|
||||
image_cross_attention_scale=image_cross_attention_scale,
|
||||
cross_attention_scale_learnable=cross_attention_scale_learnable,
|
||||
text_context_len=text_context_len)
|
||||
self.image_cross_attention = image_cross_attention
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None, mask=None, **kwargs):
|
||||
# implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
|
||||
input_tuple = (
|
||||
x,
|
||||
) # should not be (x), otherwise *input_tuple will decouple x into multiple arguments
|
||||
if context is not None:
|
||||
input_tuple = (x, context)
|
||||
if mask is not None:
|
||||
forward_mask = partial(self._forward, mask=mask)
|
||||
return checkpoint(forward_mask, (x, ), self.parameters(),
|
||||
self.checkpoint)
|
||||
return checkpoint(self._forward, input_tuple, self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, mask=None):
|
||||
x = self.attn1(self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
mask=mask) + x
|
||||
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data in spatial axis.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
use_checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
use_linear=False,
|
||||
video_length=None,
|
||||
agent_state_context_len=2,
|
||||
agent_action_context_len=16,
|
||||
image_cross_attention=False,
|
||||
cross_attention_scale_learnable=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
attention_cls = None
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
checkpoint=use_checkpoint,
|
||||
attention_cls=attention_cls,
|
||||
video_length=video_length,
|
||||
agent_state_context_len=agent_state_context_len,
|
||||
agent_action_context_len=agent_action_context_len,
|
||||
image_cross_attention=image_cross_attention,
|
||||
cross_attention_scale_learnable=cross_attention_scale_learnable,
|
||||
) for d in range(depth)
|
||||
])
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None, **kwargs):
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context, **kwargs)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class TemporalTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data in temporal axis.
|
||||
First, reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
use_checkpoint=True,
|
||||
use_linear=False,
|
||||
only_self_att=True,
|
||||
causal_attention=False,
|
||||
causal_block_size=1,
|
||||
relative_position=False,
|
||||
temporal_length=None):
|
||||
super().__init__()
|
||||
self.only_self_att = only_self_att
|
||||
self.relative_position = relative_position
|
||||
self.causal_attention = causal_attention
|
||||
self.causal_block_size = causal_block_size
|
||||
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.proj_in = nn.Conv1d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv1d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
if relative_position:
|
||||
assert (temporal_length is not None)
|
||||
attention_cls = partial(CrossAttention,
|
||||
relative_position=True,
|
||||
temporal_length=temporal_length)
|
||||
else:
|
||||
attention_cls = partial(CrossAttention,
|
||||
temporal_length=temporal_length)
|
||||
if self.causal_attention:
|
||||
assert (temporal_length is not None)
|
||||
self.mask = torch.tril(
|
||||
torch.ones([1, temporal_length, temporal_length]))
|
||||
|
||||
if self.only_self_att:
|
||||
context_dim = None
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
attention_cls=attention_cls,
|
||||
checkpoint=use_checkpoint)
|
||||
for d in range(depth)
|
||||
])
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv1d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None):
|
||||
b, c, t, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
|
||||
temp_mask = None
|
||||
if self.causal_attention:
|
||||
# Slice the from mask map
|
||||
temp_mask = self.mask[:, :t, :t].to(x.device)
|
||||
|
||||
if temp_mask is not None:
|
||||
mask = temp_mask.to(x.device)
|
||||
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b * h * w)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
if self.only_self_att:
|
||||
# NOTE: if no context is given, cross-attention defaults to self-attention
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, mask=mask)
|
||||
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
||||
else:
|
||||
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
||||
context = rearrange(context, '(b t) l con -> b t l con',
|
||||
t=t).contiguous()
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
# Calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
||||
for j in range(b):
|
||||
context_j = repeat(context[j],
|
||||
't l con -> (t r) l con',
|
||||
r=(h * w) // t,
|
||||
t=t).contiguous()
|
||||
# Note: causal mask will not applied in cross-attention case
|
||||
x[j] = block(x[j], context=context_j)
|
||||
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h,
|
||||
w=w).contiguous()
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv,
|
||||
'b (qkv heads c) h w -> qkv b heads c (h w)',
|
||||
heads=self.heads,
|
||||
qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out,
|
||||
'b heads c (h w) -> b (heads c) h w',
|
||||
heads=self.heads,
|
||||
h=h,
|
||||
w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# Compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# Attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
630
src/unifolm_wma/modules/encoders/condition.py
Normal file
630
src/unifolm_wma/modules/encoders/condition.py
Normal file
@@ -0,0 +1,630 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import kornia
|
||||
import open_clip
|
||||
import math
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from unifolm_wma.utils.common import autocast
|
||||
from unifolm_wma.utils.utils import count_params
|
||||
from unifolm_wma.modules.encoders.resampler import reshape_tensor
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
self.n_classes = n_classes
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def forward(self, batch, key=None, disable_dropout=False):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes -
|
||||
1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
def get_unconditional_conditioning(self, bs, device="cuda"):
|
||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc = torch.ones((bs, ), device=device) * uc_class
|
||||
uc = {self.key: uc}
|
||||
return uc
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
|
||||
def __init__(self,
|
||||
version="google/t5-v1_1-xxl",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="last",
|
||||
layer_idx=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = layer_idx
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert 0 <= abs(layer_idx) <= 12
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens,
|
||||
output_hidden_states=self.layer == "hidden")
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class ClipImageEmbedder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
jit=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
antialias=True,
|
||||
ucg_rate=0.):
|
||||
super().__init__()
|
||||
from clip import load as load_clip
|
||||
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer('std',
|
||||
torch.Tensor([0.26862954, 0.26130258,
|
||||
0.27577711]),
|
||||
persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# re-normalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x, no_dropout=False):
|
||||
# x is assumed to be in range [-1,1]
|
||||
out = self.model.encode_image(self.preprocess(x))
|
||||
out = out.to(x.dtype)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
out = torch.bernoulli(
|
||||
(1. - self.ucg_rate) *
|
||||
torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||
return out
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
# "pooled",
|
||||
"last",
|
||||
"penultimate"
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
arch="ViT-H-14",
|
||||
version="laion2b_s32b_b79k",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="last"):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=version)
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "last":
|
||||
self.layer_idx = 0
|
||||
elif self.layer == "penultimate":
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(
|
||||
text) ## all clip models use 77 as context length
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch="ViT-H-14",
|
||||
version="laion2b_s32b_b79k",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="pooled",
|
||||
antialias=True,
|
||||
ucg_rate=0.):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch,
|
||||
device=torch.device('cpu'),
|
||||
pretrained=version,
|
||||
)
|
||||
del model.transformer
|
||||
self.model = model
|
||||
# self.mapper = torch.nn.Linear(1280, 1024)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "penultimate":
|
||||
raise NotImplementedError()
|
||||
self.layer_idx = 1
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer('std',
|
||||
torch.Tensor([0.26862954, 0.26130258,
|
||||
0.27577711]),
|
||||
persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@autocast
|
||||
def forward(self, image, no_dropout=False):
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
z = torch.bernoulli(
|
||||
(1. - self.ucg_rate) *
|
||||
torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, img):
|
||||
img = self.preprocess(img)
|
||||
x = self.model.visual(img)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch="ViT-H-14",
|
||||
version="laion2b_s32b_b79k",
|
||||
device="cuda",
|
||||
freeze=True,
|
||||
layer="pooled",
|
||||
antialias=True):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch,
|
||||
device=torch.device('cpu'),
|
||||
pretrained=version,
|
||||
)
|
||||
del model.transformer
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "penultimate":
|
||||
raise NotImplementedError()
|
||||
self.layer_idx = 1
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer('std',
|
||||
torch.Tensor([0.26862954, 0.26130258,
|
||||
0.27577711]),
|
||||
persistent=False)
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, image, no_dropout=False):
|
||||
## image: b c h w
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, x):
|
||||
x = self.preprocess(x)
|
||||
|
||||
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
||||
if self.model.visual.input_patchnorm:
|
||||
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
self.model.visual.grid_size[0],
|
||||
self.model.visual.patch_size[0],
|
||||
self.model.visual.grid_size[1],
|
||||
self.model.visual.patch_size[1])
|
||||
x = x.permute(0, 2, 4, 1, 3, 5)
|
||||
x = x.reshape(
|
||||
x.shape[0], self.model.visual.grid_size[0] *
|
||||
self.model.visual.grid_size[1], -1)
|
||||
x = self.model.visual.patchnorm_pre_ln(x)
|
||||
x = self.model.visual.conv1(x)
|
||||
else:
|
||||
x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
# class embeddings and positional embeddings
|
||||
x = torch.cat([
|
||||
self.model.visual.class_embedding.to(x.dtype) + torch.zeros(
|
||||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
|
||||
],
|
||||
dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.model.visual.positional_embedding.to(x.dtype)
|
||||
|
||||
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
||||
x = self.model.visual.patch_dropout(x)
|
||||
x = self.model.visual.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.model.visual.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
|
||||
def __init__(self,
|
||||
clip_version="openai/clip-vit-large-patch14",
|
||||
t5_version="google/t5-v1_1-xl",
|
||||
device="cuda",
|
||||
clip_max_length=77,
|
||||
t5_max_length=77):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version,
|
||||
device,
|
||||
max_length=clip_max_length)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version,
|
||||
device,
|
||||
max_length=t5_max_length)
|
||||
print(
|
||||
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
|
||||
)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
def forward(self, text):
|
||||
clip_z = self.clip_encoder.encode(text)
|
||||
t5_z = self.t5_encoder.encode(text)
|
||||
return [clip_z, t5_z]
|
||||
|
||||
|
||||
class LinearProjector(nn.Module):
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.projector = nn.Linear(input_dim, output_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class MLPProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
mlp_type: str = "gelu-mlp") -> None:
|
||||
super().__init__()
|
||||
if mlp_type == "gelu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(input_dim, output_dim, bias=True),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(output_dim, output_dim, bias=True),
|
||||
)
|
||||
elif mlp_type == "silu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(input_dim, output_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(output_dim, output_dim, bias=True),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(
|
||||
-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
|
||||
inner_dim = int(dim * mult)
|
||||
if ffd_type == "gelu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
elif ffd_type == "silu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
|
||||
class SATokenProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=1024,
|
||||
depth=1,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=16,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
chunk_size=None):
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
if chunk_size is not None:
|
||||
num_queries = num_queries * chunk_size
|
||||
|
||||
self.latents = nn.Parameter(
|
||||
torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head,
|
||||
heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
return latents
|
||||
153
src/unifolm_wma/modules/encoders/resampler.py
Normal file
153
src/unifolm_wma/modules/encoders/resampler.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
||||
# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ImageProjModel(nn.Module):
|
||||
"""Projection Model"""
|
||||
|
||||
def __init__(self,
|
||||
cross_attention_dim=1024,
|
||||
clip_embeddings_dim=1024,
|
||||
clip_extra_context_tokens=4):
|
||||
super().__init__()
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||
self.proj = nn.Linear(
|
||||
clip_embeddings_dim,
|
||||
self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds):
|
||||
#embeds = image_embeds
|
||||
embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
||||
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(
|
||||
-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
video_length=None, # using frame-wise version or not
|
||||
):
|
||||
super().__init__()
|
||||
## queries for a single frame / image
|
||||
self.num_queries = num_queries
|
||||
self.video_length = video_length
|
||||
|
||||
## <num_queries> queries for each frame
|
||||
if video_length is not None:
|
||||
num_queries = num_queries * video_length
|
||||
|
||||
self.latents = nn.Parameter(
|
||||
torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head,
|
||||
heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
|
||||
return latents
|
||||
1005
src/unifolm_wma/modules/networks/ae_modules.py
Normal file
1005
src/unifolm_wma/modules/networks/ae_modules.py
Normal file
File diff suppressed because it is too large
Load Diff
848
src/unifolm_wma/modules/networks/wma_model.py
Normal file
848
src/unifolm_wma/modules/networks/wma_model.py
Normal file
@@ -0,0 +1,848 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from functools import partial
|
||||
from abc import abstractmethod
|
||||
from einops import rearrange
|
||||
from omegaconf import OmegaConf
|
||||
from typing import Optional, Sequence, Any, Tuple, Union, List, Dict
|
||||
from collections.abc import Mapping, Iterable, Callable
|
||||
|
||||
from unifolm_wma.utils.diffusion import timestep_embedding
|
||||
from unifolm_wma.utils.common import checkpoint
|
||||
from unifolm_wma.utils.basics import (zero_module, conv_nd, linear,
|
||||
avg_pool_nd, normalization)
|
||||
from unifolm_wma.modules.attention import SpatialTransformer, TemporalTransformer
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None, batch_size=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb, batch_size=batch_size)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
elif isinstance(layer, TemporalTransformer):
|
||||
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
|
||||
x = layer(x, context)
|
||||
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||
mode='nearest')
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
:param use_temporal_conv: if True, use the temporal convolution.
|
||||
:param use_image_dataset: if True, the temporal parameters will not be optimized.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
use_temporal_conv=False,
|
||||
tempspatial_aware=False):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.use_temporal_conv = use_temporal_conv
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels
|
||||
if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(dims,
|
||||
channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels,
|
||||
1)
|
||||
|
||||
if self.use_temporal_conv:
|
||||
self.temopral_conv = TemporalConvBlock(
|
||||
self.out_channels,
|
||||
self.out_channels,
|
||||
dropout=0.1,
|
||||
spatial_aware=tempspatial_aware)
|
||||
|
||||
def forward(self, x, emb, batch_size=None):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
input_tuple = (x, emb)
|
||||
if batch_size:
|
||||
forward_batchsize = partial(self._forward, batch_size=batch_size)
|
||||
return checkpoint(forward_batchsize, input_tuple,
|
||||
self.parameters(), self.use_checkpoint)
|
||||
return checkpoint(self._forward, input_tuple, self.parameters(),
|
||||
self.use_checkpoint)
|
||||
|
||||
def _forward(self, x, emb, batch_size=None):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
h = self.skip_connection(x) + h
|
||||
|
||||
if self.use_temporal_conv and batch_size:
|
||||
h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
|
||||
h = self.temopral_conv(h)
|
||||
h = rearrange(h, 'b c t h w -> (b t) c h w')
|
||||
return h
|
||||
|
||||
|
||||
class TemporalConvBlock(nn.Module):
|
||||
"""
|
||||
Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
dropout=0.0,
|
||||
spatial_aware=False):
|
||||
super(TemporalConvBlock, self).__init__()
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
|
||||
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
|
||||
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
|
||||
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
|
||||
|
||||
# conv layers
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.GroupNorm(32, in_channels), nn.SiLU(),
|
||||
nn.Conv3d(in_channels,
|
||||
out_channels,
|
||||
th_kernel_shape,
|
||||
padding=th_padding_shape))
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels,
|
||||
in_channels,
|
||||
tw_kernel_shape,
|
||||
padding=tw_padding_shape))
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels,
|
||||
in_channels,
|
||||
th_kernel_shape,
|
||||
padding=th_padding_shape))
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels,
|
||||
in_channels,
|
||||
tw_kernel_shape,
|
||||
padding=tw_padding_shape))
|
||||
|
||||
# Zero out the last layer params,so the conv block is identity
|
||||
nn.init.zeros_(self.conv4[-1].weight)
|
||||
nn.init.zeros_(self.conv4[-1].bias)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
|
||||
return identity + x
|
||||
|
||||
|
||||
class WMAModel(nn.Module):
|
||||
"""
|
||||
The full World-Model-Action model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
attention_resolutions: Sequence[int],
|
||||
dropout: float = 0.0,
|
||||
channel_mult: Sequence[int] = (1, 2, 4, 8),
|
||||
conv_resample: bool = True,
|
||||
dims: int = 2,
|
||||
context_dim: int | None = None,
|
||||
use_scale_shift_norm: bool = False,
|
||||
resblock_updown: bool = False,
|
||||
num_heads: int = -1,
|
||||
num_head_channels: int = -1,
|
||||
transformer_depth: int = 1,
|
||||
use_linear: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
temporal_conv: bool = False,
|
||||
tempspatial_aware: bool = False,
|
||||
temporal_attention: bool = True,
|
||||
use_relative_position: bool = True,
|
||||
use_causal_attention: bool = False,
|
||||
temporal_length: int | None = None,
|
||||
use_fp16: bool = False,
|
||||
addition_attention: bool = False,
|
||||
temporal_selfatt_only: bool = True,
|
||||
image_cross_attention: bool = False,
|
||||
cross_attention_scale_learnable: bool = False,
|
||||
default_fs: int = 4,
|
||||
fs_condition: bool = False,
|
||||
n_obs_steps: int = 1,
|
||||
num_stem_token: int = 1,
|
||||
unet_head_config: OmegaConf | None = None,
|
||||
stem_process_config: OmegaConf | None = None,
|
||||
base_model_gen_only: bool = False):
|
||||
"""
|
||||
Initialize the World-Model-Action network.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels to the backbone.
|
||||
model_channels: Base channel width for the UNet/backbone.
|
||||
out_channels: Number of output channels.
|
||||
num_res_blocks: Number of residual blocks per resolution stage.
|
||||
attention_resolutions: Resolutions at which to enable attention.
|
||||
dropout: Dropout probability used inside residual/attention blocks.
|
||||
channel_mult: Multipliers for channels at each resolution level.
|
||||
conv_resample: If True, use convolutional resampling for up/down sampling.
|
||||
dims: Spatial dimensionality of the backbone (1/2/3).
|
||||
context_dim: Optional context embedding dimension (for cross-attention).
|
||||
use_scale_shift_norm: Enable scale-shift (FiLM-style) normalization in blocks.
|
||||
resblock_updown: Use residual blocks for up/down sampling (instead of plain conv).
|
||||
num_heads: Number of attention heads (if >= 0). If -1, derive from num_head_channels.
|
||||
num_head_channels: Channels per attention head (if >= 0). If -1, derive from num_heads.
|
||||
transformer_depth: Number of transformer/attention blocks per stage.
|
||||
use_linear: Use linear attention variants where applicable.
|
||||
use_checkpoint: Enable gradient checkpointing in blocks to save memory.
|
||||
temporal_conv: Include temporal convolution along the time dimension.
|
||||
tempspatial_aware: If True, use time–space aware blocks.
|
||||
temporal_attention: Enable temporal self-attention.
|
||||
use_relative_position: Use relative position encodings in attention.
|
||||
use_causal_attention: Use causal (uni-directional) attention along time.
|
||||
temporal_length: Optional maximum temporal length expected by the model.
|
||||
use_fp16: Enable half-precision layers/normalization where supported.
|
||||
addition_attention: Add auxiliary attention modules.
|
||||
temporal_selfatt_only: Restrict attention to temporal-only (no spatial) if True.
|
||||
image_cross_attention: Enable cross-attention with image embeddings.
|
||||
cross_attention_scale_learnable: Make cross-attention scaling a learnable parameter.
|
||||
default_fs: Default frame-stride / fps.
|
||||
fs_condition: If True, condition on frame-stride/fps features.
|
||||
n_obs_steps: Number of observed steps used in conditioning heads.
|
||||
num_stem_token: Number of stem tokens for action tokenization.
|
||||
unet_head_config: OmegaConf for UNet heads (e.g., action/state heads).
|
||||
stem_process_config: OmegaConf for stem/preprocessor module.
|
||||
base_model_gen_only: Perform the generation using the base model with out action and state outputs.
|
||||
"""
|
||||
|
||||
super(WMAModel, self).__init__()
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.temporal_attention = temporal_attention
|
||||
time_embed_dim = model_channels * 4
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
temporal_self_att_only = True
|
||||
self.addition_attention = addition_attention
|
||||
self.temporal_length = temporal_length
|
||||
self.image_cross_attention = image_cross_attention
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
self.default_fs = default_fs
|
||||
self.fs_condition = fs_condition
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.num_stem_token = num_stem_token
|
||||
self.base_model_gen_only = base_model_gen_only
|
||||
|
||||
# Time embedding blocks
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
if fs_condition:
|
||||
self.fps_embedding = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
nn.init.zeros_(self.fps_embedding[-1].weight)
|
||||
nn.init.zeros_(self.fps_embedding[-1].bias)
|
||||
# Input Block
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
||||
])
|
||||
if self.addition_attention:
|
||||
self.init_attn = TimestepEmbedSequential(
|
||||
TemporalTransformer(model_channels,
|
||||
n_heads=8,
|
||||
d_head=num_head_channels,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_selfatt_only,
|
||||
causal_attention=False,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disable_self_attn=False,
|
||||
video_length=temporal_length,
|
||||
agent_state_context_len=self.n_obs_steps,
|
||||
agent_action_context_len=self.temporal_length *
|
||||
num_stem_token,
|
||||
image_cross_attention=self.image_cross_attention,
|
||||
cross_attention_scale_learnable=self.
|
||||
cross_attention_scale_learnable,
|
||||
))
|
||||
if self.temporal_attention:
|
||||
layers.append(
|
||||
TemporalTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_self_att_only,
|
||||
causal_attention=use_causal_attention,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True)
|
||||
if resblock_updown else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
layers = [
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv),
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disable_self_attn=False,
|
||||
video_length=temporal_length,
|
||||
agent_state_context_len=self.n_obs_steps,
|
||||
agent_action_context_len=self.temporal_length * num_stem_token,
|
||||
image_cross_attention=self.image_cross_attention,
|
||||
cross_attention_scale_learnable=self.
|
||||
cross_attention_scale_learnable)
|
||||
]
|
||||
if self.temporal_attention:
|
||||
layers.append(
|
||||
TemporalTransformer(ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_self_att_only,
|
||||
causal_attention=use_causal_attention,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
layers.append(
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv))
|
||||
|
||||
# Middle Block
|
||||
self.middle_block = TimestepEmbedSequential(*layers)
|
||||
|
||||
# Output Block
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disable_self_attn=False,
|
||||
video_length=temporal_length,
|
||||
agent_state_context_len=self.n_obs_steps,
|
||||
image_cross_attention=self.image_cross_attention,
|
||||
cross_attention_scale_learnable=self.
|
||||
cross_attention_scale_learnable))
|
||||
if self.temporal_attention:
|
||||
layers.append(
|
||||
TemporalTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_self_att_only,
|
||||
causal_attention=use_causal_attention,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True)
|
||||
if resblock_updown else Upsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(
|
||||
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
# Action and state prediction unet
|
||||
unet_head_config['params']['context_dims'] = [
|
||||
mult * model_channels for mult in channel_mult
|
||||
]
|
||||
self.action_unet = instantiate_from_config(unet_head_config)
|
||||
self.state_unet = instantiate_from_config(unet_head_config)
|
||||
|
||||
# Initialize action token_projector
|
||||
self.action_token_projector = instantiate_from_config(
|
||||
stem_process_config)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
x_action: Tensor,
|
||||
x_state: Tensor,
|
||||
timesteps: Tensor,
|
||||
context: Tensor | None = None,
|
||||
context_action: Tensor | None = None,
|
||||
features_adapter: Any = None,
|
||||
fs: Tensor | None = None,
|
||||
**kwargs) -> Tensor | tuple[Tensor, ...]:
|
||||
|
||||
"""
|
||||
Forward pass of the World-Model-Action backbone.
|
||||
|
||||
Args:
|
||||
x: Input tensor (latent video), shape (B, C,...).
|
||||
x_action: action stream input.
|
||||
x_state: state stream input.
|
||||
timesteps: Diffusion timesteps, shape (B,) or scalar Tensor.
|
||||
context: conditioning context for cross-attention.
|
||||
context_action: conditioning context specific to action/state (implementation-specific).
|
||||
features_adapter: module or dict to adapt intermediate features.
|
||||
fs: frame-stride / fps conditioning.
|
||||
|
||||
Returns:
|
||||
Tuple of Tensors for predictions:
|
||||
|
||||
"""
|
||||
|
||||
b, _, t, _, _ = x.shape
|
||||
t_emb = timestep_embedding(timesteps,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
|
||||
emb = emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
|
||||
# Combine emb
|
||||
if self.fs_condition:
|
||||
if fs is None:
|
||||
fs = torch.tensor([self.default_fs] * b,
|
||||
dtype=torch.long,
|
||||
device=x.device)
|
||||
fs_emb = timestep_embedding(fs,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
|
||||
fs_embed = self.fps_embedding(fs_emb)
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
emb = emb + fs_embed
|
||||
|
||||
h = x.type(self.dtype)
|
||||
adapter_idx = 0
|
||||
hs = []
|
||||
hs_a = []
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if id == 0 and self.addition_attention:
|
||||
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||
# plug-in adapter features
|
||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||
h = h + features_adapter[adapter_idx]
|
||||
adapter_idx += 1
|
||||
if id != 0:
|
||||
if isinstance(module[0], Downsample):
|
||||
hs_a.append(
|
||||
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs.append(h)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
if features_adapter is not None:
|
||||
assert len(
|
||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
hs_out = []
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(
|
||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs_out.append(h)
|
||||
h = h.type(x.dtype)
|
||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
y = self.out(h)
|
||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||
|
||||
if not self.base_model_gen_only:
|
||||
ba, _, _ = x_action.shape
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
# Predict state
|
||||
if b > 1:
|
||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
a_y = torch.zeros_like(x_action)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
return y, a_y, s_y
|
||||
0
src/unifolm_wma/modules/vision/__init__.py
Normal file
0
src/unifolm_wma/modules/vision/__init__.py
Normal file
244
src/unifolm_wma/modules/vision/base_vision.py
Normal file
244
src/unifolm_wma/modules/vision/base_vision.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
base_vision.py
|
||||
|
||||
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
|
||||
functions, and initialization logic.
|
||||
|
||||
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
|
||||
Transformer model for feature extraction.
|
||||
"""
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as TVF
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from timm.models.vision_transformer import Block, VisionTransformer
|
||||
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
||||
from torchvision.transforms import Compose, Resize
|
||||
|
||||
|
||||
# === Utility Functions for Monkey-Patching ===
|
||||
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
||||
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
result = fn(*args, **kwargs)
|
||||
return result[0] if isinstance(result, tuple) else result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# === Interface for an Image Transform ===
|
||||
class ImageTransform(Protocol):
|
||||
|
||||
def __call__(
|
||||
self, img: Image,
|
||||
**kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
...
|
||||
|
||||
|
||||
# === Custom Torchvision Image Transforms ===
|
||||
@dataclass
|
||||
class LetterboxPad:
|
||||
padding_fill_value: Tuple[int, int, int]
|
||||
|
||||
def __call__(self, image: Image) -> Image:
|
||||
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
|
||||
(w, h), max_wh = image.size, max(image.size)
|
||||
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int(
|
||||
(max_wh - h) / 2)
|
||||
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
|
||||
return TVF.pad(image,
|
||||
padding,
|
||||
fill=self.padding_fill_value,
|
||||
padding_mode="constant")
|
||||
|
||||
|
||||
# === Abstract Base Class for arbitrary Vision Backbones ===
|
||||
class VisionBackbone(nn.Module, ABC):
|
||||
|
||||
def __init__(self,
|
||||
vision_backbone_id: str,
|
||||
image_resize_strategy: str,
|
||||
default_image_size: int = 224) -> None:
|
||||
super().__init__()
|
||||
self.identifier: str = vision_backbone_id
|
||||
self.image_resize_strategy: str = image_resize_strategy
|
||||
self.default_image_size: int = default_image_size
|
||||
|
||||
# Instance attributes for a Vision Backbone
|
||||
self.featurizer: nn.Module = None
|
||||
self.image_transform: ImageTransform = None
|
||||
|
||||
def get_image_transform(self) -> ImageTransform:
|
||||
return self.image_transform
|
||||
|
||||
@abstractmethod
|
||||
def get_fsdp_wrapping_policy(self) -> Callable:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def default_image_resolution(self) -> Tuple[int, int, int]:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def embed_dim(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_patches(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def half_precision_dtype(self) -> torch.dtype:
|
||||
...
|
||||
|
||||
|
||||
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
|
||||
class TimmViTBackbone(VisionBackbone, ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone_id: str,
|
||||
timm_path_or_url: str,
|
||||
image_resize_strategy: str,
|
||||
default_image_size: int = 224,
|
||||
override_act_layer: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(vision_backbone_id,
|
||||
image_resize_strategy,
|
||||
default_image_size=default_image_size)
|
||||
self.timm_path_or_url = timm_path_or_url
|
||||
self.override_act_layer = override_act_layer
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
|
||||
if self.override_act_layer is None:
|
||||
self.featurizer: VisionTransformer = timm.create_model(
|
||||
self.timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size)
|
||||
else:
|
||||
self.featurizer: VisionTransformer = timm.create_model(
|
||||
self.timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size,
|
||||
act_layer=self.override_act_layer,
|
||||
)
|
||||
self.featurizer.eval()
|
||||
|
||||
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
|
||||
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
||||
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
||||
self.featurizer.forward = unpack_tuple(
|
||||
partial(self.featurizer.get_intermediate_layers,
|
||||
n={len(self.featurizer.blocks) - 2}))
|
||||
|
||||
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
|
||||
assert isinstance(self.featurizer, VisionTransformer), (
|
||||
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
|
||||
"file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
|
||||
)
|
||||
|
||||
# Get Config =>> Note :: Override default image size to ensure correct image transform
|
||||
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
|
||||
self.data_cfg["input_size"] = (3, self.default_image_size,
|
||||
self.default_image_size)
|
||||
|
||||
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
|
||||
default_image_transform = timm.data.create_transform(**self.data_cfg,
|
||||
is_training=False)
|
||||
|
||||
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
|
||||
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
|
||||
assert isinstance(default_image_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert isinstance(default_image_transform.transforms[0], Resize)
|
||||
default_image_transform = Compose([
|
||||
Resize(self.default_image_size,
|
||||
interpolation=default_image_transform.transforms[0].
|
||||
interpolation),
|
||||
*default_image_transform.transforms[1:],
|
||||
])
|
||||
|
||||
# Switch on `image_resize_strategy`
|
||||
if self.image_resize_strategy == "resize-naive":
|
||||
assert isinstance(default_image_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert isinstance(default_image_transform.transforms[0], Resize)
|
||||
|
||||
target_size = (self.default_image_size, self.default_image_size)
|
||||
self.image_transform = Compose([
|
||||
Resize(target_size,
|
||||
interpolation=default_image_transform.transforms[0].
|
||||
interpolation),
|
||||
*default_image_transform.transforms[1:],
|
||||
])
|
||||
|
||||
elif self.image_resize_strategy == "resize-crop":
|
||||
self.image_transform = default_image_transform
|
||||
|
||||
elif self.image_resize_strategy == "letterbox":
|
||||
assert isinstance(default_image_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
|
||||
|
||||
# Compute Padding Fill Value (rescaled normalization mean if applicable)
|
||||
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
|
||||
|
||||
# Build New Transform
|
||||
self.image_transform = Compose(
|
||||
[LetterboxPad(fill), *default_image_transform.transforms])
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
|
||||
)
|
||||
|
||||
def get_fsdp_wrapping_policy(self) -> Callable:
|
||||
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
|
||||
vit_wrap_policy = partial(_module_wrap_policy,
|
||||
module_classes={VisionTransformer})
|
||||
transformer_block_policy = partial(transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={Block})
|
||||
return partial(_or_policy,
|
||||
policies=[vit_wrap_policy, transformer_block_policy])
|
||||
|
||||
def forward(
|
||||
self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
|
||||
return self.featurizer(pixel_values)
|
||||
|
||||
@property
|
||||
def default_image_resolution(self) -> Tuple[int, int, int]:
|
||||
return self.data_cfg["input_size"]
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self.featurizer.embed_dim
|
||||
|
||||
@property
|
||||
def num_patches(self) -> int:
|
||||
return self.featurizer.patch_embed.num_patches
|
||||
|
||||
@property
|
||||
def half_precision_dtype(self) -> torch.dtype:
|
||||
return self.dtype
|
||||
273
src/unifolm_wma/modules/vision/dinosiglip_vit.py
Normal file
273
src/unifolm_wma/modules/vision/dinosiglip_vit.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
dinosiglip_vit.py
|
||||
|
||||
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
|
||||
"""
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Tuple
|
||||
from PIL import Image
|
||||
from timm.models.vision_transformer import Block, VisionTransformer
|
||||
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
||||
from torchvision.transforms import Compose, Resize, Normalize
|
||||
|
||||
from unifolm_wma.modules.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
|
||||
from unifolm_wma.utils.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
|
||||
|
||||
# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
|
||||
DINOSigLIP_VISION_BACKBONES = {
|
||||
"dinosiglip-vit-so-224px": {
|
||||
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
|
||||
"siglip": "vit_so400m_patch14_siglip_224",
|
||||
},
|
||||
"dinosiglip-vit-so-384px": {
|
||||
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
|
||||
"siglip": "vit_so400m_patch14_siglip_384",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DinoSigLIPImageTransform:
|
||||
dino_image_transform: ImageTransform
|
||||
siglip_image_transform: ImageTransform
|
||||
is_prismatic: bool = True
|
||||
|
||||
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
|
||||
return {
|
||||
"dino": self.dino_image_transform(img, **kwargs),
|
||||
"siglip": self.siglip_image_transform(img, **kwargs)
|
||||
}
|
||||
|
||||
|
||||
class DinoSigLIPViTBackbone(VisionBackbone):
|
||||
|
||||
def __init__(self,
|
||||
vision_backbone_id: str,
|
||||
image_resize_strategy: str,
|
||||
arch_specifier: str,
|
||||
output_dim: int,
|
||||
pretrained_checkpoint=None,
|
||||
freeze=True,
|
||||
default_image_size: int = 224) -> None:
|
||||
super().__init__(vision_backbone_id,
|
||||
image_resize_strategy,
|
||||
default_image_size=default_image_size)
|
||||
self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
|
||||
vision_backbone_id]["dino"]
|
||||
self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
|
||||
vision_backbone_id]["siglip"]
|
||||
|
||||
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
|
||||
self.dino_featurizer: VisionTransformer = timm.create_model(
|
||||
self.dino_timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size)
|
||||
if pretrained_checkpoint:
|
||||
ckpt = pretrained_checkpoint + '/openvla_dino.pt'
|
||||
self.dino_featurizer.load_state_dict(
|
||||
torch.load(ckpt, weights_only=True))
|
||||
print('>>> load dino weights')
|
||||
if freeze:
|
||||
self.dino_featurizer.eval()
|
||||
for param in self.dino_featurizer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.siglip_featurizer: VisionTransformer = timm.create_model(
|
||||
self.siglip_timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size)
|
||||
if pretrained_checkpoint:
|
||||
ckpt = pretrained_checkpoint + '/openvla_siglip.pt'
|
||||
self.siglip_featurizer.load_state_dict(
|
||||
torch.load(ckpt, weights_only=True))
|
||||
print('>>> load siglip weights')
|
||||
if freeze:
|
||||
self.siglip_featurizer.eval()
|
||||
for param in self.siglip_featurizer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
|
||||
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
||||
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
||||
self.dino_featurizer.forward = unpack_tuple(
|
||||
partial(self.dino_featurizer.get_intermediate_layers,
|
||||
n={len(self.dino_featurizer.blocks) - 2}))
|
||||
self.siglip_featurizer.forward = unpack_tuple(
|
||||
partial(self.siglip_featurizer.get_intermediate_layers,
|
||||
n={len(self.siglip_featurizer.blocks) - 2}))
|
||||
|
||||
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
|
||||
self.dino_data_cfg = timm.data.resolve_model_data_config(
|
||||
self.dino_featurizer)
|
||||
self.dino_data_cfg["input_size"] = (3, self.default_image_size,
|
||||
self.default_image_size)
|
||||
|
||||
self.siglip_data_cfg = timm.data.resolve_model_data_config(
|
||||
self.siglip_featurizer)
|
||||
self.siglip_data_cfg["input_size"] = (3, self.default_image_size,
|
||||
self.default_image_size)
|
||||
|
||||
# Initialize *both* Transforms
|
||||
self.default_dino_transform = timm.data.create_transform(
|
||||
**self.dino_data_cfg, is_training=False)
|
||||
self.default_siglip_transform = timm.data.create_transform(
|
||||
**self.siglip_data_cfg, is_training=False)
|
||||
|
||||
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
|
||||
assert isinstance(self.default_siglip_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert isinstance(self.default_siglip_transform.transforms[0], Resize)
|
||||
self.default_siglip_transform = Compose([
|
||||
Resize(self.default_image_size,
|
||||
interpolation=self.default_siglip_transform.transforms[0].
|
||||
interpolation),
|
||||
*self.default_siglip_transform.transforms[1:],
|
||||
])
|
||||
|
||||
if self.image_resize_strategy == "resize-naive":
|
||||
assert isinstance(
|
||||
self.default_dino_transform,
|
||||
Compose), "Unexpected `default_dino_image_transform`!"
|
||||
assert isinstance(
|
||||
self.default_siglip_transform,
|
||||
Compose), "Unexpected `default_siglip_image_transform`!"
|
||||
assert isinstance(self.default_dino_transform.transforms[0],
|
||||
Resize)
|
||||
assert isinstance(self.default_siglip_transform.transforms[0],
|
||||
Resize)
|
||||
|
||||
self.target_size = (self.default_image_size,
|
||||
self.default_image_size)
|
||||
dino_transform = Compose([
|
||||
Resize(self.target_size,
|
||||
interpolation=self.default_dino_transform.transforms[0].
|
||||
interpolation),
|
||||
*self.default_dino_transform.transforms[1:],
|
||||
])
|
||||
siglip_transform = Compose([
|
||||
Resize(self.target_size,
|
||||
interpolation=self.default_siglip_transform.
|
||||
transforms[0].interpolation),
|
||||
*self.default_siglip_transform.transforms[1:],
|
||||
])
|
||||
|
||||
self.image_transform = DinoSigLIPImageTransform(
|
||||
dino_transform, siglip_transform)
|
||||
|
||||
elif self.image_resize_strategy == "resize-crop":
|
||||
self.image_transform = DinoSigLIPImageTransform(
|
||||
self.default_dino_transform, self.default_siglip_transform)
|
||||
|
||||
elif self.image_resize_strategy == "letterbox":
|
||||
assert isinstance(self.default_dino_transform,
|
||||
Compose), "Unexpected `default_dino_transform`!"
|
||||
assert isinstance(
|
||||
self.default_siglip_transform,
|
||||
Compose), "Unexpected `default_siglip_transform`!"
|
||||
assert ("mean" in self.dino_data_cfg
|
||||
and "mean" in self.siglip_data_cfg
|
||||
), "DinoSigLIP `data_cfg` missing `mean`!"
|
||||
|
||||
# Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
|
||||
dino_fill = tuple(
|
||||
[int(x * 255) for x in self.dino_data_cfg["mean"]])
|
||||
siglip_fill = tuple(
|
||||
[int(x * 255) for x in self.siglip_data_cfg["mean"]])
|
||||
|
||||
# Build New Transform
|
||||
self.image_transform = DinoSigLIPImageTransform(
|
||||
Compose([
|
||||
LetterboxPad(dino_fill),
|
||||
*self.default_dino_transform.transforms
|
||||
]),
|
||||
Compose([
|
||||
LetterboxPad(siglip_fill),
|
||||
*self.default_siglip_transform.transforms
|
||||
]),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
|
||||
)
|
||||
|
||||
self.arch_specifier = arch_specifier
|
||||
if arch_specifier == "linear":
|
||||
self.projector = LinearProjector(self.embed_dim, output_dim)
|
||||
elif arch_specifier.endswith("fused-gelu-mlp"):
|
||||
self.projector = FusedMLPProjector(self.embed_dim, output_dim)
|
||||
elif arch_specifier.endswith("gelu-mlp"):
|
||||
self.projector = MLPProjector(self.embed_dim, output_dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"PrismaticVLM with `{arch_specifier = }` is not supported!")
|
||||
|
||||
self.on_gpu = False
|
||||
|
||||
def get_fsdp_wrapping_policy(self) -> Callable:
|
||||
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
|
||||
vit_wrap_policy = partial(_module_wrap_policy,
|
||||
module_classes={VisionTransformer})
|
||||
transformer_block_policy = partial(transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={Block})
|
||||
return partial(_or_policy,
|
||||
policies=[vit_wrap_policy, transformer_block_policy])
|
||||
|
||||
def forward(self, img) -> torch.Tensor:
|
||||
img = torch.clamp(img.float(), -1., 1.)
|
||||
img = (img + 1.0) / 2.0
|
||||
img = img * 255
|
||||
|
||||
resize = transforms.Resize(min(self.target_size),
|
||||
interpolation=self.default_dino_transform.
|
||||
transforms[0].interpolation,
|
||||
max_size=None,
|
||||
antialias=True)
|
||||
center_crop = transforms.CenterCrop(self.target_size)
|
||||
img = center_crop(resize(img))
|
||||
|
||||
dino_normalizer = Normalize(mean=torch.tensor([0.4850, 0.4560,
|
||||
0.4060]),
|
||||
std=torch.tensor([0.2290, 0.2240, 0.2250]))
|
||||
siglip_normalizer = Normalize(
|
||||
mean=torch.tensor([0.5000, 0.5000, 0.5000]),
|
||||
std=torch.tensor([0.5000, 0.5000, 0.5000]))
|
||||
pixel_values = {
|
||||
'dino': dino_normalizer(img),
|
||||
'siglip': siglip_normalizer(img)
|
||||
}
|
||||
|
||||
if self.on_gpu:
|
||||
pixel_values = {k: v.cuda() for k, v in pixel_values.items()}
|
||||
elif next(self.dino_featurizer.parameters()).device.type != 'cpu':
|
||||
self.on_gpu = True
|
||||
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
|
||||
dino_patches = self.dino_featurizer(pixel_values["dino"])
|
||||
siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
|
||||
|
||||
return self.projector(torch.cat([dino_patches, siglip_patches], dim=2))
|
||||
|
||||
@property
|
||||
def default_image_resolution(self) -> Tuple[int, int, int]:
|
||||
return self.dino_data_cfg["input_size"]
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
|
||||
|
||||
@property
|
||||
def num_patches(self) -> int:
|
||||
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
|
||||
return self.dino_featurizer.patch_embed.num_patches
|
||||
|
||||
@property
|
||||
def half_precision_dtype(self) -> torch.dtype:
|
||||
return torch.bfloat16
|
||||
Reference in New Issue
Block a user