第一次完整测例跑完

This commit is contained in:
2026-01-18 00:30:10 +08:00
parent ca15cc593b
commit 25c6fc04db
180 changed files with 29305 additions and 0 deletions

View File

View File

@@ -0,0 +1,26 @@
from abc import abstractmethod
from torch.utils.data import IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
)
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass

View File

@@ -0,0 +1,230 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, nn
from typing import Dict, List
def create_stats_buffers(
shapes: Dict[str, List[int]],
modes: Dict[str, str],
stats: Dict[str, Dict[str, Tensor]] = None,
) -> Dict[str, Dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Args: (see Normalize and Unnormalize)
Returns:
Dict: A Dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, mode in modes.items():
assert mode in ["mean_std", "min_max"]
shape = tuple(shapes[key])
if "image" in key:
# sanity checks
assert len(
shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_Dict`, as expected. During forward,
# we assert they are not infinity anymore.
if "action" in key:
target_key = "action"
elif "state" in key:
target_key = 'observation.state'
else:
target_key = key
buffer = {}
if mode == "mean_std":
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict({
"mean":
nn.Parameter(mean, requires_grad=False),
"std":
nn.Parameter(std, requires_grad=False),
})
elif mode == "min_max":
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict({
"min":
nn.Parameter(min, requires_grad=False),
"max":
nn.Parameter(max, requires_grad=False),
})
if stats is not None:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
if mode == "mean_std":
buffer["mean"].data = stats[target_key]["mean"].clone()
buffer["std"].data = stats[target_key]["std"].clone()
elif mode == "min_max":
buffer["min"].data = stats[target_key]["min"].clone()
buffer["max"].data = stats[target_key]["max"].clone()
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model.")
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
shapes: Dict[str, List[int]],
modes: Dict[str, str],
stats: Dict[str, Dict[str, Tensor]] = None,
):
"""
Args:
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
and values are Dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_Dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@torch.no_grad()
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
for key, mode in self.modes.items():
if key not in batch:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(
self,
shapes: Dict[str, List[int]],
modes: Dict[str, str],
stats: Dict[str, Dict[str, Tensor]] = None,
):
"""
Args:
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
and values are Dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_Dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@torch.no_grad()
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
for key, mode in self.modes.items():
if key not in batch:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch

View File

@@ -0,0 +1,60 @@
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from typing import Dict, List, Union
from pathlib import Path
from safetensors.torch import load_file
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(
root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
path = hf_hub_download(repo_id,
"meta_data/episode_data_index.safetensors",
repo_type="dataset",
revision=version)
return load_file(path)
def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
path = hf_hub_download(repo_id,
"meta_data/stats.safetensors",
repo_type="dataset",
revision=version)
stats = load_file(path)
return unflatten_dict(stats)

View File

@@ -0,0 +1,408 @@
import torch
import os
import random
import pandas as pd
import h5py
from decord import VideoReader, cpu
from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path
from unifolm_wma.data.utils import load_stats
from unifolm_wma.data.normolize import Normalize, Unnormalize
class WMAData(Dataset):
"""
Assuming the following dataset structure:
dataset_dir/
├── videos
│ ├──dataset_name
│ │ ├──camera_view_dir
│ │ ├── 0.mp4
│ │ ├── 1.mp4
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset_name
│ ├── meta_data
│ ├── 0.h5
│ ├── 1.h5
│ └── ...
└── dataset_name.csv
"""
def __init__(
self,
meta_path,
data_dir,
subsample=None,
video_length=16,
resolution=[256, 512],
frame_stride=1,
frame_stride_min=1,
spatial_transform=None,
crop_resolution=None,
fps_max=None,
load_raw_resolution=False,
fixed_fps=None,
random_fs=False,
cond_robot_label_prob=0.0,
transition_dir=None,
dataset_name=None,
normalization_mode='min_max',
individual_normalization=False,
n_obs_steps=1,
max_action_dim=7,
max_state_dim=7,
):
self.meta_path = meta_path
self.data_dir = data_dir
self.subsample = subsample
self.video_length = video_length
self.resolution = [resolution, resolution] if isinstance(
resolution, int) else resolution
self.fps_max = fps_max
self.frame_stride = frame_stride
self.frame_stride_min = frame_stride_min
self.fixed_fps = fixed_fps
self.load_raw_resolution = load_raw_resolution
self.random_fs = random_fs
self.cond_robot_label_prob = cond_robot_label_prob
self.transition_dir = transition_dir
self.dataset_name = dataset_name
self.max_action_dim = max_action_dim
self.max_state_dim = max_state_dim
self._load_metadata()
if spatial_transform is not None:
if spatial_transform == "random_crop":
self.spatial_transform = transforms.RandomCrop(crop_resolution)
elif spatial_transform == "center_crop":
self.spatial_transform = transforms.Compose([
transforms.CenterCrop(resolution),
])
elif spatial_transform == "resize_center_crop":
self.spatial_transform = transforms.Compose([
transforms.Resize(min(self.resolution)),
transforms.CenterCrop(self.resolution),
])
elif spatial_transform == "resize":
self.spatial_transform = transforms.Resize(self.resolution)
else:
raise NotImplementedError
else:
self.spatial_transform = None
self.normalization_mode = normalization_mode
self.individual_normalization = individual_normalization
self.n_obs_steps = n_obs_steps
self._load_stats()
if individual_normalization:
self._init_normalizers()
def _load_metadata(self):
metadata = pd.read_csv(self.meta_path, dtype=str)
if self.subsample is not None:
metadata = metadata.sample(self.subsample, random_state=0)
self.metadata = metadata
# drop the rows contain NaN values
self.metadata.dropna(inplace=True)
print(
f">>> {metadata['data_dir'].iloc[0]}: {len(metadata)} data samples loaded."
)
def _load_stats(self):
self.stats = load_stats(self.dataset_name, None, self.transition_dir)
print(f">>> {self.metadata['data_dir'].iloc[0]}: data stats loaded.")
def _init_normalizers(self):
shape_dict = {
'pre_action': [self.stats['action']['max'].shape[-1]],
'action': [self.stats['action']['max'].shape[-1]],
'observation.state':
[self.stats['observation.state']['max'].shape[-1]],
'next.state': [self.stats['observation.state']['max'].shape[-1]]
}
normalization_mode_dict = {
'pre_action': self.normalization_mode,
'action': self.normalization_mode,
'observation.state': self.normalization_mode,
'next.state': self.normalization_mode
}
self.normalizer = Normalize(shape_dict, normalization_mode_dict,
self.stats)
self.unnormalizer = Unnormalize(shape_dict, normalization_mode_dict,
self.stats)
print(
f">>> {self.metadata['data_dir'].iloc[0]}: normalizer initiated.")
def _get_video_path(self, sample):
rel_video_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.mp4')
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
return full_video_fp
def _get_transition_path(self, sample):
data_dir = Path(sample['data_dir'])
if self.dataset_name == data_dir.name:
rel_transition_fp = os.path.join(str(data_dir),
str(sample['videoid']) + '.h5')
else:
rel_transition_fp = os.path.join(str(data_dir.parent),
str(sample['videoid']) + '.h5')
full_transition_fp = os.path.join(self.data_dir, 'transitions',
rel_transition_fp)
return full_transition_fp
def get_uni_vec(self, action_state_dict, action_type, state_type):
if 'pre_action' in action_state_dict:
action_state_dict['pre_action'], _ = self._map_to_uni_action(
action_state_dict['pre_action'], action_type)
if 'action' in action_state_dict:
action_state_dict['action'], action_state_dict[
'action_mask'] = self._map_to_uni_action(
action_state_dict['action'], action_type)
if 'observation.state' in action_state_dict:
action_state_dict['observation.state'], _ = self._map_to_uni_state(
action_state_dict['observation.state'], state_type)
if 'next.state' in action_state_dict:
action_state_dict['next.state'], action_state_dict[
'state_mask'] = self._map_to_uni_state(
action_state_dict['next.state'], state_type)
return action_state_dict
def _map_to_uni_action(self, action, action_type):
action_dim = action.shape[-1]
uni_action = torch.nn.functional.pad(
action, (0, self.max_action_dim - action_dim),
mode='constant',
value=0)
uni_action_mask = torch.zeros_like(uni_action)
uni_action_mask[:, :action_dim] = 1
return uni_action, uni_action_mask
def _map_to_uni_state(self, state, state_type):
state_dim = state.shape[-1]
uni_state = torch.nn.functional.pad(
state, (0, self.max_state_dim - state_dim),
mode='constant',
value=0)
uni_state_mask = torch.zeros_like(uni_state)
uni_state_mask[:, :state_dim] = 1
return uni_state, uni_state_mask
def __getitem__(self, index):
if self.random_fs:
frame_stride = random.randint(self.frame_stride_min,
self.frame_stride)
else:
frame_stride = self.frame_stride
# Get frames until success
while True:
index = index % len(self.metadata)
sample = self.metadata.iloc[index]
video_path = self._get_video_path(sample)
instruction = sample['instruction']
if self.cond_robot_label_prob > 0.0 and random.random(
) < self.cond_robot_label_prob:
if sample['embodiment'] != 'x':
instruction = sample['embodiment'] + ' [SEP] ' + sample[
'instruction']
try:
if self.load_raw_resolution:
video_reader = VideoReader(video_path, ctx=cpu(0))
else:
video_reader = VideoReader(video_path,
ctx=cpu(0),
width=530,
height=300)
if len(video_reader) < self.video_length:
print(
f">>> Video length ({len(video_reader)}) is smaller than target length({self.video_length})"
)
index += 1
continue
else:
pass
except:
index += 1
print(f">>> Error: load video failed! path = {video_path}")
continue
fps_ori = video_reader.get_avg_fps()
if self.fixed_fps is not None:
frame_stride = int(frame_stride *
(1.0 * fps_ori / self.fixed_fps))
# To avoid extreme cases when fixed_fps is used
frame_stride = max(frame_stride, 1)
# Get valid range (adapting case by case)
required_frame_num = frame_stride * (self.video_length - 1) + 1
frame_num = len(video_reader)
if frame_num < required_frame_num:
# Drop extra samples if fixed fps is required
if self.fixed_fps is not None and frame_num < required_frame_num * 0.5:
index += 1
continue
else:
frame_stride = frame_num // self.video_length
required_frame_num = frame_stride * (self.video_length -
1) + 1
# Select a random clip
random_range = frame_num - required_frame_num
start_idx = random.randint(
0, random_range -
frame_stride) if random_range - frame_stride > 0 else 0
# Calculate frame indices
frame_indices = [
start_idx + frame_stride * i for i in range(self.video_length)
]
try:
next_frame_indices = [
idx + frame_stride for idx in frame_indices
]
frames = video_reader.get_batch(next_frame_indices)
break
except:
print(
f">>> Error: Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]"
)
index += 1
continue
# Load transition data
transition_path = self._get_transition_path(sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# Load observable states
if start_idx < self.n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = self.n_obs_steps - 1 - start_idx
first_slice = states[0:1, :] # (t, d)
padding = first_slice.repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
assert states.shape[
0] == self.n_obs_steps, '>>> Do not have enough previous states as observation.'
# Load observable actions
if start_idx < self.n_obs_steps:
pre_action_indices = list(range(0, start_idx))
pre_actions = transition_dict['action'][pre_action_indices, :]
num_padding = self.n_obs_steps - start_idx
first_slice = torch.zeros_like(transition_dict['action'][:1, :])
padding = first_slice.repeat(num_padding, 1)
pre_actions = torch.cat((padding, pre_actions), dim=0)
else:
pre_action_indices = list(
range(start_idx - self.n_obs_steps, start_idx))
pre_actions = transition_dict['action'][pre_action_indices, :]
assert pre_actions.shape[
0] == self.n_obs_steps, ">>> Do not have enough previous actions as observation"
# Load future actions
actions = transition_dict['action'][frame_indices, :]
# Load future states
next_state_indices = [idx + frame_stride for idx in frame_indices]
next_states = transition_dict['observation.state'][
next_state_indices, :]
frames_action_state_dict = {
'pre_action': pre_actions,
'action': actions,
'observation.state': states,
'next.state': next_states
}
if self.individual_normalization:
frames_action_state_dict = self.normalizer(
frames_action_state_dict)
# Update action and states to unified vector
frames_action_state_dict = self.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
# Load observable images
if start_idx < self.n_obs_steps - 1:
action_net_frame_indices = list(range(0, start_idx + 1))
action_net_frames = video_reader.get_batch(
action_net_frame_indices)
action_net_frames = torch.tensor(
action_net_frames.asnumpy()).permute(0, 3, 1, 2).float()
first_slice = action_net_frames[0:1, :]
num_padding = self.n_obs_steps - 1 - start_idx
padding = first_slice.repeat(num_padding, 1, 1, 1)
action_net_frames = torch.cat((padding, action_net_frames), dim=0)
assert (
action_net_frames.shape[0] == self.n_obs_steps
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
action_net_frames = action_net_frames.permute(1, 0, 2, 3)
else:
action_net_frame_indices = list(
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
action_net_frames = video_reader.get_batch(
action_net_frame_indices)
assert (
action_net_frames.shape[0] == self.n_obs_steps
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
action_net_frames = torch.tensor(
action_net_frames.asnumpy()).permute(3, 0, 1, 2).float()
assert (frames.shape[0] == self.video_length
), f'{len(frames)}, self.video_length={self.video_length}'
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
if self.spatial_transform is not None:
frames = self.spatial_transform(frames)
action_net_frames = self.spatial_transform(action_net_frames)
if self.resolution is not None:
assert (frames.shape[2], frames.shape[3]) == (
self.resolution[0], self.resolution[1]
), f'frames={frames.shape}, self.resolution={self.resolution}'
assert (
action_net_frames.shape[2], action_net_frames.shape[3]
) == (
self.resolution[0], self.resolution[1]
), f'action_net_frames={action_net_frames.shape}, self.resolution={self.resolution}'
# Normalize frames tensors to [-1,1]
frames = (frames / 255 - 0.5) * 2
action_net_frames = (action_net_frames / 255 - 0.5) * 2
fps_clip = fps_ori // frame_stride
if self.fps_max is not None and fps_clip > self.fps_max:
fps_clip = self.fps_max
data = {
'video': frames,
'instruction': instruction,
'path': video_path,
'fps': fps_clip,
'frame_stride': frame_stride,
'observation.image': action_net_frames,
}
data.update(frames_action_state_dict)
return data
def __len__(self):
return len(self.metadata)

View File

View File

@@ -0,0 +1,267 @@
import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from einops import rearrange
from unifolm_wma.modules.networks.ae_modules import Encoder, Decoder
from unifolm_wma.utils.distributions import DiagonalGaussianDistribution
from unifolm_wma.utils.utils import instantiate_from_config
class AutoencoderKL(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
test=False,
logdir=None,
input_dim=4,
test_args=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.input_dim = input_dim
self.test = test
self.test_args = test_args
self.logdir = logdir
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize",
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
if self.test:
self.init_test()
def init_test(self, ):
self.test = True
save_dir = os.path.join(self.logdir, "test")
if 'ckpt' in self.test_args:
ckpt_name = os.path.basename(self.test_args.ckpt).split(
'.ckpt')[0] + f'_epoch{self._cur_epoch}'
self.root = os.path.join(save_dir, ckpt_name)
else:
self.root = save_dir
if 'test_subdir' in self.test_args:
self.root = os.path.join(save_dir, self.test_args.test_subdir)
self.root_zs = os.path.join(self.root, "zs")
self.root_dec = os.path.join(self.root, "reconstructions")
self.root_inputs = os.path.join(self.root, "inputs")
os.makedirs(self.root, exist_ok=True)
if self.test_args.save_z:
os.makedirs(self.root_zs, exist_ok=True)
if self.test_args.save_reconstruction:
os.makedirs(self.root_dec, exist_ok=True)
if self.test_args.save_input:
os.makedirs(self.root_inputs, exist_ok=True)
assert (self.test_args is not None)
self.test_maximum = getattr(self.test_args, 'test_maximum', None)
self.count = 0
self.eval_metrics = {}
self.decodes = []
self.save_decode_samples = 2048
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")
try:
self._cur_epoch = sd['epoch']
sd = sd["state_dict"]
except:
self._cur_epoch = 'null'
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x, **kwargs):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z, **kwargs):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if x.dim() == 5 and self.input_dim == 4:
b, c, t, h, w = x.shape
self.b = b
self.t = t
x = rearrange(x, 'b c t h w -> (b t) c h w')
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train")
self.log("aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train")
self.log("discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val")
discloss, log_dict_disc = self.loss(inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val")
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
list(self.decoder.parameters()) +
list(self.quant_conv.parameters()) +
list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr,
betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize",
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,217 @@
"""
Contains torch Modules that correspond to basic network building blocks, like
MLP, RNN, and CNN backbones.
"""
import abc
import numpy as np
import torch
import torch.nn.functional as F
class Module(torch.nn.Module):
"""
Base class for networks. The only difference from torch.nn.Module is that it
requires implementing @output_shape.
"""
@abc.abstractmethod
def output_shape(self, input_shape=None):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
raise NotImplementedError
"""
================================================
Visual Backbone Networks
================================================
"""
class ConvBase(Module):
"""
Base class for ConvNets.
"""
def __init__(self):
super(ConvBase, self).__init__()
# dirty hack - re-implement to pass the buck onto subclasses from ABC parent
def output_shape(self, input_shape):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
raise NotImplementedError
def forward(self, inputs):
x = self.nets(inputs)
if list(self.output_shape(list(inputs.shape)[1:])) != list(
x.shape)[1:]:
raise ValueError('Size mismatch: expect size %s, but got size %s' %
(str(self.output_shape(list(
inputs.shape)[1:])), str(list(x.shape)[1:])))
return x
class SpatialSoftmax(ConvBase):
"""
Spatial Softmax Layer.
Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al.
https://rll.berkeley.edu/dsae/dsae.pdf
"""
def __init__(
self,
input_shape,
num_kp=32,
temperature=1.,
learnable_temperature=False,
output_variance=False,
noise_std=0.0,
):
"""
Args:
input_shape (list): shape of the input feature (C, H, W)
num_kp (int): number of keypoints (None for not using spatialsoftmax)
temperature (float): temperature term for the softmax.
learnable_temperature (bool): whether to learn the temperature
output_variance (bool): treat attention as a distribution, and compute second-order statistics to return
noise_std (float): add random spatial noise to the predicted keypoints
"""
super(SpatialSoftmax, self).__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape # (C, H, W)
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._num_kp = num_kp
else:
self.nets = None
self._num_kp = self._in_c
self.learnable_temperature = learnable_temperature
self.output_variance = output_variance
self.noise_std = noise_std
if self.learnable_temperature:
# temperature will be learned
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
requires_grad=True)
self.register_parameter('temperature', temperature)
else:
# temperature held constant after initialization
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
requires_grad=False)
self.register_buffer('temperature', temperature)
pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., self._in_w),
np.linspace(-1., 1., self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h *
self._in_w)).float()
pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h *
self._in_w)).float()
self.register_buffer('pos_x', pos_x)
self.register_buffer('pos_y', pos_y)
self.kps = None
def __repr__(self):
"""Pretty print network."""
header = format(str(self.__class__.__name__))
return header + '(num_kp={}, temperature={}, noise={})'.format(
self._num_kp, self.temperature.item(), self.noise_std)
def output_shape(self, input_shape):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
assert (len(input_shape) == 3)
assert (input_shape[0] == self._in_c)
return [self._num_kp, 2]
def forward(self, feature):
"""
Forward pass through spatial softmax layer. For each keypoint, a 2D spatial
probability distribution is created using a softmax, where the support is the
pixel locations. This distribution is used to compute the expected value of
the pixel location, which becomes a keypoint of dimension 2. K such keypoints
are created.
Returns:
out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly
keypoint variance of shape [B, K, 2, 2] corresponding to the covariance
under the 2D spatial softmax distribution
"""
assert (feature.shape[1] == self._in_c)
assert (feature.shape[2] == self._in_h)
assert (feature.shape[3] == self._in_w)
if self.nets is not None:
feature = self.nets(feature)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
feature = feature.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(feature / self.temperature, dim=-1)
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
# stack to [B * K, 2]
expected_xy = torch.cat([expected_x, expected_y], 1)
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._num_kp, 2)
if self.training:
noise = torch.randn_like(feature_keypoints) * self.noise_std
feature_keypoints += noise
if self.output_variance:
# treat attention as a distribution, and compute second-order statistics to return
expected_xx = torch.sum(self.pos_x * self.pos_x * attention,
dim=1,
keepdim=True)
expected_yy = torch.sum(self.pos_y * self.pos_y * attention,
dim=1,
keepdim=True)
expected_xy = torch.sum(self.pos_x * self.pos_y * attention,
dim=1,
keepdim=True)
var_x = expected_xx - expected_x * expected_x
var_y = expected_yy - expected_y * expected_y
var_xy = expected_xy - expected_x * expected_y
# stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix
feature_covar = torch.cat([var_x, var_xy, var_xy, var_y],
1).reshape(-1, self._num_kp, 2, 2)
feature_keypoints = (feature_keypoints, feature_covar)
if isinstance(feature_keypoints, tuple):
self.kps = (feature_keypoints[0].detach(),
feature_keypoints[1].detach())
else:
self.kps = feature_keypoints.detach()
return feature_keypoints

View File

@@ -0,0 +1,83 @@
from diffusers.optimization import (Union, SchedulerType, Optional, Optimizer,
TYPE_TO_SCHEDULER_FUNCTION)
def get_scheduler(name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
**kwargs):
"""
Added kwargs vs diffuser's original implementation
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer, **kwargs)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(
f"{name} requires `num_warmup_steps`, please provide that argument."
)
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer,
num_warmup_steps=num_warmup_steps,
**kwargs)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(
f"{name} requires `num_training_steps`, please provide that argument."
)
return schedule_func(optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**kwargs)
import torch
from torch.optim.lr_scheduler import _LRScheduler
import pytorch_lightning as pl
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType
class SelectiveLRScheduler(_LRScheduler):
def __init__(self,
optimizer,
base_scheduler,
group_indices,
default_lr=[1e-5, 1e-4],
last_epoch=-1):
self.base_scheduler = base_scheduler
self.group_indices = group_indices # Indices of parameter groups to update
self.default_lr = default_lr
super().__init__(optimizer, last_epoch)
def step(self, epoch=None):
self.base_scheduler.step()
base_lrs = self.base_scheduler.get_last_lr()
for idx, group in enumerate(self.optimizer.param_groups):
if idx in self.group_indices:
group['lr'] = base_lrs[idx]
else:
# Reset the learning rate to its initial value
group['lr'] = self.default_lr[idx]

View File

@@ -0,0 +1,16 @@
import torch.nn as nn
class ModuleAttrMixin(nn.Module):
def __init__(self):
super().__init__()
self._dummy_variable = nn.Parameter()
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype

View File

@@ -0,0 +1,91 @@
import collections
import torch
import torch.nn as nn
from typing import Dict, Callable, List
def dict_apply(
x: Dict[str, torch.Tensor],
func: Callable[[torch.Tensor],
torch.Tensor]) -> Dict[str, torch.Tensor]:
result = dict()
for key, value in x.items():
if isinstance(value, dict):
result[key] = dict_apply(value, func)
else:
result[key] = func(value)
return result
def pad_remaining_dims(x, target):
assert x.shape == target.shape[:len(x.shape)]
return x.reshape(x.shape + (1, ) * (len(target.shape) - len(x.shape)))
def dict_apply_split(
x: Dict[str, torch.Tensor], split_func: Callable[[torch.Tensor],
Dict[str, torch.Tensor]]
) -> Dict[str, torch.Tensor]:
results = collections.defaultdict(dict)
for key, value in x.items():
result = split_func(value)
for k, v in result.items():
results[k][key] = v
return results
def dict_apply_reduce(
x: List[Dict[str,
torch.Tensor]], reduce_func: Callable[[List[torch.Tensor]],
torch.Tensor]
) -> Dict[str, torch.Tensor]:
result = dict()
for key in x[0].keys():
result[key] = reduce_func([x_[key] for x_ in x])
return result
def replace_submodules(root_module: nn.Module, predicate: Callable[[nn.Module],
bool],
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [
k.split('.')
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule('.'.join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
bn_list = [
k.split('.')
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
assert len(bn_list) == 0
return root_module
def optimizer_to(optimizer, device):
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device=device)
return optimizer

View File

@@ -0,0 +1,960 @@
"""
A collection of utilities for working with nested tensor structures consisting
of numpy arrays and torch tensors.
"""
import collections
import numpy as np
import torch
def recursive_dict_list_tuple_apply(x, type_func_dict):
"""
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
{data_type: function_to_apply}.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
type_func_dict (dict): a mapping from data types to the functions to be
applied for each data type.
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
assert (list not in type_func_dict)
assert (tuple not in type_func_dict)
assert (dict not in type_func_dict)
if isinstance(x, (dict, collections.OrderedDict)):
new_x = collections.OrderedDict() if isinstance(
x, collections.OrderedDict) else dict()
for k, v in x.items():
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
return new_x
elif isinstance(x, (list, tuple)):
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
if isinstance(x, tuple):
ret = tuple(ret)
return ret
else:
for t, f in type_func_dict.items():
if isinstance(x, t):
return f(x)
else:
raise NotImplementedError('Cannot handle data type %s' %
str(type(x)))
def map_tensor(x, func):
"""
Apply function @func to torch.Tensor objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each tensor
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
torch.Tensor: func,
type(None): lambda x: x,
})
def map_ndarray(x, func):
"""
Apply function @func to np.ndarray objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
np.ndarray: func,
type(None): lambda x: x,
})
def map_tensor_ndarray(x, tensor_func, ndarray_func):
"""
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
np.ndarray objects in a nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
tensor_func (function): function to apply to each tensor
ndarray_Func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: tensor_func,
np.ndarray: ndarray_func,
type(None): lambda x: x,
})
def clone(x):
"""
Clones all torch tensors and numpy arrays in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.clone(),
np.ndarray: lambda x: x.copy(),
type(None): lambda x: x,
})
def detach(x):
"""
Detaches all torch tensors in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
torch.Tensor: lambda x: x.detach(),
})
def to_batch(x):
"""
Introduces a leading batch dimension of 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x[None, ...],
np.ndarray: lambda x: x[None, ...],
type(None): lambda x: x,
})
def to_sequence(x):
"""
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x[:, None, ...],
np.ndarray: lambda x: x[:, None, ...],
type(None): lambda x: x,
})
def index_at_time(x, ind):
"""
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
ind (int): index
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x[:, ind, ...],
np.ndarray: lambda x: x[:, ind, ...],
type(None): lambda x: x,
})
def unsqueeze(x, dim):
"""
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
dim (int): dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
type(None): lambda x: x,
})
def contiguous(x):
"""
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.contiguous(),
np.ndarray: lambda x: np.ascontiguousarray(x),
type(None): lambda x: x,
})
def to_device(x, device):
"""
Sends all torch tensors in nested dictionary or list or tuple to device
@device, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x, d=device: x.to(d),
type(None): lambda x: x,
})
def to_tensor(x):
"""
Converts all numpy arrays in nested dictionary or list or tuple to
torch tensors (and leaves existing torch Tensors as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x,
np.ndarray: lambda x: torch.from_numpy(x),
type(None): lambda x: x,
})
def to_numpy(x):
"""
Converts all torch tensors in nested dictionary or list or tuple to
numpy (and leaves existing numpy arrays as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy()
else:
return tensor.detach().numpy()
return recursive_dict_list_tuple_apply(x, {
torch.Tensor: f,
np.ndarray: lambda x: x,
type(None): lambda x: x,
})
def to_list(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to a list, and returns a new nested structure. Useful for
json encoding.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy().tolist()
else:
return tensor.detach().numpy().tolist()
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: f,
np.ndarray: lambda x: x.tolist(),
type(None): lambda x: x,
})
def to_float(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to float type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.float(),
np.ndarray: lambda x: x.astype(np.float32),
type(None): lambda x: x,
})
def to_uint8(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to uint8 type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.byte(),
np.ndarray: lambda x: x.astype(np.uint8),
type(None): lambda x: x,
})
def to_torch(x, device):
"""
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
torch tensors on device @device and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return to_device(to_float(to_tensor(x)), device)
def to_one_hot_single(tensor, num_class):
"""
Convert tensor to one-hot representation, assuming a certain number of total class labels.
Args:
tensor (torch.Tensor): tensor containing integer labels
num_class (int): number of classes
Returns:
x (torch.Tensor): tensor containing one-hot representation of labels
"""
x = torch.zeros(tensor.size() + (num_class, )).to(tensor.device)
x.scatter_(-1, tensor.unsqueeze(-1), 1)
return x
def to_one_hot(tensor, num_class):
"""
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
assuming a certain number of total class labels.
Args:
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
num_class (int): number of classes
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(tensor,
func=lambda x, nc=num_class: to_one_hot_single(x, nc))
def flatten_single(x, begin_axis=1):
"""
Flatten a tensor in all dimensions from @begin_axis onwards.
Args:
x (torch.Tensor): tensor to flatten
begin_axis (int): which axis to flatten from
Returns:
y (torch.Tensor): flattened tensor
"""
fixed_size = x.size()[:begin_axis]
_s = list(fixed_size) + [-1]
return x.reshape(*_s)
def flatten(x, begin_axis=1):
"""
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): which axis to flatten from
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
torch.Tensor:
lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
})
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions in a tensor to a target dimension.
Args:
x (torch.Tensor): tensor to reshape
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (torch.Tensor): reshaped tensor
"""
assert (begin_axis <= end_axis)
assert (begin_axis >= 0)
assert (end_axis < len(x.shape))
assert (isinstance(target_dims, (tuple, list)))
s = x.shape
final_s = []
for i in range(len(s)):
if i == begin_axis:
final_s.extend(target_dims)
elif i < begin_axis or i > end_axis:
final_s.append(s[i])
return x.reshape(*final_s)
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
to a target dimension.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor:
lambda x, b=begin_axis, e=end_axis, t=target_dims:
reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t),
np.ndarray:
lambda x, b=begin_axis, e=end_axis, t=target_dims:
reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t),
type(None):
lambda x: x,
})
def join_dimensions(x, begin_axis, end_axis):
"""
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
all tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor:
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]),
np.ndarray:
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]),
type(None):
lambda x: x,
})
def expand_at_single(x, size, dim):
"""
Expand a tensor at a single dimension @dim by @size
Args:
x (torch.Tensor): input tensor
size (int): size to expand
dim (int): dimension to expand
Returns:
y (torch.Tensor): expanded tensor
"""
assert dim < x.ndimension()
assert x.shape[dim] == 1
expand_dims = [-1] * x.ndimension()
expand_dims[dim] = size
return x.expand(*expand_dims)
def expand_at(x, size, dim):
"""
Expand all tensors in nested dictionary or list or tuple at a single
dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
def unsqueeze_expand_at(x, size, dim):
"""
Unsqueeze and expand a tensor at a dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to unsqueeze and expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze(x, dim)
return expand_at(x, size, dim)
def repeat_by_expand_at(x, repeats, dim):
"""
Repeat a dimension by combining expand and reshape operations.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
repeats (int): number of times to repeat the target dimension
dim (int): dimension to repeat on
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze_expand_at(x, repeats, dim + 1)
return join_dimensions(x, dim, dim + 1)
def named_reduce_single(x, reduction, dim):
"""
Reduce tensor at a dimension by named reduction functions.
Args:
x (torch.Tensor): tensor to be reduced
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (torch.Tensor): reduced tensor
"""
assert x.ndimension() > dim
assert reduction in ["sum", "max", "mean", "flatten"]
if reduction == "flatten":
x = flatten(x, begin_axis=dim)
elif reduction == "max":
x = torch.max(x, dim=dim)[0] # [B, D]
elif reduction == "sum":
x = torch.sum(x, dim=dim)
else:
x = torch.mean(x, dim=dim)
return x
def named_reduce(x, reduction, dim):
"""
Reduces all tensors in nested dictionary or list or tuple at a dimension
using a named reduction function.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(
x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
"""
This function indexes out a target dimension of a tensor in a structured way,
by allowing a different value to be selected for each member of a flat index
tensor (@indices) corresponding to a source dimension. This can be interpreted
as moving along the source dimension, using the corresponding index value
in @indices to select values for all other dimensions outside of the
source and target dimensions. A common use case is to gather values
in target dimension 1 for each batch member (target dimension 0).
Args:
x (torch.Tensor): tensor to gather values for
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
"""
assert len(indices.shape) == 1
assert x.shape[source_dim] == indices.shape[0]
# unsqueeze in all dimensions except the source dimension
new_shape = [1] * x.ndimension()
new_shape[source_dim] = -1
indices = indices.reshape(*new_shape)
# repeat in all dimensions - but preserve shape of source dimension,
# and make sure target_dimension has singleton dimension
expand_shape = list(x.shape)
expand_shape[source_dim] = -1
expand_shape[target_dim] = 1
indices = indices.expand(*expand_shape)
out = x.gather(dim=target_dim, index=indices)
return out.squeeze(target_dim)
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
"""
Apply @gather_along_dim_with_dim_single to all tensors in a nested
dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x,
lambda y, t=target_dim, s=source_dim, i=indices:
gather_along_dim_with_dim_single(y, t, s, i))
def gather_sequence_single(seq, indices):
"""
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
the batch given an index for each sequence.
Args:
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Return:
y (torch.Tensor): indexed tensor of shape [B, ....]
"""
return gather_along_dim_with_dim_single(seq,
target_dim=1,
source_dim=0,
indices=indices)
def gather_sequence(seq, indices):
"""
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
for tensors with leading dimensions [B, T, ...].
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Returns:
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
"""
return gather_along_dim_with_dim(seq,
target_dim=1,
source_dim=0,
indices=indices)
def pad_sequence_single(seq,
padding,
batched=False,
pad_same=True,
pad_values=None):
"""
Pad input tensor or array @seq in the time dimension (dimension 1).
Args:
seq (np.ndarray or torch.Tensor): sequence to be padded
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (np.ndarray or torch.Tensor)
"""
assert isinstance(seq, (np.ndarray, torch.Tensor))
assert pad_same or pad_values is not None
if pad_values is not None:
assert isinstance(pad_values, float)
repeat_func = np.repeat if isinstance(
seq, np.ndarray) else torch.repeat_interleave
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
ones_like_func = np.ones_like if isinstance(
seq, np.ndarray) else torch.ones_like
seq_dim = 1 if batched else 0
begin_pad = []
end_pad = []
if padding[0] > 0:
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
if padding[1] > 0:
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
end_pad.append(repeat_func(pad, padding[1], seq_dim))
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
"""
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (dict or list or tuple)
"""
return recursive_dict_list_tuple_apply(
seq, {
torch.Tensor:
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
pad_sequence_single(x, p, b, ps, pv),
np.ndarray:
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
pad_sequence_single(x, p, b, ps, pv),
type(None):
lambda x: x,
})
def assert_size_at_dim_single(x, size, dim, msg):
"""
Ensure that array or tensor @x has size @size in dim @dim.
Args:
x (np.ndarray or torch.Tensor): input array or tensor
size (int): size that tensors should have at @dim
dim (int): dimension to check
msg (str): text to display if assertion fails
"""
assert x.shape[dim] == size, msg
def assert_size_at_dim(x, size, dim, msg):
"""
Ensure that arrays and tensors in nested dictionary or list or tuple have
size @size in dim @dim.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size that tensors should have at @dim
dim (int): dimension to check
"""
map_tensor(
x,
lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
def get_shape(x):
"""
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
tensor's shape
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.shape,
np.ndarray: lambda x: x.shape,
type(None): lambda x: x,
})
def list_of_flat_dict_to_dict_of_list(list_of_dict):
"""
Helper function to go from a list of flat dictionaries to a dictionary of lists.
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
floats, etc.
Args:
list_of_dict (list): list of flat dictionaries
Returns:
dict_of_list (dict): dictionary of lists
"""
assert isinstance(list_of_dict, list)
dic = collections.OrderedDict()
for i in range(len(list_of_dict)):
for k in list_of_dict[i]:
if k not in dic:
dic[k] = []
dic[k].append(list_of_dict[i][k])
return dic
def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):
"""
Flatten a nested dict or list to a list.
For example, given a dict
{
a: 1
b: {
c: 2
}
c: 3
}
the function would return [(a, 1), (b_c, 2), (c, 3)]
Args:
d (dict, list): a nested dict or list to be flattened
parent_key (str): recursion helper
sep (str): separator for nesting keys
item_key (str): recursion helper
Returns:
list: a list of (key, value) tuples
"""
items = []
if isinstance(d, (tuple, list)):
new_key = parent_key + sep + item_key if len(
parent_key) > 0 else item_key
for i, v in enumerate(d):
items.extend(
flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
return items
elif isinstance(d, dict):
new_key = parent_key + sep + item_key if len(
parent_key) > 0 else item_key
for k, v in d.items():
assert isinstance(k, str)
items.extend(
flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
return items
else:
new_key = parent_key + sep + item_key if len(
parent_key) > 0 else item_key
return [(new_key, d)]
def time_distributed(inputs,
op,
activation=None,
inputs_as_kwargs=False,
inputs_as_args=False,
**kwargs):
"""
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
outputs to [B, T, ...].
Args:
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
op: a layer op that accepts inputs
activation: activation to apply at the output
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
inputs_as_args (bool) whether to feed input as a args list to the op
kwargs (dict): other kwargs to supply to the op
Returns:
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
"""
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
inputs = join_dimensions(inputs, 0, 1)
if inputs_as_kwargs:
outputs = op(**inputs, **kwargs)
elif inputs_as_args:
outputs = op(*inputs, **kwargs)
else:
outputs = op(inputs, **kwargs)
if activation is not None:
outputs = map_tensor(outputs, activation)
outputs = reshape_dimensions(outputs,
begin_axis=0,
end_axis=0,
target_dims=(batch_size, seq_len))
return outputs

View File

@@ -0,0 +1,701 @@
import logging
import torch
import torch.nn as nn
import einops
from einops import rearrange, repeat
from typing import Union
from unifolm_wma.models.diffusion_head.conv1d_components import (
Downsample1d, Upsample1d, Conv1dBlock)
from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
from unifolm_wma.utils.basics import zero_module
from unifolm_wma.utils.common import (
checkpoint,
exists,
default,
)
from unifolm_wma.utils.utils import instantiate_from_config
logger = logging.getLogger(__name__)
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.,
relative_position=False):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
def efficient_forward(self, x, context=None):
spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None
q = self.to_q(x)
if spatial_self_attn:
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q,
k,
v,
attn_bias=None,
op=None)
out = (out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attention_cls=None):
super().__init__()
attn_cls = CrossAttention if attention_cls is None else attention_cls
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None, **kwargs):
## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
input_tuple = (
x,
) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
if context is not None:
input_tuple = (x, context)
return checkpoint(self._forward, input_tuple, self.parameters(),
self.checkpoint)
def _forward(self, x, context=None, mask=None):
x = self.attn1(self.norm1(x),
context=context if self.disable_self_attn else None,
mask=mask) + x
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
class ActionLatentImageCrossAttention(nn.Module):
def __init__(self,
in_channels,
in_dim,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
use_checkpoint=True,
disable_self_attn=False,
use_linear=True):
super().__init__()
"""
in_channels: action input dim
"""
self.in_channels = in_channels
self.in_dim = in_dim
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=8,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.proj_in_action = nn.Linear(in_dim, inner_dim)
self.proj_in_cond = nn.Linear(context_dim, inner_dim)
self.proj_out = zero_module(nn.Linear(inner_dim, in_dim))
self.use_linear = use_linear
attention_cls = None
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
attention_cls=attention_cls)
for d in range(depth)
])
def forward(self, x, context=None, **kwargs):
ba, ca, da = x.shape
b, t, c, h, w = context.shape
context = rearrange(context, 'b t c h w -> b (t h w) c').contiguous()
x_in = x
x = self.norm(x) # ba x ja x d_in
if self.use_linear:
x = self.proj_in_action(x)
context = self.proj_in_cond(context)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, **kwargs)
if self.use_linear:
x = self.proj_out(x)
return x + x_in
class ConditionalResidualBlock1D(nn.Module):
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8,
cond_predict_scale=True,
use_linear_act_proj=False):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels,
out_channels,
kernel_size,
n_groups=n_groups),
Conv1dBlock(out_channels,
out_channels,
kernel_size,
n_groups=n_groups),
])
self.cond_predict_scale = cond_predict_scale
self.use_linear_act_proj = use_linear_act_proj
self.out_channels = out_channels
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels
if cond_predict_scale and use_linear_act_proj:
cond_channels = out_channels * 2
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, cond=None):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
B, T, _ = cond.shape
out = self.blocks[0](x)
if self.cond_predict_scale:
embed = self.cond_encoder(cond)
if self.use_linear_act_proj:
embed = embed.reshape(B * T, -1)
embed = embed.reshape(-1, 2, self.out_channels, 1)
else:
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
out = scale * out + bias
# else:
# out = out + embed
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
def __init__(self,
input_dim,
n_obs_steps=1,
local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=[256, 512, 1024],
kernel_size=3,
n_groups=8,
cond_predict_scale=False,
horizon=16,
num_head_channels=64,
use_linear_attn=True,
use_linear_act_proj=True,
act_proj_dim=32,
cond_cross_attention=False,
context_dims=None,
image_size=None,
imagen_cond_gradient=False,
last_frame_only=False,
use_imagen_mid_only=False,
use_z_only=False,
spatial_num_kp=32,
obs_encoder_config=None):
super().__init__()
self.n_obs_steps = n_obs_steps
self.obs_encoder = instantiate_from_config(obs_encoder_config)
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed + self.obs_encoder.output_shape()[-1] * self.n_obs_steps
in_out = list(zip(all_dims[:-1], all_dims[1:]))
local_cond_encoder = None
down_modules = nn.ModuleList([])
dim_a_list = []
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
if ind == 0:
dim_a = horizon
else:
dim_a = horizon // 2 * ind
dim_a_list.append(dim_a)
# for attention
num_heads = dim_out // num_head_channels
dim_head = num_head_channels
if use_linear_act_proj:
if use_imagen_mid_only:
cur_cond_dim = cond_dim + 2 * context_dims[-1]
elif use_z_only:
cur_cond_dim = cond_dim + 2 * spatial_num_kp
else:
cur_cond_dim = cond_dim + 2 * context_dims[ind]
else:
cur_cond_dim = cond_dim + horizon * context_dims[ind]
down_modules.append(
nn.ModuleList([
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ActionLatentImageCrossAttention(
dim_out,
dim_a,
num_heads,
dim_head,
context_dim=context_dims[ind],
use_linear=use_linear_attn)
if cond_cross_attention else nn.Identity(),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ActionLatentImageCrossAttention(mid_dim,
dim_a_list[-1],
num_heads,
dim_head,
context_dim=context_dims[-1],
use_linear=use_linear_attn)
if cond_cross_attention else nn.Identity(),
])
up_modules = nn.ModuleList([])
context_dims = context_dims[::-1]
for ind, (dim_in, dim_out) in enumerate(
reversed(in_out[1:] + [(down_dims[-1], down_dims[-1])])):
is_last = ind >= (len(in_out) - 1)
if use_linear_act_proj:
if use_imagen_mid_only:
cur_cond_dim = cond_dim + 2 * context_dims[0]
elif use_z_only:
cur_cond_dim = cond_dim + 2 * spatial_num_kp
else:
cur_cond_dim = cond_dim + 2 * context_dims[ind]
else:
cur_cond_dim = cond_dim + horizon * context_dims[ind]
up_modules.append(
nn.ModuleList([
ConditionalResidualBlock1D(
dim_out + dim_in,
dim_in,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ConditionalResidualBlock1D(
dim_in,
dim_in,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ActionLatentImageCrossAttention(
dim_in,
dim_a_list.pop(),
num_heads,
dim_head,
context_dim=context_dims[ind],
use_linear=use_linear_attn)
if cond_cross_attention else nn.Identity(),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
if use_z_only:
h, w = image_size
self.spatial_softmax_blocks = nn.ModuleList(
[SpatialSoftmax((4, h, w), spatial_num_kp)])
else:
self.spatial_softmax_blocks = nn.ModuleList([])
context_dims = context_dims[::-1]
for ind, context_dim in enumerate(context_dims):
h, w = image_size
if ind != 0:
h //= 2**ind
w //= 2**ind
net = SpatialSoftmax((context_dim, h, w), context_dim)
self.spatial_softmax_blocks.append(net)
self.spatial_softmax_blocks.append(net)
self.spatial_softmax_blocks += self.spatial_softmax_blocks[
0:4][::-1]
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
self.cond_cross_attention = cond_cross_attention
self.use_linear_act_proj = use_linear_act_proj
self.proj_in_action = nn.Sequential(nn.Linear(1, act_proj_dim),
nn.LayerNorm(act_proj_dim))
self.proj_in_horizon = nn.Sequential(nn.Linear(horizon, act_proj_dim),
nn.LayerNorm(act_proj_dim))
self.proj_out_action = nn.Sequential(nn.LayerNorm(act_proj_dim),
nn.Linear(act_proj_dim, 1))
self.proj_out_horizon = nn.Sequential(nn.LayerNorm(act_proj_dim),
nn.Linear(act_proj_dim, horizon))
logger.info("number of parameters: %e",
sum(p.numel() for p in self.parameters()))
self.imagen_cond_gradient = imagen_cond_gradient
self.use_imagen_mid_only = use_imagen_mid_only
self.use_z_only = use_z_only
self.spatial_num_kp = spatial_num_kp
self.last_frame_only = last_frame_only
self.horizon = horizon
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
imagen_cond=None,
cond=None,
**kwargs):
"""
sample: (B,T,input_dim)
timestep: (B,) or int, diffusion step
imagen_cond: a list of hidden info from video gen unet
cond: dict:
image: (B, 3, To, h, w)
agent_pos: (B, Ta, d)
output: (B,T,input_dim)
"""
if not self.imagen_cond_gradient:
imagen_cond = [c.detach() for c in imagen_cond]
cond = {'image': cond[0], 'agent_pos': cond[1]}
cond['image'] = cond['image'].permute(0, 2, 1, 3,
4)
cond['image'] = rearrange(cond['image'], 'b t c h w -> (b t) c h w')
cond['agent_pos'] = rearrange(cond['agent_pos'], 'b t d -> (b t) d')
B, T, D = sample.shape
if self.use_linear_act_proj:
sample = self.proj_in_action(sample.unsqueeze(-1))
global_cond = self.obs_encoder(cond)
global_cond = rearrange(global_cond,
'(b t) d -> b 1 (t d)',
b=B,
t=self.n_obs_steps)
global_cond = repeat(global_cond,
'b c d -> b (repeat c) d',
repeat=T)
else:
sample = einops.rearrange(sample, 'b h t -> b t h')
sample = self.proj_in_horizon(sample)
robo_state_cond = rearrange(robo_state_cond, 'b t d -> b 1 (t d)')
robo_state_cond = repeat(robo_state_cond,
'b c d -> b (repeat c) d',
repeat=2)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps],
dtype=torch.long,
device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
x = sample if not self.use_linear_act_proj else sample.reshape(
B * T, D, -1)
h = []
for idx, modules in enumerate(self.down_modules):
if self.cond_cross_attention:
(resnet, resnet2, crossatten, downsample) = modules
else:
(resnet, resnet2, _, downsample) = modules
# Access the cond from the unet embeds from video unet
if self.use_imagen_mid_only:
imagen_cond = imagen_cond_mid
elif self.use_z_only:
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
else:
imagen_cond = imagen_cond_down[idx]
if self.last_frame_only:
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
imagen_cond = repeat(imagen_cond,
'b t c h w -> b (repeat t) c h w',
repeat=self.horizon)
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
if self.use_imagen_mid_only:
imagen_cond = self.spatial_softmax_blocks[len(
self.spatial_softmax_blocks) // 2](imagen_cond)
elif self.use_z_only:
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
else:
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
h.append(x)
x = downsample(x)
#>>> mide blocks
resnet, resnet2, _ = self.mid_modules
# Access the cond from the unet embeds from video unet
if self.use_z_only:
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
else:
imagen_cond = imagen_cond_mid
if self.last_frame_only:
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
imagen_cond = repeat(imagen_cond,
'b t c h w -> b (repeat t) c h w',
repeat=self.horizon)
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
idx += 1
if self.use_z_only:
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
else:
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
#>>> up blocks
idx += 1
for jdx, modules in enumerate(self.up_modules):
if self.cond_cross_attention:
(resnet, resnet2, crossatten, upsample) = modules
else:
(resnet, resnet2, _, upsample) = modules
# Access the cond from the unet embeds from video unet
if self.use_imagen_mid_only:
imagen_cond = imagen_cond_mid
elif self.use_z_only:
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
else:
imagen_cond = imagen_cond_up[jdx]
if self.last_frame_only:
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
imagen_cond = repeat(imagen_cond,
'b t c h w -> b (repeat t) c h w',
repeat=self.horizon)
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
if self.use_imagen_mid_only:
imagen_cond = self.spatial_softmax_blocks[len(
self.spatial_softmax_blocks) // 2](imagen_cond)
elif self.use_z_only:
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
else:
imagen_cond = self.spatial_softmax_blocks[jdx +
idx](imagen_cond)
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
x = upsample(x)
x = self.final_conv(x)
if self.use_linear_act_proj:
x = x.reshape(B, T, D, -1)
x = self.proj_out_action(x)
x = x.reshape(B, T, D)
else:
x = self.proj_out_horizon(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x

View File

@@ -0,0 +1,52 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels,
out_channels,
kernel_size,
padding=kernel_size // 2),
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
def test():
cb = Conv1dBlock(256, 128, kernel_size=3)
x = torch.zeros((1, 256, 16))
o = cb(x)

View File

@@ -0,0 +1,80 @@
import copy
import torch
from torch.nn.modules.batchnorm import _BatchNorm
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(self,
model,
update_after_step=0,
inv_gamma=1.0,
power=2 / 3,
min_value=0.0,
max_value=0.9999):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.decay = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma)**-self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model):
self.decay = self.get_decay(self.optimization_step)
all_dataptrs = set()
for module, ema_module in zip(new_model.modules(),
self.averaged_model.modules()):
for param, ema_param in zip(module.parameters(recurse=False),
ema_module.parameters(recurse=False)):
# iterative over immediate parameters only.
if isinstance(param, dict):
raise RuntimeError('Dict parameter not supported')
if isinstance(module, _BatchNorm):
# skip batchnorms
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
elif not param.requires_grad:
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.decay)
ema_param.add_(param.data.to(dtype=ema_param.dtype),
alpha=1 - self.decay)
# verify that iterating over module and then parameters is identical to parameters recursively.
# assert old_all_dataptrs == all_dataptrs
self.optimization_step += 1

View File

@@ -0,0 +1,19 @@
import math
import torch
import torch.nn as nn
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

View File

@@ -0,0 +1,322 @@
import torch
import torch.nn as nn
import torchvision.transforms.functional as ttf
import unifolm_wma.models.diffusion_head.common.tensor_util as tu
class CropRandomizer(nn.Module):
"""
Randomly sample crops at input, and then average across crop features at output.
"""
def __init__(
self,
input_shape,
crop_height,
crop_width,
num_crops=1,
pos_enc=False,
):
"""
Args:
input_shape (tuple, list): shape of input (not including batch dimension)
crop_height (int): crop height
crop_width (int): crop width
num_crops (int): number of random crops to take
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
location of the cropped pixels in the source image
"""
super().__init__()
assert len(input_shape) == 3 # (C, H, W)
assert crop_height < input_shape[1]
assert crop_width < input_shape[2]
self.input_shape = input_shape
self.crop_height = crop_height
self.crop_width = crop_width
self.num_crops = num_crops
self.pos_enc = pos_enc
def output_shape_in(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_in operation, where raw inputs (usually observation modalities)
are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
# the number of crops are reshaped into the batch dimension, increasing the batch
# size from B to B * N
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
return [out_c, self.crop_height, self.crop_width]
def output_shape_out(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_out operation, where processed inputs (usually encoded observation
modalities) are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
# and then pools to result in [B, ...], only the batch dimension changes,
# and so the other dimensions retain their shape.
return list(input_shape)
def forward_in(self, inputs):
"""
Samples N random crops for each input in the batch, and then reshapes
inputs to [B * N, ...].
"""
assert len(
inputs.shape) >= 3 # must have at least (C, H, W) dimensions
if self.training:
# generate random crops
out, _ = sample_random_image_crops(
images=inputs,
crop_height=self.crop_height,
crop_width=self.crop_width,
num_crops=self.num_crops,
pos_enc=self.pos_enc,
)
# [B, N, ...] -> [B * N, ...]
return tu.join_dimensions(out, 0, 1)
else:
# take center crop during eval
out = ttf.center_crop(img=inputs,
output_size=(self.crop_height,
self.crop_width))
if self.num_crops > 1:
B, C, H, W = out.shape
out = out.unsqueeze(1).expand(B, self.num_crops, C, H,
W).reshape(-1, C, H, W)
# [B * N, ...]
return out
def forward_out(self, inputs):
"""
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
to result in shape [B, ...] to make sure the network output is consistent with
what would have happened if there were no randomization.
"""
if self.num_crops <= 1:
return inputs
else:
batch_size = (inputs.shape[0] // self.num_crops)
out = tu.reshape_dimensions(inputs,
begin_axis=0,
end_axis=0,
target_dims=(batch_size,
self.num_crops))
return out.mean(dim=1)
def forward(self, inputs):
return self.forward_in(inputs)
def __repr__(self):
"""Pretty print network."""
header = '{}'.format(str(self.__class__.__name__))
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
self.input_shape, self.crop_height, self.crop_width,
self.num_crops)
return msg
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
"""
Crops images at the locations specified by @crop_indices. Crops will be
taken across all channels.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
N is the number of crops to take per image and each entry corresponds
to the pixel height and width of where to take the crop. Note that
the indices can also be of shape [..., 2] if only 1 crop should
be taken per image. Leading dimensions must be consistent with
@images argument. Each index specifies the top left of the crop.
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
H and W are the height and width of @images and CH and CW are
@crop_height and @crop_width.
crop_height (int): height of crop to take
crop_width (int): width of crop to take
Returns:
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
"""
# make sure length of input shapes is consistent
assert crop_indices.shape[-1] == 2
ndim_im_shape = len(images.shape)
ndim_indices_shape = len(crop_indices.shape)
assert (ndim_im_shape == ndim_indices_shape +
1) or (ndim_im_shape == ndim_indices_shape + 2)
# maybe pad so that @crop_indices is shape [..., N, 2]
is_padded = False
if ndim_im_shape == ndim_indices_shape + 2:
crop_indices = crop_indices.unsqueeze(-2)
is_padded = True
# make sure leading dimensions between images and indices are consistent
assert images.shape[:-3] == crop_indices.shape[:-2]
device = images.device
image_c, image_h, image_w = images.shape[-3:]
num_crops = crop_indices.shape[-2]
# make sure @crop_indices are in valid range
assert (crop_indices[..., 0] >= 0).all().item()
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
assert (crop_indices[..., 1] >= 0).all().item()
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
crop_ind_grid_h = torch.arange(crop_height).to(device)
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h,
size=crop_width,
dim=-1)
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
crop_ind_grid_w = torch.arange(crop_width).to(device)
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w,
size=crop_height,
dim=0)
# combine into shape [CH, CW, 2]
crop_in_grid = torch.cat(
(crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
# shape array that tells us which pixels from the corresponding source image to grab.
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [
crop_height, crop_width, 2
]
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(
-2) + crop_in_grid.reshape(grid_reshape)
# For using @torch.gather, convert to flat indices from 2D indices, and also
# repeat across the channel dimension. To get flat index of each pixel to grab for
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[
..., 1] # shape [..., N, CH, CW]
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c,
dim=-3) # shape [..., N, C, CH, CW]
all_crop_inds = tu.flatten(all_crop_inds,
begin_axis=-2) # shape [..., N, C, CH * CW]
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
reshape_axis = len(crops.shape) - 1
crops = tu.reshape_dimensions(crops,
begin_axis=reshape_axis,
end_axis=reshape_axis,
target_dims=(crop_height, crop_width))
if is_padded:
# undo padding -> [..., C, CH, CW]
crops = crops.squeeze(-4)
return crops
def sample_random_image_crops(images,
crop_height,
crop_width,
num_crops,
pos_enc=False):
"""
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
@images.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_height (int): height of crop to take
crop_width (int): width of crop to take
num_crops (n): number of crops to sample
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
encoding of the original source pixel locations. This means that the
output crops will contain information about where in the source image
it was sampled from.
Returns:
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
"""
device = images.device
# maybe add 2 channels of spatial encoding to the source image
source_im = images
if pos_enc:
# spatial encoding [y, x] in [0, 1]
h, w = source_im.shape[-2:]
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
pos_y = pos_y.float().to(device) / float(h)
pos_x = pos_x.float().to(device) / float(w)
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
leading_shape = source_im.shape[:-3]
position_enc = position_enc[(None, ) * len(leading_shape)]
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
# concat across channel dimension with input
source_im = torch.cat((source_im, position_enc), dim=-3)
# make sure sample boundaries ensure crops are fully within the images
image_c, image_h, image_w = source_im.shape[-3:]
max_sample_h = image_h - crop_height
max_sample_w = image_w - crop_width
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
# we will sample [B, N] indices, but this supports having more than one leading dimension,
# or possibly no leading dimension.
#
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
crop_inds_h = (
max_sample_h *
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds_w = (
max_sample_w *
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds = torch.cat(
(crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)),
dim=-1) # shape [..., N, 2]
crops = crop_image_from_indices(
images=source_im,
crop_indices=crop_inds,
crop_height=crop_height,
crop_width=crop_width,
)
return crops, crop_inds

View File

@@ -0,0 +1,30 @@
import torch
import torchvision
def get_resnet(name, weights=None, **kwargs):
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", "r3m"
"""
# load r3m weights
if (weights == "r3m") or (weights == "R3M"):
return get_r3m(name=name, **kwargs)
func = getattr(torchvision.models, name)
resnet = func(weights=weights, **kwargs)
resnet.fc = torch.nn.Identity()
return resnet
def get_r3m(name, **kwargs):
"""
name: resnet18, resnet34, resnet50
"""
import r3m
r3m.device = 'cpu'
model = r3m.load_r3m(name)
r3m_model = model.module
resnet_model = r3m_model.convnet
resnet_model = resnet_model.to('cpu')
return resnet_model

View File

@@ -0,0 +1,247 @@
import copy
import torch
import torch.nn as nn
import torchvision
import json
import os
from unifolm_wma.models.diffusion_head.vision.crop_randomizer import CropRandomizer
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
from unifolm_wma.models.diffusion_head.common.module_attr_mixin import ModuleAttrMixin
from unifolm_wma.models.diffusion_head.common.pytorch_util import dict_apply, replace_submodules
from unifolm_wma.utils.utils import instantiate_from_config
from einops import rearrange, repeat
from typing import Dict, Tuple, Union
from pathlib import Path
class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(
self,
rgb_model_config: Dict,
shape_meta_path: str | None = None,
resize_shape: Union[Tuple[int, int], Dict[str, tuple],
None] = None,
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
random_crop: bool = True,
# replace BatchNorm with GroupNorm
use_group_norm: bool = False,
# use single rgb model for all rgb inputs
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
imagenet_norm: bool = False,
use_spatial_softmax=False,
spatial_softmax_kp=32,
use_dinoSiglip=False):
"""
Assumes rgb input: B,C,H,W
Assumes low_dim input: B,D
"""
super().__init__()
if not shape_meta_path:
shape_meta_path = str(Path(os.getcwd()) / "configs/train/meta.json")
with open(shape_meta_path, 'r') as file:
shape_meta = json.load(file)
rgb_model = instantiate_from_config(rgb_model_config)
rgb_keys = list()
low_dim_keys = list()
key_model_map = nn.ModuleDict()
key_transform_map = nn.ModuleDict()
key_shape_map = dict()
# handle sharing vision backbone
if share_rgb_model:
assert isinstance(rgb_model, nn.Module)
key_model_map['rgb'] = rgb_model
obs_shape_meta = shape_meta['obs']
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
type = attr.get('type', 'low_dim')
key_shape_map[key] = shape
if type == 'rgb':
rgb_keys.append(key)
if not use_dinoSiglip:
# configure model for this key
this_model = None
if not share_rgb_model:
if isinstance(rgb_model, dict):
# have provided model for each key
this_model = rgb_model[key]
else:
assert isinstance(rgb_model, nn.Module)
# have a copy of the rgb model
this_model = copy.deepcopy(rgb_model)
if this_model is not None:
if use_group_norm:
this_model = replace_submodules(
root_module=this_model,
predicate=lambda x: isinstance(
x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16,
num_channels=x.num_features))
key_model_map[key] = this_model
# configure resize
input_shape = shape
this_resizer = nn.Identity()
if resize_shape is not None:
if isinstance(resize_shape, dict):
h, w = resize_shape[key]
else:
h, w = resize_shape
this_resizer = torchvision.transforms.Resize(size=(h,
w))
input_shape = (shape[0], h, w)
# configure randomizer
this_randomizer = nn.Identity()
if crop_shape is not None:
if isinstance(crop_shape, dict):
h, w = crop_shape[key]
else:
h, w = crop_shape
if random_crop:
this_randomizer = CropRandomizer(
input_shape=input_shape,
crop_height=h,
crop_width=w,
num_crops=1,
pos_enc=False)
else:
this_normalizer = torchvision.transforms.CenterCrop(
size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if imagenet_norm:
this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
this_transform = nn.Sequential(this_resizer,
this_randomizer,
this_normalizer)
key_transform_map[key] = this_transform
else:
key_model_map[key] = rgb_model
elif type == 'low_dim':
low_dim_keys.append(key)
else:
raise RuntimeError(f"Unsupported obs type: {type}")
rgb_keys = sorted(rgb_keys)
low_dim_keys = sorted(low_dim_keys)
self.shape_meta = shape_meta
self.key_model_map = key_model_map
self.key_transform_map = key_transform_map
self.share_rgb_model = share_rgb_model
self.rgb_keys = rgb_keys
self.low_dim_keys = low_dim_keys
self.key_shape_map = key_shape_map
self.use_dinoSiglip = use_dinoSiglip
##NOTE add spatial softmax
self.use_spatial_softmax = use_spatial_softmax
if use_spatial_softmax and not use_dinoSiglip:
model = nn.Sequential(
key_model_map['image'].conv1,
key_model_map['image'].bn1,
key_model_map['image'].relu,
key_model_map['image'].maxpool,
key_model_map['image'].layer1,
key_model_map['image'].layer2,
key_model_map['image'].layer3,
key_model_map['image'].layer4,
)
key_model_map['image'] = model
input_shape = self.output_shape(resnet_output_shape=True)
self.spatial_softmax = SpatialSoftmax(input_shape,
num_kp=spatial_softmax_kp)
def forward(self, obs_dict, resnet_output_shape=False):
batch_size = None
features = list()
# process rgb input
if self.share_rgb_model:
# pass all rgb obs to rgb model
imgs = list()
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
imgs.append(img)
# (N*B,C,H,W)
imgs = torch.cat(imgs, dim=0)
# (N*B,D)
feature = self.key_model_map['rgb'](imgs)
# (N,B,D)
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
# (B,N,D)
feature = torch.moveaxis(feature, 0, 1)
# (B,N*D)
feature = feature.reshape(batch_size, -1)
features.append(feature)
else:
# run each rgb obs to independent models
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
if not self.use_dinoSiglip:
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
else:
feature = self.key_model_map[key](img)[:, :1, :]
if resnet_output_shape:
return feature
if not self.use_dinoSiglip and self.use_spatial_softmax:
feature = self.spatial_softmax(feature)
feature = feature.reshape(batch_size, -1)
features.append(feature)
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# concatenate all features
result = torch.cat(features, dim=-1)
return result
@torch.no_grad()
def output_shape(self, resnet_output_shape=False):
example_obs_dict = dict()
obs_shape_meta = self.shape_meta['obs']
batch_size = 1
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
this_obs = torch.zeros((batch_size, ) + shape,
dtype=self.dtype,
device=self.device)
example_obs_dict[key] = this_obs
example_output = self.forward(example_obs_dict,
resnet_output_shape=resnet_output_shape)
output_shape = example_output.shape[1:]
return output_shape

View File

@@ -0,0 +1,473 @@
import numpy as np
import torch
import copy
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
from unifolm_wma.utils.common import noise_like
from unifolm_wma.utils.common import extract_into_tensor
from tqdm import tqdm
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.counter = 0
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self,
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.,
verbose=True):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
.device)
if self.model.use_dynamic_rescale:
self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
self.ddim_scale_arr_prev = torch.cat(
[self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev',
to_torch(self.model.alphas_cumprod_prev))
# Calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# DDIM sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
schedule_verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
precision=None,
fs=None,
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
guidance_rescale=0.0,
**kwargs):
# Check condition bs
if conditioning is not None:
if isinstance(conditioning, dict):
try:
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
except:
cbs = conditioning[list(
conditioning.keys())[0]][0].shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S,
ddim_discretize=timestep_spacing,
ddim_eta=eta,
verbose=schedule_verbose)
# Make shape
if len(shape) == 3:
C, H, W = shape
size = (batch_size, C, H, W)
elif len(shape) == 4:
C, T, H, W = shape
size = (batch_size, C, T, H, W)
samples, actions, states, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
verbose=verbose,
precision=precision,
fs=fs,
guidance_rescale=guidance_rescale,
**kwargs)
return samples, actions, states, intermediates
@torch.no_grad()
def ddim_sampling(self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
verbose=True,
precision=None,
fs=None,
guidance_rescale=0.0,
**kwargs):
device = self.model.betas.device
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
else:
img = x_T
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
if precision is not None:
if precision == 16:
img = img.to(dtype=torch.float16)
action = action.to(dtype=torch.float16)
state = state.to(dtype=torch.float16)
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(
min(timesteps / self.ddim_timesteps.shape[0], 1) *
self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {
'x_inter': [img],
'pred_x0': [img],
'x_inter_action': [action],
'pred_x0_action': [action],
'x_inter_state': [state],
'pred_x0_state': [state],
}
time_range = reversed(range(
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
0]
if verbose:
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
else:
iterator = time_range
clean_cond = kwargs.pop("clean_cond", False)
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
assert x0 is not None
if clean_cond:
img_orig = x0
else:
img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(
img,
action,
state,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask=mask,
x0=x0,
fs=fs,
guidance_rescale=guidance_rescale,
**kwargs)
img, pred_x0, model_output_action, model_output_state = outs
action = dp_ddim_scheduler_action.step(
model_output_action,
step,
action,
generator=None,
).prev_sample
state = dp_ddim_scheduler_state.step(
model_output_state,
step,
state,
generator=None,
).prev_sample
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
intermediates['x_inter_action'].append(action)
intermediates['x_inter_state'].append(state)
return img, action, state, intermediates
@torch.no_grad()
def p_sample_ddim(self,
x,
x_action,
x_state,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
uc_type=None,
conditional_guidance_scale_temporal=None,
mask=None,
x0=None,
guidance_rescale=0.0,
**kwargs):
b, *_, device = *x.shape, x.device
if x.dim() == 5:
is_video = True
else:
is_video = False
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output, model_output_action, model_output_state = self.model.apply_model(
x, x_action, x_state, t, c, **kwargs) # unet denoiser
else:
# do_classifier_free_guidance
if isinstance(c, torch.Tensor) or isinstance(c, dict):
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
x, x_action, x_state, t, c, **kwargs)
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
x, x_action, x_state, t, unconditional_conditioning,
**kwargs)
else:
raise NotImplementedError
model_output = e_t_uncond + unconditional_guidance_scale * (
e_t_cond - e_t_uncond)
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
e_t_cond_action - e_t_uncond_action)
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
e_t_cond_state - e_t_uncond_state)
if guidance_rescale > 0.0:
model_output = rescale_noise_cfg(
model_output, e_t_cond, guidance_rescale=guidance_rescale)
model_output_action = rescale_noise_cfg(
model_output_action,
e_t_cond_action,
guidance_rescale=guidance_rescale)
model_output_state = rescale_noise_cfg(
model_output_state,
e_t_cond_state,
guidance_rescale=guidance_rescale)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
**corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video:
size = (b, 1, 1, 1, 1)
else:
size = (b, 1, 1, 1)
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(size,
sqrt_one_minus_alphas[index],
device=device)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale:
scale_t = torch.full(size,
self.ddim_scale_arr[index],
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0, model_output_action, model_output_state
@torch.no_grad()
def decode(self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
callback=None):
timesteps = np.arange(self.ddpm_num_timesteps
) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0], ),
step,
device=x_latent.device,
dtype=torch.long)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) *
noise)

View File

View File

@@ -0,0 +1,806 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from functools import partial
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
from unifolm_wma.utils.common import (
checkpoint,
exists,
default,
)
from unifolm_wma.utils.basics import zero_module
class RelativePosition(nn.Module):
""" https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(
torch.Tensor(max_relative_position * 2 + 1, num_units))
nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k):
device = self.embeddings_table.device
range_vec_q = torch.arange(length_q, device=device)
range_vec_k = torch.arange(length_k, device=device)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(distance_mat,
-self.max_relative_position,
self.max_relative_position)
final_mat = distance_mat_clipped + self.max_relative_position
final_mat = final_mat.long()
embeddings = self.embeddings_table[final_mat]
return embeddings
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.,
relative_position=False,
temporal_length=None,
video_length=None,
agent_state_context_len=2,
agent_action_context_len=16,
image_cross_attention=False,
image_cross_attention_scale=1.0,
agent_state_cross_attention_scale=1.0,
agent_action_cross_attention_scale=1.0,
cross_attention_scale_learnable=False,
text_context_len=77):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
self.relative_position = relative_position
if self.relative_position:
assert (temporal_length is not None)
self.relative_position_k = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length)
self.relative_position_v = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length)
else:
## only used for spatial attention, while NOT for temporal attention
if XFORMERS_IS_AVAILBLE and temporal_length is None:
self.forward = self.efficient_forward
self.video_length = video_length
self.image_cross_attention = image_cross_attention
self.image_cross_attention_scale = image_cross_attention_scale
self.agent_state_cross_attention_scale = agent_state_cross_attention_scale
self.agent_action_cross_attention_scale = agent_action_cross_attention_scale
self.text_context_len = text_context_len
self.agent_state_context_len = agent_state_context_len
self.agent_action_context_len = agent_action_context_len
self.cross_attention_scale_learnable = cross_attention_scale_learnable
if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_k_as = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_as = nn.Linear(context_dim, inner_dim, bias=False)
self.to_k_aa = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_aa = nn.Linear(context_dim, inner_dim, bias=False)
if cross_attention_scale_learnable:
self.register_parameter('alpha_ctx',
nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_cas',
nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_caa',
nn.Parameter(torch.tensor(0.)))
def forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None
h = self.heads
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:,
self.agent_state_context_len:self.
agent_state_context_len +
self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len +
self.agent_action_context_len:self.
agent_state_context_len +
self.agent_action_context_len +
self.text_context_len, :]
context_image = context[:, self.agent_state_context_len +
self.agent_action_context_len +
self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
else:
if not spatial_self_attn:
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
if self.relative_position:
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
k2 = self.relative_position_k(len_q, len_k)
sim2 = einsum('b t d, t s d -> b t s', q,
k2) * self.scale # TODO check
sim += sim2
del k
if exists(mask):
## feasible for causal attention mask only
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
sim.masked_fill_(~(mask > 0.5), max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v)
if self.relative_position:
v2 = self.relative_position_v(len_q, len_v)
out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
out += out2
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
if k_ip is not None and k_as is not None and k_aa is not None:
## for image cross-attention
k_ip, v_ip = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_ip, v_ip))
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
k_ip) * self.scale
del k_ip
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
## for agent state cross-attention
k_as, v_as = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_as, v_as))
sim_as = torch.einsum('b i d, b j d -> b i j', q,
k_as) * self.scale
del k_as
sim_as = sim_as.softmax(dim=-1)
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
## for agent action cross-attention
k_aa, v_aa = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_aa, v_aa))
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
k_aa) * self.scale
del k_aa
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
if out_ip is not None and out_as is not None and out_aa is not None:
if self.cross_attention_scale_learnable:
out = out + \
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
else:
out = out + \
self.image_cross_attention_scale * out_ip + \
self.agent_state_cross_attention_scale * out_as + \
self.agent_action_cross_attention_scale * out_aa
return self.to_out(out)
def efficient_forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k, v, out = None, None, None
k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
if context.shape[1] == self.text_context_len + self.video_length:
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
k = self.to_k(context)
v = self.to_v(context)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
else:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
q.shape[1],
k_aa.shape[1],
block_size=16).to(k_aa.device)
else:
if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..."
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
if k is not None:
k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(k, v),
)
out = xformers.ops.memory_efficient_attention(q,
k,
v,
attn_bias=None,
op=None)
out = (out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
if k_ip is not None:
# For image cross-attention
k_ip, v_ip = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_ip, v_ip),
)
out_ip = xformers.ops.memory_efficient_attention(q,
k_ip,
v_ip,
attn_bias=None,
op=None)
out_ip = (out_ip.unsqueeze(0).reshape(
b, self.heads, out_ip.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out_ip.shape[1],
self.heads * self.dim_head))
if k_as is not None:
# For agent state cross-attention
k_as, v_as = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_as, v_as),
)
out_as = xformers.ops.memory_efficient_attention(q,
k_as,
v_as,
attn_bias=None,
op=None)
out_as = (out_as.unsqueeze(0).reshape(
b, self.heads, out_as.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out_as.shape[1],
self.heads * self.dim_head))
if k_aa is not None:
# For agent action cross-attention
k_aa, v_aa = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_aa, v_aa),
)
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
attn_mask_aa = attn_mask_aa.to(q.dtype)
out_aa = xformers.ops.memory_efficient_attention(
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
out_aa = (out_aa.unsqueeze(0).reshape(
b, self.heads, out_aa.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out_aa.shape[1],
self.heads * self.dim_head))
if exists(mask):
raise NotImplementedError
out = 0.0 if out is None else out
out_ip = 0.0 if out_ip is None else out_ip
out_as = 0.0 if out_as is None else out_as
out_aa = 0.0 if out_aa is None else out_aa
if self.cross_attention_scale_learnable:
out = out + \
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
else:
out = out + \
self.image_cross_attention_scale * out_ip + \
self.agent_state_cross_attention_scale * out_as + \
self.agent_action_cross_attention_scale * out_aa
return self.to_out(out)
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
num_token = l2 // block_size
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
col_indices = torch.arange(l2)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros_like(mask, dtype=torch.float)
attn_mask[mask] = float('-inf')
return attn_mask
class BasicTransformerBlock(nn.Module):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attention_cls=None,
video_length=None,
agent_state_context_len=2,
agent_action_context_len=16,
image_cross_attention=False,
image_cross_attention_scale=1.0,
cross_attention_scale_learnable=False,
text_context_len=77):
super().__init__()
attn_cls = CrossAttention if attention_cls is None else attention_cls
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
video_length=video_length,
agent_state_context_len=agent_state_context_len,
agent_action_context_len=agent_action_context_len,
image_cross_attention=image_cross_attention,
image_cross_attention_scale=image_cross_attention_scale,
cross_attention_scale_learnable=cross_attention_scale_learnable,
text_context_len=text_context_len)
self.image_cross_attention = image_cross_attention
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None, mask=None, **kwargs):
# implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
input_tuple = (
x,
) # should not be (x), otherwise *input_tuple will decouple x into multiple arguments
if context is not None:
input_tuple = (x, context)
if mask is not None:
forward_mask = partial(self._forward, mask=mask)
return checkpoint(forward_mask, (x, ), self.parameters(),
self.checkpoint)
return checkpoint(self._forward, input_tuple, self.parameters(),
self.checkpoint)
def _forward(self, x, context=None, mask=None):
x = self.attn1(self.norm1(x),
context=context if self.disable_self_attn else None,
mask=mask) + x
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data in spatial axis.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
use_checkpoint=True,
disable_self_attn=False,
use_linear=False,
video_length=None,
agent_state_context_len=2,
agent_action_context_len=16,
image_cross_attention=False,
cross_attention_scale_learnable=False):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
attention_cls = None
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
attention_cls=attention_cls,
video_length=video_length,
agent_state_context_len=agent_state_context_len,
agent_action_context_len=agent_action_context_len,
image_cross_attention=image_cross_attention,
cross_attention_scale_learnable=cross_attention_scale_learnable,
) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None, **kwargs):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, **kwargs)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
class TemporalTransformer(nn.Module):
"""
Transformer block for image-like data in temporal axis.
First, reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
use_checkpoint=True,
use_linear=False,
only_self_att=True,
causal_attention=False,
causal_block_size=1,
relative_position=False,
temporal_length=None):
super().__init__()
self.only_self_att = only_self_att
self.relative_position = relative_position
self.causal_attention = causal_attention
self.causal_block_size = causal_block_size
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.proj_in = nn.Conv1d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
if not use_linear:
self.proj_in = nn.Conv1d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
if relative_position:
assert (temporal_length is not None)
attention_cls = partial(CrossAttention,
relative_position=True,
temporal_length=temporal_length)
else:
attention_cls = partial(CrossAttention,
temporal_length=temporal_length)
if self.causal_attention:
assert (temporal_length is not None)
self.mask = torch.tril(
torch.ones([1, temporal_length, temporal_length]))
if self.only_self_att:
context_dim = None
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
attention_cls=attention_cls,
checkpoint=use_checkpoint)
for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv1d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):
b, c, t, h, w = x.shape
x_in = x
x = self.norm(x)
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
if self.use_linear:
x = self.proj_in(x)
temp_mask = None
if self.causal_attention:
# Slice the from mask map
temp_mask = self.mask[:, :t, :t].to(x.device)
if temp_mask is not None:
mask = temp_mask.to(x.device)
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b * h * w)
else:
mask = None
if self.only_self_att:
# NOTE: if no context is given, cross-attention defaults to self-attention
for i, block in enumerate(self.transformer_blocks):
x = block(x, mask=mask)
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
context = rearrange(context, '(b t) l con -> b t l con',
t=t).contiguous()
for i, block in enumerate(self.transformer_blocks):
# Calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
context_j = repeat(context[j],
't l con -> (t r) l con',
r=(h * w) // t,
t=t).contiguous()
# Note: causal mask will not applied in cross-attention case
x[j] = block(x[j], context=context_j)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
if not self.use_linear:
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
x = self.proj_out(x)
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h,
w=w).contiguous()
return x + x_in
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# Compute attention
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# Attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x + h_

View File

@@ -0,0 +1,630 @@
import torch
import torch.nn as nn
import kornia
import open_clip
import math
from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
from unifolm_wma.utils.common import autocast
from unifolm_wma.utils.utils import count_params
from unifolm_wma.modules.encoders.resampler import reshape_tensor
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes -
1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs, ), device=device) * uc_class
uc = {self.key: uc}
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self,
version="google/t5-v1_1-xxl",
device="cuda",
max_length=77,
freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens,
output_hidden_states=self.layer == "hidden")
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class ClipImageEmbedder(nn.Module):
def __init__(self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=True,
ucg_rate=0.):
super().__init__()
from clip import load as load_clip
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer('std',
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# re-normalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli(
(1. - self.ucg_rate) *
torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
# "pooled",
"last",
"penultimate"
]
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(
text) ## all clip models use 77 as context length
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="pooled",
antialias=True,
ucg_rate=0.):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device('cpu'),
pretrained=version,
)
del model.transformer
self.model = model
# self.mapper = torch.nn.Linear(1280, 1024)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer('std',
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
@autocast
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli(
(1. - self.ucg_rate) *
torch.ones(z.shape[0], device=z.device))[:, None] * z
return z
def encode_with_vision_transformer(self, img):
img = self.preprocess(img)
x = self.model.visual(img)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
freeze=True,
layer="pooled",
antialias=True):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device('cpu'),
pretrained=version,
)
del model.transformer
self.model = model
self.device = device
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer('std',
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def forward(self, image, no_dropout=False):
## image: b c h w
z = self.encode_with_vision_transformer(image)
return z
def encode_with_vision_transformer(self, x):
x = self.preprocess(x)
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.model.visual.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(x.shape[0], x.shape[1],
self.model.visual.grid_size[0],
self.model.visual.patch_size[0],
self.model.visual.grid_size[1],
self.model.visual.patch_size[1])
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(
x.shape[0], self.model.visual.grid_size[0] *
self.model.visual.grid_size[1], -1)
x = self.model.visual.patchnorm_pre_ln(x)
x = self.model.visual.conv1(x)
else:
x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat([
self.model.visual.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.model.visual.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.model.visual.patch_dropout(x)
x = self.model.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.model.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
return x
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version,
device,
max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version,
device,
max_length=t5_max_length)
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
class LinearProjector(nn.Module):
def __init__(self, input_dim: int, output_dim: int) -> None:
super().__init__()
self.projector = nn.Linear(input_dim, output_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class MLPProjector(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int,
mlp_type: str = "gelu-mlp") -> None:
super().__init__()
if mlp_type == "gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(input_dim, output_dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(output_dim, output_dim, bias=True),
)
elif mlp_type == "silu-mlp":
self.projector = nn.Sequential(
nn.Linear(input_dim, output_dim, bias=True),
nn.SiLU(),
nn.Linear(output_dim, output_dim, bias=True),
)
else:
raise ValueError(
f"Projector with `{mlp_type = }` is not supported!")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(
-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
inner_dim = int(dim * mult)
if ffd_type == "gelu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(approximate='tanh'),
nn.Linear(inner_dim, dim, bias=False),
)
elif ffd_type == "silu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.SiLU(),
nn.Linear(inner_dim, dim, bias=False),
)
else:
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
class SATokenProjector(nn.Module):
def __init__(self,
dim=1024,
depth=1,
dim_head=64,
heads=16,
num_queries=16,
output_dim=1024,
ff_mult=4,
chunk_size=None):
super().__init__()
self.num_queries = num_queries
self.chunk_size = chunk_size
if chunk_size is not None:
num_queries = num_queries * chunk_size
self.latents = nn.Parameter(
torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head,
heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)
return latents

View File

@@ -0,0 +1,153 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
import math
import torch
import torch.nn as nn
class ImageProjModel(nn.Module):
"""Projection Model"""
def __init__(self,
cross_attention_dim=1024,
clip_embeddings_dim=1024,
clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = nn.Linear(
clip_embeddings_dim,
self.clip_extra_context_tokens * cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
#embeds = image_embeds
embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(
-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
video_length=None, # using frame-wise version or not
):
super().__init__()
## queries for a single frame / image
self.num_queries = num_queries
self.video_length = video_length
## <num_queries> queries for each frame
if video_length is not None:
num_queries = num_queries * video_length
self.latents = nn.Parameter(
torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head,
heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)
return latents

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,848 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from functools import partial
from abc import abstractmethod
from einops import rearrange
from omegaconf import OmegaConf
from typing import Optional, Sequence, Any, Tuple, Union, List, Dict
from collections.abc import Mapping, Iterable, Callable
from unifolm_wma.utils.diffusion import timestep_embedding
from unifolm_wma.utils.common import checkpoint
from unifolm_wma.utils.basics import (zero_module, conv_nd, linear,
avg_pool_nd, normalization)
from unifolm_wma.modules.attention import SpatialTransformer, TemporalTransformer
from unifolm_wma.utils.utils import instantiate_from_config
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None, batch_size=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb, batch_size=batch_size)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
elif isinstance(layer, TemporalTransformer):
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
x = layer(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
else:
x = layer(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims,
self.channels,
self.out_channels,
3,
padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode='nearest')
else:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
:param use_temporal_conv: if True, use the temporal convolution.
:param use_image_dataset: if True, the temporal parameters will not be optimized.
"""
def __init__(self,
channels,
emb_channels,
dropout,
out_channels=None,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
use_conv=False,
up=False,
down=False,
use_temporal_conv=False,
tempspatial_aware=False):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.use_temporal_conv = use_temporal_conv
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels
if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims,
channels,
self.out_channels,
3,
padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels,
1)
if self.use_temporal_conv:
self.temopral_conv = TemporalConvBlock(
self.out_channels,
self.out_channels,
dropout=0.1,
spatial_aware=tempspatial_aware)
def forward(self, x, emb, batch_size=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
input_tuple = (x, emb)
if batch_size:
forward_batchsize = partial(self._forward, batch_size=batch_size)
return checkpoint(forward_batchsize, input_tuple,
self.parameters(), self.use_checkpoint)
return checkpoint(self._forward, input_tuple, self.parameters(),
self.use_checkpoint)
def _forward(self, x, emb, batch_size=None):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
h = self.skip_connection(x) + h
if self.use_temporal_conv and batch_size:
h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
h = self.temopral_conv(h)
h = rearrange(h, 'b c t h w -> (b t) c h w')
return h
class TemporalConvBlock(nn.Module):
"""
Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
"""
def __init__(self,
in_channels,
out_channels=None,
dropout=0.0,
spatial_aware=False):
super(TemporalConvBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
self.out_channels = out_channels
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_channels), nn.SiLU(),
nn.Conv3d(in_channels,
out_channels,
th_kernel_shape,
padding=th_padding_shape))
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels,
in_channels,
tw_kernel_shape,
padding=tw_padding_shape))
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels,
in_channels,
th_kernel_shape,
padding=th_padding_shape))
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels,
in_channels,
tw_kernel_shape,
padding=tw_padding_shape))
# Zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return identity + x
class WMAModel(nn.Module):
"""
The full World-Model-Action model.
"""
def __init__(self,
in_channels: int,
model_channels: int,
out_channels: int,
num_res_blocks: int,
attention_resolutions: Sequence[int],
dropout: float = 0.0,
channel_mult: Sequence[int] = (1, 2, 4, 8),
conv_resample: bool = True,
dims: int = 2,
context_dim: int | None = None,
use_scale_shift_norm: bool = False,
resblock_updown: bool = False,
num_heads: int = -1,
num_head_channels: int = -1,
transformer_depth: int = 1,
use_linear: bool = False,
use_checkpoint: bool = False,
temporal_conv: bool = False,
tempspatial_aware: bool = False,
temporal_attention: bool = True,
use_relative_position: bool = True,
use_causal_attention: bool = False,
temporal_length: int | None = None,
use_fp16: bool = False,
addition_attention: bool = False,
temporal_selfatt_only: bool = True,
image_cross_attention: bool = False,
cross_attention_scale_learnable: bool = False,
default_fs: int = 4,
fs_condition: bool = False,
n_obs_steps: int = 1,
num_stem_token: int = 1,
unet_head_config: OmegaConf | None = None,
stem_process_config: OmegaConf | None = None,
base_model_gen_only: bool = False):
"""
Initialize the World-Model-Action network.
Args:
in_channels: Number of input channels to the backbone.
model_channels: Base channel width for the UNet/backbone.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks per resolution stage.
attention_resolutions: Resolutions at which to enable attention.
dropout: Dropout probability used inside residual/attention blocks.
channel_mult: Multipliers for channels at each resolution level.
conv_resample: If True, use convolutional resampling for up/down sampling.
dims: Spatial dimensionality of the backbone (1/2/3).
context_dim: Optional context embedding dimension (for cross-attention).
use_scale_shift_norm: Enable scale-shift (FiLM-style) normalization in blocks.
resblock_updown: Use residual blocks for up/down sampling (instead of plain conv).
num_heads: Number of attention heads (if >= 0). If -1, derive from num_head_channels.
num_head_channels: Channels per attention head (if >= 0). If -1, derive from num_heads.
transformer_depth: Number of transformer/attention blocks per stage.
use_linear: Use linear attention variants where applicable.
use_checkpoint: Enable gradient checkpointing in blocks to save memory.
temporal_conv: Include temporal convolution along the time dimension.
tempspatial_aware: If True, use timespace aware blocks.
temporal_attention: Enable temporal self-attention.
use_relative_position: Use relative position encodings in attention.
use_causal_attention: Use causal (uni-directional) attention along time.
temporal_length: Optional maximum temporal length expected by the model.
use_fp16: Enable half-precision layers/normalization where supported.
addition_attention: Add auxiliary attention modules.
temporal_selfatt_only: Restrict attention to temporal-only (no spatial) if True.
image_cross_attention: Enable cross-attention with image embeddings.
cross_attention_scale_learnable: Make cross-attention scaling a learnable parameter.
default_fs: Default frame-stride / fps.
fs_condition: If True, condition on frame-stride/fps features.
n_obs_steps: Number of observed steps used in conditioning heads.
num_stem_token: Number of stem tokens for action tokenization.
unet_head_config: OmegaConf for UNet heads (e.g., action/state heads).
stem_process_config: OmegaConf for stem/preprocessor module.
base_model_gen_only: Perform the generation using the base model with out action and state outputs.
"""
super(WMAModel, self).__init__()
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.temporal_attention = temporal_attention
time_embed_dim = model_channels * 4
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
temporal_self_att_only = True
self.addition_attention = addition_attention
self.temporal_length = temporal_length
self.image_cross_attention = image_cross_attention
self.cross_attention_scale_learnable = cross_attention_scale_learnable
self.default_fs = default_fs
self.fs_condition = fs_condition
self.n_obs_steps = n_obs_steps
self.num_stem_token = num_stem_token
self.base_model_gen_only = base_model_gen_only
# Time embedding blocks
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if fs_condition:
self.fps_embedding = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
# Input Block
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1))
])
if self.addition_attention:
self.init_attn = TimestepEmbedSequential(
TemporalTransformer(model_channels,
n_heads=8,
d_head=num_head_channels,
depth=transformer_depth,
context_dim=context_dim,
use_checkpoint=use_checkpoint,
only_self_att=temporal_selfatt_only,
causal_attention=False,
relative_position=use_relative_position,
temporal_length=temporal_length))
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
agent_state_context_len=self.n_obs_steps,
agent_action_context_len=self.temporal_length *
num_stem_token,
image_cross_attention=self.image_cross_attention,
cross_attention_scale_learnable=self.
cross_attention_scale_learnable,
))
if self.temporal_attention:
layers.append(
TemporalTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length))
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True)
if resblock_updown else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers = [
ResBlock(ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv),
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
agent_state_context_len=self.n_obs_steps,
agent_action_context_len=self.temporal_length * num_stem_token,
image_cross_attention=self.image_cross_attention,
cross_attention_scale_learnable=self.
cross_attention_scale_learnable)
]
if self.temporal_attention:
layers.append(
TemporalTransformer(ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length))
layers.append(
ResBlock(ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv))
# Middle Block
self.middle_block = TimestepEmbedSequential(*layers)
# Output Block
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(ch + ich,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
agent_state_context_len=self.n_obs_steps,
image_cross_attention=self.image_cross_attention,
cross_attention_scale_learnable=self.
cross_attention_scale_learnable))
if self.temporal_attention:
layers.append(
TemporalTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length))
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True)
if resblock_updown else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
# Action and state prediction unet
unet_head_config['params']['context_dims'] = [
mult * model_channels for mult in channel_mult
]
self.action_unet = instantiate_from_config(unet_head_config)
self.state_unet = instantiate_from_config(unet_head_config)
# Initialize action token_projector
self.action_token_projector = instantiate_from_config(
stem_process_config)
def forward(self,
x: Tensor,
x_action: Tensor,
x_state: Tensor,
timesteps: Tensor,
context: Tensor | None = None,
context_action: Tensor | None = None,
features_adapter: Any = None,
fs: Tensor | None = None,
**kwargs) -> Tensor | tuple[Tensor, ...]:
"""
Forward pass of the World-Model-Action backbone.
Args:
x: Input tensor (latent video), shape (B, C,...).
x_action: action stream input.
x_state: state stream input.
timesteps: Diffusion timesteps, shape (B,) or scalar Tensor.
context: conditioning context for cross-attention.
context_action: conditioning context specific to action/state (implementation-specific).
features_adapter: module or dict to adapt intermediate features.
fs: frame-stride / fps conditioning.
Returns:
Tuple of Tensors for predictions:
"""
b, _, t, _, _ = x.shape
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb)
bt, l_context, _ = context.shape
if self.base_model_gen_only:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
else:
if l_context == self.n_obs_steps + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
77, :]
context_img = context[:, self.n_obs_steps + 77:, :]
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context = torch.cat(
[context_agent_state, context_text, context_img], dim=1)
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_agent_action = context[:, self.
n_obs_steps:self.n_obs_steps +
16, :]
context_agent_action = rearrange(
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
context_agent_action = self.action_token_projector(
context_agent_action)
context_agent_action = rearrange(context_agent_action,
'(b o) l d -> b o l d',
o=t)
context_agent_action = rearrange(context_agent_action,
'b o (t l) d -> b o t l d',
t=t)
context_agent_action = context_agent_action.permute(
0, 2, 1, 3, 4)
context_agent_action = rearrange(context_agent_action,
'b t o l d -> (b t) (o l) d')
context_text = context[:, self.n_obs_steps +
16:self.n_obs_steps + 16 + 77, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context = torch.cat([
context_agent_state, context_agent_action, context_text,
context_img
],
dim=1)
emb = emb.repeat_interleave(repeats=t, dim=0)
x = rearrange(x, 'b c t h w -> (b t) c h w')
# Combine emb
if self.fs_condition:
if fs is None:
fs = torch.tensor([self.default_fs] * b,
dtype=torch.long,
device=x.device)
fs_emb = timestep_embedding(fs,
self.model_channels,
repeat_only=False).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + fs_embed
h = x.type(self.dtype)
adapter_idx = 0
hs = []
hs_a = []
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b)
if id == 0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b)
# plug-in adapter features
if ((id + 1) % 3 == 0) and features_adapter is not None:
h = h + features_adapter[adapter_idx]
adapter_idx += 1
if id != 0:
if isinstance(module[0], Downsample):
hs_a.append(
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
hs.append(h)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
if features_adapter is not None:
assert len(
features_adapter) == adapter_idx, 'Wrong features_adapter'
h = self.middle_block(h, emb, context=context, batch_size=b)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
hs_out = []
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b)
if isinstance(module[-1], Upsample):
hs_a.append(
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
hs_out.append(h)
h = h.type(x.dtype)
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
y = self.out(h)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
if not self.base_model_gen_only:
ba, _, _ = x_action.shape
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
# Predict state
if b > 1:
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
else:
s_y = self.state_unet(x_state, timesteps, hs_a,
context_action[:2], **kwargs)
else:
a_y = torch.zeros_like(x_action)
s_y = torch.zeros_like(x_state)
return y, a_y, s_y

View File

@@ -0,0 +1,244 @@
"""
base_vision.py
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
functions, and initialization logic.
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
Transformer model for feature extraction.
"""
import timm
import torch
import torch.nn as nn
import torchvision.transforms.functional as TVF
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
from PIL.Image import Image
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
from torchvision.transforms import Compose, Resize
# === Utility Functions for Monkey-Patching ===
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
result = fn(*args, **kwargs)
return result[0] if isinstance(result, tuple) else result
return wrapper
# === Interface for an Image Transform ===
class ImageTransform(Protocol):
def __call__(
self, img: Image,
**kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
...
# === Custom Torchvision Image Transforms ===
@dataclass
class LetterboxPad:
padding_fill_value: Tuple[int, int, int]
def __call__(self, image: Image) -> Image:
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
(w, h), max_wh = image.size, max(image.size)
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int(
(max_wh - h) / 2)
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
return TVF.pad(image,
padding,
fill=self.padding_fill_value,
padding_mode="constant")
# === Abstract Base Class for arbitrary Vision Backbones ===
class VisionBackbone(nn.Module, ABC):
def __init__(self,
vision_backbone_id: str,
image_resize_strategy: str,
default_image_size: int = 224) -> None:
super().__init__()
self.identifier: str = vision_backbone_id
self.image_resize_strategy: str = image_resize_strategy
self.default_image_size: int = default_image_size
# Instance attributes for a Vision Backbone
self.featurizer: nn.Module = None
self.image_transform: ImageTransform = None
def get_image_transform(self) -> ImageTransform:
return self.image_transform
@abstractmethod
def get_fsdp_wrapping_policy(self) -> Callable:
...
@abstractmethod
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
raise NotImplementedError
@property
@abstractmethod
def default_image_resolution(self) -> Tuple[int, int, int]:
...
@property
@abstractmethod
def embed_dim(self) -> int:
...
@property
@abstractmethod
def num_patches(self) -> int:
...
@property
@abstractmethod
def half_precision_dtype(self) -> torch.dtype:
...
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
class TimmViTBackbone(VisionBackbone, ABC):
def __init__(
self,
vision_backbone_id: str,
timm_path_or_url: str,
image_resize_strategy: str,
default_image_size: int = 224,
override_act_layer: Optional[str] = None,
) -> None:
super().__init__(vision_backbone_id,
image_resize_strategy,
default_image_size=default_image_size)
self.timm_path_or_url = timm_path_or_url
self.override_act_layer = override_act_layer
self.dtype = torch.bfloat16
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
if self.override_act_layer is None:
self.featurizer: VisionTransformer = timm.create_model(
self.timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
else:
self.featurizer: VisionTransformer = timm.create_model(
self.timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size,
act_layer=self.override_act_layer,
)
self.featurizer.eval()
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
self.featurizer.forward = unpack_tuple(
partial(self.featurizer.get_intermediate_layers,
n={len(self.featurizer.blocks) - 2}))
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
assert isinstance(self.featurizer, VisionTransformer), (
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
"file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
)
# Get Config =>> Note :: Override default image size to ensure correct image transform
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
self.data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
default_image_transform = timm.data.create_transform(**self.data_cfg,
is_training=False)
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(default_image_transform.transforms[0], Resize)
default_image_transform = Compose([
Resize(self.default_image_size,
interpolation=default_image_transform.transforms[0].
interpolation),
*default_image_transform.transforms[1:],
])
# Switch on `image_resize_strategy`
if self.image_resize_strategy == "resize-naive":
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(default_image_transform.transforms[0], Resize)
target_size = (self.default_image_size, self.default_image_size)
self.image_transform = Compose([
Resize(target_size,
interpolation=default_image_transform.transforms[0].
interpolation),
*default_image_transform.transforms[1:],
])
elif self.image_resize_strategy == "resize-crop":
self.image_transform = default_image_transform
elif self.image_resize_strategy == "letterbox":
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
# Compute Padding Fill Value (rescaled normalization mean if applicable)
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
# Build New Transform
self.image_transform = Compose(
[LetterboxPad(fill), *default_image_transform.transforms])
else:
raise ValueError(
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
)
def get_fsdp_wrapping_policy(self) -> Callable:
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
vit_wrap_policy = partial(_module_wrap_policy,
module_classes={VisionTransformer})
transformer_block_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={Block})
return partial(_or_policy,
policies=[vit_wrap_policy, transformer_block_policy])
def forward(
self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> torch.Tensor:
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
return self.featurizer(pixel_values)
@property
def default_image_resolution(self) -> Tuple[int, int, int]:
return self.data_cfg["input_size"]
@property
def embed_dim(self) -> int:
return self.featurizer.embed_dim
@property
def num_patches(self) -> int:
return self.featurizer.patch_embed.num_patches
@property
def half_precision_dtype(self) -> torch.dtype:
return self.dtype

View File

@@ -0,0 +1,273 @@
"""
dinosiglip_vit.py
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
"""
import timm
import torch
import torchvision.transforms as transforms
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Tuple
from PIL import Image
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
from torchvision.transforms import Compose, Resize, Normalize
from unifolm_wma.modules.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
from unifolm_wma.utils.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
DINOSigLIP_VISION_BACKBONES = {
"dinosiglip-vit-so-224px": {
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
"siglip": "vit_so400m_patch14_siglip_224",
},
"dinosiglip-vit-so-384px": {
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
"siglip": "vit_so400m_patch14_siglip_384",
},
}
@dataclass
class DinoSigLIPImageTransform:
dino_image_transform: ImageTransform
siglip_image_transform: ImageTransform
is_prismatic: bool = True
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
return {
"dino": self.dino_image_transform(img, **kwargs),
"siglip": self.siglip_image_transform(img, **kwargs)
}
class DinoSigLIPViTBackbone(VisionBackbone):
def __init__(self,
vision_backbone_id: str,
image_resize_strategy: str,
arch_specifier: str,
output_dim: int,
pretrained_checkpoint=None,
freeze=True,
default_image_size: int = 224) -> None:
super().__init__(vision_backbone_id,
image_resize_strategy,
default_image_size=default_image_size)
self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
vision_backbone_id]["dino"]
self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
vision_backbone_id]["siglip"]
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
self.dino_featurizer: VisionTransformer = timm.create_model(
self.dino_timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
if pretrained_checkpoint:
ckpt = pretrained_checkpoint + '/openvla_dino.pt'
self.dino_featurizer.load_state_dict(
torch.load(ckpt, weights_only=True))
print('>>> load dino weights')
if freeze:
self.dino_featurizer.eval()
for param in self.dino_featurizer.parameters():
param.requires_grad = False
self.siglip_featurizer: VisionTransformer = timm.create_model(
self.siglip_timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
if pretrained_checkpoint:
ckpt = pretrained_checkpoint + '/openvla_siglip.pt'
self.siglip_featurizer.load_state_dict(
torch.load(ckpt, weights_only=True))
print('>>> load siglip weights')
if freeze:
self.siglip_featurizer.eval()
for param in self.siglip_featurizer.parameters():
param.requires_grad = False
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
self.dino_featurizer.forward = unpack_tuple(
partial(self.dino_featurizer.get_intermediate_layers,
n={len(self.dino_featurizer.blocks) - 2}))
self.siglip_featurizer.forward = unpack_tuple(
partial(self.siglip_featurizer.get_intermediate_layers,
n={len(self.siglip_featurizer.blocks) - 2}))
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
self.dino_data_cfg = timm.data.resolve_model_data_config(
self.dino_featurizer)
self.dino_data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
self.siglip_data_cfg = timm.data.resolve_model_data_config(
self.siglip_featurizer)
self.siglip_data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
# Initialize *both* Transforms
self.default_dino_transform = timm.data.create_transform(
**self.dino_data_cfg, is_training=False)
self.default_siglip_transform = timm.data.create_transform(
**self.siglip_data_cfg, is_training=False)
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
assert isinstance(self.default_siglip_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(self.default_siglip_transform.transforms[0], Resize)
self.default_siglip_transform = Compose([
Resize(self.default_image_size,
interpolation=self.default_siglip_transform.transforms[0].
interpolation),
*self.default_siglip_transform.transforms[1:],
])
if self.image_resize_strategy == "resize-naive":
assert isinstance(
self.default_dino_transform,
Compose), "Unexpected `default_dino_image_transform`!"
assert isinstance(
self.default_siglip_transform,
Compose), "Unexpected `default_siglip_image_transform`!"
assert isinstance(self.default_dino_transform.transforms[0],
Resize)
assert isinstance(self.default_siglip_transform.transforms[0],
Resize)
self.target_size = (self.default_image_size,
self.default_image_size)
dino_transform = Compose([
Resize(self.target_size,
interpolation=self.default_dino_transform.transforms[0].
interpolation),
*self.default_dino_transform.transforms[1:],
])
siglip_transform = Compose([
Resize(self.target_size,
interpolation=self.default_siglip_transform.
transforms[0].interpolation),
*self.default_siglip_transform.transforms[1:],
])
self.image_transform = DinoSigLIPImageTransform(
dino_transform, siglip_transform)
elif self.image_resize_strategy == "resize-crop":
self.image_transform = DinoSigLIPImageTransform(
self.default_dino_transform, self.default_siglip_transform)
elif self.image_resize_strategy == "letterbox":
assert isinstance(self.default_dino_transform,
Compose), "Unexpected `default_dino_transform`!"
assert isinstance(
self.default_siglip_transform,
Compose), "Unexpected `default_siglip_transform`!"
assert ("mean" in self.dino_data_cfg
and "mean" in self.siglip_data_cfg
), "DinoSigLIP `data_cfg` missing `mean`!"
# Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
dino_fill = tuple(
[int(x * 255) for x in self.dino_data_cfg["mean"]])
siglip_fill = tuple(
[int(x * 255) for x in self.siglip_data_cfg["mean"]])
# Build New Transform
self.image_transform = DinoSigLIPImageTransform(
Compose([
LetterboxPad(dino_fill),
*self.default_dino_transform.transforms
]),
Compose([
LetterboxPad(siglip_fill),
*self.default_siglip_transform.transforms
]),
)
else:
raise ValueError(
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
)
self.arch_specifier = arch_specifier
if arch_specifier == "linear":
self.projector = LinearProjector(self.embed_dim, output_dim)
elif arch_specifier.endswith("fused-gelu-mlp"):
self.projector = FusedMLPProjector(self.embed_dim, output_dim)
elif arch_specifier.endswith("gelu-mlp"):
self.projector = MLPProjector(self.embed_dim, output_dim)
else:
raise ValueError(
f"PrismaticVLM with `{arch_specifier = }` is not supported!")
self.on_gpu = False
def get_fsdp_wrapping_policy(self) -> Callable:
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
vit_wrap_policy = partial(_module_wrap_policy,
module_classes={VisionTransformer})
transformer_block_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={Block})
return partial(_or_policy,
policies=[vit_wrap_policy, transformer_block_policy])
def forward(self, img) -> torch.Tensor:
img = torch.clamp(img.float(), -1., 1.)
img = (img + 1.0) / 2.0
img = img * 255
resize = transforms.Resize(min(self.target_size),
interpolation=self.default_dino_transform.
transforms[0].interpolation,
max_size=None,
antialias=True)
center_crop = transforms.CenterCrop(self.target_size)
img = center_crop(resize(img))
dino_normalizer = Normalize(mean=torch.tensor([0.4850, 0.4560,
0.4060]),
std=torch.tensor([0.2290, 0.2240, 0.2250]))
siglip_normalizer = Normalize(
mean=torch.tensor([0.5000, 0.5000, 0.5000]),
std=torch.tensor([0.5000, 0.5000, 0.5000]))
pixel_values = {
'dino': dino_normalizer(img),
'siglip': siglip_normalizer(img)
}
if self.on_gpu:
pixel_values = {k: v.cuda() for k, v in pixel_values.items()}
elif next(self.dino_featurizer.parameters()).device.type != 'cpu':
self.on_gpu = True
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
dino_patches = self.dino_featurizer(pixel_values["dino"])
siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
return self.projector(torch.cat([dino_patches, siglip_patches], dim=2))
@property
def default_image_resolution(self) -> Tuple[int, int, int]:
return self.dino_data_cfg["input_size"]
@property
def embed_dim(self) -> int:
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
@property
def num_patches(self) -> int:
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
return self.dino_featurizer.patch_embed.num_patches
@property
def half_precision_dtype(self) -> torch.dtype:
return torch.bfloat16

View File

@@ -0,0 +1,104 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import torch.nn as nn
from unifolm_wma.utils.utils import instantiate_from_config
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def nonlinearity(type='silu'):
if type == 'silu':
return nn.SiLU()
elif type == 'leaky_relu':
return nn.LeakyReLU()
class GroupNormSpecific(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def normalization(channels, num_groups=32):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNormSpecific(num_groups, channels)
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}

View File

@@ -0,0 +1,226 @@
import os
import time
import logging
import json
mainlogger = logging.getLogger('mainlogger')
import torch
import torchvision
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities import rank_zero_info
from unifolm_wma.utils.save_video import log_local, prepare_to_log
STAT_DIR = '~/'
class ImageLogger(Callback):
def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=True, save_dir=None, \
to_local=False, log_images_kwargs=None):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.to_local = to_local
self.clamp = clamp
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.save_stat_dir = os.path.join(save_dir, "stat")
os.makedirs(self.save_stat_dir, exist_ok=True)
self.fps_stat = {}
self.fs_stat = {}
if self.to_local:
self.save_dir = os.path.join(save_dir, "images")
os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True)
os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True)
self.count_data = 0
def log_to_tensorboard(self,
pl_module,
batch_logs,
filename,
split,
save_fps=8):
""" log images and videos to tensorboard """
global_step = pl_module.global_step
for key in batch_logs:
value = batch_logs[key]
tag = "gs%d-%s/%s||%s||%s||%s" % (
global_step, split, key,
batch_logs['condition'][0].split('_')[0],
batch_logs['condition'][0].split('_')[1],
batch_logs['video_idx'])
if isinstance(value, list) and isinstance(value[0], str):
captions = ' |------| '.join(value)
pl_module.logger.experiment.add_text(tag,
captions,
global_step=global_step)
elif isinstance(value, torch.Tensor) and value.dim() == 5:
video = value
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet,
nrow=int(n),
padding=0)
for framesheet in video
]
grid = torch.stack(
frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
pl_module.logger.experiment.add_video(tag,
grid,
fps=save_fps,
global_step=global_step)
elif isinstance(value, torch.Tensor) and value.dim() == 4:
img = value
grid = torchvision.utils.make_grid(img, nrow=int(n), padding=0)
grid = (grid + 1.0) / 2.0
pl_module.logger.experiment.add_image(tag,
grid,
global_step=global_step)
elif isinstance(value, torch.Tensor) and value.dim() == 3:
b, _, _ = value.shape
value1 = value[:b // 2, ...]
value2 = value[b // 2:, ...]
_, num_points, d = value1.shape
for i in range(d):
data1 = value1[0, :, i].cpu().detach().numpy()
data2 = value2[0, :, i].cpu().detach().numpy()
fig, ax = plt.subplots()
ax.plot(data1, label='Target 1')
ax.plot(data2, label='Sample 1')
ax.set_title(f'Comparison at dimension {i} for {key}')
ax.legend()
pl_module.logger.experiment.add_figure(
tag + f"| {key}_dim_{i}", fig, global_step=global_step)
plt.close(fig)
else:
pass
@rank_zero_only
def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"):
""" generate images, then save and log to tensorboard """
# Update fps and fs statistics
batch_fps = batch['fps'].tolist()
batch_fs = batch['frame_stride'].tolist()
for num in batch_fps:
self.fps_stat[num] = self.fps_stat.get(num, 0) + 1
for num in batch_fs:
self.fs_stat[num] = self.fs_stat.get(num, 0) + 1
skip_freq = self.batch_freq if split == "train" else 5
## NOTE HAND CODE
self.count_data += 12.5 * 2
if self.count_data >= skip_freq:
self.count_data = 0
is_train = pl_module.training
if is_train:
pl_module.eval()
torch.cuda.empty_cache()
with torch.no_grad():
log_func = pl_module.log_images
batch_logs = log_func(batch,
split=split,
**self.log_images_kwargs)
# Log fps and fs statistics
with open(self.save_stat_dir + '/fps_fs_stat.json',
'w') as file:
json.dump({
'fps': self.fps_stat,
'fs': self.fs_stat
},
file,
indent=4)
batch_logs = prepare_to_log(batch_logs, self.max_images,
self.clamp)
torch.cuda.empty_cache()
filename = "ep{}_idx{}_rank{}".format(pl_module.current_epoch,
batch_idx,
pl_module.global_rank)
if self.to_local:
mainlogger.info("Log [%s] batch <%s> to local ..." %
(split, filename))
filename = "gs{}_".format(pl_module.global_step) + filename
log_local(batch_logs,
os.path.join(self.save_dir, split),
filename,
save_fps=10)
else:
mainlogger.info("Log [%s] batch <%s> to tensorboard ..." %
(split, filename))
self.log_to_tensorboard(pl_module,
batch_logs,
filename,
split,
save_fps=10)
mainlogger.info('Finish!')
if is_train:
pl_module.train()
def on_train_batch_end(self,
trainer,
pl_module,
outputs,
batch,
batch_idx,
dataloader_idx=None):
if self.batch_freq != -1 and pl_module.logdir:
self.log_batch_imgs(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(self,
trainer,
pl_module,
outputs,
batch,
batch_idx,
dataloader_idx=None):
#Different with validation_step() that saving the whole validation set and only keep the latest,
#It records the performance of every validation (without overwritten) by only keep a subset
if self.batch_freq != -1 and pl_module.logdir:
self.log_batch_imgs(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, 'calibrate_grad_norm'):
if (pl_module.calibrate_grad_norm
and batch_idx % 25 == 0) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# See https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
# Lightning update
if int((pl.__version__).split('.')[1]) >= 7:
gpu_index = trainer.strategy.root_device.index
else:
gpu_index = trainer.root_gpu
torch.cuda.reset_peak_memory_stats(gpu_index)
torch.cuda.synchronize(gpu_index)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module):
if int((pl.__version__).split('.')[1]) >= 7:
gpu_index = trainer.strategy.root_device.index
else:
gpu_index = trainer.root_gpu
torch.cuda.synchronize(gpu_index)
max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2**20
epoch_time = time.time() - self.start_time
try:
max_memory = trainer.training_type_plugin.reduce(max_memory)
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass

View File

@@ -0,0 +1,111 @@
import math
from inspect import isfunction
import torch
from torch import Tensor, nn
import torch.distributed as dist
def gather_data(data, return_np=True):
''' gather data from multiple processes to one list '''
data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
dist.all_gather(data_list, data) # gather not supported with NCCL
if return_np:
data_list = [data.cpu().numpy() for data in data_list]
return data_list
def autocast(f):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=True,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled()):
return f(*args, **kwargs)
return do_autocast
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1, ) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def exists(val):
return val is not None
def identity(*args, **kwargs):
return nn.Identity()
def uniq(arr):
return {el: True for el in arr}.keys()
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def shape_to_str(x):
shape_str = "x".join([str(x) for x in x.shape])
return shape_str
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
ckpt = torch.utils.checkpoint.checkpoint
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
return ckpt(func, *inputs, use_reentrant=False)
else:
return func(*inputs)

View File

@@ -0,0 +1,242 @@
import os, sys
import numpy as np
import torch
import pytorch_lightning as pl
from functools import partial
from torch.utils.data import (DataLoader, Dataset, ConcatDataset,
WeightedRandomSampler)
from unifolm_wma.data.base import Txt2ImgIterableBaseDataset
from unifolm_wma.utils.utils import instantiate_from_config
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, Txt2ImgIterableBaseDataset):
split_size = dataset.num_records // worker_info.num_workers
# Reset num_records to the true number to retain reliable length information
dataset.sample_ids = dataset.valid_ids[worker_id *
split_size:(worker_id + 1) *
split_size]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
else:
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=True,
train_img=None,
dataset_and_weights=None):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(self._val_dataloader,
shuffle=shuffle_val_dataloader)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(self._test_dataloader,
shuffle=shuffle_test_loader)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.img_loader = None
self.wrap = wrap
self.collate_fn = None
self.dataset_weights = dataset_and_weights
assert round(sum(self.dataset_weights.values()),
2) == 1.0, "The sum of dataset weights != 1.0"
def prepare_data(self):
pass
def setup(self, stage=None):
if 'train' in self.dataset_configs:
self.train_datasets = dict()
for dataname in self.dataset_weights:
data_dir = self.dataset_configs['train']['params']['data_dir']
transition_dir = '/'.join([data_dir, 'transitions'])
csv_file = f'{dataname}.csv'
meta_path = '/'.join([data_dir, csv_file])
self.dataset_configs['train']['params'][
'meta_path'] = meta_path
self.dataset_configs['train']['params'][
'transition_dir'] = transition_dir
self.dataset_configs['train']['params'][
'dataset_name'] = dataname
self.train_datasets[dataname] = instantiate_from_config(
self.dataset_configs['train'])
# Setup validation dataset
if 'validation' in self.dataset_configs:
self.val_datasets = dict()
for dataname in self.dataset_weights:
data_dir = self.dataset_configs['validation']['params'][
'data_dir']
transition_dir = '/'.join([data_dir, 'transitions'])
csv_file = f'{dataname}.csv'
meta_path = '/'.join([data_dir, csv_file])
self.dataset_configs['validation']['params'][
'meta_path'] = meta_path
self.dataset_configs['validation']['params'][
'transition_dir'] = transition_dir
self.dataset_configs['validation']['params'][
'dataset_name'] = dataname
self.val_datasets[dataname] = instantiate_from_config(
self.dataset_configs['validation'])
# Setup test dataset
if 'test' in self.dataset_configs:
self.test_datasets = dict()
for dataname in self.dataset_weights:
data_dir = self.dataset_configs['test']['params']['data_dir']
transition_dir = '/'.join([data_dir, 'transitions'])
csv_file = f'{dataname}.csv'
meta_path = '/'.join([data_dir, csv_file])
self.dataset_configs['test']['params']['meta_path'] = meta_path
self.dataset_configs['test']['params'][
'transition_dir'] = transition_dir
self.dataset_configs['test']['params'][
'dataset_name'] = dataname
self.test_datasets[dataname] = instantiate_from_config(
self.dataset_configs['test'])
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
is_iterable_dataset = False # NOTE Hand Code
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
combined_dataset = []
sample_weights = []
for dataname, dataset in self.train_datasets.items():
combined_dataset.append(dataset)
sample_weights.append(
torch.full((len(dataset), ),
self.dataset_weights[dataname] / len(dataset)))
combined_dataset = ConcatDataset(combined_dataset)
sample_weights = torch.cat(sample_weights)
sampler = WeightedRandomSampler(sample_weights,
num_samples=len(combined_dataset),
replacement=True)
loader = DataLoader(combined_dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
drop_last=True
)
return loader
def _val_dataloader(self, shuffle=False):
is_iterable_dataset = False # NOTE Hand Code
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
combined_dataset = []
sample_weights = []
for dataname, dataset in self.val_datasets.items():
combined_dataset.append(dataset)
sample_weights.append(
torch.full((len(dataset), ),
self.dataset_weights[dataname] / len(dataset)))
combined_dataset = ConcatDataset(combined_dataset)
sample_weights = torch.cat(sample_weights)
sampler = WeightedRandomSampler(sample_weights,
num_samples=len(combined_dataset),
replacement=True)
loader = DataLoader(combined_dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn)
return loader
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = False # NOTE Hand Code
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
combined_dataset = []
sample_weights = []
for dataname, dataset in self.test_datasets.items():
combined_dataset.append(dataset)
sample_weights.append(
torch.full((len(dataset), ),
self.dataset_weights[dataname] / len(dataset)))
combined_dataset = ConcatDataset(combined_dataset)
sample_weights = torch.cat(sample_weights)
sampler = WeightedRandomSampler(sample_weights,
num_samples=len(combined_dataset),
replacement=True)
loader = DataLoader(combined_dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn)
return loader
def _predict_dataloader(self, shuffle=False):
if isinstance(self.datasets['predict'],
Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
)
def __len__(self):
count = 0
for _, values in self.train_datasets.items():
count += len(values)
return count

View File

@@ -0,0 +1,191 @@
import math
import numpy as np
import torch
import torch.nn.functional as F
from einops import repeat
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) *
torch.arange(start=0, end=half, dtype=torch.float32) /
half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
def make_beta_schedule(schedule,
n_timestep,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
if schedule == "linear":
betas = (torch.linspace(linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64)**2)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep +
cosine_s)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start,
linear_end,
n_timestep,
dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start,
linear_end,
n_timestep,
dtype=torch.float64)**0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method,
num_ddim_timesteps,
num_ddpm_timesteps,
verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
steps_out = ddim_timesteps + 1
elif ddim_discr_method == 'uniform_trailing':
c = num_ddpm_timesteps / num_ddim_timesteps
ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0,
-c))).astype(np.int64)
steps_out = ddim_timesteps - 1
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
num_ddim_timesteps))**2).astype(int)
steps_out = ddim_timesteps + 1
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
# steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
def make_ddim_sampling_parameters(alphacums,
ddim_timesteps,
eta,
verbose=True):
# select alphas for computing the variance schedule
# print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}')
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] +
alphacums[ddim_timesteps[:-1]].tolist())
# according the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
)
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`numpy.ndarray`):
the betas that the scheduler is being initialized with.
Returns:
`numpy.ndarray`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_bar_sqrt = np.sqrt(alphas_cumprod)
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 -
alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = np.concatenate([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)),
keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# Rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# Mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (
1 - guidance_rescale) * noise_cfg
return noise_cfg

View File

@@ -0,0 +1,94 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self, noise=None):
if noise is None:
noise = torch.randn(self.mean.shape)
x = self.mean + self.std * noise.to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var +
self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar +
torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
((mean1 - mean2)**2) * torch.exp(-logvar2))

View File

@@ -0,0 +1,84 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
#Remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,
(1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(
m_param[key])
shadow_params[sname].sub_(
one_minus_decay *
(shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

View File

@@ -0,0 +1,66 @@
"""
nn_utils.py
Utility functions and PyTorch submodule definitions.
"""
import torch
import torch.nn as nn
# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
class LinearProjector(nn.Module):
def __init__(self, vision_dim: int, llm_dim: int) -> None:
super().__init__()
self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
return self.projector(img_patches)
class MLPProjector(nn.Module):
def __init__(self,
vision_dim: int,
llm_dim: int,
mlp_type: str = "gelu-mlp") -> None:
super().__init__()
if mlp_type == "gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(vision_dim, llm_dim, bias=True),
nn.GELU(),
nn.Linear(llm_dim, llm_dim, bias=True),
)
else:
raise ValueError(
f"Projector with `{mlp_type = }` is not supported!")
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
return self.projector(img_patches)
class FusedMLPProjector(nn.Module):
def __init__(self,
fused_vision_dim: int,
llm_dim: int,
mlp_type: str = "fused-gelu-mlp") -> None:
super().__init__()
self.initial_projection_dim = fused_vision_dim * 4
if mlp_type == "fused-gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(fused_vision_dim,
self.initial_projection_dim,
bias=True),
nn.GELU(),
nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
nn.GELU(),
nn.Linear(llm_dim, llm_dim, bias=True),
)
else:
raise ValueError(
f"Fused Projector with `{mlp_type = }` is not supported!")
def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
return self.projector(fused_img_patches)

View File

@@ -0,0 +1,147 @@
import torch
import torch.nn as nn
class LinearProjector(nn.Module):
def __init__(self, input_dim: int, output_dim: int) -> None:
super().__init__()
self.projector = nn.Linear(input_dim, output_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class MLPProjector(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
mlp_type: str = "gelu-mlp") -> None:
super().__init__()
if mlp_type == "gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(vision_dim, llm_dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(llm_dim, llm_dim, bias=True),
)
elif mlp_type == "silu-mlp":
self.projector = nn.Sequential(
nn.Linear(vision_dim, llm_dim, bias=True),
nn.SiLU(),
nn.Linear(llm_dim, llm_dim, bias=True),
)
else:
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
inner_dim = int(dim * mult)
if ffd_type = "gelu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(approximate='tanh'),
nn.Linear(inner_dim, dim, bias=False),
)
elif ffd_type = "silu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.SiLU(),
nn.Linear(inner_dim, dim, bias=False),
)
else:
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
class TokenProjector(nn.Module):
def __init__(
self,
dim=1024,
depth=1,
dim_head=64,
heads=16,
num_queries=16,
output_dim=1024,
ff_mult=4,
chunck_size=None,
):
super().__init__()
self.num_queries = num_queries
self.chunck_size = chunck_size
if chunck_size is not None:
num_queries = num_queries * chunck_size
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x)
latents = self.latents.repeat(x.size(0), 1, 1)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)

View File

@@ -0,0 +1,258 @@
import os
import torch
import numpy as np
import torchvision
from tqdm import tqdm
from PIL import Image
from einops import rearrange
from torch import Tensor
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_tensor
def frames_to_mp4(frame_dir, output_path, fps):
def read_first_n_frames(d: os.PathLike, num_frames: int):
if num_frames:
images = [
Image.open(os.path.join(d, f))
for f in sorted(os.listdir(d))[:num_frames]
]
else:
images = [
Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))
]
images = [to_tensor(x) for x in images]
return torch.stack(images)
videos = read_first_n_frames(frame_dir, num_frames=None)
videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(output_path,
videos,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
"""
video: torch.Tensor, b,c,t,h,w, 0-1
if -1~1, enable rescale=True
"""
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
nrow = int(np.sqrt(n)) if nrow is None else nrow
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=nrow, padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids,
dim=0)
grid = torch.clamp(grid.float(), -1., 1.)
if rescale:
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(
0, 2, 3, 1)
torchvision.io.write_video(savepath,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True):
assert (video.dim() == 5)
assert (isinstance(video, torch.Tensor))
video = video.detach().cpu()
if clamp:
video = torch.clamp(video, -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n)))
for framesheet in video
]
grid = torch.stack(frame_grids,
dim=0)
if rescale:
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(
0, 2, 3, 1)
path = os.path.join(root, filename)
torchvision.io.write_video(path,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
if batch_logs is None:
return None
""" save images and videos from images dict """
def save_img_grid(grid, path, rescale):
if rescale:
grid = (grid + 1.0) / 2.0
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
for key in batch_logs:
value = batch_logs[key]
if isinstance(value, list) and isinstance(value[0], str):
# A batch of captions
path = os.path.join(save_dir, "%s-%s.txt" % (key, filename))
with open(path, 'w') as f:
for i, txt in enumerate(value):
f.write(f'idx={i}, txt={txt}\n')
f.close()
elif isinstance(value, torch.Tensor) and value.dim() == 5:
# Save video grids
video = value
# Only save grayscale or rgb mode
if video.shape[1] != 1 and video.shape[1] != 3:
continue
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(1), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids,
dim=0)
if rescale:
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename))
torchvision.io.write_video(path,
grid,
fps=save_fps,
video_codec='h264',
options={'crf': '10'})
# Save frame sheet
img = value
video_frames = rearrange(img, 'b c t h w -> (b t) c h w')
t = img.shape[2]
grid = torchvision.utils.make_grid(video_frames, nrow=t, padding=0)
path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
# Save_img_grid(grid, path, rescale)
elif isinstance(value, torch.Tensor) and value.dim() == 4:
# Save image grids
img = value
# Only save grayscale or rgb mode
if img.shape[1] != 1 and img.shape[1] != 3:
continue
n = img.shape[0]
grid = torchvision.utils.make_grid(img, nrow=1, padding=0)
path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
save_img_grid(grid, path, rescale)
else:
pass
def prepare_to_log(batch_logs, max_images=100000, clamp=True):
if batch_logs is None:
return None
for key in batch_logs:
N = batch_logs[key].shape[0] if hasattr(
batch_logs[key], 'shape') else len(batch_logs[key])
N = min(N, max_images)
batch_logs[key] = batch_logs[key][:N]
# In batch_logs: images <batched tensor> & instruction <text list>
if isinstance(batch_logs[key], torch.Tensor):
batch_logs[key] = batch_logs[key].detach().cpu()
if clamp:
try:
batch_logs[key] = torch.clamp(batch_logs[key].float(), -1.,
1.)
except RuntimeError:
print("clamp_scalar_cpu not implemented for Half")
return batch_logs
# ----------------------------------------------------------------------------------------------
def fill_with_black_squares(video, desired_len: int) -> Tensor:
if len(video) >= desired_len:
return video
return torch.cat([
video,
torch.zeros_like(video[0]).unsqueeze(0).repeat(
desired_len - len(video), 1, 1, 1),
],
dim=0)
# ----------------------------------------------------------------------------------------------
def load_num_videos(data_path, num_videos):
# First argument can be either data_path of np array
if isinstance(data_path, str):
videos = np.load(data_path)['arr_0'] # NTHWC
elif isinstance(data_path, np.ndarray):
videos = data_path
else:
raise Exception
if num_videos is not None:
videos = videos[:num_videos, :, :, :, :]
return videos
def npz_to_video_grid(data_path,
out_path,
num_frames,
fps,
num_videos=None,
nrow=None,
verbose=True):
if isinstance(data_path, str):
videos = load_num_videos(data_path, num_videos)
elif isinstance(data_path, np.ndarray):
videos = data_path
else:
raise Exception
n, t, h, w, c = videos.shape
videos_th = []
for i in range(n):
video = videos[i, :, :, :, :]
images = [video[j, :, :, :] for j in range(t)]
images = [to_tensor(img) for img in images]
video = torch.stack(images)
videos_th.append(video)
if verbose:
videos = [
fill_with_black_squares(v, num_frames)
for v in tqdm(videos_th, desc='Adding empty frames')
]
else:
videos = [fill_with_black_squares(v, num_frames)
for v in videos_th] # NTCHW
frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4)
if nrow is None:
nrow = int(np.ceil(np.sqrt(n)))
if verbose:
frame_grids = [
make_grid(fs, nrow=nrow)
for fs in tqdm(frame_grids, desc='Making grids')
]
else:
frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
if os.path.dirname(out_path) != "":
os.makedirs(os.path.dirname(out_path), exist_ok=True)
frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(
0, 2, 3, 1)
torchvision.io.write_video(out_path,
frame_grids,
fps=fps,
video_codec='h264',
options={'crf': '10'})

View File

@@ -0,0 +1,231 @@
import os
import logging
mainlogger = logging.getLogger('mainlogger')
import torch
import pandas as pd
from omegaconf import OmegaConf
from collections import OrderedDict
def init_workspace(name, logdir, model_config, lightning_config, rank=0):
workdir = os.path.join(logdir, name)
ckptdir = os.path.join(workdir, "checkpoints")
cfgdir = os.path.join(workdir, "configs")
loginfo = os.path.join(workdir, "loginfo")
# Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower)
os.makedirs(workdir, exist_ok=True)
os.makedirs(ckptdir, exist_ok=True)
os.makedirs(cfgdir, exist_ok=True)
os.makedirs(loginfo, exist_ok=True)
if rank == 0:
if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'),
exist_ok=True)
OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml"))
OmegaConf.save(OmegaConf.create({"lightning": lightning_config}),
os.path.join(cfgdir, "lightning.yaml"))
return workdir, ckptdir, cfgdir, loginfo
def check_config_attribute(config, name):
if name in config:
value = getattr(config, name)
return value
else:
return None
def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger):
default_callbacks_cfg = {
"model_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch}",
"verbose": True,
"save_last": False,
}
},
"batch_logger": {
"target": "unifolm_wma.utils.callbacks.ImageLogger",
"params": {
"save_dir": logdir,
"batch_frequency": 1000,
"max_images": 4,
"clamp": True,
}
},
"learning_rate_logger": {
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
"params": {
"logging_interval": "step",
"log_momentum": False
}
},
"cuda_callback": {
"target": "unifolm_wma.utils.callbacks.CUDACallback",
},
}
# Optional setting for saving checkpoints
monitor_metric = check_config_attribute(config.model.params, "monitor")
if monitor_metric is not None:
mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.")
default_callbacks_cfg["model_checkpoint"]["params"][
"monitor"] = monitor_metric
default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3
default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min"
if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
mainlogger.info(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
)
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {
"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch}-{step}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
return callbacks_cfg
def get_trainer_logger(lightning_config, logdir, on_debug):
default_logger_cfgs = {
"tensorboard": {
"target": "pytorch_lightning.loggers.TensorBoardLogger",
"params": {
"save_dir": logdir,
"name": "tensorboard",
}
},
"testtube": {
"target": "pytorch_lightning.loggers.CSVLogger",
"params": {
"name": "testtube",
"save_dir": logdir,
}
},
}
os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True)
default_logger_cfg = default_logger_cfgs["tensorboard"]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
return logger_cfg
def get_trainer_strategy(lightning_config):
default_strategy_dict = {
"target": "pytorch_lightning.strategies.DDPShardedStrategy"
}
if "strategy" in lightning_config:
strategy_cfg = lightning_config.strategy
return strategy_cfg
else:
strategy_cfg = OmegaConf.create()
strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg)
return strategy_cfg
def load_checkpoints(model, model_cfg):
if check_config_attribute(model_cfg, "pretrained_checkpoint"):
pretrained_ckpt = model_cfg.pretrained_checkpoint
assert os.path.exists(
pretrained_ckpt
), "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt
mainlogger.info(">>> Load weights from pretrained checkpoint")
pl_sd = torch.load(pretrained_ckpt, map_location="cpu")
try:
if 'state_dict' in pl_sd.keys():
model.load_state_dict(pl_sd["state_dict"], strict=False)
mainlogger.info(
">>> Loaded weights from pretrained checkpoint: %s" %
pretrained_ckpt)
else:
# deepspeed
new_pl_sd = OrderedDict()
for key in pl_sd['module'].keys():
new_pl_sd[key[16:]] = pl_sd['module'][key]
model.load_state_dict(new_pl_sd, strict=False)
except:
model.load_state_dict(pl_sd)
else:
mainlogger.info(">>> Start training from scratch")
return model
def set_logger(logfile, name='mainlogger'):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logfile, mode='w')
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
fh.setFormatter(
logging.Formatter("%(asctime)s-%(levelname)s: %(message)s"))
ch.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(fh)
logger.addHandler(ch)
return logger
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
def count_trainable_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_num_parameters(model):
models = [('World Model', model.model.diffusion_model),
('Action Head', model.model.diffusion_model.action_unet),
('State Head', model.model.diffusion_model.state_unet),
('Total Trainable', model),
('Total', model)]
data = []
for index, (name, model) in enumerate(models):
if name == "Total Trainable":
total_params = count_trainable_parameters(model)
else:
total_params = count_parameters(model)
if total_params < 0.1e9:
total_params_value = round(total_params / 1e6, 2)
unit = 'M'
else:
total_params_value = round(total_params / 1e9, 2)
unit = 'B'
data.append({
'Model Name': name,
'Params': f"{total_params_value} {unit}"
})
df = pd.DataFrame(data)
print(df)

View File

@@ -0,0 +1,81 @@
import importlib
import numpy as np
import cv2
import torch
import torch.distributed as dist
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params."
)
return total_params
def check_istarget(name, para_list):
"""
name: full name of source para
para_list: partial name of target para
"""
istarget = False
for para in para_list:
if para in name:
return True
return istarget
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def load_npz_from_dir(data_dir):
data = [
np.load(os.path.join(data_dir, data_name))['arr_0']
for data_name in os.listdir(data_dir)
]
data = np.concatenate(data, axis=0)
return data
def load_npz_from_paths(data_paths):
data = [np.load(data_path)['arr_0'] for data_path in data_paths]
data = np.concatenate(data, axis=0)
return data
def resize_numpy_image(image,
max_resolution=512 * 512,
resize_short_edge=None):
h, w = image.shape[:2]
if resize_short_edge is not None:
k = resize_short_edge / min(h, w)
else:
k = max_resolution / (h * w)
k = k**0.5
h = int(np.round(h * k / 64)) * 64
w = int(np.round(w * k / 64)) * 64
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
return image
def setup_dist(args):
if dist.is_initialized():
return
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group('nccl', init_method='env://')