第一次完整测例跑完

This commit is contained in:
2026-01-18 00:30:10 +08:00
parent ca15cc593b
commit 25c6fc04db
180 changed files with 29305 additions and 0 deletions

View File

@@ -0,0 +1,217 @@
"""
Contains torch Modules that correspond to basic network building blocks, like
MLP, RNN, and CNN backbones.
"""
import abc
import numpy as np
import torch
import torch.nn.functional as F
class Module(torch.nn.Module):
"""
Base class for networks. The only difference from torch.nn.Module is that it
requires implementing @output_shape.
"""
@abc.abstractmethod
def output_shape(self, input_shape=None):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
raise NotImplementedError
"""
================================================
Visual Backbone Networks
================================================
"""
class ConvBase(Module):
"""
Base class for ConvNets.
"""
def __init__(self):
super(ConvBase, self).__init__()
# dirty hack - re-implement to pass the buck onto subclasses from ABC parent
def output_shape(self, input_shape):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
raise NotImplementedError
def forward(self, inputs):
x = self.nets(inputs)
if list(self.output_shape(list(inputs.shape)[1:])) != list(
x.shape)[1:]:
raise ValueError('Size mismatch: expect size %s, but got size %s' %
(str(self.output_shape(list(
inputs.shape)[1:])), str(list(x.shape)[1:])))
return x
class SpatialSoftmax(ConvBase):
"""
Spatial Softmax Layer.
Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al.
https://rll.berkeley.edu/dsae/dsae.pdf
"""
def __init__(
self,
input_shape,
num_kp=32,
temperature=1.,
learnable_temperature=False,
output_variance=False,
noise_std=0.0,
):
"""
Args:
input_shape (list): shape of the input feature (C, H, W)
num_kp (int): number of keypoints (None for not using spatialsoftmax)
temperature (float): temperature term for the softmax.
learnable_temperature (bool): whether to learn the temperature
output_variance (bool): treat attention as a distribution, and compute second-order statistics to return
noise_std (float): add random spatial noise to the predicted keypoints
"""
super(SpatialSoftmax, self).__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape # (C, H, W)
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._num_kp = num_kp
else:
self.nets = None
self._num_kp = self._in_c
self.learnable_temperature = learnable_temperature
self.output_variance = output_variance
self.noise_std = noise_std
if self.learnable_temperature:
# temperature will be learned
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
requires_grad=True)
self.register_parameter('temperature', temperature)
else:
# temperature held constant after initialization
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
requires_grad=False)
self.register_buffer('temperature', temperature)
pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., self._in_w),
np.linspace(-1., 1., self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h *
self._in_w)).float()
pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h *
self._in_w)).float()
self.register_buffer('pos_x', pos_x)
self.register_buffer('pos_y', pos_y)
self.kps = None
def __repr__(self):
"""Pretty print network."""
header = format(str(self.__class__.__name__))
return header + '(num_kp={}, temperature={}, noise={})'.format(
self._num_kp, self.temperature.item(), self.noise_std)
def output_shape(self, input_shape):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
assert (len(input_shape) == 3)
assert (input_shape[0] == self._in_c)
return [self._num_kp, 2]
def forward(self, feature):
"""
Forward pass through spatial softmax layer. For each keypoint, a 2D spatial
probability distribution is created using a softmax, where the support is the
pixel locations. This distribution is used to compute the expected value of
the pixel location, which becomes a keypoint of dimension 2. K such keypoints
are created.
Returns:
out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly
keypoint variance of shape [B, K, 2, 2] corresponding to the covariance
under the 2D spatial softmax distribution
"""
assert (feature.shape[1] == self._in_c)
assert (feature.shape[2] == self._in_h)
assert (feature.shape[3] == self._in_w)
if self.nets is not None:
feature = self.nets(feature)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
feature = feature.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(feature / self.temperature, dim=-1)
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
# stack to [B * K, 2]
expected_xy = torch.cat([expected_x, expected_y], 1)
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._num_kp, 2)
if self.training:
noise = torch.randn_like(feature_keypoints) * self.noise_std
feature_keypoints += noise
if self.output_variance:
# treat attention as a distribution, and compute second-order statistics to return
expected_xx = torch.sum(self.pos_x * self.pos_x * attention,
dim=1,
keepdim=True)
expected_yy = torch.sum(self.pos_y * self.pos_y * attention,
dim=1,
keepdim=True)
expected_xy = torch.sum(self.pos_x * self.pos_y * attention,
dim=1,
keepdim=True)
var_x = expected_xx - expected_x * expected_x
var_y = expected_yy - expected_y * expected_y
var_xy = expected_xy - expected_x * expected_y
# stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix
feature_covar = torch.cat([var_x, var_xy, var_xy, var_y],
1).reshape(-1, self._num_kp, 2, 2)
feature_keypoints = (feature_keypoints, feature_covar)
if isinstance(feature_keypoints, tuple):
self.kps = (feature_keypoints[0].detach(),
feature_keypoints[1].detach())
else:
self.kps = feature_keypoints.detach()
return feature_keypoints

View File

@@ -0,0 +1,83 @@
from diffusers.optimization import (Union, SchedulerType, Optional, Optimizer,
TYPE_TO_SCHEDULER_FUNCTION)
def get_scheduler(name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
**kwargs):
"""
Added kwargs vs diffuser's original implementation
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer, **kwargs)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(
f"{name} requires `num_warmup_steps`, please provide that argument."
)
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer,
num_warmup_steps=num_warmup_steps,
**kwargs)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(
f"{name} requires `num_training_steps`, please provide that argument."
)
return schedule_func(optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**kwargs)
import torch
from torch.optim.lr_scheduler import _LRScheduler
import pytorch_lightning as pl
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType
class SelectiveLRScheduler(_LRScheduler):
def __init__(self,
optimizer,
base_scheduler,
group_indices,
default_lr=[1e-5, 1e-4],
last_epoch=-1):
self.base_scheduler = base_scheduler
self.group_indices = group_indices # Indices of parameter groups to update
self.default_lr = default_lr
super().__init__(optimizer, last_epoch)
def step(self, epoch=None):
self.base_scheduler.step()
base_lrs = self.base_scheduler.get_last_lr()
for idx, group in enumerate(self.optimizer.param_groups):
if idx in self.group_indices:
group['lr'] = base_lrs[idx]
else:
# Reset the learning rate to its initial value
group['lr'] = self.default_lr[idx]

View File

@@ -0,0 +1,16 @@
import torch.nn as nn
class ModuleAttrMixin(nn.Module):
def __init__(self):
super().__init__()
self._dummy_variable = nn.Parameter()
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype

View File

@@ -0,0 +1,91 @@
import collections
import torch
import torch.nn as nn
from typing import Dict, Callable, List
def dict_apply(
x: Dict[str, torch.Tensor],
func: Callable[[torch.Tensor],
torch.Tensor]) -> Dict[str, torch.Tensor]:
result = dict()
for key, value in x.items():
if isinstance(value, dict):
result[key] = dict_apply(value, func)
else:
result[key] = func(value)
return result
def pad_remaining_dims(x, target):
assert x.shape == target.shape[:len(x.shape)]
return x.reshape(x.shape + (1, ) * (len(target.shape) - len(x.shape)))
def dict_apply_split(
x: Dict[str, torch.Tensor], split_func: Callable[[torch.Tensor],
Dict[str, torch.Tensor]]
) -> Dict[str, torch.Tensor]:
results = collections.defaultdict(dict)
for key, value in x.items():
result = split_func(value)
for k, v in result.items():
results[k][key] = v
return results
def dict_apply_reduce(
x: List[Dict[str,
torch.Tensor]], reduce_func: Callable[[List[torch.Tensor]],
torch.Tensor]
) -> Dict[str, torch.Tensor]:
result = dict()
for key in x[0].keys():
result[key] = reduce_func([x_[key] for x_ in x])
return result
def replace_submodules(root_module: nn.Module, predicate: Callable[[nn.Module],
bool],
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [
k.split('.')
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule('.'.join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
bn_list = [
k.split('.')
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
assert len(bn_list) == 0
return root_module
def optimizer_to(optimizer, device):
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device=device)
return optimizer

View File

@@ -0,0 +1,960 @@
"""
A collection of utilities for working with nested tensor structures consisting
of numpy arrays and torch tensors.
"""
import collections
import numpy as np
import torch
def recursive_dict_list_tuple_apply(x, type_func_dict):
"""
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
{data_type: function_to_apply}.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
type_func_dict (dict): a mapping from data types to the functions to be
applied for each data type.
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
assert (list not in type_func_dict)
assert (tuple not in type_func_dict)
assert (dict not in type_func_dict)
if isinstance(x, (dict, collections.OrderedDict)):
new_x = collections.OrderedDict() if isinstance(
x, collections.OrderedDict) else dict()
for k, v in x.items():
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
return new_x
elif isinstance(x, (list, tuple)):
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
if isinstance(x, tuple):
ret = tuple(ret)
return ret
else:
for t, f in type_func_dict.items():
if isinstance(x, t):
return f(x)
else:
raise NotImplementedError('Cannot handle data type %s' %
str(type(x)))
def map_tensor(x, func):
"""
Apply function @func to torch.Tensor objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each tensor
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
torch.Tensor: func,
type(None): lambda x: x,
})
def map_ndarray(x, func):
"""
Apply function @func to np.ndarray objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
np.ndarray: func,
type(None): lambda x: x,
})
def map_tensor_ndarray(x, tensor_func, ndarray_func):
"""
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
np.ndarray objects in a nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
tensor_func (function): function to apply to each tensor
ndarray_Func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: tensor_func,
np.ndarray: ndarray_func,
type(None): lambda x: x,
})
def clone(x):
"""
Clones all torch tensors and numpy arrays in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.clone(),
np.ndarray: lambda x: x.copy(),
type(None): lambda x: x,
})
def detach(x):
"""
Detaches all torch tensors in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
torch.Tensor: lambda x: x.detach(),
})
def to_batch(x):
"""
Introduces a leading batch dimension of 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x[None, ...],
np.ndarray: lambda x: x[None, ...],
type(None): lambda x: x,
})
def to_sequence(x):
"""
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x[:, None, ...],
np.ndarray: lambda x: x[:, None, ...],
type(None): lambda x: x,
})
def index_at_time(x, ind):
"""
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
ind (int): index
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x[:, ind, ...],
np.ndarray: lambda x: x[:, ind, ...],
type(None): lambda x: x,
})
def unsqueeze(x, dim):
"""
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
dim (int): dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
type(None): lambda x: x,
})
def contiguous(x):
"""
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.contiguous(),
np.ndarray: lambda x: np.ascontiguousarray(x),
type(None): lambda x: x,
})
def to_device(x, device):
"""
Sends all torch tensors in nested dictionary or list or tuple to device
@device, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x, d=device: x.to(d),
type(None): lambda x: x,
})
def to_tensor(x):
"""
Converts all numpy arrays in nested dictionary or list or tuple to
torch tensors (and leaves existing torch Tensors as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x,
np.ndarray: lambda x: torch.from_numpy(x),
type(None): lambda x: x,
})
def to_numpy(x):
"""
Converts all torch tensors in nested dictionary or list or tuple to
numpy (and leaves existing numpy arrays as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy()
else:
return tensor.detach().numpy()
return recursive_dict_list_tuple_apply(x, {
torch.Tensor: f,
np.ndarray: lambda x: x,
type(None): lambda x: x,
})
def to_list(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to a list, and returns a new nested structure. Useful for
json encoding.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy().tolist()
else:
return tensor.detach().numpy().tolist()
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: f,
np.ndarray: lambda x: x.tolist(),
type(None): lambda x: x,
})
def to_float(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to float type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.float(),
np.ndarray: lambda x: x.astype(np.float32),
type(None): lambda x: x,
})
def to_uint8(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to uint8 type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.byte(),
np.ndarray: lambda x: x.astype(np.uint8),
type(None): lambda x: x,
})
def to_torch(x, device):
"""
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
torch tensors on device @device and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return to_device(to_float(to_tensor(x)), device)
def to_one_hot_single(tensor, num_class):
"""
Convert tensor to one-hot representation, assuming a certain number of total class labels.
Args:
tensor (torch.Tensor): tensor containing integer labels
num_class (int): number of classes
Returns:
x (torch.Tensor): tensor containing one-hot representation of labels
"""
x = torch.zeros(tensor.size() + (num_class, )).to(tensor.device)
x.scatter_(-1, tensor.unsqueeze(-1), 1)
return x
def to_one_hot(tensor, num_class):
"""
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
assuming a certain number of total class labels.
Args:
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
num_class (int): number of classes
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(tensor,
func=lambda x, nc=num_class: to_one_hot_single(x, nc))
def flatten_single(x, begin_axis=1):
"""
Flatten a tensor in all dimensions from @begin_axis onwards.
Args:
x (torch.Tensor): tensor to flatten
begin_axis (int): which axis to flatten from
Returns:
y (torch.Tensor): flattened tensor
"""
fixed_size = x.size()[:begin_axis]
_s = list(fixed_size) + [-1]
return x.reshape(*_s)
def flatten(x, begin_axis=1):
"""
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): which axis to flatten from
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
torch.Tensor:
lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
})
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions in a tensor to a target dimension.
Args:
x (torch.Tensor): tensor to reshape
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (torch.Tensor): reshaped tensor
"""
assert (begin_axis <= end_axis)
assert (begin_axis >= 0)
assert (end_axis < len(x.shape))
assert (isinstance(target_dims, (tuple, list)))
s = x.shape
final_s = []
for i in range(len(s)):
if i == begin_axis:
final_s.extend(target_dims)
elif i < begin_axis or i > end_axis:
final_s.append(s[i])
return x.reshape(*final_s)
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
to a target dimension.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor:
lambda x, b=begin_axis, e=end_axis, t=target_dims:
reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t),
np.ndarray:
lambda x, b=begin_axis, e=end_axis, t=target_dims:
reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t),
type(None):
lambda x: x,
})
def join_dimensions(x, begin_axis, end_axis):
"""
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
all tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor:
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]),
np.ndarray:
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]),
type(None):
lambda x: x,
})
def expand_at_single(x, size, dim):
"""
Expand a tensor at a single dimension @dim by @size
Args:
x (torch.Tensor): input tensor
size (int): size to expand
dim (int): dimension to expand
Returns:
y (torch.Tensor): expanded tensor
"""
assert dim < x.ndimension()
assert x.shape[dim] == 1
expand_dims = [-1] * x.ndimension()
expand_dims[dim] = size
return x.expand(*expand_dims)
def expand_at(x, size, dim):
"""
Expand all tensors in nested dictionary or list or tuple at a single
dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
def unsqueeze_expand_at(x, size, dim):
"""
Unsqueeze and expand a tensor at a dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to unsqueeze and expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze(x, dim)
return expand_at(x, size, dim)
def repeat_by_expand_at(x, repeats, dim):
"""
Repeat a dimension by combining expand and reshape operations.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
repeats (int): number of times to repeat the target dimension
dim (int): dimension to repeat on
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze_expand_at(x, repeats, dim + 1)
return join_dimensions(x, dim, dim + 1)
def named_reduce_single(x, reduction, dim):
"""
Reduce tensor at a dimension by named reduction functions.
Args:
x (torch.Tensor): tensor to be reduced
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (torch.Tensor): reduced tensor
"""
assert x.ndimension() > dim
assert reduction in ["sum", "max", "mean", "flatten"]
if reduction == "flatten":
x = flatten(x, begin_axis=dim)
elif reduction == "max":
x = torch.max(x, dim=dim)[0] # [B, D]
elif reduction == "sum":
x = torch.sum(x, dim=dim)
else:
x = torch.mean(x, dim=dim)
return x
def named_reduce(x, reduction, dim):
"""
Reduces all tensors in nested dictionary or list or tuple at a dimension
using a named reduction function.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(
x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
"""
This function indexes out a target dimension of a tensor in a structured way,
by allowing a different value to be selected for each member of a flat index
tensor (@indices) corresponding to a source dimension. This can be interpreted
as moving along the source dimension, using the corresponding index value
in @indices to select values for all other dimensions outside of the
source and target dimensions. A common use case is to gather values
in target dimension 1 for each batch member (target dimension 0).
Args:
x (torch.Tensor): tensor to gather values for
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
"""
assert len(indices.shape) == 1
assert x.shape[source_dim] == indices.shape[0]
# unsqueeze in all dimensions except the source dimension
new_shape = [1] * x.ndimension()
new_shape[source_dim] = -1
indices = indices.reshape(*new_shape)
# repeat in all dimensions - but preserve shape of source dimension,
# and make sure target_dimension has singleton dimension
expand_shape = list(x.shape)
expand_shape[source_dim] = -1
expand_shape[target_dim] = 1
indices = indices.expand(*expand_shape)
out = x.gather(dim=target_dim, index=indices)
return out.squeeze(target_dim)
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
"""
Apply @gather_along_dim_with_dim_single to all tensors in a nested
dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x,
lambda y, t=target_dim, s=source_dim, i=indices:
gather_along_dim_with_dim_single(y, t, s, i))
def gather_sequence_single(seq, indices):
"""
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
the batch given an index for each sequence.
Args:
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Return:
y (torch.Tensor): indexed tensor of shape [B, ....]
"""
return gather_along_dim_with_dim_single(seq,
target_dim=1,
source_dim=0,
indices=indices)
def gather_sequence(seq, indices):
"""
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
for tensors with leading dimensions [B, T, ...].
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Returns:
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
"""
return gather_along_dim_with_dim(seq,
target_dim=1,
source_dim=0,
indices=indices)
def pad_sequence_single(seq,
padding,
batched=False,
pad_same=True,
pad_values=None):
"""
Pad input tensor or array @seq in the time dimension (dimension 1).
Args:
seq (np.ndarray or torch.Tensor): sequence to be padded
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (np.ndarray or torch.Tensor)
"""
assert isinstance(seq, (np.ndarray, torch.Tensor))
assert pad_same or pad_values is not None
if pad_values is not None:
assert isinstance(pad_values, float)
repeat_func = np.repeat if isinstance(
seq, np.ndarray) else torch.repeat_interleave
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
ones_like_func = np.ones_like if isinstance(
seq, np.ndarray) else torch.ones_like
seq_dim = 1 if batched else 0
begin_pad = []
end_pad = []
if padding[0] > 0:
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
if padding[1] > 0:
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
end_pad.append(repeat_func(pad, padding[1], seq_dim))
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
"""
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (dict or list or tuple)
"""
return recursive_dict_list_tuple_apply(
seq, {
torch.Tensor:
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
pad_sequence_single(x, p, b, ps, pv),
np.ndarray:
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
pad_sequence_single(x, p, b, ps, pv),
type(None):
lambda x: x,
})
def assert_size_at_dim_single(x, size, dim, msg):
"""
Ensure that array or tensor @x has size @size in dim @dim.
Args:
x (np.ndarray or torch.Tensor): input array or tensor
size (int): size that tensors should have at @dim
dim (int): dimension to check
msg (str): text to display if assertion fails
"""
assert x.shape[dim] == size, msg
def assert_size_at_dim(x, size, dim, msg):
"""
Ensure that arrays and tensors in nested dictionary or list or tuple have
size @size in dim @dim.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size that tensors should have at @dim
dim (int): dimension to check
"""
map_tensor(
x,
lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
def get_shape(x):
"""
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
tensor's shape
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.shape,
np.ndarray: lambda x: x.shape,
type(None): lambda x: x,
})
def list_of_flat_dict_to_dict_of_list(list_of_dict):
"""
Helper function to go from a list of flat dictionaries to a dictionary of lists.
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
floats, etc.
Args:
list_of_dict (list): list of flat dictionaries
Returns:
dict_of_list (dict): dictionary of lists
"""
assert isinstance(list_of_dict, list)
dic = collections.OrderedDict()
for i in range(len(list_of_dict)):
for k in list_of_dict[i]:
if k not in dic:
dic[k] = []
dic[k].append(list_of_dict[i][k])
return dic
def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):
"""
Flatten a nested dict or list to a list.
For example, given a dict
{
a: 1
b: {
c: 2
}
c: 3
}
the function would return [(a, 1), (b_c, 2), (c, 3)]
Args:
d (dict, list): a nested dict or list to be flattened
parent_key (str): recursion helper
sep (str): separator for nesting keys
item_key (str): recursion helper
Returns:
list: a list of (key, value) tuples
"""
items = []
if isinstance(d, (tuple, list)):
new_key = parent_key + sep + item_key if len(
parent_key) > 0 else item_key
for i, v in enumerate(d):
items.extend(
flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
return items
elif isinstance(d, dict):
new_key = parent_key + sep + item_key if len(
parent_key) > 0 else item_key
for k, v in d.items():
assert isinstance(k, str)
items.extend(
flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
return items
else:
new_key = parent_key + sep + item_key if len(
parent_key) > 0 else item_key
return [(new_key, d)]
def time_distributed(inputs,
op,
activation=None,
inputs_as_kwargs=False,
inputs_as_args=False,
**kwargs):
"""
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
outputs to [B, T, ...].
Args:
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
op: a layer op that accepts inputs
activation: activation to apply at the output
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
inputs_as_args (bool) whether to feed input as a args list to the op
kwargs (dict): other kwargs to supply to the op
Returns:
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
"""
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
inputs = join_dimensions(inputs, 0, 1)
if inputs_as_kwargs:
outputs = op(**inputs, **kwargs)
elif inputs_as_args:
outputs = op(*inputs, **kwargs)
else:
outputs = op(inputs, **kwargs)
if activation is not None:
outputs = map_tensor(outputs, activation)
outputs = reshape_dimensions(outputs,
begin_axis=0,
end_axis=0,
target_dims=(batch_size, seq_len))
return outputs

View File

@@ -0,0 +1,701 @@
import logging
import torch
import torch.nn as nn
import einops
from einops import rearrange, repeat
from typing import Union
from unifolm_wma.models.diffusion_head.conv1d_components import (
Downsample1d, Upsample1d, Conv1dBlock)
from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
from unifolm_wma.utils.basics import zero_module
from unifolm_wma.utils.common import (
checkpoint,
exists,
default,
)
from unifolm_wma.utils.utils import instantiate_from_config
logger = logging.getLogger(__name__)
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.,
relative_position=False):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
def efficient_forward(self, x, context=None):
spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None
q = self.to_q(x)
if spatial_self_attn:
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q,
k,
v,
attn_bias=None,
op=None)
out = (out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attention_cls=None):
super().__init__()
attn_cls = CrossAttention if attention_cls is None else attention_cls
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None, **kwargs):
## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
input_tuple = (
x,
) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
if context is not None:
input_tuple = (x, context)
return checkpoint(self._forward, input_tuple, self.parameters(),
self.checkpoint)
def _forward(self, x, context=None, mask=None):
x = self.attn1(self.norm1(x),
context=context if self.disable_self_attn else None,
mask=mask) + x
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
class ActionLatentImageCrossAttention(nn.Module):
def __init__(self,
in_channels,
in_dim,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
use_checkpoint=True,
disable_self_attn=False,
use_linear=True):
super().__init__()
"""
in_channels: action input dim
"""
self.in_channels = in_channels
self.in_dim = in_dim
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=8,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.proj_in_action = nn.Linear(in_dim, inner_dim)
self.proj_in_cond = nn.Linear(context_dim, inner_dim)
self.proj_out = zero_module(nn.Linear(inner_dim, in_dim))
self.use_linear = use_linear
attention_cls = None
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
attention_cls=attention_cls)
for d in range(depth)
])
def forward(self, x, context=None, **kwargs):
ba, ca, da = x.shape
b, t, c, h, w = context.shape
context = rearrange(context, 'b t c h w -> b (t h w) c').contiguous()
x_in = x
x = self.norm(x) # ba x ja x d_in
if self.use_linear:
x = self.proj_in_action(x)
context = self.proj_in_cond(context)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, **kwargs)
if self.use_linear:
x = self.proj_out(x)
return x + x_in
class ConditionalResidualBlock1D(nn.Module):
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8,
cond_predict_scale=True,
use_linear_act_proj=False):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels,
out_channels,
kernel_size,
n_groups=n_groups),
Conv1dBlock(out_channels,
out_channels,
kernel_size,
n_groups=n_groups),
])
self.cond_predict_scale = cond_predict_scale
self.use_linear_act_proj = use_linear_act_proj
self.out_channels = out_channels
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels
if cond_predict_scale and use_linear_act_proj:
cond_channels = out_channels * 2
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, cond=None):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
B, T, _ = cond.shape
out = self.blocks[0](x)
if self.cond_predict_scale:
embed = self.cond_encoder(cond)
if self.use_linear_act_proj:
embed = embed.reshape(B * T, -1)
embed = embed.reshape(-1, 2, self.out_channels, 1)
else:
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
out = scale * out + bias
# else:
# out = out + embed
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
def __init__(self,
input_dim,
n_obs_steps=1,
local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=[256, 512, 1024],
kernel_size=3,
n_groups=8,
cond_predict_scale=False,
horizon=16,
num_head_channels=64,
use_linear_attn=True,
use_linear_act_proj=True,
act_proj_dim=32,
cond_cross_attention=False,
context_dims=None,
image_size=None,
imagen_cond_gradient=False,
last_frame_only=False,
use_imagen_mid_only=False,
use_z_only=False,
spatial_num_kp=32,
obs_encoder_config=None):
super().__init__()
self.n_obs_steps = n_obs_steps
self.obs_encoder = instantiate_from_config(obs_encoder_config)
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed + self.obs_encoder.output_shape()[-1] * self.n_obs_steps
in_out = list(zip(all_dims[:-1], all_dims[1:]))
local_cond_encoder = None
down_modules = nn.ModuleList([])
dim_a_list = []
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
if ind == 0:
dim_a = horizon
else:
dim_a = horizon // 2 * ind
dim_a_list.append(dim_a)
# for attention
num_heads = dim_out // num_head_channels
dim_head = num_head_channels
if use_linear_act_proj:
if use_imagen_mid_only:
cur_cond_dim = cond_dim + 2 * context_dims[-1]
elif use_z_only:
cur_cond_dim = cond_dim + 2 * spatial_num_kp
else:
cur_cond_dim = cond_dim + 2 * context_dims[ind]
else:
cur_cond_dim = cond_dim + horizon * context_dims[ind]
down_modules.append(
nn.ModuleList([
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ActionLatentImageCrossAttention(
dim_out,
dim_a,
num_heads,
dim_head,
context_dim=context_dims[ind],
use_linear=use_linear_attn)
if cond_cross_attention else nn.Identity(),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ActionLatentImageCrossAttention(mid_dim,
dim_a_list[-1],
num_heads,
dim_head,
context_dim=context_dims[-1],
use_linear=use_linear_attn)
if cond_cross_attention else nn.Identity(),
])
up_modules = nn.ModuleList([])
context_dims = context_dims[::-1]
for ind, (dim_in, dim_out) in enumerate(
reversed(in_out[1:] + [(down_dims[-1], down_dims[-1])])):
is_last = ind >= (len(in_out) - 1)
if use_linear_act_proj:
if use_imagen_mid_only:
cur_cond_dim = cond_dim + 2 * context_dims[0]
elif use_z_only:
cur_cond_dim = cond_dim + 2 * spatial_num_kp
else:
cur_cond_dim = cond_dim + 2 * context_dims[ind]
else:
cur_cond_dim = cond_dim + horizon * context_dims[ind]
up_modules.append(
nn.ModuleList([
ConditionalResidualBlock1D(
dim_out + dim_in,
dim_in,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ConditionalResidualBlock1D(
dim_in,
dim_in,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ActionLatentImageCrossAttention(
dim_in,
dim_a_list.pop(),
num_heads,
dim_head,
context_dim=context_dims[ind],
use_linear=use_linear_attn)
if cond_cross_attention else nn.Identity(),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
if use_z_only:
h, w = image_size
self.spatial_softmax_blocks = nn.ModuleList(
[SpatialSoftmax((4, h, w), spatial_num_kp)])
else:
self.spatial_softmax_blocks = nn.ModuleList([])
context_dims = context_dims[::-1]
for ind, context_dim in enumerate(context_dims):
h, w = image_size
if ind != 0:
h //= 2**ind
w //= 2**ind
net = SpatialSoftmax((context_dim, h, w), context_dim)
self.spatial_softmax_blocks.append(net)
self.spatial_softmax_blocks.append(net)
self.spatial_softmax_blocks += self.spatial_softmax_blocks[
0:4][::-1]
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
self.cond_cross_attention = cond_cross_attention
self.use_linear_act_proj = use_linear_act_proj
self.proj_in_action = nn.Sequential(nn.Linear(1, act_proj_dim),
nn.LayerNorm(act_proj_dim))
self.proj_in_horizon = nn.Sequential(nn.Linear(horizon, act_proj_dim),
nn.LayerNorm(act_proj_dim))
self.proj_out_action = nn.Sequential(nn.LayerNorm(act_proj_dim),
nn.Linear(act_proj_dim, 1))
self.proj_out_horizon = nn.Sequential(nn.LayerNorm(act_proj_dim),
nn.Linear(act_proj_dim, horizon))
logger.info("number of parameters: %e",
sum(p.numel() for p in self.parameters()))
self.imagen_cond_gradient = imagen_cond_gradient
self.use_imagen_mid_only = use_imagen_mid_only
self.use_z_only = use_z_only
self.spatial_num_kp = spatial_num_kp
self.last_frame_only = last_frame_only
self.horizon = horizon
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
imagen_cond=None,
cond=None,
**kwargs):
"""
sample: (B,T,input_dim)
timestep: (B,) or int, diffusion step
imagen_cond: a list of hidden info from video gen unet
cond: dict:
image: (B, 3, To, h, w)
agent_pos: (B, Ta, d)
output: (B,T,input_dim)
"""
if not self.imagen_cond_gradient:
imagen_cond = [c.detach() for c in imagen_cond]
cond = {'image': cond[0], 'agent_pos': cond[1]}
cond['image'] = cond['image'].permute(0, 2, 1, 3,
4)
cond['image'] = rearrange(cond['image'], 'b t c h w -> (b t) c h w')
cond['agent_pos'] = rearrange(cond['agent_pos'], 'b t d -> (b t) d')
B, T, D = sample.shape
if self.use_linear_act_proj:
sample = self.proj_in_action(sample.unsqueeze(-1))
global_cond = self.obs_encoder(cond)
global_cond = rearrange(global_cond,
'(b t) d -> b 1 (t d)',
b=B,
t=self.n_obs_steps)
global_cond = repeat(global_cond,
'b c d -> b (repeat c) d',
repeat=T)
else:
sample = einops.rearrange(sample, 'b h t -> b t h')
sample = self.proj_in_horizon(sample)
robo_state_cond = rearrange(robo_state_cond, 'b t d -> b 1 (t d)')
robo_state_cond = repeat(robo_state_cond,
'b c d -> b (repeat c) d',
repeat=2)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps],
dtype=torch.long,
device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
x = sample if not self.use_linear_act_proj else sample.reshape(
B * T, D, -1)
h = []
for idx, modules in enumerate(self.down_modules):
if self.cond_cross_attention:
(resnet, resnet2, crossatten, downsample) = modules
else:
(resnet, resnet2, _, downsample) = modules
# Access the cond from the unet embeds from video unet
if self.use_imagen_mid_only:
imagen_cond = imagen_cond_mid
elif self.use_z_only:
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
else:
imagen_cond = imagen_cond_down[idx]
if self.last_frame_only:
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
imagen_cond = repeat(imagen_cond,
'b t c h w -> b (repeat t) c h w',
repeat=self.horizon)
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
if self.use_imagen_mid_only:
imagen_cond = self.spatial_softmax_blocks[len(
self.spatial_softmax_blocks) // 2](imagen_cond)
elif self.use_z_only:
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
else:
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
h.append(x)
x = downsample(x)
#>>> mide blocks
resnet, resnet2, _ = self.mid_modules
# Access the cond from the unet embeds from video unet
if self.use_z_only:
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
else:
imagen_cond = imagen_cond_mid
if self.last_frame_only:
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
imagen_cond = repeat(imagen_cond,
'b t c h w -> b (repeat t) c h w',
repeat=self.horizon)
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
idx += 1
if self.use_z_only:
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
else:
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
#>>> up blocks
idx += 1
for jdx, modules in enumerate(self.up_modules):
if self.cond_cross_attention:
(resnet, resnet2, crossatten, upsample) = modules
else:
(resnet, resnet2, _, upsample) = modules
# Access the cond from the unet embeds from video unet
if self.use_imagen_mid_only:
imagen_cond = imagen_cond_mid
elif self.use_z_only:
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
else:
imagen_cond = imagen_cond_up[jdx]
if self.last_frame_only:
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
imagen_cond = repeat(imagen_cond,
'b t c h w -> b (repeat t) c h w',
repeat=self.horizon)
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
if self.use_imagen_mid_only:
imagen_cond = self.spatial_softmax_blocks[len(
self.spatial_softmax_blocks) // 2](imagen_cond)
elif self.use_z_only:
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
else:
imagen_cond = self.spatial_softmax_blocks[jdx +
idx](imagen_cond)
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
x = upsample(x)
x = self.final_conv(x)
if self.use_linear_act_proj:
x = x.reshape(B, T, D, -1)
x = self.proj_out_action(x)
x = x.reshape(B, T, D)
else:
x = self.proj_out_horizon(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x

View File

@@ -0,0 +1,52 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels,
out_channels,
kernel_size,
padding=kernel_size // 2),
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
def test():
cb = Conv1dBlock(256, 128, kernel_size=3)
x = torch.zeros((1, 256, 16))
o = cb(x)

View File

@@ -0,0 +1,80 @@
import copy
import torch
from torch.nn.modules.batchnorm import _BatchNorm
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(self,
model,
update_after_step=0,
inv_gamma=1.0,
power=2 / 3,
min_value=0.0,
max_value=0.9999):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.decay = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma)**-self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model):
self.decay = self.get_decay(self.optimization_step)
all_dataptrs = set()
for module, ema_module in zip(new_model.modules(),
self.averaged_model.modules()):
for param, ema_param in zip(module.parameters(recurse=False),
ema_module.parameters(recurse=False)):
# iterative over immediate parameters only.
if isinstance(param, dict):
raise RuntimeError('Dict parameter not supported')
if isinstance(module, _BatchNorm):
# skip batchnorms
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
elif not param.requires_grad:
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.decay)
ema_param.add_(param.data.to(dtype=ema_param.dtype),
alpha=1 - self.decay)
# verify that iterating over module and then parameters is identical to parameters recursively.
# assert old_all_dataptrs == all_dataptrs
self.optimization_step += 1

View File

@@ -0,0 +1,19 @@
import math
import torch
import torch.nn as nn
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

View File

@@ -0,0 +1,322 @@
import torch
import torch.nn as nn
import torchvision.transforms.functional as ttf
import unifolm_wma.models.diffusion_head.common.tensor_util as tu
class CropRandomizer(nn.Module):
"""
Randomly sample crops at input, and then average across crop features at output.
"""
def __init__(
self,
input_shape,
crop_height,
crop_width,
num_crops=1,
pos_enc=False,
):
"""
Args:
input_shape (tuple, list): shape of input (not including batch dimension)
crop_height (int): crop height
crop_width (int): crop width
num_crops (int): number of random crops to take
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
location of the cropped pixels in the source image
"""
super().__init__()
assert len(input_shape) == 3 # (C, H, W)
assert crop_height < input_shape[1]
assert crop_width < input_shape[2]
self.input_shape = input_shape
self.crop_height = crop_height
self.crop_width = crop_width
self.num_crops = num_crops
self.pos_enc = pos_enc
def output_shape_in(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_in operation, where raw inputs (usually observation modalities)
are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
# the number of crops are reshaped into the batch dimension, increasing the batch
# size from B to B * N
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
return [out_c, self.crop_height, self.crop_width]
def output_shape_out(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_out operation, where processed inputs (usually encoded observation
modalities) are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
# and then pools to result in [B, ...], only the batch dimension changes,
# and so the other dimensions retain their shape.
return list(input_shape)
def forward_in(self, inputs):
"""
Samples N random crops for each input in the batch, and then reshapes
inputs to [B * N, ...].
"""
assert len(
inputs.shape) >= 3 # must have at least (C, H, W) dimensions
if self.training:
# generate random crops
out, _ = sample_random_image_crops(
images=inputs,
crop_height=self.crop_height,
crop_width=self.crop_width,
num_crops=self.num_crops,
pos_enc=self.pos_enc,
)
# [B, N, ...] -> [B * N, ...]
return tu.join_dimensions(out, 0, 1)
else:
# take center crop during eval
out = ttf.center_crop(img=inputs,
output_size=(self.crop_height,
self.crop_width))
if self.num_crops > 1:
B, C, H, W = out.shape
out = out.unsqueeze(1).expand(B, self.num_crops, C, H,
W).reshape(-1, C, H, W)
# [B * N, ...]
return out
def forward_out(self, inputs):
"""
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
to result in shape [B, ...] to make sure the network output is consistent with
what would have happened if there were no randomization.
"""
if self.num_crops <= 1:
return inputs
else:
batch_size = (inputs.shape[0] // self.num_crops)
out = tu.reshape_dimensions(inputs,
begin_axis=0,
end_axis=0,
target_dims=(batch_size,
self.num_crops))
return out.mean(dim=1)
def forward(self, inputs):
return self.forward_in(inputs)
def __repr__(self):
"""Pretty print network."""
header = '{}'.format(str(self.__class__.__name__))
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
self.input_shape, self.crop_height, self.crop_width,
self.num_crops)
return msg
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
"""
Crops images at the locations specified by @crop_indices. Crops will be
taken across all channels.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
N is the number of crops to take per image and each entry corresponds
to the pixel height and width of where to take the crop. Note that
the indices can also be of shape [..., 2] if only 1 crop should
be taken per image. Leading dimensions must be consistent with
@images argument. Each index specifies the top left of the crop.
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
H and W are the height and width of @images and CH and CW are
@crop_height and @crop_width.
crop_height (int): height of crop to take
crop_width (int): width of crop to take
Returns:
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
"""
# make sure length of input shapes is consistent
assert crop_indices.shape[-1] == 2
ndim_im_shape = len(images.shape)
ndim_indices_shape = len(crop_indices.shape)
assert (ndim_im_shape == ndim_indices_shape +
1) or (ndim_im_shape == ndim_indices_shape + 2)
# maybe pad so that @crop_indices is shape [..., N, 2]
is_padded = False
if ndim_im_shape == ndim_indices_shape + 2:
crop_indices = crop_indices.unsqueeze(-2)
is_padded = True
# make sure leading dimensions between images and indices are consistent
assert images.shape[:-3] == crop_indices.shape[:-2]
device = images.device
image_c, image_h, image_w = images.shape[-3:]
num_crops = crop_indices.shape[-2]
# make sure @crop_indices are in valid range
assert (crop_indices[..., 0] >= 0).all().item()
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
assert (crop_indices[..., 1] >= 0).all().item()
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
crop_ind_grid_h = torch.arange(crop_height).to(device)
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h,
size=crop_width,
dim=-1)
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
crop_ind_grid_w = torch.arange(crop_width).to(device)
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w,
size=crop_height,
dim=0)
# combine into shape [CH, CW, 2]
crop_in_grid = torch.cat(
(crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
# shape array that tells us which pixels from the corresponding source image to grab.
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [
crop_height, crop_width, 2
]
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(
-2) + crop_in_grid.reshape(grid_reshape)
# For using @torch.gather, convert to flat indices from 2D indices, and also
# repeat across the channel dimension. To get flat index of each pixel to grab for
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[
..., 1] # shape [..., N, CH, CW]
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c,
dim=-3) # shape [..., N, C, CH, CW]
all_crop_inds = tu.flatten(all_crop_inds,
begin_axis=-2) # shape [..., N, C, CH * CW]
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
reshape_axis = len(crops.shape) - 1
crops = tu.reshape_dimensions(crops,
begin_axis=reshape_axis,
end_axis=reshape_axis,
target_dims=(crop_height, crop_width))
if is_padded:
# undo padding -> [..., C, CH, CW]
crops = crops.squeeze(-4)
return crops
def sample_random_image_crops(images,
crop_height,
crop_width,
num_crops,
pos_enc=False):
"""
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
@images.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_height (int): height of crop to take
crop_width (int): width of crop to take
num_crops (n): number of crops to sample
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
encoding of the original source pixel locations. This means that the
output crops will contain information about where in the source image
it was sampled from.
Returns:
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
"""
device = images.device
# maybe add 2 channels of spatial encoding to the source image
source_im = images
if pos_enc:
# spatial encoding [y, x] in [0, 1]
h, w = source_im.shape[-2:]
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
pos_y = pos_y.float().to(device) / float(h)
pos_x = pos_x.float().to(device) / float(w)
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
leading_shape = source_im.shape[:-3]
position_enc = position_enc[(None, ) * len(leading_shape)]
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
# concat across channel dimension with input
source_im = torch.cat((source_im, position_enc), dim=-3)
# make sure sample boundaries ensure crops are fully within the images
image_c, image_h, image_w = source_im.shape[-3:]
max_sample_h = image_h - crop_height
max_sample_w = image_w - crop_width
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
# we will sample [B, N] indices, but this supports having more than one leading dimension,
# or possibly no leading dimension.
#
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
crop_inds_h = (
max_sample_h *
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds_w = (
max_sample_w *
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds = torch.cat(
(crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)),
dim=-1) # shape [..., N, 2]
crops = crop_image_from_indices(
images=source_im,
crop_indices=crop_inds,
crop_height=crop_height,
crop_width=crop_width,
)
return crops, crop_inds

View File

@@ -0,0 +1,30 @@
import torch
import torchvision
def get_resnet(name, weights=None, **kwargs):
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", "r3m"
"""
# load r3m weights
if (weights == "r3m") or (weights == "R3M"):
return get_r3m(name=name, **kwargs)
func = getattr(torchvision.models, name)
resnet = func(weights=weights, **kwargs)
resnet.fc = torch.nn.Identity()
return resnet
def get_r3m(name, **kwargs):
"""
name: resnet18, resnet34, resnet50
"""
import r3m
r3m.device = 'cpu'
model = r3m.load_r3m(name)
r3m_model = model.module
resnet_model = r3m_model.convnet
resnet_model = resnet_model.to('cpu')
return resnet_model

View File

@@ -0,0 +1,247 @@
import copy
import torch
import torch.nn as nn
import torchvision
import json
import os
from unifolm_wma.models.diffusion_head.vision.crop_randomizer import CropRandomizer
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
from unifolm_wma.models.diffusion_head.common.module_attr_mixin import ModuleAttrMixin
from unifolm_wma.models.diffusion_head.common.pytorch_util import dict_apply, replace_submodules
from unifolm_wma.utils.utils import instantiate_from_config
from einops import rearrange, repeat
from typing import Dict, Tuple, Union
from pathlib import Path
class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(
self,
rgb_model_config: Dict,
shape_meta_path: str | None = None,
resize_shape: Union[Tuple[int, int], Dict[str, tuple],
None] = None,
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
random_crop: bool = True,
# replace BatchNorm with GroupNorm
use_group_norm: bool = False,
# use single rgb model for all rgb inputs
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
imagenet_norm: bool = False,
use_spatial_softmax=False,
spatial_softmax_kp=32,
use_dinoSiglip=False):
"""
Assumes rgb input: B,C,H,W
Assumes low_dim input: B,D
"""
super().__init__()
if not shape_meta_path:
shape_meta_path = str(Path(os.getcwd()) / "configs/train/meta.json")
with open(shape_meta_path, 'r') as file:
shape_meta = json.load(file)
rgb_model = instantiate_from_config(rgb_model_config)
rgb_keys = list()
low_dim_keys = list()
key_model_map = nn.ModuleDict()
key_transform_map = nn.ModuleDict()
key_shape_map = dict()
# handle sharing vision backbone
if share_rgb_model:
assert isinstance(rgb_model, nn.Module)
key_model_map['rgb'] = rgb_model
obs_shape_meta = shape_meta['obs']
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
type = attr.get('type', 'low_dim')
key_shape_map[key] = shape
if type == 'rgb':
rgb_keys.append(key)
if not use_dinoSiglip:
# configure model for this key
this_model = None
if not share_rgb_model:
if isinstance(rgb_model, dict):
# have provided model for each key
this_model = rgb_model[key]
else:
assert isinstance(rgb_model, nn.Module)
# have a copy of the rgb model
this_model = copy.deepcopy(rgb_model)
if this_model is not None:
if use_group_norm:
this_model = replace_submodules(
root_module=this_model,
predicate=lambda x: isinstance(
x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16,
num_channels=x.num_features))
key_model_map[key] = this_model
# configure resize
input_shape = shape
this_resizer = nn.Identity()
if resize_shape is not None:
if isinstance(resize_shape, dict):
h, w = resize_shape[key]
else:
h, w = resize_shape
this_resizer = torchvision.transforms.Resize(size=(h,
w))
input_shape = (shape[0], h, w)
# configure randomizer
this_randomizer = nn.Identity()
if crop_shape is not None:
if isinstance(crop_shape, dict):
h, w = crop_shape[key]
else:
h, w = crop_shape
if random_crop:
this_randomizer = CropRandomizer(
input_shape=input_shape,
crop_height=h,
crop_width=w,
num_crops=1,
pos_enc=False)
else:
this_normalizer = torchvision.transforms.CenterCrop(
size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if imagenet_norm:
this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
this_transform = nn.Sequential(this_resizer,
this_randomizer,
this_normalizer)
key_transform_map[key] = this_transform
else:
key_model_map[key] = rgb_model
elif type == 'low_dim':
low_dim_keys.append(key)
else:
raise RuntimeError(f"Unsupported obs type: {type}")
rgb_keys = sorted(rgb_keys)
low_dim_keys = sorted(low_dim_keys)
self.shape_meta = shape_meta
self.key_model_map = key_model_map
self.key_transform_map = key_transform_map
self.share_rgb_model = share_rgb_model
self.rgb_keys = rgb_keys
self.low_dim_keys = low_dim_keys
self.key_shape_map = key_shape_map
self.use_dinoSiglip = use_dinoSiglip
##NOTE add spatial softmax
self.use_spatial_softmax = use_spatial_softmax
if use_spatial_softmax and not use_dinoSiglip:
model = nn.Sequential(
key_model_map['image'].conv1,
key_model_map['image'].bn1,
key_model_map['image'].relu,
key_model_map['image'].maxpool,
key_model_map['image'].layer1,
key_model_map['image'].layer2,
key_model_map['image'].layer3,
key_model_map['image'].layer4,
)
key_model_map['image'] = model
input_shape = self.output_shape(resnet_output_shape=True)
self.spatial_softmax = SpatialSoftmax(input_shape,
num_kp=spatial_softmax_kp)
def forward(self, obs_dict, resnet_output_shape=False):
batch_size = None
features = list()
# process rgb input
if self.share_rgb_model:
# pass all rgb obs to rgb model
imgs = list()
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
imgs.append(img)
# (N*B,C,H,W)
imgs = torch.cat(imgs, dim=0)
# (N*B,D)
feature = self.key_model_map['rgb'](imgs)
# (N,B,D)
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
# (B,N,D)
feature = torch.moveaxis(feature, 0, 1)
# (B,N*D)
feature = feature.reshape(batch_size, -1)
features.append(feature)
else:
# run each rgb obs to independent models
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
if not self.use_dinoSiglip:
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
else:
feature = self.key_model_map[key](img)[:, :1, :]
if resnet_output_shape:
return feature
if not self.use_dinoSiglip and self.use_spatial_softmax:
feature = self.spatial_softmax(feature)
feature = feature.reshape(batch_size, -1)
features.append(feature)
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# concatenate all features
result = torch.cat(features, dim=-1)
return result
@torch.no_grad()
def output_shape(self, resnet_output_shape=False):
example_obs_dict = dict()
obs_shape_meta = self.shape_meta['obs']
batch_size = 1
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
this_obs = torch.zeros((batch_size, ) + shape,
dtype=self.dtype,
device=self.device)
example_obs_dict[key] = this_obs
example_output = self.forward(example_obs_dict,
resnet_output_shape=resnet_output_shape)
output_shape = example_output.shape[1:]
return output_shape