712 lines
28 KiB
Python
712 lines
28 KiB
Python
import logging
|
|
import torch
|
|
import torch.nn as nn
|
|
import einops
|
|
|
|
from einops import rearrange, repeat
|
|
from typing import Union
|
|
|
|
from unifolm_wma.models.diffusion_head.conv1d_components import (
|
|
Downsample1d, Upsample1d, Conv1dBlock)
|
|
from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb
|
|
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
|
|
|
|
from unifolm_wma.utils.basics import zero_module
|
|
from unifolm_wma.utils.common import (
|
|
checkpoint,
|
|
exists,
|
|
default,
|
|
)
|
|
from unifolm_wma.utils.utils import instantiate_from_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GEGLU(nn.Module):
|
|
|
|
def __init__(self, dim_in, dim_out):
|
|
super().__init__()
|
|
self.proj = nn.Linear(dim_in, dim_out * 2)
|
|
|
|
def forward(self, x):
|
|
x, gate = self.proj(x).chunk(2, dim=-1)
|
|
return x * F.gelu(gate)
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
|
|
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
dim_out = default(dim_out, dim)
|
|
project_in = nn.Sequential(nn.Linear(
|
|
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
|
|
|
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
|
nn.Linear(inner_dim, dim_out))
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
|
|
def __init__(self,
|
|
query_dim,
|
|
context_dim=None,
|
|
heads=8,
|
|
dim_head=64,
|
|
dropout=0.,
|
|
relative_position=False):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
context_dim = default(context_dim, query_dim)
|
|
|
|
self.scale = dim_head**-0.5
|
|
self.heads = heads
|
|
self.dim_head = dim_head
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
|
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
|
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
|
|
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
|
|
nn.Dropout(dropout))
|
|
|
|
def efficient_forward(self, x, context=None):
|
|
spatial_self_attn = (context is None)
|
|
k_ip, v_ip, out_ip = None, None, None
|
|
|
|
q = self.to_q(x)
|
|
if spatial_self_attn:
|
|
context = default(context, x)
|
|
k = self.to_k(context)
|
|
v = self.to_v(context)
|
|
|
|
b, _, _ = q.shape
|
|
q, k, v = map(
|
|
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
|
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
|
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
|
(q, k, v),
|
|
)
|
|
# actually compute the attention, what we cannot get enough of
|
|
out = xformers.ops.memory_efficient_attention(q,
|
|
k,
|
|
v,
|
|
attn_bias=None,
|
|
op=None)
|
|
out = (out.unsqueeze(0).reshape(
|
|
b, self.heads, out.shape[1],
|
|
self.dim_head).permute(0, 2, 1,
|
|
3).reshape(b, out.shape[1],
|
|
self.heads * self.dim_head))
|
|
return self.to_out(out)
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module):
|
|
|
|
def __init__(self,
|
|
dim,
|
|
n_heads,
|
|
d_head,
|
|
dropout=0.,
|
|
context_dim=None,
|
|
gated_ff=True,
|
|
checkpoint=True,
|
|
disable_self_attn=False,
|
|
attention_cls=None):
|
|
super().__init__()
|
|
attn_cls = CrossAttention if attention_cls is None else attention_cls
|
|
self.disable_self_attn = disable_self_attn
|
|
self.attn1 = attn_cls(
|
|
query_dim=dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
dropout=dropout,
|
|
context_dim=context_dim if self.disable_self_attn else None)
|
|
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
|
self.attn2 = attn_cls(query_dim=dim,
|
|
context_dim=context_dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
dropout=dropout)
|
|
|
|
self.norm1 = nn.LayerNorm(dim)
|
|
self.norm2 = nn.LayerNorm(dim)
|
|
self.norm3 = nn.LayerNorm(dim)
|
|
self.checkpoint = checkpoint
|
|
|
|
def forward(self, x, context=None, **kwargs):
|
|
## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
|
|
input_tuple = (
|
|
x,
|
|
) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
|
|
if context is not None:
|
|
input_tuple = (x, context)
|
|
return checkpoint(self._forward, input_tuple, self.parameters(),
|
|
self.checkpoint)
|
|
|
|
def _forward(self, x, context=None, mask=None):
|
|
x = self.attn1(self.norm1(x),
|
|
context=context if self.disable_self_attn else None,
|
|
mask=mask) + x
|
|
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
|
|
x = self.ff(self.norm3(x)) + x
|
|
return x
|
|
|
|
|
|
class ActionLatentImageCrossAttention(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
in_dim,
|
|
n_heads,
|
|
d_head,
|
|
depth=1,
|
|
dropout=0.,
|
|
context_dim=None,
|
|
use_checkpoint=True,
|
|
disable_self_attn=False,
|
|
use_linear=True):
|
|
super().__init__()
|
|
"""
|
|
in_channels: action input dim
|
|
|
|
"""
|
|
self.in_channels = in_channels
|
|
self.in_dim = in_dim
|
|
inner_dim = n_heads * d_head
|
|
self.norm = torch.nn.GroupNorm(num_groups=8,
|
|
num_channels=in_channels,
|
|
eps=1e-6,
|
|
affine=True)
|
|
|
|
self.proj_in_action = nn.Linear(in_dim, inner_dim)
|
|
self.proj_in_cond = nn.Linear(context_dim, inner_dim)
|
|
self.proj_out = zero_module(nn.Linear(inner_dim, in_dim))
|
|
self.use_linear = use_linear
|
|
|
|
attention_cls = None
|
|
self.transformer_blocks = nn.ModuleList([
|
|
BasicTransformerBlock(inner_dim,
|
|
n_heads,
|
|
d_head,
|
|
dropout=dropout,
|
|
context_dim=context_dim,
|
|
disable_self_attn=disable_self_attn,
|
|
checkpoint=use_checkpoint,
|
|
attention_cls=attention_cls)
|
|
for d in range(depth)
|
|
])
|
|
|
|
def forward(self, x, context=None, **kwargs):
|
|
ba, ca, da = x.shape
|
|
b, t, c, h, w = context.shape
|
|
context = rearrange(context, 'b t c h w -> b (t h w) c').contiguous()
|
|
|
|
x_in = x
|
|
x = self.norm(x) # ba x ja x d_in
|
|
if self.use_linear:
|
|
x = self.proj_in_action(x)
|
|
context = self.proj_in_cond(context)
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
x = block(x, context=context, **kwargs)
|
|
if self.use_linear:
|
|
x = self.proj_out(x)
|
|
return x + x_in
|
|
|
|
|
|
class ConditionalResidualBlock1D(nn.Module):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
cond_dim,
|
|
kernel_size=3,
|
|
n_groups=8,
|
|
cond_predict_scale=True,
|
|
use_linear_act_proj=False):
|
|
super().__init__()
|
|
|
|
self.blocks = nn.ModuleList([
|
|
Conv1dBlock(in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
n_groups=n_groups),
|
|
Conv1dBlock(out_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
n_groups=n_groups),
|
|
])
|
|
|
|
self.cond_predict_scale = cond_predict_scale
|
|
self.use_linear_act_proj = use_linear_act_proj
|
|
self.out_channels = out_channels
|
|
# FiLM modulation https://arxiv.org/abs/1709.07871
|
|
# predicts per-channel scale and bias
|
|
cond_channels = out_channels
|
|
if cond_predict_scale and use_linear_act_proj:
|
|
cond_channels = out_channels * 2
|
|
self.cond_encoder = nn.Sequential(
|
|
nn.Mish(),
|
|
nn.Linear(cond_dim, cond_channels),
|
|
)
|
|
# make sure dimensions compatible
|
|
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
|
|
if in_channels != out_channels else nn.Identity()
|
|
|
|
def forward(self, x, cond=None):
|
|
'''
|
|
x : [ batch_size x in_channels x horizon ]
|
|
cond : [ batch_size x cond_dim]
|
|
|
|
returns:
|
|
out : [ batch_size x out_channels x horizon ]
|
|
'''
|
|
B, T, _ = cond.shape
|
|
|
|
out = self.blocks[0](x)
|
|
if self.cond_predict_scale:
|
|
embed = self.cond_encoder(cond)
|
|
if self.use_linear_act_proj:
|
|
embed = embed.reshape(B * T, -1)
|
|
embed = embed.reshape(-1, 2, self.out_channels, 1)
|
|
else:
|
|
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
|
scale = embed[:, 0, ...]
|
|
bias = embed[:, 1, ...]
|
|
out = scale * out + bias
|
|
# else:
|
|
# out = out + embed
|
|
out = self.blocks[1](out)
|
|
out = out + self.residual_conv(x)
|
|
return out
|
|
|
|
|
|
class ConditionalUnet1D(nn.Module):
|
|
|
|
def __init__(self,
|
|
input_dim,
|
|
n_obs_steps=1,
|
|
local_cond_dim=None,
|
|
global_cond_dim=None,
|
|
diffusion_step_embed_dim=256,
|
|
down_dims=[256, 512, 1024],
|
|
kernel_size=3,
|
|
n_groups=8,
|
|
cond_predict_scale=False,
|
|
horizon=16,
|
|
num_head_channels=64,
|
|
use_linear_attn=True,
|
|
use_linear_act_proj=True,
|
|
act_proj_dim=32,
|
|
cond_cross_attention=False,
|
|
context_dims=None,
|
|
image_size=None,
|
|
imagen_cond_gradient=False,
|
|
last_frame_only=False,
|
|
use_imagen_mid_only=False,
|
|
use_z_only=False,
|
|
spatial_num_kp=32,
|
|
obs_encoder_config=None):
|
|
super().__init__()
|
|
|
|
self.n_obs_steps = n_obs_steps
|
|
self.obs_encoder = instantiate_from_config(obs_encoder_config)
|
|
|
|
all_dims = [input_dim] + list(down_dims)
|
|
start_dim = down_dims[0]
|
|
|
|
dsed = diffusion_step_embed_dim
|
|
diffusion_step_encoder = nn.Sequential(
|
|
SinusoidalPosEmb(dsed),
|
|
nn.Linear(dsed, dsed * 4),
|
|
nn.Mish(),
|
|
nn.Linear(dsed * 4, dsed),
|
|
)
|
|
cond_dim = dsed + self.obs_encoder.output_shape()[-1] * self.n_obs_steps
|
|
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
|
local_cond_encoder = None
|
|
down_modules = nn.ModuleList([])
|
|
|
|
dim_a_list = []
|
|
for ind, (dim_in, dim_out) in enumerate(in_out):
|
|
is_last = ind >= (len(in_out) - 1)
|
|
if ind == 0:
|
|
dim_a = horizon
|
|
else:
|
|
dim_a = horizon // 2 * ind
|
|
dim_a_list.append(dim_a)
|
|
|
|
# for attention
|
|
num_heads = dim_out // num_head_channels
|
|
dim_head = num_head_channels
|
|
if use_linear_act_proj:
|
|
if use_imagen_mid_only:
|
|
cur_cond_dim = cond_dim + 2 * context_dims[-1]
|
|
elif use_z_only:
|
|
cur_cond_dim = cond_dim + 2 * spatial_num_kp
|
|
else:
|
|
cur_cond_dim = cond_dim + 2 * context_dims[ind]
|
|
else:
|
|
cur_cond_dim = cond_dim + horizon * context_dims[ind]
|
|
|
|
down_modules.append(
|
|
nn.ModuleList([
|
|
ConditionalResidualBlock1D(
|
|
dim_in,
|
|
dim_out,
|
|
cond_dim=cur_cond_dim,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
cond_predict_scale=cond_predict_scale,
|
|
use_linear_act_proj=use_linear_act_proj),
|
|
ConditionalResidualBlock1D(
|
|
dim_out,
|
|
dim_out,
|
|
cond_dim=cur_cond_dim,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
cond_predict_scale=cond_predict_scale,
|
|
use_linear_act_proj=use_linear_act_proj),
|
|
ActionLatentImageCrossAttention(
|
|
dim_out,
|
|
dim_a,
|
|
num_heads,
|
|
dim_head,
|
|
context_dim=context_dims[ind],
|
|
use_linear=use_linear_attn)
|
|
if cond_cross_attention else nn.Identity(),
|
|
Downsample1d(dim_out) if not is_last else nn.Identity()
|
|
]))
|
|
|
|
mid_dim = all_dims[-1]
|
|
self.mid_modules = nn.ModuleList([
|
|
ConditionalResidualBlock1D(
|
|
mid_dim,
|
|
mid_dim,
|
|
cond_dim=cur_cond_dim,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
cond_predict_scale=cond_predict_scale,
|
|
use_linear_act_proj=use_linear_act_proj),
|
|
ConditionalResidualBlock1D(
|
|
mid_dim,
|
|
mid_dim,
|
|
cond_dim=cur_cond_dim,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
cond_predict_scale=cond_predict_scale,
|
|
use_linear_act_proj=use_linear_act_proj),
|
|
ActionLatentImageCrossAttention(mid_dim,
|
|
dim_a_list[-1],
|
|
num_heads,
|
|
dim_head,
|
|
context_dim=context_dims[-1],
|
|
use_linear=use_linear_attn)
|
|
if cond_cross_attention else nn.Identity(),
|
|
])
|
|
|
|
up_modules = nn.ModuleList([])
|
|
context_dims = context_dims[::-1]
|
|
for ind, (dim_in, dim_out) in enumerate(
|
|
reversed(in_out[1:] + [(down_dims[-1], down_dims[-1])])):
|
|
is_last = ind >= (len(in_out) - 1)
|
|
if use_linear_act_proj:
|
|
if use_imagen_mid_only:
|
|
cur_cond_dim = cond_dim + 2 * context_dims[0]
|
|
elif use_z_only:
|
|
cur_cond_dim = cond_dim + 2 * spatial_num_kp
|
|
else:
|
|
cur_cond_dim = cond_dim + 2 * context_dims[ind]
|
|
else:
|
|
cur_cond_dim = cond_dim + horizon * context_dims[ind]
|
|
up_modules.append(
|
|
nn.ModuleList([
|
|
ConditionalResidualBlock1D(
|
|
dim_out + dim_in,
|
|
dim_in,
|
|
cond_dim=cur_cond_dim,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
cond_predict_scale=cond_predict_scale,
|
|
use_linear_act_proj=use_linear_act_proj),
|
|
ConditionalResidualBlock1D(
|
|
dim_in,
|
|
dim_in,
|
|
cond_dim=cur_cond_dim,
|
|
kernel_size=kernel_size,
|
|
n_groups=n_groups,
|
|
cond_predict_scale=cond_predict_scale,
|
|
use_linear_act_proj=use_linear_act_proj),
|
|
ActionLatentImageCrossAttention(
|
|
dim_in,
|
|
dim_a_list.pop(),
|
|
num_heads,
|
|
dim_head,
|
|
context_dim=context_dims[ind],
|
|
use_linear=use_linear_attn)
|
|
if cond_cross_attention else nn.Identity(),
|
|
Upsample1d(dim_in) if not is_last else nn.Identity()
|
|
]))
|
|
|
|
final_conv = nn.Sequential(
|
|
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
|
nn.Conv1d(start_dim, input_dim, 1),
|
|
)
|
|
|
|
if use_z_only:
|
|
h, w = image_size
|
|
self.spatial_softmax_blocks = nn.ModuleList(
|
|
[SpatialSoftmax((4, h, w), spatial_num_kp)])
|
|
else:
|
|
self.spatial_softmax_blocks = nn.ModuleList([])
|
|
context_dims = context_dims[::-1]
|
|
for ind, context_dim in enumerate(context_dims):
|
|
h, w = image_size
|
|
if ind != 0:
|
|
h //= 2**ind
|
|
w //= 2**ind
|
|
net = SpatialSoftmax((context_dim, h, w), context_dim)
|
|
self.spatial_softmax_blocks.append(net)
|
|
self.spatial_softmax_blocks.append(net)
|
|
self.spatial_softmax_blocks += self.spatial_softmax_blocks[
|
|
0:4][::-1]
|
|
|
|
self.diffusion_step_encoder = diffusion_step_encoder
|
|
self.local_cond_encoder = local_cond_encoder
|
|
self.up_modules = up_modules
|
|
self.down_modules = down_modules
|
|
self.final_conv = final_conv
|
|
|
|
self.cond_cross_attention = cond_cross_attention
|
|
self.use_linear_act_proj = use_linear_act_proj
|
|
|
|
self.proj_in_action = nn.Sequential(nn.Linear(1, act_proj_dim),
|
|
nn.LayerNorm(act_proj_dim))
|
|
self.proj_in_horizon = nn.Sequential(nn.Linear(horizon, act_proj_dim),
|
|
nn.LayerNorm(act_proj_dim))
|
|
self.proj_out_action = nn.Sequential(nn.LayerNorm(act_proj_dim),
|
|
nn.Linear(act_proj_dim, 1))
|
|
self.proj_out_horizon = nn.Sequential(nn.LayerNorm(act_proj_dim),
|
|
nn.Linear(act_proj_dim, horizon))
|
|
logger.info("number of parameters: %e",
|
|
sum(p.numel() for p in self.parameters()))
|
|
|
|
self.imagen_cond_gradient = imagen_cond_gradient
|
|
self.use_imagen_mid_only = use_imagen_mid_only
|
|
self.use_z_only = use_z_only
|
|
self.spatial_num_kp = spatial_num_kp
|
|
self.last_frame_only = last_frame_only
|
|
self.horizon = horizon
|
|
|
|
# Context precomputation cache
|
|
self._global_cond_cache_enabled = False
|
|
self._global_cond_cache = {}
|
|
|
|
def forward(self,
|
|
sample: torch.Tensor,
|
|
timestep: Union[torch.Tensor, float, int],
|
|
imagen_cond=None,
|
|
cond=None,
|
|
**kwargs):
|
|
"""
|
|
sample: (B,T,input_dim)
|
|
timestep: (B,) or int, diffusion step
|
|
imagen_cond: a list of hidden info from video gen unet
|
|
cond: dict:
|
|
image: (B, 3, To, h, w)
|
|
agent_pos: (B, Ta, d)
|
|
output: (B,T,input_dim)
|
|
"""
|
|
|
|
if not self.imagen_cond_gradient:
|
|
imagen_cond = [c.detach() for c in imagen_cond]
|
|
|
|
cond = {'image': cond[0], 'agent_pos': cond[1]}
|
|
|
|
cond['image'] = cond['image'].permute(0, 2, 1, 3,
|
|
4)
|
|
cond['image'] = rearrange(cond['image'], 'b t c h w -> (b t) c h w')
|
|
cond['agent_pos'] = rearrange(cond['agent_pos'], 'b t d -> (b t) d')
|
|
|
|
B, T, D = sample.shape
|
|
if self.use_linear_act_proj:
|
|
sample = self.proj_in_action(sample.unsqueeze(-1))
|
|
_gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr())
|
|
if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache:
|
|
global_cond = self._global_cond_cache[_gc_key]
|
|
else:
|
|
global_cond = self.obs_encoder(cond)
|
|
global_cond = rearrange(global_cond,
|
|
'(b t) d -> b 1 (t d)',
|
|
b=B,
|
|
t=self.n_obs_steps)
|
|
global_cond = repeat(global_cond,
|
|
'b c d -> b (repeat c) d',
|
|
repeat=T)
|
|
if self._global_cond_cache_enabled:
|
|
self._global_cond_cache[_gc_key] = global_cond
|
|
else:
|
|
sample = einops.rearrange(sample, 'b h t -> b t h')
|
|
sample = self.proj_in_horizon(sample)
|
|
robo_state_cond = rearrange(robo_state_cond, 'b t d -> b 1 (t d)')
|
|
robo_state_cond = repeat(robo_state_cond,
|
|
'b c d -> b (repeat c) d',
|
|
repeat=2)
|
|
|
|
# 1. time
|
|
timesteps = timestep
|
|
if not torch.is_tensor(timesteps):
|
|
timesteps = torch.tensor([timesteps],
|
|
dtype=torch.long,
|
|
device=sample.device)
|
|
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
|
timesteps = timesteps[None].to(sample.device)
|
|
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
|
timesteps = timesteps.expand(sample.shape[0])
|
|
global_feature = self.diffusion_step_encoder(timesteps)
|
|
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
|
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
|
|
|
x = sample if not self.use_linear_act_proj else sample.reshape(
|
|
B * T, D, -1)
|
|
h = []
|
|
for idx, modules in enumerate(self.down_modules):
|
|
if self.cond_cross_attention:
|
|
(resnet, resnet2, crossatten, downsample) = modules
|
|
else:
|
|
(resnet, resnet2, _, downsample) = modules
|
|
|
|
# Access the cond from the unet embeds from video unet
|
|
if self.use_imagen_mid_only:
|
|
imagen_cond = imagen_cond_mid
|
|
elif self.use_z_only:
|
|
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
|
else:
|
|
imagen_cond = imagen_cond_down[idx]
|
|
if self.last_frame_only:
|
|
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
|
imagen_cond = repeat(imagen_cond,
|
|
'b t c h w -> b (repeat t) c h w',
|
|
repeat=self.horizon)
|
|
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
|
if self.use_imagen_mid_only:
|
|
imagen_cond = self.spatial_softmax_blocks[len(
|
|
self.spatial_softmax_blocks) // 2](imagen_cond)
|
|
elif self.use_z_only:
|
|
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
|
else:
|
|
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
|
|
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
|
|
|
if self.use_linear_act_proj:
|
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
|
cur_global_feature = global_feature.unsqueeze(
|
|
1).repeat_interleave(repeats=T, dim=1)
|
|
else:
|
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
|
cur_global_feature = global_feature.unsqueeze(
|
|
1).repeat_interleave(repeats=2, dim=1)
|
|
cur_global_feature = torch.cat(
|
|
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
|
x = resnet(x, cur_global_feature)
|
|
x = resnet2(x, cur_global_feature)
|
|
h.append(x)
|
|
x = downsample(x)
|
|
|
|
#>>> mide blocks
|
|
resnet, resnet2, _ = self.mid_modules
|
|
# Access the cond from the unet embeds from video unet
|
|
if self.use_z_only:
|
|
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
|
else:
|
|
imagen_cond = imagen_cond_mid
|
|
if self.last_frame_only:
|
|
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
|
imagen_cond = repeat(imagen_cond,
|
|
'b t c h w -> b (repeat t) c h w',
|
|
repeat=self.horizon)
|
|
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
|
idx += 1
|
|
if self.use_z_only:
|
|
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
|
else:
|
|
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
|
|
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
|
if self.use_linear_act_proj:
|
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
|
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
|
repeats=T, dim=1)
|
|
else:
|
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
|
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
|
repeats=2, dim=1)
|
|
cur_global_feature = torch.cat(
|
|
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
|
x = resnet(x, cur_global_feature)
|
|
x = resnet2(x, cur_global_feature)
|
|
|
|
#>>> up blocks
|
|
idx += 1
|
|
for jdx, modules in enumerate(self.up_modules):
|
|
if self.cond_cross_attention:
|
|
(resnet, resnet2, crossatten, upsample) = modules
|
|
else:
|
|
(resnet, resnet2, _, upsample) = modules
|
|
|
|
# Access the cond from the unet embeds from video unet
|
|
if self.use_imagen_mid_only:
|
|
imagen_cond = imagen_cond_mid
|
|
elif self.use_z_only:
|
|
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
|
else:
|
|
imagen_cond = imagen_cond_up[jdx]
|
|
if self.last_frame_only:
|
|
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
|
imagen_cond = repeat(imagen_cond,
|
|
'b t c h w -> b (repeat t) c h w',
|
|
repeat=self.horizon)
|
|
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
|
if self.use_imagen_mid_only:
|
|
imagen_cond = self.spatial_softmax_blocks[len(
|
|
self.spatial_softmax_blocks) // 2](imagen_cond)
|
|
elif self.use_z_only:
|
|
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
|
else:
|
|
imagen_cond = self.spatial_softmax_blocks[jdx +
|
|
idx](imagen_cond)
|
|
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
|
|
|
if self.use_linear_act_proj:
|
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
|
cur_global_feature = global_feature.unsqueeze(
|
|
1).repeat_interleave(repeats=T, dim=1)
|
|
else:
|
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
|
cur_global_feature = global_feature.unsqueeze(
|
|
1).repeat_interleave(repeats=2, dim=1)
|
|
|
|
cur_global_feature = torch.cat(
|
|
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
|
|
|
x = torch.cat((x, h.pop()), dim=1)
|
|
x = resnet(x, cur_global_feature)
|
|
x = resnet2(x, cur_global_feature)
|
|
x = upsample(x)
|
|
|
|
x = self.final_conv(x)
|
|
|
|
if self.use_linear_act_proj:
|
|
x = x.reshape(B, T, D, -1)
|
|
x = self.proj_out_action(x)
|
|
x = x.reshape(B, T, D)
|
|
else:
|
|
x = self.proj_out_horizon(x)
|
|
x = einops.rearrange(x, 'b t h -> b h t')
|
|
return x
|