init commit
This commit is contained in:
0
src/unifolm_wma/models/__init__.py
Normal file
0
src/unifolm_wma/models/__init__.py
Normal file
267
src/unifolm_wma/models/autoencoder.py
Normal file
267
src/unifolm_wma/models/autoencoder.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from einops import rearrange
|
||||
from unifolm_wma.modules.networks.ae_modules import Encoder, Decoder
|
||||
from unifolm_wma.utils.distributions import DiagonalGaussianDistribution
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
test=False,
|
||||
logdir=None,
|
||||
input_dim=4,
|
||||
test_args=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
self.input_dim = input_dim
|
||||
self.test = test
|
||||
self.test_args = test_args
|
||||
self.logdir = logdir
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer("colorize",
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
if self.test:
|
||||
self.init_test()
|
||||
|
||||
def init_test(self, ):
|
||||
self.test = True
|
||||
save_dir = os.path.join(self.logdir, "test")
|
||||
if 'ckpt' in self.test_args:
|
||||
ckpt_name = os.path.basename(self.test_args.ckpt).split(
|
||||
'.ckpt')[0] + f'_epoch{self._cur_epoch}'
|
||||
self.root = os.path.join(save_dir, ckpt_name)
|
||||
else:
|
||||
self.root = save_dir
|
||||
if 'test_subdir' in self.test_args:
|
||||
self.root = os.path.join(save_dir, self.test_args.test_subdir)
|
||||
|
||||
self.root_zs = os.path.join(self.root, "zs")
|
||||
self.root_dec = os.path.join(self.root, "reconstructions")
|
||||
self.root_inputs = os.path.join(self.root, "inputs")
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
|
||||
if self.test_args.save_z:
|
||||
os.makedirs(self.root_zs, exist_ok=True)
|
||||
if self.test_args.save_reconstruction:
|
||||
os.makedirs(self.root_dec, exist_ok=True)
|
||||
if self.test_args.save_input:
|
||||
os.makedirs(self.root_inputs, exist_ok=True)
|
||||
assert (self.test_args is not None)
|
||||
self.test_maximum = getattr(self.test_args, 'test_maximum', None)
|
||||
self.count = 0
|
||||
self.eval_metrics = {}
|
||||
self.decodes = []
|
||||
self.save_decode_samples = 2048
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
try:
|
||||
self._cur_epoch = sd['epoch']
|
||||
sd = sd["state_dict"]
|
||||
except:
|
||||
self._cur_epoch = 'null'
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z, **kwargs):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if x.dim() == 5 and self.input_dim == 4:
|
||||
b, c, t, h, w = x.shape
|
||||
self.b = b
|
||||
self.t = t
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train")
|
||||
self.log("aeloss",
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train")
|
||||
|
||||
self.log("discloss",
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val")
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
|
||||
list(self.decoder.parameters()) +
|
||||
list(self.quant_conv.parameters()) +
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize",
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
2524
src/unifolm_wma/models/ddpms.py
Normal file
2524
src/unifolm_wma/models/ddpms.py
Normal file
File diff suppressed because it is too large
Load Diff
0
src/unifolm_wma/models/diffusion_head/__init__.py
Normal file
0
src/unifolm_wma/models/diffusion_head/__init__.py
Normal file
217
src/unifolm_wma/models/diffusion_head/base_nets.py
Normal file
217
src/unifolm_wma/models/diffusion_head/base_nets.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Contains torch Modules that correspond to basic network building blocks, like
|
||||
MLP, RNN, and CNN backbones.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
"""
|
||||
Base class for networks. The only difference from torch.nn.Module is that it
|
||||
requires implementing @output_shape.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def output_shape(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
"""
|
||||
================================================
|
||||
Visual Backbone Networks
|
||||
================================================
|
||||
"""
|
||||
|
||||
|
||||
class ConvBase(Module):
|
||||
"""
|
||||
Base class for ConvNets.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ConvBase, self).__init__()
|
||||
|
||||
# dirty hack - re-implement to pass the buck onto subclasses from ABC parent
|
||||
def output_shape(self, input_shape):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.nets(inputs)
|
||||
if list(self.output_shape(list(inputs.shape)[1:])) != list(
|
||||
x.shape)[1:]:
|
||||
raise ValueError('Size mismatch: expect size %s, but got size %s' %
|
||||
(str(self.output_shape(list(
|
||||
inputs.shape)[1:])), str(list(x.shape)[1:])))
|
||||
return x
|
||||
|
||||
|
||||
class SpatialSoftmax(ConvBase):
|
||||
"""
|
||||
Spatial Softmax Layer.
|
||||
|
||||
Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al.
|
||||
https://rll.berkeley.edu/dsae/dsae.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
num_kp=32,
|
||||
temperature=1.,
|
||||
learnable_temperature=False,
|
||||
output_variance=False,
|
||||
noise_std=0.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): shape of the input feature (C, H, W)
|
||||
num_kp (int): number of keypoints (None for not using spatialsoftmax)
|
||||
temperature (float): temperature term for the softmax.
|
||||
learnable_temperature (bool): whether to learn the temperature
|
||||
output_variance (bool): treat attention as a distribution, and compute second-order statistics to return
|
||||
noise_std (float): add random spatial noise to the predicted keypoints
|
||||
"""
|
||||
super(SpatialSoftmax, self).__init__()
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape # (C, H, W)
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._num_kp = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._num_kp = self._in_c
|
||||
self.learnable_temperature = learnable_temperature
|
||||
self.output_variance = output_variance
|
||||
self.noise_std = noise_std
|
||||
|
||||
if self.learnable_temperature:
|
||||
# temperature will be learned
|
||||
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
|
||||
requires_grad=True)
|
||||
self.register_parameter('temperature', temperature)
|
||||
else:
|
||||
# temperature held constant after initialization
|
||||
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
|
||||
requires_grad=False)
|
||||
self.register_buffer('temperature', temperature)
|
||||
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., self._in_w),
|
||||
np.linspace(-1., 1., self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h *
|
||||
self._in_w)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h *
|
||||
self._in_w)).float()
|
||||
self.register_buffer('pos_x', pos_x)
|
||||
self.register_buffer('pos_y', pos_y)
|
||||
|
||||
self.kps = None
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty print network."""
|
||||
header = format(str(self.__class__.__name__))
|
||||
return header + '(num_kp={}, temperature={}, noise={})'.format(
|
||||
self._num_kp, self.temperature.item(), self.noise_std)
|
||||
|
||||
def output_shape(self, input_shape):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
assert (len(input_shape) == 3)
|
||||
assert (input_shape[0] == self._in_c)
|
||||
return [self._num_kp, 2]
|
||||
|
||||
def forward(self, feature):
|
||||
"""
|
||||
Forward pass through spatial softmax layer. For each keypoint, a 2D spatial
|
||||
probability distribution is created using a softmax, where the support is the
|
||||
pixel locations. This distribution is used to compute the expected value of
|
||||
the pixel location, which becomes a keypoint of dimension 2. K such keypoints
|
||||
are created.
|
||||
|
||||
Returns:
|
||||
out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly
|
||||
keypoint variance of shape [B, K, 2, 2] corresponding to the covariance
|
||||
under the 2D spatial softmax distribution
|
||||
"""
|
||||
assert (feature.shape[1] == self._in_c)
|
||||
assert (feature.shape[2] == self._in_h)
|
||||
assert (feature.shape[3] == self._in_w)
|
||||
if self.nets is not None:
|
||||
feature = self.nets(feature)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
feature = feature.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(feature / self.temperature, dim=-1)
|
||||
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
|
||||
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
|
||||
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
|
||||
# stack to [B * K, 2]
|
||||
expected_xy = torch.cat([expected_x, expected_y], 1)
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._num_kp, 2)
|
||||
|
||||
if self.training:
|
||||
noise = torch.randn_like(feature_keypoints) * self.noise_std
|
||||
feature_keypoints += noise
|
||||
|
||||
if self.output_variance:
|
||||
# treat attention as a distribution, and compute second-order statistics to return
|
||||
expected_xx = torch.sum(self.pos_x * self.pos_x * attention,
|
||||
dim=1,
|
||||
keepdim=True)
|
||||
expected_yy = torch.sum(self.pos_y * self.pos_y * attention,
|
||||
dim=1,
|
||||
keepdim=True)
|
||||
expected_xy = torch.sum(self.pos_x * self.pos_y * attention,
|
||||
dim=1,
|
||||
keepdim=True)
|
||||
var_x = expected_xx - expected_x * expected_x
|
||||
var_y = expected_yy - expected_y * expected_y
|
||||
var_xy = expected_xy - expected_x * expected_y
|
||||
# stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix
|
||||
feature_covar = torch.cat([var_x, var_xy, var_xy, var_y],
|
||||
1).reshape(-1, self._num_kp, 2, 2)
|
||||
feature_keypoints = (feature_keypoints, feature_covar)
|
||||
|
||||
if isinstance(feature_keypoints, tuple):
|
||||
self.kps = (feature_keypoints[0].detach(),
|
||||
feature_keypoints[1].detach())
|
||||
else:
|
||||
self.kps = feature_keypoints.detach()
|
||||
return feature_keypoints
|
||||
83
src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py
Normal file
83
src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from diffusers.optimization import (Union, SchedulerType, Optional, Optimizer,
|
||||
TYPE_TO_SCHEDULER_FUNCTION)
|
||||
|
||||
|
||||
def get_scheduler(name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Added kwargs vs diffuser's original implementation
|
||||
|
||||
Unified API to get any scheduler from its name.
|
||||
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer, **kwargs)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(
|
||||
f"{name} requires `num_warmup_steps`, please provide that argument."
|
||||
)
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
**kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(
|
||||
f"{name} requires `num_training_steps`, please provide that argument."
|
||||
)
|
||||
|
||||
return schedule_func(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
**kwargs)
|
||||
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import pytorch_lightning as pl
|
||||
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType
|
||||
|
||||
|
||||
class SelectiveLRScheduler(_LRScheduler):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
base_scheduler,
|
||||
group_indices,
|
||||
default_lr=[1e-5, 1e-4],
|
||||
last_epoch=-1):
|
||||
self.base_scheduler = base_scheduler
|
||||
self.group_indices = group_indices # Indices of parameter groups to update
|
||||
self.default_lr = default_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def step(self, epoch=None):
|
||||
self.base_scheduler.step()
|
||||
base_lrs = self.base_scheduler.get_last_lr()
|
||||
|
||||
for idx, group in enumerate(self.optimizer.param_groups):
|
||||
if idx in self.group_indices:
|
||||
group['lr'] = base_lrs[idx]
|
||||
else:
|
||||
# Reset the learning rate to its initial value
|
||||
group['lr'] = self.default_lr[idx]
|
||||
@@ -0,0 +1,16 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModuleAttrMixin(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._dummy_variable = nn.Parameter()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
91
src/unifolm_wma/models/diffusion_head/common/pytorch_util.py
Normal file
91
src/unifolm_wma/models/diffusion_head/common/pytorch_util.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import collections
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Callable, List
|
||||
|
||||
|
||||
def dict_apply(
|
||||
x: Dict[str, torch.Tensor],
|
||||
func: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
result = dict()
|
||||
for key, value in x.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = dict_apply(value, func)
|
||||
else:
|
||||
result[key] = func(value)
|
||||
return result
|
||||
|
||||
|
||||
def pad_remaining_dims(x, target):
|
||||
assert x.shape == target.shape[:len(x.shape)]
|
||||
return x.reshape(x.shape + (1, ) * (len(target.shape) - len(x.shape)))
|
||||
|
||||
|
||||
def dict_apply_split(
|
||||
x: Dict[str, torch.Tensor], split_func: Callable[[torch.Tensor],
|
||||
Dict[str, torch.Tensor]]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
results = collections.defaultdict(dict)
|
||||
for key, value in x.items():
|
||||
result = split_func(value)
|
||||
for k, v in result.items():
|
||||
results[k][key] = v
|
||||
return results
|
||||
|
||||
|
||||
def dict_apply_reduce(
|
||||
x: List[Dict[str,
|
||||
torch.Tensor]], reduce_func: Callable[[List[torch.Tensor]],
|
||||
torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
result = dict()
|
||||
for key in x[0].keys():
|
||||
result[key] = reduce_func([x_[key] for x_ in x])
|
||||
return result
|
||||
|
||||
|
||||
def replace_submodules(root_module: nn.Module, predicate: Callable[[nn.Module],
|
||||
bool],
|
||||
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
|
||||
"""
|
||||
predicate: Return true if the module is to be replaced.
|
||||
func: Return new module to use.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
bn_list = [
|
||||
k.split('.')
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
for *parent, k in bn_list:
|
||||
parent_module = root_module
|
||||
if len(parent) > 0:
|
||||
parent_module = root_module.get_submodule('.'.join(parent))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
bn_list = [
|
||||
k.split('.')
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
assert len(bn_list) == 0
|
||||
return root_module
|
||||
|
||||
|
||||
def optimizer_to(optimizer, device):
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(device=device)
|
||||
return optimizer
|
||||
960
src/unifolm_wma/models/diffusion_head/common/tensor_util.py
Normal file
960
src/unifolm_wma/models/diffusion_head/common/tensor_util.py
Normal file
@@ -0,0 +1,960 @@
|
||||
"""
|
||||
A collection of utilities for working with nested tensor structures consisting
|
||||
of numpy arrays and torch tensors.
|
||||
"""
|
||||
import collections
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def recursive_dict_list_tuple_apply(x, type_func_dict):
|
||||
"""
|
||||
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
|
||||
{data_type: function_to_apply}.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
type_func_dict (dict): a mapping from data types to the functions to be
|
||||
applied for each data type.
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
assert (list not in type_func_dict)
|
||||
assert (tuple not in type_func_dict)
|
||||
assert (dict not in type_func_dict)
|
||||
|
||||
if isinstance(x, (dict, collections.OrderedDict)):
|
||||
new_x = collections.OrderedDict() if isinstance(
|
||||
x, collections.OrderedDict) else dict()
|
||||
for k, v in x.items():
|
||||
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
|
||||
return new_x
|
||||
elif isinstance(x, (list, tuple)):
|
||||
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
|
||||
if isinstance(x, tuple):
|
||||
ret = tuple(ret)
|
||||
return ret
|
||||
else:
|
||||
for t, f in type_func_dict.items():
|
||||
if isinstance(x, t):
|
||||
return f(x)
|
||||
else:
|
||||
raise NotImplementedError('Cannot handle data type %s' %
|
||||
str(type(x)))
|
||||
|
||||
|
||||
def map_tensor(x, func):
|
||||
"""
|
||||
Apply function @func to torch.Tensor objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each tensor
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor: func,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def map_ndarray(x, func):
|
||||
"""
|
||||
Apply function @func to np.ndarray objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
np.ndarray: func,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def map_tensor_ndarray(x, tensor_func, ndarray_func):
|
||||
"""
|
||||
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
|
||||
np.ndarray objects in a nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
tensor_func (function): function to apply to each tensor
|
||||
ndarray_Func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: tensor_func,
|
||||
np.ndarray: ndarray_func,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def clone(x):
|
||||
"""
|
||||
Clones all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.clone(),
|
||||
np.ndarray: lambda x: x.copy(),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def detach(x):
|
||||
"""
|
||||
Detaches all torch tensors in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor: lambda x: x.detach(),
|
||||
})
|
||||
|
||||
|
||||
def to_batch(x):
|
||||
"""
|
||||
Introduces a leading batch dimension of 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x[None, ...],
|
||||
np.ndarray: lambda x: x[None, ...],
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_sequence(x):
|
||||
"""
|
||||
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x[:, None, ...],
|
||||
np.ndarray: lambda x: x[:, None, ...],
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def index_at_time(x, ind):
|
||||
"""
|
||||
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
|
||||
nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
ind (int): index
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x[:, ind, ...],
|
||||
np.ndarray: lambda x: x[:, ind, ...],
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def unsqueeze(x, dim):
|
||||
"""
|
||||
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
|
||||
in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
dim (int): dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
|
||||
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def contiguous(x):
|
||||
"""
|
||||
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
|
||||
list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.contiguous(),
|
||||
np.ndarray: lambda x: np.ascontiguousarray(x),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_device(x, device):
|
||||
"""
|
||||
Sends all torch tensors in nested dictionary or list or tuple to device
|
||||
@device, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x, d=device: x.to(d),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_tensor(x):
|
||||
"""
|
||||
Converts all numpy arrays in nested dictionary or list or tuple to
|
||||
torch tensors (and leaves existing torch Tensors as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x,
|
||||
np.ndarray: lambda x: torch.from_numpy(x),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
"""
|
||||
Converts all torch tensors in nested dictionary or list or tuple to
|
||||
numpy (and leaves existing numpy arrays as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy()
|
||||
else:
|
||||
return tensor.detach().numpy()
|
||||
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_list(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to a list, and returns a new nested structure. Useful for
|
||||
json encoding.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy().tolist()
|
||||
else:
|
||||
return tensor.detach().numpy().tolist()
|
||||
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x.tolist(),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_float(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to float type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.float(),
|
||||
np.ndarray: lambda x: x.astype(np.float32),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_uint8(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to uint8 type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.byte(),
|
||||
np.ndarray: lambda x: x.astype(np.uint8),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_torch(x, device):
|
||||
"""
|
||||
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
|
||||
torch tensors on device @device and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return to_device(to_float(to_tensor(x)), device)
|
||||
|
||||
|
||||
def to_one_hot_single(tensor, num_class):
|
||||
"""
|
||||
Convert tensor to one-hot representation, assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): tensor containing integer labels
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): tensor containing one-hot representation of labels
|
||||
"""
|
||||
x = torch.zeros(tensor.size() + (num_class, )).to(tensor.device)
|
||||
x.scatter_(-1, tensor.unsqueeze(-1), 1)
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, num_class):
|
||||
"""
|
||||
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
|
||||
assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(tensor,
|
||||
func=lambda x, nc=num_class: to_one_hot_single(x, nc))
|
||||
|
||||
|
||||
def flatten_single(x, begin_axis=1):
|
||||
"""
|
||||
Flatten a tensor in all dimensions from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to flatten
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): flattened tensor
|
||||
"""
|
||||
fixed_size = x.size()[:begin_axis]
|
||||
_s = list(fixed_size) + [-1]
|
||||
return x.reshape(*_s)
|
||||
|
||||
|
||||
def flatten(x, begin_axis=1):
|
||||
"""
|
||||
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor:
|
||||
lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
|
||||
})
|
||||
|
||||
|
||||
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions in a tensor to a target dimension.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to reshape
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reshaped tensor
|
||||
"""
|
||||
assert (begin_axis <= end_axis)
|
||||
assert (begin_axis >= 0)
|
||||
assert (end_axis < len(x.shape))
|
||||
assert (isinstance(target_dims, (tuple, list)))
|
||||
s = x.shape
|
||||
final_s = []
|
||||
for i in range(len(s)):
|
||||
if i == begin_axis:
|
||||
final_s.extend(target_dims)
|
||||
elif i < begin_axis or i > end_axis:
|
||||
final_s.append(s[i])
|
||||
return x.reshape(*final_s)
|
||||
|
||||
|
||||
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
|
||||
to a target dimension.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor:
|
||||
lambda x, b=begin_axis, e=end_axis, t=target_dims:
|
||||
reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t),
|
||||
np.ndarray:
|
||||
lambda x, b=begin_axis, e=end_axis, t=target_dims:
|
||||
reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t),
|
||||
type(None):
|
||||
lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def join_dimensions(x, begin_axis, end_axis):
|
||||
"""
|
||||
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
|
||||
all tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor:
|
||||
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]),
|
||||
np.ndarray:
|
||||
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]),
|
||||
type(None):
|
||||
lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def expand_at_single(x, size, dim):
|
||||
"""
|
||||
Expand a tensor at a single dimension @dim by @size
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input tensor
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): expanded tensor
|
||||
"""
|
||||
assert dim < x.ndimension()
|
||||
assert x.shape[dim] == 1
|
||||
expand_dims = [-1] * x.ndimension()
|
||||
expand_dims[dim] = size
|
||||
return x.expand(*expand_dims)
|
||||
|
||||
|
||||
def expand_at(x, size, dim):
|
||||
"""
|
||||
Expand all tensors in nested dictionary or list or tuple at a single
|
||||
dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
|
||||
|
||||
|
||||
def unsqueeze_expand_at(x, size, dim):
|
||||
"""
|
||||
Unsqueeze and expand a tensor at a dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to unsqueeze and expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze(x, dim)
|
||||
return expand_at(x, size, dim)
|
||||
|
||||
|
||||
def repeat_by_expand_at(x, repeats, dim):
|
||||
"""
|
||||
Repeat a dimension by combining expand and reshape operations.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
repeats (int): number of times to repeat the target dimension
|
||||
dim (int): dimension to repeat on
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze_expand_at(x, repeats, dim + 1)
|
||||
return join_dimensions(x, dim, dim + 1)
|
||||
|
||||
|
||||
def named_reduce_single(x, reduction, dim):
|
||||
"""
|
||||
Reduce tensor at a dimension by named reduction functions.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to be reduced
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reduced tensor
|
||||
"""
|
||||
assert x.ndimension() > dim
|
||||
assert reduction in ["sum", "max", "mean", "flatten"]
|
||||
if reduction == "flatten":
|
||||
x = flatten(x, begin_axis=dim)
|
||||
elif reduction == "max":
|
||||
x = torch.max(x, dim=dim)[0] # [B, D]
|
||||
elif reduction == "sum":
|
||||
x = torch.sum(x, dim=dim)
|
||||
else:
|
||||
x = torch.mean(x, dim=dim)
|
||||
return x
|
||||
|
||||
|
||||
def named_reduce(x, reduction, dim):
|
||||
"""
|
||||
Reduces all tensors in nested dictionary or list or tuple at a dimension
|
||||
using a named reduction function.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(
|
||||
x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
|
||||
|
||||
|
||||
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
This function indexes out a target dimension of a tensor in a structured way,
|
||||
by allowing a different value to be selected for each member of a flat index
|
||||
tensor (@indices) corresponding to a source dimension. This can be interpreted
|
||||
as moving along the source dimension, using the corresponding index value
|
||||
in @indices to select values for all other dimensions outside of the
|
||||
source and target dimensions. A common use case is to gather values
|
||||
in target dimension 1 for each batch member (target dimension 0).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to gather values for
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
|
||||
"""
|
||||
assert len(indices.shape) == 1
|
||||
assert x.shape[source_dim] == indices.shape[0]
|
||||
|
||||
# unsqueeze in all dimensions except the source dimension
|
||||
new_shape = [1] * x.ndimension()
|
||||
new_shape[source_dim] = -1
|
||||
indices = indices.reshape(*new_shape)
|
||||
|
||||
# repeat in all dimensions - but preserve shape of source dimension,
|
||||
# and make sure target_dimension has singleton dimension
|
||||
expand_shape = list(x.shape)
|
||||
expand_shape[source_dim] = -1
|
||||
expand_shape[target_dim] = 1
|
||||
indices = indices.expand(*expand_shape)
|
||||
|
||||
out = x.gather(dim=target_dim, index=indices)
|
||||
return out.squeeze(target_dim)
|
||||
|
||||
|
||||
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
Apply @gather_along_dim_with_dim_single to all tensors in a nested
|
||||
dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x,
|
||||
lambda y, t=target_dim, s=source_dim, i=indices:
|
||||
gather_along_dim_with_dim_single(y, t, s, i))
|
||||
|
||||
|
||||
def gather_sequence_single(seq, indices):
|
||||
"""
|
||||
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
|
||||
the batch given an index for each sequence.
|
||||
|
||||
Args:
|
||||
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Return:
|
||||
y (torch.Tensor): indexed tensor of shape [B, ....]
|
||||
"""
|
||||
return gather_along_dim_with_dim_single(seq,
|
||||
target_dim=1,
|
||||
source_dim=0,
|
||||
indices=indices)
|
||||
|
||||
|
||||
def gather_sequence(seq, indices):
|
||||
"""
|
||||
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
|
||||
for tensors with leading dimensions [B, T, ...].
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
|
||||
"""
|
||||
return gather_along_dim_with_dim(seq,
|
||||
target_dim=1,
|
||||
source_dim=0,
|
||||
indices=indices)
|
||||
|
||||
|
||||
def pad_sequence_single(seq,
|
||||
padding,
|
||||
batched=False,
|
||||
pad_same=True,
|
||||
pad_values=None):
|
||||
"""
|
||||
Pad input tensor or array @seq in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (np.ndarray or torch.Tensor): sequence to be padded
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (np.ndarray or torch.Tensor)
|
||||
"""
|
||||
assert isinstance(seq, (np.ndarray, torch.Tensor))
|
||||
assert pad_same or pad_values is not None
|
||||
if pad_values is not None:
|
||||
assert isinstance(pad_values, float)
|
||||
repeat_func = np.repeat if isinstance(
|
||||
seq, np.ndarray) else torch.repeat_interleave
|
||||
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
|
||||
ones_like_func = np.ones_like if isinstance(
|
||||
seq, np.ndarray) else torch.ones_like
|
||||
seq_dim = 1 if batched else 0
|
||||
|
||||
begin_pad = []
|
||||
end_pad = []
|
||||
|
||||
if padding[0] > 0:
|
||||
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
|
||||
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
|
||||
if padding[1] > 0:
|
||||
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
|
||||
end_pad.append(repeat_func(pad, padding[1], seq_dim))
|
||||
|
||||
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
|
||||
|
||||
|
||||
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||
"""
|
||||
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (dict or list or tuple)
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
seq, {
|
||||
torch.Tensor:
|
||||
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
|
||||
pad_sequence_single(x, p, b, ps, pv),
|
||||
np.ndarray:
|
||||
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
|
||||
pad_sequence_single(x, p, b, ps, pv),
|
||||
type(None):
|
||||
lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def assert_size_at_dim_single(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that array or tensor @x has size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (np.ndarray or torch.Tensor): input array or tensor
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
msg (str): text to display if assertion fails
|
||||
"""
|
||||
assert x.shape[dim] == size, msg
|
||||
|
||||
|
||||
def assert_size_at_dim(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that arrays and tensors in nested dictionary or list or tuple have
|
||||
size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
"""
|
||||
map_tensor(
|
||||
x,
|
||||
lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
|
||||
|
||||
|
||||
def get_shape(x):
|
||||
"""
|
||||
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
|
||||
tensor's shape
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.shape,
|
||||
np.ndarray: lambda x: x.shape,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def list_of_flat_dict_to_dict_of_list(list_of_dict):
|
||||
"""
|
||||
Helper function to go from a list of flat dictionaries to a dictionary of lists.
|
||||
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
|
||||
floats, etc.
|
||||
|
||||
Args:
|
||||
list_of_dict (list): list of flat dictionaries
|
||||
|
||||
Returns:
|
||||
dict_of_list (dict): dictionary of lists
|
||||
"""
|
||||
assert isinstance(list_of_dict, list)
|
||||
dic = collections.OrderedDict()
|
||||
for i in range(len(list_of_dict)):
|
||||
for k in list_of_dict[i]:
|
||||
if k not in dic:
|
||||
dic[k] = []
|
||||
dic[k].append(list_of_dict[i][k])
|
||||
return dic
|
||||
|
||||
|
||||
def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):
|
||||
"""
|
||||
Flatten a nested dict or list to a list.
|
||||
|
||||
For example, given a dict
|
||||
{
|
||||
a: 1
|
||||
b: {
|
||||
c: 2
|
||||
}
|
||||
c: 3
|
||||
}
|
||||
|
||||
the function would return [(a, 1), (b_c, 2), (c, 3)]
|
||||
|
||||
Args:
|
||||
d (dict, list): a nested dict or list to be flattened
|
||||
parent_key (str): recursion helper
|
||||
sep (str): separator for nesting keys
|
||||
item_key (str): recursion helper
|
||||
Returns:
|
||||
list: a list of (key, value) tuples
|
||||
"""
|
||||
items = []
|
||||
if isinstance(d, (tuple, list)):
|
||||
new_key = parent_key + sep + item_key if len(
|
||||
parent_key) > 0 else item_key
|
||||
for i, v in enumerate(d):
|
||||
items.extend(
|
||||
flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
|
||||
return items
|
||||
elif isinstance(d, dict):
|
||||
new_key = parent_key + sep + item_key if len(
|
||||
parent_key) > 0 else item_key
|
||||
for k, v in d.items():
|
||||
assert isinstance(k, str)
|
||||
items.extend(
|
||||
flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
|
||||
return items
|
||||
else:
|
||||
new_key = parent_key + sep + item_key if len(
|
||||
parent_key) > 0 else item_key
|
||||
return [(new_key, d)]
|
||||
|
||||
|
||||
def time_distributed(inputs,
|
||||
op,
|
||||
activation=None,
|
||||
inputs_as_kwargs=False,
|
||||
inputs_as_args=False,
|
||||
**kwargs):
|
||||
"""
|
||||
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
|
||||
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
|
||||
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
|
||||
outputs to [B, T, ...].
|
||||
|
||||
Args:
|
||||
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
op: a layer op that accepts inputs
|
||||
activation: activation to apply at the output
|
||||
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
|
||||
inputs_as_args (bool) whether to feed input as a args list to the op
|
||||
kwargs (dict): other kwargs to supply to the op
|
||||
|
||||
Returns:
|
||||
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
|
||||
"""
|
||||
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
|
||||
inputs = join_dimensions(inputs, 0, 1)
|
||||
if inputs_as_kwargs:
|
||||
outputs = op(**inputs, **kwargs)
|
||||
elif inputs_as_args:
|
||||
outputs = op(*inputs, **kwargs)
|
||||
else:
|
||||
outputs = op(inputs, **kwargs)
|
||||
|
||||
if activation is not None:
|
||||
outputs = map_tensor(outputs, activation)
|
||||
outputs = reshape_dimensions(outputs,
|
||||
begin_axis=0,
|
||||
end_axis=0,
|
||||
target_dims=(batch_size, seq_len))
|
||||
return outputs
|
||||
701
src/unifolm_wma/models/diffusion_head/conditional_unet1d.py
Normal file
701
src/unifolm_wma/models/diffusion_head/conditional_unet1d.py
Normal file
@@ -0,0 +1,701 @@
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import einops
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from typing import Union
|
||||
|
||||
from unifolm_wma.models.diffusion_head.conv1d_components import (
|
||||
Downsample1d, Upsample1d, Conv1dBlock)
|
||||
from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb
|
||||
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
|
||||
|
||||
from unifolm_wma.utils.basics import zero_module
|
||||
from unifolm_wma.utils.common import (
|
||||
checkpoint,
|
||||
exists,
|
||||
default,
|
||||
)
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
relative_position=False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout))
|
||||
|
||||
def efficient_forward(self, x, context=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
|
||||
q = self.to_q(x)
|
||||
if spatial_self_attn:
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out = (out.unsqueeze(0).reshape(
|
||||
b, self.heads, out.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
attention_cls=None):
|
||||
super().__init__()
|
||||
attn_cls = CrossAttention if attention_cls is None else attention_cls
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None)
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout)
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None, **kwargs):
|
||||
## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
|
||||
input_tuple = (
|
||||
x,
|
||||
) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
|
||||
if context is not None:
|
||||
input_tuple = (x, context)
|
||||
return checkpoint(self._forward, input_tuple, self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, mask=None):
|
||||
x = self.attn1(self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
mask=mask) + x
|
||||
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class ActionLatentImageCrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
in_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
use_checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
use_linear=True):
|
||||
super().__init__()
|
||||
"""
|
||||
in_channels: action input dim
|
||||
|
||||
"""
|
||||
self.in_channels = in_channels
|
||||
self.in_dim = in_dim
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=8,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
|
||||
self.proj_in_action = nn.Linear(in_dim, inner_dim)
|
||||
self.proj_in_cond = nn.Linear(context_dim, inner_dim)
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_dim))
|
||||
self.use_linear = use_linear
|
||||
|
||||
attention_cls = None
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
checkpoint=use_checkpoint,
|
||||
attention_cls=attention_cls)
|
||||
for d in range(depth)
|
||||
])
|
||||
|
||||
def forward(self, x, context=None, **kwargs):
|
||||
ba, ca, da = x.shape
|
||||
b, t, c, h, w = context.shape
|
||||
context = rearrange(context, 'b t c h w -> b (t h w) c').contiguous()
|
||||
|
||||
x_in = x
|
||||
x = self.norm(x) # ba x ja x d_in
|
||||
if self.use_linear:
|
||||
x = self.proj_in_action(x)
|
||||
context = self.proj_in_cond(context)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context, **kwargs)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
cond_dim,
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=True,
|
||||
use_linear_act_proj=False):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Conv1dBlock(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
n_groups=n_groups),
|
||||
Conv1dBlock(out_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
n_groups=n_groups),
|
||||
])
|
||||
|
||||
self.cond_predict_scale = cond_predict_scale
|
||||
self.use_linear_act_proj = use_linear_act_proj
|
||||
self.out_channels = out_channels
|
||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||
# predicts per-channel scale and bias
|
||||
cond_channels = out_channels
|
||||
if cond_predict_scale and use_linear_act_proj:
|
||||
cond_channels = out_channels * 2
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(cond_dim, cond_channels),
|
||||
)
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
|
||||
if in_channels != out_channels else nn.Identity()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
'''
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
'''
|
||||
B, T, _ = cond.shape
|
||||
|
||||
out = self.blocks[0](x)
|
||||
if self.cond_predict_scale:
|
||||
embed = self.cond_encoder(cond)
|
||||
if self.use_linear_act_proj:
|
||||
embed = embed.reshape(B * T, -1)
|
||||
embed = embed.reshape(-1, 2, self.out_channels, 1)
|
||||
else:
|
||||
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:, 0, ...]
|
||||
bias = embed[:, 1, ...]
|
||||
out = scale * out + bias
|
||||
# else:
|
||||
# out = out + embed
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
n_obs_steps=1,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=[256, 512, 1024],
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False,
|
||||
horizon=16,
|
||||
num_head_channels=64,
|
||||
use_linear_attn=True,
|
||||
use_linear_act_proj=True,
|
||||
act_proj_dim=32,
|
||||
cond_cross_attention=False,
|
||||
context_dims=None,
|
||||
image_size=None,
|
||||
imagen_cond_gradient=False,
|
||||
last_frame_only=False,
|
||||
use_imagen_mid_only=False,
|
||||
use_z_only=False,
|
||||
spatial_num_kp=32,
|
||||
obs_encoder_config=None):
|
||||
super().__init__()
|
||||
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.obs_encoder = instantiate_from_config(obs_encoder_config)
|
||||
|
||||
all_dims = [input_dim] + list(down_dims)
|
||||
start_dim = down_dims[0]
|
||||
|
||||
dsed = diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed + self.obs_encoder.output_shape()[-1] * self.n_obs_steps
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
||||
local_cond_encoder = None
|
||||
down_modules = nn.ModuleList([])
|
||||
|
||||
dim_a_list = []
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
if ind == 0:
|
||||
dim_a = horizon
|
||||
else:
|
||||
dim_a = horizon // 2 * ind
|
||||
dim_a_list.append(dim_a)
|
||||
|
||||
# for attention
|
||||
num_heads = dim_out // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if use_linear_act_proj:
|
||||
if use_imagen_mid_only:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[-1]
|
||||
elif use_z_only:
|
||||
cur_cond_dim = cond_dim + 2 * spatial_num_kp
|
||||
else:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[ind]
|
||||
else:
|
||||
cur_cond_dim = cond_dim + horizon * context_dims[ind]
|
||||
|
||||
down_modules.append(
|
||||
nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ActionLatentImageCrossAttention(
|
||||
dim_out,
|
||||
dim_a,
|
||||
num_heads,
|
||||
dim_head,
|
||||
context_dim=context_dims[ind],
|
||||
use_linear=use_linear_attn)
|
||||
if cond_cross_attention else nn.Identity(),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ActionLatentImageCrossAttention(mid_dim,
|
||||
dim_a_list[-1],
|
||||
num_heads,
|
||||
dim_head,
|
||||
context_dim=context_dims[-1],
|
||||
use_linear=use_linear_attn)
|
||||
if cond_cross_attention else nn.Identity(),
|
||||
])
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
context_dims = context_dims[::-1]
|
||||
for ind, (dim_in, dim_out) in enumerate(
|
||||
reversed(in_out[1:] + [(down_dims[-1], down_dims[-1])])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
if use_linear_act_proj:
|
||||
if use_imagen_mid_only:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[0]
|
||||
elif use_z_only:
|
||||
cur_cond_dim = cond_dim + 2 * spatial_num_kp
|
||||
else:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[ind]
|
||||
else:
|
||||
cur_cond_dim = cond_dim + horizon * context_dims[ind]
|
||||
up_modules.append(
|
||||
nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out + dim_in,
|
||||
dim_in,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_in,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ActionLatentImageCrossAttention(
|
||||
dim_in,
|
||||
dim_a_list.pop(),
|
||||
num_heads,
|
||||
dim_head,
|
||||
context_dim=context_dims[ind],
|
||||
use_linear=use_linear_attn)
|
||||
if cond_cross_attention else nn.Identity(),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||
nn.Conv1d(start_dim, input_dim, 1),
|
||||
)
|
||||
|
||||
if use_z_only:
|
||||
h, w = image_size
|
||||
self.spatial_softmax_blocks = nn.ModuleList(
|
||||
[SpatialSoftmax((4, h, w), spatial_num_kp)])
|
||||
else:
|
||||
self.spatial_softmax_blocks = nn.ModuleList([])
|
||||
context_dims = context_dims[::-1]
|
||||
for ind, context_dim in enumerate(context_dims):
|
||||
h, w = image_size
|
||||
if ind != 0:
|
||||
h //= 2**ind
|
||||
w //= 2**ind
|
||||
net = SpatialSoftmax((context_dim, h, w), context_dim)
|
||||
self.spatial_softmax_blocks.append(net)
|
||||
self.spatial_softmax_blocks.append(net)
|
||||
self.spatial_softmax_blocks += self.spatial_softmax_blocks[
|
||||
0:4][::-1]
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.local_cond_encoder = local_cond_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
self.cond_cross_attention = cond_cross_attention
|
||||
self.use_linear_act_proj = use_linear_act_proj
|
||||
|
||||
self.proj_in_action = nn.Sequential(nn.Linear(1, act_proj_dim),
|
||||
nn.LayerNorm(act_proj_dim))
|
||||
self.proj_in_horizon = nn.Sequential(nn.Linear(horizon, act_proj_dim),
|
||||
nn.LayerNorm(act_proj_dim))
|
||||
self.proj_out_action = nn.Sequential(nn.LayerNorm(act_proj_dim),
|
||||
nn.Linear(act_proj_dim, 1))
|
||||
self.proj_out_horizon = nn.Sequential(nn.LayerNorm(act_proj_dim),
|
||||
nn.Linear(act_proj_dim, horizon))
|
||||
logger.info("number of parameters: %e",
|
||||
sum(p.numel() for p in self.parameters()))
|
||||
|
||||
self.imagen_cond_gradient = imagen_cond_gradient
|
||||
self.use_imagen_mid_only = use_imagen_mid_only
|
||||
self.use_z_only = use_z_only
|
||||
self.spatial_num_kp = spatial_num_kp
|
||||
self.last_frame_only = last_frame_only
|
||||
self.horizon = horizon
|
||||
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
imagen_cond=None,
|
||||
cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
sample: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
imagen_cond: a list of hidden info from video gen unet
|
||||
cond: dict:
|
||||
image: (B, 3, To, h, w)
|
||||
agent_pos: (B, Ta, d)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
|
||||
if not self.imagen_cond_gradient:
|
||||
imagen_cond = [c.detach() for c in imagen_cond]
|
||||
|
||||
cond = {'image': cond[0], 'agent_pos': cond[1]}
|
||||
|
||||
cond['image'] = cond['image'].permute(0, 2, 1, 3,
|
||||
4)
|
||||
cond['image'] = rearrange(cond['image'], 'b t c h w -> (b t) c h w')
|
||||
cond['agent_pos'] = rearrange(cond['agent_pos'], 'b t d -> (b t) d')
|
||||
|
||||
B, T, D = sample.shape
|
||||
if self.use_linear_act_proj:
|
||||
sample = self.proj_in_action(sample.unsqueeze(-1))
|
||||
global_cond = self.obs_encoder(cond)
|
||||
global_cond = rearrange(global_cond,
|
||||
'(b t) d -> b 1 (t d)',
|
||||
b=B,
|
||||
t=self.n_obs_steps)
|
||||
global_cond = repeat(global_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=T)
|
||||
else:
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
sample = self.proj_in_horizon(sample)
|
||||
robo_state_cond = rearrange(robo_state_cond, 'b t d -> b 1 (t d)')
|
||||
robo_state_cond = repeat(robo_state_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=2)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps],
|
||||
dtype=torch.long,
|
||||
device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
||||
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
||||
|
||||
x = sample if not self.use_linear_act_proj else sample.reshape(
|
||||
B * T, D, -1)
|
||||
h = []
|
||||
for idx, modules in enumerate(self.down_modules):
|
||||
if self.cond_cross_attention:
|
||||
(resnet, resnet2, crossatten, downsample) = modules
|
||||
else:
|
||||
(resnet, resnet2, _, downsample) = modules
|
||||
|
||||
# Access the cond from the unet embeds from video unet
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = imagen_cond_mid
|
||||
elif self.use_z_only:
|
||||
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
imagen_cond = imagen_cond_down[idx]
|
||||
if self.last_frame_only:
|
||||
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
||||
imagen_cond = repeat(imagen_cond,
|
||||
'b t c h w -> b (repeat t) c h w',
|
||||
repeat=self.horizon)
|
||||
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[len(
|
||||
self.spatial_softmax_blocks) // 2](imagen_cond)
|
||||
elif self.use_z_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
||||
else:
|
||||
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=2, dim=1)
|
||||
cur_global_feature = torch.cat(
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
#>>> mide blocks
|
||||
resnet, resnet2, _ = self.mid_modules
|
||||
# Access the cond from the unet embeds from video unet
|
||||
if self.use_z_only:
|
||||
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
imagen_cond = imagen_cond_mid
|
||||
if self.last_frame_only:
|
||||
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
||||
imagen_cond = repeat(imagen_cond,
|
||||
'b t c h w -> b (repeat t) c h w',
|
||||
repeat=self.horizon)
|
||||
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
||||
idx += 1
|
||||
if self.use_z_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
||||
else:
|
||||
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
||||
repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
||||
repeats=2, dim=1)
|
||||
cur_global_feature = torch.cat(
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
|
||||
#>>> up blocks
|
||||
idx += 1
|
||||
for jdx, modules in enumerate(self.up_modules):
|
||||
if self.cond_cross_attention:
|
||||
(resnet, resnet2, crossatten, upsample) = modules
|
||||
else:
|
||||
(resnet, resnet2, _, upsample) = modules
|
||||
|
||||
# Access the cond from the unet embeds from video unet
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = imagen_cond_mid
|
||||
elif self.use_z_only:
|
||||
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
imagen_cond = imagen_cond_up[jdx]
|
||||
if self.last_frame_only:
|
||||
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
||||
imagen_cond = repeat(imagen_cond,
|
||||
'b t c h w -> b (repeat t) c h w',
|
||||
repeat=self.horizon)
|
||||
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[len(
|
||||
self.spatial_softmax_blocks) // 2](imagen_cond)
|
||||
elif self.use_z_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
||||
else:
|
||||
imagen_cond = self.spatial_softmax_blocks[jdx +
|
||||
idx](imagen_cond)
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=2, dim=1)
|
||||
|
||||
cur_global_feature = torch.cat(
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
x = x.reshape(B, T, D, -1)
|
||||
x = self.proj_out_action(x)
|
||||
x = x.reshape(B, T, D)
|
||||
else:
|
||||
x = self.proj_out_horizon(x)
|
||||
x = einops.rearrange(x, 'b t h -> b h t')
|
||||
return x
|
||||
52
src/unifolm_wma/models/diffusion_head/conv1d_components.py
Normal file
52
src/unifolm_wma/models/diffusion_head/conv1d_components.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
'''
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
'''
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2),
|
||||
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
def test():
|
||||
cb = Conv1dBlock(256, 128, kernel_size=3)
|
||||
x = torch.zeros((1, 256, 16))
|
||||
o = cb(x)
|
||||
80
src/unifolm_wma/models/diffusion_head/ema_model.py
Normal file
80
src/unifolm_wma/models/diffusion_head/ema_model.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
||||
class EMAModel:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
update_after_step=0,
|
||||
inv_gamma=1.0,
|
||||
power=2 / 3,
|
||||
min_value=0.0,
|
||||
max_value=0.9999):
|
||||
"""
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||
at 215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
|
||||
self.averaged_model = model
|
||||
self.averaged_model.eval()
|
||||
self.averaged_model.requires_grad_(False)
|
||||
|
||||
self.update_after_step = update_after_step
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
self.decay = 0.0
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
"""
|
||||
step = max(0, optimization_step - self.update_after_step - 1)
|
||||
value = 1 - (1 + step / self.inv_gamma)**-self.power
|
||||
|
||||
if step <= 0:
|
||||
return 0.0
|
||||
|
||||
return max(self.min_value, min(value, self.max_value))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, new_model):
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
all_dataptrs = set()
|
||||
for module, ema_module in zip(new_model.modules(),
|
||||
self.averaged_model.modules()):
|
||||
for param, ema_param in zip(module.parameters(recurse=False),
|
||||
ema_module.parameters(recurse=False)):
|
||||
# iterative over immediate parameters only.
|
||||
if isinstance(param, dict):
|
||||
raise RuntimeError('Dict parameter not supported')
|
||||
|
||||
if isinstance(module, _BatchNorm):
|
||||
# skip batchnorms
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
elif not param.requires_grad:
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
else:
|
||||
ema_param.mul_(self.decay)
|
||||
ema_param.add_(param.data.to(dtype=ema_param.dtype),
|
||||
alpha=1 - self.decay)
|
||||
|
||||
# verify that iterating over module and then parameters is identical to parameters recursively.
|
||||
# assert old_all_dataptrs == all_dataptrs
|
||||
self.optimization_step += 1
|
||||
@@ -0,0 +1,19 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
322
src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py
Normal file
322
src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as ttf
|
||||
import unifolm_wma.models.diffusion_head.common.tensor_util as tu
|
||||
|
||||
|
||||
class CropRandomizer(nn.Module):
|
||||
"""
|
||||
Randomly sample crops at input, and then average across crop features at output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
crop_height,
|
||||
crop_width,
|
||||
num_crops=1,
|
||||
pos_enc=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape (tuple, list): shape of input (not including batch dimension)
|
||||
crop_height (int): crop height
|
||||
crop_width (int): crop width
|
||||
num_crops (int): number of random crops to take
|
||||
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
|
||||
location of the cropped pixels in the source image
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3 # (C, H, W)
|
||||
assert crop_height < input_shape[1]
|
||||
assert crop_width < input_shape[2]
|
||||
|
||||
self.input_shape = input_shape
|
||||
self.crop_height = crop_height
|
||||
self.crop_width = crop_width
|
||||
self.num_crops = num_crops
|
||||
self.pos_enc = pos_enc
|
||||
|
||||
def output_shape_in(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_in operation, where raw inputs (usually observation modalities)
|
||||
are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
|
||||
# the number of crops are reshaped into the batch dimension, increasing the batch
|
||||
# size from B to B * N
|
||||
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
|
||||
return [out_c, self.crop_height, self.crop_width]
|
||||
|
||||
def output_shape_out(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_out operation, where processed inputs (usually encoded observation
|
||||
modalities) are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
|
||||
# and then pools to result in [B, ...], only the batch dimension changes,
|
||||
# and so the other dimensions retain their shape.
|
||||
return list(input_shape)
|
||||
|
||||
def forward_in(self, inputs):
|
||||
"""
|
||||
Samples N random crops for each input in the batch, and then reshapes
|
||||
inputs to [B * N, ...].
|
||||
"""
|
||||
assert len(
|
||||
inputs.shape) >= 3 # must have at least (C, H, W) dimensions
|
||||
if self.training:
|
||||
# generate random crops
|
||||
out, _ = sample_random_image_crops(
|
||||
images=inputs,
|
||||
crop_height=self.crop_height,
|
||||
crop_width=self.crop_width,
|
||||
num_crops=self.num_crops,
|
||||
pos_enc=self.pos_enc,
|
||||
)
|
||||
# [B, N, ...] -> [B * N, ...]
|
||||
return tu.join_dimensions(out, 0, 1)
|
||||
else:
|
||||
# take center crop during eval
|
||||
out = ttf.center_crop(img=inputs,
|
||||
output_size=(self.crop_height,
|
||||
self.crop_width))
|
||||
if self.num_crops > 1:
|
||||
B, C, H, W = out.shape
|
||||
out = out.unsqueeze(1).expand(B, self.num_crops, C, H,
|
||||
W).reshape(-1, C, H, W)
|
||||
# [B * N, ...]
|
||||
return out
|
||||
|
||||
def forward_out(self, inputs):
|
||||
"""
|
||||
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
|
||||
to result in shape [B, ...] to make sure the network output is consistent with
|
||||
what would have happened if there were no randomization.
|
||||
"""
|
||||
if self.num_crops <= 1:
|
||||
return inputs
|
||||
else:
|
||||
batch_size = (inputs.shape[0] // self.num_crops)
|
||||
out = tu.reshape_dimensions(inputs,
|
||||
begin_axis=0,
|
||||
end_axis=0,
|
||||
target_dims=(batch_size,
|
||||
self.num_crops))
|
||||
return out.mean(dim=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.forward_in(inputs)
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty print network."""
|
||||
header = '{}'.format(str(self.__class__.__name__))
|
||||
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
|
||||
self.input_shape, self.crop_height, self.crop_width,
|
||||
self.num_crops)
|
||||
return msg
|
||||
|
||||
|
||||
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
|
||||
"""
|
||||
Crops images at the locations specified by @crop_indices. Crops will be
|
||||
taken across all channels.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
|
||||
N is the number of crops to take per image and each entry corresponds
|
||||
to the pixel height and width of where to take the crop. Note that
|
||||
the indices can also be of shape [..., 2] if only 1 crop should
|
||||
be taken per image. Leading dimensions must be consistent with
|
||||
@images argument. Each index specifies the top left of the crop.
|
||||
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
|
||||
H and W are the height and width of @images and CH and CW are
|
||||
@crop_height and @crop_width.
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
Returns:
|
||||
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
|
||||
"""
|
||||
|
||||
# make sure length of input shapes is consistent
|
||||
assert crop_indices.shape[-1] == 2
|
||||
ndim_im_shape = len(images.shape)
|
||||
ndim_indices_shape = len(crop_indices.shape)
|
||||
assert (ndim_im_shape == ndim_indices_shape +
|
||||
1) or (ndim_im_shape == ndim_indices_shape + 2)
|
||||
|
||||
# maybe pad so that @crop_indices is shape [..., N, 2]
|
||||
is_padded = False
|
||||
if ndim_im_shape == ndim_indices_shape + 2:
|
||||
crop_indices = crop_indices.unsqueeze(-2)
|
||||
is_padded = True
|
||||
|
||||
# make sure leading dimensions between images and indices are consistent
|
||||
assert images.shape[:-3] == crop_indices.shape[:-2]
|
||||
|
||||
device = images.device
|
||||
image_c, image_h, image_w = images.shape[-3:]
|
||||
num_crops = crop_indices.shape[-2]
|
||||
|
||||
# make sure @crop_indices are in valid range
|
||||
assert (crop_indices[..., 0] >= 0).all().item()
|
||||
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
|
||||
assert (crop_indices[..., 1] >= 0).all().item()
|
||||
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
|
||||
|
||||
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
|
||||
|
||||
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
|
||||
crop_ind_grid_h = torch.arange(crop_height).to(device)
|
||||
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h,
|
||||
size=crop_width,
|
||||
dim=-1)
|
||||
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
|
||||
crop_ind_grid_w = torch.arange(crop_width).to(device)
|
||||
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w,
|
||||
size=crop_height,
|
||||
dim=0)
|
||||
# combine into shape [CH, CW, 2]
|
||||
crop_in_grid = torch.cat(
|
||||
(crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
|
||||
|
||||
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
|
||||
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
|
||||
# shape array that tells us which pixels from the corresponding source image to grab.
|
||||
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [
|
||||
crop_height, crop_width, 2
|
||||
]
|
||||
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(
|
||||
-2) + crop_in_grid.reshape(grid_reshape)
|
||||
|
||||
# For using @torch.gather, convert to flat indices from 2D indices, and also
|
||||
# repeat across the channel dimension. To get flat index of each pixel to grab for
|
||||
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
|
||||
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[
|
||||
..., 1] # shape [..., N, CH, CW]
|
||||
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c,
|
||||
dim=-3) # shape [..., N, C, CH, CW]
|
||||
all_crop_inds = tu.flatten(all_crop_inds,
|
||||
begin_axis=-2) # shape [..., N, C, CH * CW]
|
||||
|
||||
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
|
||||
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
|
||||
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
|
||||
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
|
||||
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
|
||||
reshape_axis = len(crops.shape) - 1
|
||||
crops = tu.reshape_dimensions(crops,
|
||||
begin_axis=reshape_axis,
|
||||
end_axis=reshape_axis,
|
||||
target_dims=(crop_height, crop_width))
|
||||
|
||||
if is_padded:
|
||||
# undo padding -> [..., C, CH, CW]
|
||||
crops = crops.squeeze(-4)
|
||||
return crops
|
||||
|
||||
|
||||
def sample_random_image_crops(images,
|
||||
crop_height,
|
||||
crop_width,
|
||||
num_crops,
|
||||
pos_enc=False):
|
||||
"""
|
||||
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
|
||||
@images.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
num_crops (n): number of crops to sample
|
||||
|
||||
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
|
||||
encoding of the original source pixel locations. This means that the
|
||||
output crops will contain information about where in the source image
|
||||
it was sampled from.
|
||||
|
||||
Returns:
|
||||
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
|
||||
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
|
||||
|
||||
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
|
||||
"""
|
||||
device = images.device
|
||||
|
||||
# maybe add 2 channels of spatial encoding to the source image
|
||||
source_im = images
|
||||
if pos_enc:
|
||||
# spatial encoding [y, x] in [0, 1]
|
||||
h, w = source_im.shape[-2:]
|
||||
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
||||
pos_y = pos_y.float().to(device) / float(h)
|
||||
pos_x = pos_x.float().to(device) / float(w)
|
||||
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
|
||||
|
||||
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
|
||||
leading_shape = source_im.shape[:-3]
|
||||
position_enc = position_enc[(None, ) * len(leading_shape)]
|
||||
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
|
||||
|
||||
# concat across channel dimension with input
|
||||
source_im = torch.cat((source_im, position_enc), dim=-3)
|
||||
|
||||
# make sure sample boundaries ensure crops are fully within the images
|
||||
image_c, image_h, image_w = source_im.shape[-3:]
|
||||
max_sample_h = image_h - crop_height
|
||||
max_sample_w = image_w - crop_width
|
||||
|
||||
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
|
||||
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
|
||||
# we will sample [B, N] indices, but this supports having more than one leading dimension,
|
||||
# or possibly no leading dimension.
|
||||
#
|
||||
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
|
||||
crop_inds_h = (
|
||||
max_sample_h *
|
||||
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds_w = (
|
||||
max_sample_w *
|
||||
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds = torch.cat(
|
||||
(crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)),
|
||||
dim=-1) # shape [..., N, 2]
|
||||
|
||||
crops = crop_image_from_indices(
|
||||
images=source_im,
|
||||
crop_indices=crop_inds,
|
||||
crop_height=crop_height,
|
||||
crop_width=crop_width,
|
||||
)
|
||||
|
||||
return crops, crop_inds
|
||||
30
src/unifolm_wma/models/diffusion_head/vision/model_getter.py
Normal file
30
src/unifolm_wma/models/diffusion_head/vision/model_getter.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
|
||||
def get_resnet(name, weights=None, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
weights: "IMAGENET1K_V1", "r3m"
|
||||
"""
|
||||
# load r3m weights
|
||||
if (weights == "r3m") or (weights == "R3M"):
|
||||
return get_r3m(name=name, **kwargs)
|
||||
|
||||
func = getattr(torchvision.models, name)
|
||||
resnet = func(weights=weights, **kwargs)
|
||||
resnet.fc = torch.nn.Identity()
|
||||
return resnet
|
||||
|
||||
|
||||
def get_r3m(name, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
"""
|
||||
import r3m
|
||||
r3m.device = 'cpu'
|
||||
model = r3m.load_r3m(name)
|
||||
r3m_model = model.module
|
||||
resnet_model = r3m_model.convnet
|
||||
resnet_model = resnet_model.to('cpu')
|
||||
return resnet_model
|
||||
@@ -0,0 +1,247 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import json
|
||||
import os
|
||||
|
||||
from unifolm_wma.models.diffusion_head.vision.crop_randomizer import CropRandomizer
|
||||
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
|
||||
from unifolm_wma.models.diffusion_head.common.module_attr_mixin import ModuleAttrMixin
|
||||
from unifolm_wma.models.diffusion_head.common.pytorch_util import dict_apply, replace_submodules
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from einops import rearrange, repeat
|
||||
from typing import Dict, Tuple, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rgb_model_config: Dict,
|
||||
shape_meta_path: str | None = None,
|
||||
resize_shape: Union[Tuple[int, int], Dict[str, tuple],
|
||||
None] = None,
|
||||
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||
random_crop: bool = True,
|
||||
# replace BatchNorm with GroupNorm
|
||||
use_group_norm: bool = False,
|
||||
# use single rgb model for all rgb inputs
|
||||
share_rgb_model: bool = False,
|
||||
# renormalize rgb input with imagenet normalization
|
||||
# assuming input in [0,1]
|
||||
imagenet_norm: bool = False,
|
||||
use_spatial_softmax=False,
|
||||
spatial_softmax_kp=32,
|
||||
use_dinoSiglip=False):
|
||||
"""
|
||||
Assumes rgb input: B,C,H,W
|
||||
Assumes low_dim input: B,D
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if not shape_meta_path:
|
||||
shape_meta_path = str(Path(os.getcwd()) / "configs/train/meta.json")
|
||||
|
||||
with open(shape_meta_path, 'r') as file:
|
||||
shape_meta = json.load(file)
|
||||
|
||||
rgb_model = instantiate_from_config(rgb_model_config)
|
||||
|
||||
rgb_keys = list()
|
||||
low_dim_keys = list()
|
||||
key_model_map = nn.ModuleDict()
|
||||
key_transform_map = nn.ModuleDict()
|
||||
key_shape_map = dict()
|
||||
|
||||
# handle sharing vision backbone
|
||||
if share_rgb_model:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
key_model_map['rgb'] = rgb_model
|
||||
|
||||
obs_shape_meta = shape_meta['obs']
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr['shape'])
|
||||
type = attr.get('type', 'low_dim')
|
||||
key_shape_map[key] = shape
|
||||
if type == 'rgb':
|
||||
rgb_keys.append(key)
|
||||
if not use_dinoSiglip:
|
||||
# configure model for this key
|
||||
this_model = None
|
||||
if not share_rgb_model:
|
||||
if isinstance(rgb_model, dict):
|
||||
# have provided model for each key
|
||||
this_model = rgb_model[key]
|
||||
else:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
# have a copy of the rgb model
|
||||
this_model = copy.deepcopy(rgb_model)
|
||||
|
||||
if this_model is not None:
|
||||
if use_group_norm:
|
||||
this_model = replace_submodules(
|
||||
root_module=this_model,
|
||||
predicate=lambda x: isinstance(
|
||||
x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16,
|
||||
num_channels=x.num_features))
|
||||
key_model_map[key] = this_model
|
||||
|
||||
# configure resize
|
||||
input_shape = shape
|
||||
this_resizer = nn.Identity()
|
||||
if resize_shape is not None:
|
||||
if isinstance(resize_shape, dict):
|
||||
h, w = resize_shape[key]
|
||||
else:
|
||||
h, w = resize_shape
|
||||
this_resizer = torchvision.transforms.Resize(size=(h,
|
||||
w))
|
||||
input_shape = (shape[0], h, w)
|
||||
|
||||
# configure randomizer
|
||||
this_randomizer = nn.Identity()
|
||||
if crop_shape is not None:
|
||||
if isinstance(crop_shape, dict):
|
||||
h, w = crop_shape[key]
|
||||
else:
|
||||
h, w = crop_shape
|
||||
if random_crop:
|
||||
this_randomizer = CropRandomizer(
|
||||
input_shape=input_shape,
|
||||
crop_height=h,
|
||||
crop_width=w,
|
||||
num_crops=1,
|
||||
pos_enc=False)
|
||||
else:
|
||||
this_normalizer = torchvision.transforms.CenterCrop(
|
||||
size=(h, w))
|
||||
# configure normalizer
|
||||
this_normalizer = nn.Identity()
|
||||
if imagenet_norm:
|
||||
this_normalizer = torchvision.transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
this_transform = nn.Sequential(this_resizer,
|
||||
this_randomizer,
|
||||
this_normalizer)
|
||||
key_transform_map[key] = this_transform
|
||||
else:
|
||||
key_model_map[key] = rgb_model
|
||||
elif type == 'low_dim':
|
||||
low_dim_keys.append(key)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported obs type: {type}")
|
||||
|
||||
rgb_keys = sorted(rgb_keys)
|
||||
low_dim_keys = sorted(low_dim_keys)
|
||||
|
||||
self.shape_meta = shape_meta
|
||||
self.key_model_map = key_model_map
|
||||
self.key_transform_map = key_transform_map
|
||||
self.share_rgb_model = share_rgb_model
|
||||
self.rgb_keys = rgb_keys
|
||||
self.low_dim_keys = low_dim_keys
|
||||
self.key_shape_map = key_shape_map
|
||||
self.use_dinoSiglip = use_dinoSiglip
|
||||
|
||||
##NOTE add spatial softmax
|
||||
self.use_spatial_softmax = use_spatial_softmax
|
||||
if use_spatial_softmax and not use_dinoSiglip:
|
||||
model = nn.Sequential(
|
||||
key_model_map['image'].conv1,
|
||||
key_model_map['image'].bn1,
|
||||
key_model_map['image'].relu,
|
||||
key_model_map['image'].maxpool,
|
||||
key_model_map['image'].layer1,
|
||||
key_model_map['image'].layer2,
|
||||
key_model_map['image'].layer3,
|
||||
key_model_map['image'].layer4,
|
||||
)
|
||||
key_model_map['image'] = model
|
||||
input_shape = self.output_shape(resnet_output_shape=True)
|
||||
self.spatial_softmax = SpatialSoftmax(input_shape,
|
||||
num_kp=spatial_softmax_kp)
|
||||
|
||||
def forward(self, obs_dict, resnet_output_shape=False):
|
||||
batch_size = None
|
||||
features = list()
|
||||
# process rgb input
|
||||
if self.share_rgb_model:
|
||||
# pass all rgb obs to rgb model
|
||||
imgs = list()
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
imgs.append(img)
|
||||
# (N*B,C,H,W)
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
# (N*B,D)
|
||||
feature = self.key_model_map['rgb'](imgs)
|
||||
# (N,B,D)
|
||||
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
|
||||
# (B,N,D)
|
||||
feature = torch.moveaxis(feature, 0, 1)
|
||||
# (B,N*D)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
features.append(feature)
|
||||
else:
|
||||
# run each rgb obs to independent models
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
if not self.use_dinoSiglip:
|
||||
img = self.key_transform_map[key](img)
|
||||
feature = self.key_model_map[key](img)
|
||||
else:
|
||||
feature = self.key_model_map[key](img)[:, :1, :]
|
||||
|
||||
if resnet_output_shape:
|
||||
return feature
|
||||
if not self.use_dinoSiglip and self.use_spatial_softmax:
|
||||
feature = self.spatial_softmax(feature)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
features.append(feature)
|
||||
|
||||
# process lowdim input
|
||||
for key in self.low_dim_keys:
|
||||
data = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = data.shape[0]
|
||||
else:
|
||||
assert batch_size == data.shape[0]
|
||||
assert data.shape[1:] == self.key_shape_map[key]
|
||||
features.append(data)
|
||||
|
||||
# concatenate all features
|
||||
result = torch.cat(features, dim=-1)
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def output_shape(self, resnet_output_shape=False):
|
||||
example_obs_dict = dict()
|
||||
obs_shape_meta = self.shape_meta['obs']
|
||||
batch_size = 1
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr['shape'])
|
||||
this_obs = torch.zeros((batch_size, ) + shape,
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
example_obs_dict[key] = this_obs
|
||||
example_output = self.forward(example_obs_dict,
|
||||
resnet_output_shape=resnet_output_shape)
|
||||
output_shape = example_output.shape[1:]
|
||||
return output_shape
|
||||
473
src/unifolm_wma/models/samplers/ddim.py
Normal file
473
src/unifolm_wma/models/samplers/ddim.py
Normal file
@@ -0,0 +1,473 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import copy
|
||||
|
||||
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
|
||||
from unifolm_wma.utils.common import noise_like
|
||||
from unifolm_wma.utils.common import extract_into_tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.counter = 0
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize="uniform",
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
|
||||
.device)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
|
||||
self.ddim_scale_arr_prev = torch.cat(
|
||||
[self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# Calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# DDIM sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
schedule_verbose=False,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
precision=None,
|
||||
fs=None,
|
||||
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
|
||||
# Check condition bs
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
try:
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
except:
|
||||
cbs = conditioning[list(
|
||||
conditioning.keys())[0]][0].shape[0]
|
||||
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S,
|
||||
ddim_discretize=timestep_spacing,
|
||||
ddim_eta=eta,
|
||||
verbose=schedule_verbose)
|
||||
|
||||
# Make shape
|
||||
if len(shape) == 3:
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
elif len(shape) == 4:
|
||||
C, T, H, W = shape
|
||||
size = (batch_size, C, T, H, W)
|
||||
|
||||
samples, actions, states, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
verbose=verbose,
|
||||
precision=precision,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
return samples, actions, states, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
verbose=True,
|
||||
precision=None,
|
||||
fs=None,
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
device = self.model.betas.device
|
||||
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
|
||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
else:
|
||||
img = x_T
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
img = img.to(dtype=torch.float16)
|
||||
action = action.to(dtype=torch.float16)
|
||||
state = state.to(dtype=torch.float16)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1) *
|
||||
self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {
|
||||
'x_inter': [img],
|
||||
'pred_x0': [img],
|
||||
'x_inter_action': [action],
|
||||
'pred_x0_action': [action],
|
||||
'x_inter_state': [state],
|
||||
'pred_x0_state': [state],
|
||||
}
|
||||
time_range = reversed(range(
|
||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
0]
|
||||
if verbose:
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
else:
|
||||
iterator = time_range
|
||||
|
||||
clean_cond = kwargs.pop("clean_cond", False)
|
||||
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
model_output_state,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
|
||||
return img, action, state, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self,
|
||||
x,
|
||||
x_action,
|
||||
x_state,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
uc_type=None,
|
||||
conditional_guidance_scale_temporal=None,
|
||||
mask=None,
|
||||
x0=None,
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
b, *_, device = *x.shape, x.device
|
||||
if x.dim() == 5:
|
||||
is_video = True
|
||||
else:
|
||||
is_video = False
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output, model_output_action, model_output_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, c, **kwargs) # unet denoiser
|
||||
else:
|
||||
# do_classifier_free_guidance
|
||||
if isinstance(c, torch.Tensor) or isinstance(c, dict):
|
||||
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, c, **kwargs)
|
||||
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, unconditional_conditioning,
|
||||
**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t_cond - e_t_uncond)
|
||||
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
||||
e_t_cond_action - e_t_uncond_action)
|
||||
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
||||
e_t_cond_state - e_t_uncond_state)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
model_output = rescale_noise_cfg(
|
||||
model_output, e_t_cond, guidance_rescale=guidance_rescale)
|
||||
model_output_action = rescale_noise_cfg(
|
||||
model_output_action,
|
||||
e_t_cond_action,
|
||||
guidance_rescale=guidance_rescale)
|
||||
model_output_state = rescale_noise_cfg(
|
||||
model_output_state,
|
||||
e_t_cond_state,
|
||||
guidance_rescale=guidance_rescale)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||
**corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if is_video:
|
||||
size = (b, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (b, 1, 1, 1)
|
||||
|
||||
a_t = torch.full(size, alphas[index], device=device)
|
||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(size,
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
scale_t = torch.full(size,
|
||||
self.ddim_scale_arr[index],
|
||||
device=device)
|
||||
prev_scale_t = torch.full(size,
|
||||
self.ddim_scale_arr_prev[index],
|
||||
device=device)
|
||||
rescale = (prev_scale_t / scale_t)
|
||||
pred_x0 *= rescale
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
|
||||
return x_prev, pred_x0, model_output_action, model_output_state
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps
|
||||
) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0], ),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback: callback(i)
|
||||
return x_dec
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) *
|
||||
noise)
|
||||
Reference in New Issue
Block a user