第一次完整测例跑完
This commit is contained in:
0
src/unifolm_wma/__init__.py
Normal file
0
src/unifolm_wma/__init__.py
Normal file
26
src/unifolm_wma/data/base.py
Normal file
26
src/unifolm_wma/data/base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from abc import abstractmethod
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
'''
|
||||
Define an interface to make the IterableDatasets for text2img data chainable
|
||||
'''
|
||||
|
||||
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||
super().__init__()
|
||||
self.num_records = num_records
|
||||
self.valid_ids = valid_ids
|
||||
self.sample_ids = valid_ids
|
||||
self.size = size
|
||||
|
||||
print(
|
||||
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_records
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
||||
230
src/unifolm_wma/data/normolize.py
Normal file
230
src/unifolm_wma/data/normolize.py
Normal file
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
shapes: Dict[str, List[int]],
|
||||
modes: Dict[str, str],
|
||||
stats: Dict[str, Dict[str, Tensor]] = None,
|
||||
) -> Dict[str, Dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
Dict: A Dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
|
||||
shape = tuple(shapes[key])
|
||||
|
||||
if "image" in key:
|
||||
# sanity checks
|
||||
assert len(
|
||||
shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_Dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
if "action" in key:
|
||||
target_key = "action"
|
||||
elif "state" in key:
|
||||
target_key = 'observation.state'
|
||||
else:
|
||||
target_key = key
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict({
|
||||
"mean":
|
||||
nn.Parameter(mean, requires_grad=False),
|
||||
"std":
|
||||
nn.Parameter(std, requires_grad=False),
|
||||
})
|
||||
elif mode == "min_max":
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict({
|
||||
"min":
|
||||
nn.Parameter(min, requires_grad=False),
|
||||
"max":
|
||||
nn.Parameter(max, requires_grad=False),
|
||||
})
|
||||
|
||||
if stats is not None:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[target_key]["mean"].clone()
|
||||
buffer["std"].data = stats[target_key]["std"].clone()
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[target_key]["min"].clone()
|
||||
buffer["max"].data = stats[target_key]["max"].clone()
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model.")
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: Dict[str, List[int]],
|
||||
modes: Dict[str, str],
|
||||
stats: Dict[str, Dict[str, Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are Dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_Dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
for key, mode in self.modes.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: Dict[str, List[int]],
|
||||
modes: Dict[str, str],
|
||||
stats: Dict[str, Dict[str, Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are Dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_Dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
for key, mode in self.modes.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
60
src/unifolm_wma/data/utils.py
Normal file
60
src/unifolm_wma/data/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from typing import Dict, List, Union
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
|
||||
def unflatten_dict(d, sep="/"):
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
d = outdict
|
||||
for part in parts[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
d[parts[-1]] = value
|
||||
return outdict
|
||||
|
||||
|
||||
def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(
|
||||
root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id,
|
||||
"meta_data/episode_data_index.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id,
|
||||
"meta_data/stats.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
408
src/unifolm_wma/data/wma_data.py
Normal file
408
src/unifolm_wma/data/wma_data.py
Normal file
@@ -0,0 +1,408 @@
|
||||
import torch
|
||||
import os
|
||||
import random
|
||||
import pandas as pd
|
||||
import h5py
|
||||
|
||||
from decord import VideoReader, cpu
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from pathlib import Path
|
||||
|
||||
from unifolm_wma.data.utils import load_stats
|
||||
from unifolm_wma.data.normolize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class WMAData(Dataset):
|
||||
"""
|
||||
Assuming the following dataset structure:
|
||||
dataset_dir/
|
||||
├── videos
|
||||
│ ├──dataset_name
|
||||
│ │ ├──camera_view_dir
|
||||
│ │ ├── 0.mp4
|
||||
│ │ ├── 1.mp4
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset_name
|
||||
│ ├── meta_data
|
||||
│ ├── 0.h5
|
||||
│ ├── 1.h5
|
||||
│ └── ...
|
||||
└── dataset_name.csv
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta_path,
|
||||
data_dir,
|
||||
subsample=None,
|
||||
video_length=16,
|
||||
resolution=[256, 512],
|
||||
frame_stride=1,
|
||||
frame_stride_min=1,
|
||||
spatial_transform=None,
|
||||
crop_resolution=None,
|
||||
fps_max=None,
|
||||
load_raw_resolution=False,
|
||||
fixed_fps=None,
|
||||
random_fs=False,
|
||||
cond_robot_label_prob=0.0,
|
||||
transition_dir=None,
|
||||
dataset_name=None,
|
||||
normalization_mode='min_max',
|
||||
individual_normalization=False,
|
||||
n_obs_steps=1,
|
||||
max_action_dim=7,
|
||||
max_state_dim=7,
|
||||
):
|
||||
self.meta_path = meta_path
|
||||
self.data_dir = data_dir
|
||||
self.subsample = subsample
|
||||
self.video_length = video_length
|
||||
self.resolution = [resolution, resolution] if isinstance(
|
||||
resolution, int) else resolution
|
||||
self.fps_max = fps_max
|
||||
self.frame_stride = frame_stride
|
||||
self.frame_stride_min = frame_stride_min
|
||||
self.fixed_fps = fixed_fps
|
||||
self.load_raw_resolution = load_raw_resolution
|
||||
self.random_fs = random_fs
|
||||
self.cond_robot_label_prob = cond_robot_label_prob
|
||||
self.transition_dir = transition_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_action_dim = max_action_dim
|
||||
self.max_state_dim = max_state_dim
|
||||
|
||||
self._load_metadata()
|
||||
if spatial_transform is not None:
|
||||
if spatial_transform == "random_crop":
|
||||
self.spatial_transform = transforms.RandomCrop(crop_resolution)
|
||||
elif spatial_transform == "center_crop":
|
||||
self.spatial_transform = transforms.Compose([
|
||||
transforms.CenterCrop(resolution),
|
||||
])
|
||||
elif spatial_transform == "resize_center_crop":
|
||||
self.spatial_transform = transforms.Compose([
|
||||
transforms.Resize(min(self.resolution)),
|
||||
transforms.CenterCrop(self.resolution),
|
||||
])
|
||||
elif spatial_transform == "resize":
|
||||
self.spatial_transform = transforms.Resize(self.resolution)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.spatial_transform = None
|
||||
|
||||
self.normalization_mode = normalization_mode
|
||||
self.individual_normalization = individual_normalization
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self._load_stats()
|
||||
if individual_normalization:
|
||||
self._init_normalizers()
|
||||
|
||||
def _load_metadata(self):
|
||||
metadata = pd.read_csv(self.meta_path, dtype=str)
|
||||
if self.subsample is not None:
|
||||
metadata = metadata.sample(self.subsample, random_state=0)
|
||||
|
||||
self.metadata = metadata
|
||||
# drop the rows contain NaN values
|
||||
self.metadata.dropna(inplace=True)
|
||||
print(
|
||||
f">>> {metadata['data_dir'].iloc[0]}: {len(metadata)} data samples loaded."
|
||||
)
|
||||
|
||||
def _load_stats(self):
|
||||
self.stats = load_stats(self.dataset_name, None, self.transition_dir)
|
||||
print(f">>> {self.metadata['data_dir'].iloc[0]}: data stats loaded.")
|
||||
|
||||
def _init_normalizers(self):
|
||||
shape_dict = {
|
||||
'pre_action': [self.stats['action']['max'].shape[-1]],
|
||||
'action': [self.stats['action']['max'].shape[-1]],
|
||||
'observation.state':
|
||||
[self.stats['observation.state']['max'].shape[-1]],
|
||||
'next.state': [self.stats['observation.state']['max'].shape[-1]]
|
||||
}
|
||||
normalization_mode_dict = {
|
||||
'pre_action': self.normalization_mode,
|
||||
'action': self.normalization_mode,
|
||||
'observation.state': self.normalization_mode,
|
||||
'next.state': self.normalization_mode
|
||||
}
|
||||
self.normalizer = Normalize(shape_dict, normalization_mode_dict,
|
||||
self.stats)
|
||||
self.unnormalizer = Unnormalize(shape_dict, normalization_mode_dict,
|
||||
self.stats)
|
||||
print(
|
||||
f">>> {self.metadata['data_dir'].iloc[0]}: normalizer initiated.")
|
||||
|
||||
def _get_video_path(self, sample):
|
||||
rel_video_fp = os.path.join(sample['data_dir'],
|
||||
str(sample['videoid']) + '.mp4')
|
||||
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
|
||||
return full_video_fp
|
||||
|
||||
def _get_transition_path(self, sample):
|
||||
data_dir = Path(sample['data_dir'])
|
||||
if self.dataset_name == data_dir.name:
|
||||
rel_transition_fp = os.path.join(str(data_dir),
|
||||
str(sample['videoid']) + '.h5')
|
||||
else:
|
||||
rel_transition_fp = os.path.join(str(data_dir.parent),
|
||||
str(sample['videoid']) + '.h5')
|
||||
full_transition_fp = os.path.join(self.data_dir, 'transitions',
|
||||
rel_transition_fp)
|
||||
return full_transition_fp
|
||||
|
||||
def get_uni_vec(self, action_state_dict, action_type, state_type):
|
||||
if 'pre_action' in action_state_dict:
|
||||
action_state_dict['pre_action'], _ = self._map_to_uni_action(
|
||||
action_state_dict['pre_action'], action_type)
|
||||
if 'action' in action_state_dict:
|
||||
action_state_dict['action'], action_state_dict[
|
||||
'action_mask'] = self._map_to_uni_action(
|
||||
action_state_dict['action'], action_type)
|
||||
if 'observation.state' in action_state_dict:
|
||||
action_state_dict['observation.state'], _ = self._map_to_uni_state(
|
||||
action_state_dict['observation.state'], state_type)
|
||||
if 'next.state' in action_state_dict:
|
||||
action_state_dict['next.state'], action_state_dict[
|
||||
'state_mask'] = self._map_to_uni_state(
|
||||
action_state_dict['next.state'], state_type)
|
||||
return action_state_dict
|
||||
|
||||
def _map_to_uni_action(self, action, action_type):
|
||||
action_dim = action.shape[-1]
|
||||
uni_action = torch.nn.functional.pad(
|
||||
action, (0, self.max_action_dim - action_dim),
|
||||
mode='constant',
|
||||
value=0)
|
||||
uni_action_mask = torch.zeros_like(uni_action)
|
||||
uni_action_mask[:, :action_dim] = 1
|
||||
return uni_action, uni_action_mask
|
||||
|
||||
def _map_to_uni_state(self, state, state_type):
|
||||
state_dim = state.shape[-1]
|
||||
uni_state = torch.nn.functional.pad(
|
||||
state, (0, self.max_state_dim - state_dim),
|
||||
mode='constant',
|
||||
value=0)
|
||||
uni_state_mask = torch.zeros_like(uni_state)
|
||||
uni_state_mask[:, :state_dim] = 1
|
||||
return uni_state, uni_state_mask
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
if self.random_fs:
|
||||
frame_stride = random.randint(self.frame_stride_min,
|
||||
self.frame_stride)
|
||||
else:
|
||||
frame_stride = self.frame_stride
|
||||
|
||||
# Get frames until success
|
||||
while True:
|
||||
index = index % len(self.metadata)
|
||||
sample = self.metadata.iloc[index]
|
||||
video_path = self._get_video_path(sample)
|
||||
|
||||
instruction = sample['instruction']
|
||||
if self.cond_robot_label_prob > 0.0 and random.random(
|
||||
) < self.cond_robot_label_prob:
|
||||
if sample['embodiment'] != 'x':
|
||||
instruction = sample['embodiment'] + ' [SEP] ' + sample[
|
||||
'instruction']
|
||||
try:
|
||||
if self.load_raw_resolution:
|
||||
video_reader = VideoReader(video_path, ctx=cpu(0))
|
||||
else:
|
||||
video_reader = VideoReader(video_path,
|
||||
ctx=cpu(0),
|
||||
width=530,
|
||||
height=300)
|
||||
if len(video_reader) < self.video_length:
|
||||
print(
|
||||
f">>> Video length ({len(video_reader)}) is smaller than target length({self.video_length})"
|
||||
)
|
||||
index += 1
|
||||
continue
|
||||
else:
|
||||
pass
|
||||
except:
|
||||
index += 1
|
||||
print(f">>> Error: load video failed! path = {video_path}")
|
||||
continue
|
||||
|
||||
fps_ori = video_reader.get_avg_fps()
|
||||
if self.fixed_fps is not None:
|
||||
frame_stride = int(frame_stride *
|
||||
(1.0 * fps_ori / self.fixed_fps))
|
||||
|
||||
# To avoid extreme cases when fixed_fps is used
|
||||
frame_stride = max(frame_stride, 1)
|
||||
|
||||
# Get valid range (adapting case by case)
|
||||
required_frame_num = frame_stride * (self.video_length - 1) + 1
|
||||
frame_num = len(video_reader)
|
||||
if frame_num < required_frame_num:
|
||||
# Drop extra samples if fixed fps is required
|
||||
if self.fixed_fps is not None and frame_num < required_frame_num * 0.5:
|
||||
index += 1
|
||||
continue
|
||||
else:
|
||||
frame_stride = frame_num // self.video_length
|
||||
required_frame_num = frame_stride * (self.video_length -
|
||||
1) + 1
|
||||
|
||||
# Select a random clip
|
||||
random_range = frame_num - required_frame_num
|
||||
start_idx = random.randint(
|
||||
0, random_range -
|
||||
frame_stride) if random_range - frame_stride > 0 else 0
|
||||
|
||||
# Calculate frame indices
|
||||
frame_indices = [
|
||||
start_idx + frame_stride * i for i in range(self.video_length)
|
||||
]
|
||||
try:
|
||||
next_frame_indices = [
|
||||
idx + frame_stride for idx in frame_indices
|
||||
]
|
||||
frames = video_reader.get_batch(next_frame_indices)
|
||||
break
|
||||
except:
|
||||
print(
|
||||
f">>> Error: Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]"
|
||||
)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
# Load transition data
|
||||
transition_path = self._get_transition_path(sample)
|
||||
with h5py.File(transition_path, 'r') as h5f:
|
||||
transition_dict = {}
|
||||
for key in h5f.keys():
|
||||
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||
for key in h5f.attrs.keys():
|
||||
transition_dict[key] = h5f.attrs[key]
|
||||
|
||||
# Load observable states
|
||||
if start_idx < self.n_obs_steps - 1:
|
||||
state_indices = list(range(0, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
num_padding = self.n_obs_steps - 1 - start_idx
|
||||
first_slice = states[0:1, :] # (t, d)
|
||||
padding = first_slice.repeat(num_padding, 1)
|
||||
states = torch.cat((padding, states), dim=0)
|
||||
else:
|
||||
state_indices = list(
|
||||
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
assert states.shape[
|
||||
0] == self.n_obs_steps, '>>> Do not have enough previous states as observation.'
|
||||
|
||||
# Load observable actions
|
||||
if start_idx < self.n_obs_steps:
|
||||
pre_action_indices = list(range(0, start_idx))
|
||||
pre_actions = transition_dict['action'][pre_action_indices, :]
|
||||
num_padding = self.n_obs_steps - start_idx
|
||||
first_slice = torch.zeros_like(transition_dict['action'][:1, :])
|
||||
padding = first_slice.repeat(num_padding, 1)
|
||||
pre_actions = torch.cat((padding, pre_actions), dim=0)
|
||||
else:
|
||||
pre_action_indices = list(
|
||||
range(start_idx - self.n_obs_steps, start_idx))
|
||||
pre_actions = transition_dict['action'][pre_action_indices, :]
|
||||
assert pre_actions.shape[
|
||||
0] == self.n_obs_steps, ">>> Do not have enough previous actions as observation"
|
||||
|
||||
# Load future actions
|
||||
actions = transition_dict['action'][frame_indices, :]
|
||||
# Load future states
|
||||
next_state_indices = [idx + frame_stride for idx in frame_indices]
|
||||
next_states = transition_dict['observation.state'][
|
||||
next_state_indices, :]
|
||||
frames_action_state_dict = {
|
||||
'pre_action': pre_actions,
|
||||
'action': actions,
|
||||
'observation.state': states,
|
||||
'next.state': next_states
|
||||
}
|
||||
if self.individual_normalization:
|
||||
frames_action_state_dict = self.normalizer(
|
||||
frames_action_state_dict)
|
||||
|
||||
# Update action and states to unified vector
|
||||
frames_action_state_dict = self.get_uni_vec(
|
||||
frames_action_state_dict,
|
||||
transition_dict['action_type'],
|
||||
transition_dict['state_type'],
|
||||
)
|
||||
|
||||
# Load observable images
|
||||
if start_idx < self.n_obs_steps - 1:
|
||||
action_net_frame_indices = list(range(0, start_idx + 1))
|
||||
action_net_frames = video_reader.get_batch(
|
||||
action_net_frame_indices)
|
||||
action_net_frames = torch.tensor(
|
||||
action_net_frames.asnumpy()).permute(0, 3, 1, 2).float()
|
||||
first_slice = action_net_frames[0:1, :]
|
||||
num_padding = self.n_obs_steps - 1 - start_idx
|
||||
padding = first_slice.repeat(num_padding, 1, 1, 1)
|
||||
action_net_frames = torch.cat((padding, action_net_frames), dim=0)
|
||||
assert (
|
||||
action_net_frames.shape[0] == self.n_obs_steps
|
||||
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
|
||||
action_net_frames = action_net_frames.permute(1, 0, 2, 3)
|
||||
else:
|
||||
action_net_frame_indices = list(
|
||||
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
|
||||
action_net_frames = video_reader.get_batch(
|
||||
action_net_frame_indices)
|
||||
assert (
|
||||
action_net_frames.shape[0] == self.n_obs_steps
|
||||
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
|
||||
action_net_frames = torch.tensor(
|
||||
action_net_frames.asnumpy()).permute(3, 0, 1, 2).float()
|
||||
|
||||
assert (frames.shape[0] == self.video_length
|
||||
), f'{len(frames)}, self.video_length={self.video_length}'
|
||||
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
|
||||
|
||||
if self.spatial_transform is not None:
|
||||
frames = self.spatial_transform(frames)
|
||||
action_net_frames = self.spatial_transform(action_net_frames)
|
||||
|
||||
if self.resolution is not None:
|
||||
assert (frames.shape[2], frames.shape[3]) == (
|
||||
self.resolution[0], self.resolution[1]
|
||||
), f'frames={frames.shape}, self.resolution={self.resolution}'
|
||||
assert (
|
||||
action_net_frames.shape[2], action_net_frames.shape[3]
|
||||
) == (
|
||||
self.resolution[0], self.resolution[1]
|
||||
), f'action_net_frames={action_net_frames.shape}, self.resolution={self.resolution}'
|
||||
|
||||
# Normalize frames tensors to [-1,1]
|
||||
frames = (frames / 255 - 0.5) * 2
|
||||
action_net_frames = (action_net_frames / 255 - 0.5) * 2
|
||||
fps_clip = fps_ori // frame_stride
|
||||
if self.fps_max is not None and fps_clip > self.fps_max:
|
||||
fps_clip = self.fps_max
|
||||
|
||||
data = {
|
||||
'video': frames,
|
||||
'instruction': instruction,
|
||||
'path': video_path,
|
||||
'fps': fps_clip,
|
||||
'frame_stride': frame_stride,
|
||||
'observation.image': action_net_frames,
|
||||
}
|
||||
data.update(frames_action_state_dict)
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metadata)
|
||||
0
src/unifolm_wma/models/__init__.py
Normal file
0
src/unifolm_wma/models/__init__.py
Normal file
267
src/unifolm_wma/models/autoencoder.py
Normal file
267
src/unifolm_wma/models/autoencoder.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from einops import rearrange
|
||||
from unifolm_wma.modules.networks.ae_modules import Encoder, Decoder
|
||||
from unifolm_wma.utils.distributions import DiagonalGaussianDistribution
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
test=False,
|
||||
logdir=None,
|
||||
input_dim=4,
|
||||
test_args=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
self.input_dim = input_dim
|
||||
self.test = test
|
||||
self.test_args = test_args
|
||||
self.logdir = logdir
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer("colorize",
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
if self.test:
|
||||
self.init_test()
|
||||
|
||||
def init_test(self, ):
|
||||
self.test = True
|
||||
save_dir = os.path.join(self.logdir, "test")
|
||||
if 'ckpt' in self.test_args:
|
||||
ckpt_name = os.path.basename(self.test_args.ckpt).split(
|
||||
'.ckpt')[0] + f'_epoch{self._cur_epoch}'
|
||||
self.root = os.path.join(save_dir, ckpt_name)
|
||||
else:
|
||||
self.root = save_dir
|
||||
if 'test_subdir' in self.test_args:
|
||||
self.root = os.path.join(save_dir, self.test_args.test_subdir)
|
||||
|
||||
self.root_zs = os.path.join(self.root, "zs")
|
||||
self.root_dec = os.path.join(self.root, "reconstructions")
|
||||
self.root_inputs = os.path.join(self.root, "inputs")
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
|
||||
if self.test_args.save_z:
|
||||
os.makedirs(self.root_zs, exist_ok=True)
|
||||
if self.test_args.save_reconstruction:
|
||||
os.makedirs(self.root_dec, exist_ok=True)
|
||||
if self.test_args.save_input:
|
||||
os.makedirs(self.root_inputs, exist_ok=True)
|
||||
assert (self.test_args is not None)
|
||||
self.test_maximum = getattr(self.test_args, 'test_maximum', None)
|
||||
self.count = 0
|
||||
self.eval_metrics = {}
|
||||
self.decodes = []
|
||||
self.save_decode_samples = 2048
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")
|
||||
try:
|
||||
self._cur_epoch = sd['epoch']
|
||||
sd = sd["state_dict"]
|
||||
except:
|
||||
self._cur_epoch = 'null'
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z, **kwargs):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if x.dim() == 5 and self.input_dim == 4:
|
||||
b, c, t, h, w = x.shape
|
||||
self.b = b
|
||||
self.t = t
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train")
|
||||
self.log("aeloss",
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="train")
|
||||
|
||||
self.log("discloss",
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val")
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
|
||||
list(self.decoder.parameters()) +
|
||||
list(self.quant_conv.parameters()) +
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr,
|
||||
betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize",
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
2524
src/unifolm_wma/models/ddpms.py
Normal file
2524
src/unifolm_wma/models/ddpms.py
Normal file
File diff suppressed because it is too large
Load Diff
0
src/unifolm_wma/models/diffusion_head/__init__.py
Normal file
0
src/unifolm_wma/models/diffusion_head/__init__.py
Normal file
217
src/unifolm_wma/models/diffusion_head/base_nets.py
Normal file
217
src/unifolm_wma/models/diffusion_head/base_nets.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Contains torch Modules that correspond to basic network building blocks, like
|
||||
MLP, RNN, and CNN backbones.
|
||||
"""
|
||||
|
||||
import abc
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Module(torch.nn.Module):
|
||||
"""
|
||||
Base class for networks. The only difference from torch.nn.Module is that it
|
||||
requires implementing @output_shape.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def output_shape(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
"""
|
||||
================================================
|
||||
Visual Backbone Networks
|
||||
================================================
|
||||
"""
|
||||
|
||||
|
||||
class ConvBase(Module):
|
||||
"""
|
||||
Base class for ConvNets.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(ConvBase, self).__init__()
|
||||
|
||||
# dirty hack - re-implement to pass the buck onto subclasses from ABC parent
|
||||
def output_shape(self, input_shape):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.nets(inputs)
|
||||
if list(self.output_shape(list(inputs.shape)[1:])) != list(
|
||||
x.shape)[1:]:
|
||||
raise ValueError('Size mismatch: expect size %s, but got size %s' %
|
||||
(str(self.output_shape(list(
|
||||
inputs.shape)[1:])), str(list(x.shape)[1:])))
|
||||
return x
|
||||
|
||||
|
||||
class SpatialSoftmax(ConvBase):
|
||||
"""
|
||||
Spatial Softmax Layer.
|
||||
|
||||
Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al.
|
||||
https://rll.berkeley.edu/dsae/dsae.pdf
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
num_kp=32,
|
||||
temperature=1.,
|
||||
learnable_temperature=False,
|
||||
output_variance=False,
|
||||
noise_std=0.0,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape (list): shape of the input feature (C, H, W)
|
||||
num_kp (int): number of keypoints (None for not using spatialsoftmax)
|
||||
temperature (float): temperature term for the softmax.
|
||||
learnable_temperature (bool): whether to learn the temperature
|
||||
output_variance (bool): treat attention as a distribution, and compute second-order statistics to return
|
||||
noise_std (float): add random spatial noise to the predicted keypoints
|
||||
"""
|
||||
super(SpatialSoftmax, self).__init__()
|
||||
assert len(input_shape) == 3
|
||||
self._in_c, self._in_h, self._in_w = input_shape # (C, H, W)
|
||||
|
||||
if num_kp is not None:
|
||||
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
|
||||
self._num_kp = num_kp
|
||||
else:
|
||||
self.nets = None
|
||||
self._num_kp = self._in_c
|
||||
self.learnable_temperature = learnable_temperature
|
||||
self.output_variance = output_variance
|
||||
self.noise_std = noise_std
|
||||
|
||||
if self.learnable_temperature:
|
||||
# temperature will be learned
|
||||
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
|
||||
requires_grad=True)
|
||||
self.register_parameter('temperature', temperature)
|
||||
else:
|
||||
# temperature held constant after initialization
|
||||
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
|
||||
requires_grad=False)
|
||||
self.register_buffer('temperature', temperature)
|
||||
|
||||
pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., self._in_w),
|
||||
np.linspace(-1., 1., self._in_h))
|
||||
pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h *
|
||||
self._in_w)).float()
|
||||
pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h *
|
||||
self._in_w)).float()
|
||||
self.register_buffer('pos_x', pos_x)
|
||||
self.register_buffer('pos_y', pos_y)
|
||||
|
||||
self.kps = None
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty print network."""
|
||||
header = format(str(self.__class__.__name__))
|
||||
return header + '(num_kp={}, temperature={}, noise={})'.format(
|
||||
self._num_kp, self.temperature.item(), self.noise_std)
|
||||
|
||||
def output_shape(self, input_shape):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
assert (len(input_shape) == 3)
|
||||
assert (input_shape[0] == self._in_c)
|
||||
return [self._num_kp, 2]
|
||||
|
||||
def forward(self, feature):
|
||||
"""
|
||||
Forward pass through spatial softmax layer. For each keypoint, a 2D spatial
|
||||
probability distribution is created using a softmax, where the support is the
|
||||
pixel locations. This distribution is used to compute the expected value of
|
||||
the pixel location, which becomes a keypoint of dimension 2. K such keypoints
|
||||
are created.
|
||||
|
||||
Returns:
|
||||
out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly
|
||||
keypoint variance of shape [B, K, 2, 2] corresponding to the covariance
|
||||
under the 2D spatial softmax distribution
|
||||
"""
|
||||
assert (feature.shape[1] == self._in_c)
|
||||
assert (feature.shape[2] == self._in_h)
|
||||
assert (feature.shape[3] == self._in_w)
|
||||
if self.nets is not None:
|
||||
feature = self.nets(feature)
|
||||
|
||||
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
|
||||
feature = feature.reshape(-1, self._in_h * self._in_w)
|
||||
# 2d softmax normalization
|
||||
attention = F.softmax(feature / self.temperature, dim=-1)
|
||||
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
|
||||
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
|
||||
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
|
||||
# stack to [B * K, 2]
|
||||
expected_xy = torch.cat([expected_x, expected_y], 1)
|
||||
# reshape to [B, K, 2]
|
||||
feature_keypoints = expected_xy.view(-1, self._num_kp, 2)
|
||||
|
||||
if self.training:
|
||||
noise = torch.randn_like(feature_keypoints) * self.noise_std
|
||||
feature_keypoints += noise
|
||||
|
||||
if self.output_variance:
|
||||
# treat attention as a distribution, and compute second-order statistics to return
|
||||
expected_xx = torch.sum(self.pos_x * self.pos_x * attention,
|
||||
dim=1,
|
||||
keepdim=True)
|
||||
expected_yy = torch.sum(self.pos_y * self.pos_y * attention,
|
||||
dim=1,
|
||||
keepdim=True)
|
||||
expected_xy = torch.sum(self.pos_x * self.pos_y * attention,
|
||||
dim=1,
|
||||
keepdim=True)
|
||||
var_x = expected_xx - expected_x * expected_x
|
||||
var_y = expected_yy - expected_y * expected_y
|
||||
var_xy = expected_xy - expected_x * expected_y
|
||||
# stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix
|
||||
feature_covar = torch.cat([var_x, var_xy, var_xy, var_y],
|
||||
1).reshape(-1, self._num_kp, 2, 2)
|
||||
feature_keypoints = (feature_keypoints, feature_covar)
|
||||
|
||||
if isinstance(feature_keypoints, tuple):
|
||||
self.kps = (feature_keypoints[0].detach(),
|
||||
feature_keypoints[1].detach())
|
||||
else:
|
||||
self.kps = feature_keypoints.detach()
|
||||
return feature_keypoints
|
||||
83
src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py
Normal file
83
src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py
Normal file
@@ -0,0 +1,83 @@
|
||||
from diffusers.optimization import (Union, SchedulerType, Optional, Optimizer,
|
||||
TYPE_TO_SCHEDULER_FUNCTION)
|
||||
|
||||
|
||||
def get_scheduler(name: Union[str, SchedulerType],
|
||||
optimizer: Optimizer,
|
||||
num_warmup_steps: Optional[int] = None,
|
||||
num_training_steps: Optional[int] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
Added kwargs vs diffuser's original implementation
|
||||
|
||||
Unified API to get any scheduler from its name.
|
||||
|
||||
Args:
|
||||
name (`str` or `SchedulerType`):
|
||||
The name of the scheduler to use.
|
||||
optimizer (`torch.optim.Optimizer`):
|
||||
The optimizer that will be used during training.
|
||||
num_warmup_steps (`int`, *optional*):
|
||||
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
num_training_steps (`int``, *optional*):
|
||||
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
||||
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
||||
"""
|
||||
name = SchedulerType(name)
|
||||
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
||||
if name == SchedulerType.CONSTANT:
|
||||
return schedule_func(optimizer, **kwargs)
|
||||
|
||||
# All other schedulers require `num_warmup_steps`
|
||||
if num_warmup_steps is None:
|
||||
raise ValueError(
|
||||
f"{name} requires `num_warmup_steps`, please provide that argument."
|
||||
)
|
||||
|
||||
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
||||
return schedule_func(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
**kwargs)
|
||||
|
||||
# All other schedulers require `num_training_steps`
|
||||
if num_training_steps is None:
|
||||
raise ValueError(
|
||||
f"{name} requires `num_training_steps`, please provide that argument."
|
||||
)
|
||||
|
||||
return schedule_func(optimizer,
|
||||
num_warmup_steps=num_warmup_steps,
|
||||
num_training_steps=num_training_steps,
|
||||
**kwargs)
|
||||
|
||||
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
import pytorch_lightning as pl
|
||||
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType
|
||||
|
||||
|
||||
class SelectiveLRScheduler(_LRScheduler):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
base_scheduler,
|
||||
group_indices,
|
||||
default_lr=[1e-5, 1e-4],
|
||||
last_epoch=-1):
|
||||
self.base_scheduler = base_scheduler
|
||||
self.group_indices = group_indices # Indices of parameter groups to update
|
||||
self.default_lr = default_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def step(self, epoch=None):
|
||||
self.base_scheduler.step()
|
||||
base_lrs = self.base_scheduler.get_last_lr()
|
||||
|
||||
for idx, group in enumerate(self.optimizer.param_groups):
|
||||
if idx in self.group_indices:
|
||||
group['lr'] = base_lrs[idx]
|
||||
else:
|
||||
# Reset the learning rate to its initial value
|
||||
group['lr'] = self.default_lr[idx]
|
||||
@@ -0,0 +1,16 @@
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModuleAttrMixin(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._dummy_variable = nn.Parameter()
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(iter(self.parameters())).device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(iter(self.parameters())).dtype
|
||||
91
src/unifolm_wma/models/diffusion_head/common/pytorch_util.py
Normal file
91
src/unifolm_wma/models/diffusion_head/common/pytorch_util.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import collections
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Callable, List
|
||||
|
||||
|
||||
def dict_apply(
|
||||
x: Dict[str, torch.Tensor],
|
||||
func: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
result = dict()
|
||||
for key, value in x.items():
|
||||
if isinstance(value, dict):
|
||||
result[key] = dict_apply(value, func)
|
||||
else:
|
||||
result[key] = func(value)
|
||||
return result
|
||||
|
||||
|
||||
def pad_remaining_dims(x, target):
|
||||
assert x.shape == target.shape[:len(x.shape)]
|
||||
return x.reshape(x.shape + (1, ) * (len(target.shape) - len(x.shape)))
|
||||
|
||||
|
||||
def dict_apply_split(
|
||||
x: Dict[str, torch.Tensor], split_func: Callable[[torch.Tensor],
|
||||
Dict[str, torch.Tensor]]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
results = collections.defaultdict(dict)
|
||||
for key, value in x.items():
|
||||
result = split_func(value)
|
||||
for k, v in result.items():
|
||||
results[k][key] = v
|
||||
return results
|
||||
|
||||
|
||||
def dict_apply_reduce(
|
||||
x: List[Dict[str,
|
||||
torch.Tensor]], reduce_func: Callable[[List[torch.Tensor]],
|
||||
torch.Tensor]
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
result = dict()
|
||||
for key in x[0].keys():
|
||||
result[key] = reduce_func([x_[key] for x_ in x])
|
||||
return result
|
||||
|
||||
|
||||
def replace_submodules(root_module: nn.Module, predicate: Callable[[nn.Module],
|
||||
bool],
|
||||
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
|
||||
"""
|
||||
predicate: Return true if the module is to be replaced.
|
||||
func: Return new module to use.
|
||||
"""
|
||||
if predicate(root_module):
|
||||
return func(root_module)
|
||||
|
||||
bn_list = [
|
||||
k.split('.')
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
for *parent, k in bn_list:
|
||||
parent_module = root_module
|
||||
if len(parent) > 0:
|
||||
parent_module = root_module.get_submodule('.'.join(parent))
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
src_module = parent_module[int(k)]
|
||||
else:
|
||||
src_module = getattr(parent_module, k)
|
||||
tgt_module = func(src_module)
|
||||
if isinstance(parent_module, nn.Sequential):
|
||||
parent_module[int(k)] = tgt_module
|
||||
else:
|
||||
setattr(parent_module, k, tgt_module)
|
||||
# verify that all BN are replaced
|
||||
bn_list = [
|
||||
k.split('.')
|
||||
for k, m in root_module.named_modules(remove_duplicate=True)
|
||||
if predicate(m)
|
||||
]
|
||||
assert len(bn_list) == 0
|
||||
return root_module
|
||||
|
||||
|
||||
def optimizer_to(optimizer, device):
|
||||
for state in optimizer.state.values():
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(device=device)
|
||||
return optimizer
|
||||
960
src/unifolm_wma/models/diffusion_head/common/tensor_util.py
Normal file
960
src/unifolm_wma/models/diffusion_head/common/tensor_util.py
Normal file
@@ -0,0 +1,960 @@
|
||||
"""
|
||||
A collection of utilities for working with nested tensor structures consisting
|
||||
of numpy arrays and torch tensors.
|
||||
"""
|
||||
import collections
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def recursive_dict_list_tuple_apply(x, type_func_dict):
|
||||
"""
|
||||
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
|
||||
{data_type: function_to_apply}.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
type_func_dict (dict): a mapping from data types to the functions to be
|
||||
applied for each data type.
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
assert (list not in type_func_dict)
|
||||
assert (tuple not in type_func_dict)
|
||||
assert (dict not in type_func_dict)
|
||||
|
||||
if isinstance(x, (dict, collections.OrderedDict)):
|
||||
new_x = collections.OrderedDict() if isinstance(
|
||||
x, collections.OrderedDict) else dict()
|
||||
for k, v in x.items():
|
||||
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
|
||||
return new_x
|
||||
elif isinstance(x, (list, tuple)):
|
||||
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
|
||||
if isinstance(x, tuple):
|
||||
ret = tuple(ret)
|
||||
return ret
|
||||
else:
|
||||
for t, f in type_func_dict.items():
|
||||
if isinstance(x, t):
|
||||
return f(x)
|
||||
else:
|
||||
raise NotImplementedError('Cannot handle data type %s' %
|
||||
str(type(x)))
|
||||
|
||||
|
||||
def map_tensor(x, func):
|
||||
"""
|
||||
Apply function @func to torch.Tensor objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each tensor
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor: func,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def map_ndarray(x, func):
|
||||
"""
|
||||
Apply function @func to np.ndarray objects in a nested dictionary or
|
||||
list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
np.ndarray: func,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def map_tensor_ndarray(x, tensor_func, ndarray_func):
|
||||
"""
|
||||
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
|
||||
np.ndarray objects in a nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
tensor_func (function): function to apply to each tensor
|
||||
ndarray_Func (function): function to apply to each array
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: tensor_func,
|
||||
np.ndarray: ndarray_func,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def clone(x):
|
||||
"""
|
||||
Clones all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.clone(),
|
||||
np.ndarray: lambda x: x.copy(),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def detach(x):
|
||||
"""
|
||||
Detaches all torch tensors in nested dictionary or list
|
||||
or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor: lambda x: x.detach(),
|
||||
})
|
||||
|
||||
|
||||
def to_batch(x):
|
||||
"""
|
||||
Introduces a leading batch dimension of 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x[None, ...],
|
||||
np.ndarray: lambda x: x[None, ...],
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_sequence(x):
|
||||
"""
|
||||
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
|
||||
arrays in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x[:, None, ...],
|
||||
np.ndarray: lambda x: x[:, None, ...],
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def index_at_time(x, ind):
|
||||
"""
|
||||
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
|
||||
nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
ind (int): index
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x[:, ind, ...],
|
||||
np.ndarray: lambda x: x[:, ind, ...],
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def unsqueeze(x, dim):
|
||||
"""
|
||||
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
|
||||
in nested dictionary or list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
dim (int): dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
|
||||
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def contiguous(x):
|
||||
"""
|
||||
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
|
||||
list or tuple and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.contiguous(),
|
||||
np.ndarray: lambda x: np.ascontiguousarray(x),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_device(x, device):
|
||||
"""
|
||||
Sends all torch tensors in nested dictionary or list or tuple to device
|
||||
@device, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x, d=device: x.to(d),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_tensor(x):
|
||||
"""
|
||||
Converts all numpy arrays in nested dictionary or list or tuple to
|
||||
torch tensors (and leaves existing torch Tensors as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x,
|
||||
np.ndarray: lambda x: torch.from_numpy(x),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_numpy(x):
|
||||
"""
|
||||
Converts all torch tensors in nested dictionary or list or tuple to
|
||||
numpy (and leaves existing numpy arrays as-is), and returns
|
||||
a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy()
|
||||
else:
|
||||
return tensor.detach().numpy()
|
||||
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_list(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to a list, and returns a new nested structure. Useful for
|
||||
json encoding.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
|
||||
def f(tensor):
|
||||
if tensor.is_cuda:
|
||||
return tensor.detach().cpu().numpy().tolist()
|
||||
else:
|
||||
return tensor.detach().numpy().tolist()
|
||||
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: f,
|
||||
np.ndarray: lambda x: x.tolist(),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_float(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to float type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.float(),
|
||||
np.ndarray: lambda x: x.astype(np.float32),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_uint8(x):
|
||||
"""
|
||||
Converts all torch tensors and numpy arrays in nested dictionary or list
|
||||
or tuple to uint8 type entries, and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.byte(),
|
||||
np.ndarray: lambda x: x.astype(np.uint8),
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def to_torch(x, device):
|
||||
"""
|
||||
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
|
||||
torch tensors on device @device and returns a new nested structure.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
device (torch.Device): device to send tensors to
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return to_device(to_float(to_tensor(x)), device)
|
||||
|
||||
|
||||
def to_one_hot_single(tensor, num_class):
|
||||
"""
|
||||
Convert tensor to one-hot representation, assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): tensor containing integer labels
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): tensor containing one-hot representation of labels
|
||||
"""
|
||||
x = torch.zeros(tensor.size() + (num_class, )).to(tensor.device)
|
||||
x.scatter_(-1, tensor.unsqueeze(-1), 1)
|
||||
return x
|
||||
|
||||
|
||||
def to_one_hot(tensor, num_class):
|
||||
"""
|
||||
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
|
||||
assuming a certain number of total class labels.
|
||||
|
||||
Args:
|
||||
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
num_class (int): number of classes
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(tensor,
|
||||
func=lambda x, nc=num_class: to_one_hot_single(x, nc))
|
||||
|
||||
|
||||
def flatten_single(x, begin_axis=1):
|
||||
"""
|
||||
Flatten a tensor in all dimensions from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to flatten
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): flattened tensor
|
||||
"""
|
||||
fixed_size = x.size()[:begin_axis]
|
||||
_s = list(fixed_size) + [-1]
|
||||
return x.reshape(*_s)
|
||||
|
||||
|
||||
def flatten(x, begin_axis=1):
|
||||
"""
|
||||
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): which axis to flatten from
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(x, {
|
||||
torch.Tensor:
|
||||
lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
|
||||
})
|
||||
|
||||
|
||||
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions in a tensor to a target dimension.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to reshape
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reshaped tensor
|
||||
"""
|
||||
assert (begin_axis <= end_axis)
|
||||
assert (begin_axis >= 0)
|
||||
assert (end_axis < len(x.shape))
|
||||
assert (isinstance(target_dims, (tuple, list)))
|
||||
s = x.shape
|
||||
final_s = []
|
||||
for i in range(len(s)):
|
||||
if i == begin_axis:
|
||||
final_s.extend(target_dims)
|
||||
elif i < begin_axis or i > end_axis:
|
||||
final_s.append(s[i])
|
||||
return x.reshape(*final_s)
|
||||
|
||||
|
||||
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
|
||||
"""
|
||||
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
|
||||
to a target dimension.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
target_dims (tuple or list): target shape for the range of dimensions
|
||||
(@begin_axis, @end_axis)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor:
|
||||
lambda x, b=begin_axis, e=end_axis, t=target_dims:
|
||||
reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t),
|
||||
np.ndarray:
|
||||
lambda x, b=begin_axis, e=end_axis, t=target_dims:
|
||||
reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=t),
|
||||
type(None):
|
||||
lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def join_dimensions(x, begin_axis, end_axis):
|
||||
"""
|
||||
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
|
||||
all tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
begin_axis (int): begin dimension
|
||||
end_axis (int): end dimension
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor:
|
||||
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]),
|
||||
np.ndarray:
|
||||
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
|
||||
x, begin_axis=b, end_axis=e, target_dims=[-1]),
|
||||
type(None):
|
||||
lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def expand_at_single(x, size, dim):
|
||||
"""
|
||||
Expand a tensor at a single dimension @dim by @size
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input tensor
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): expanded tensor
|
||||
"""
|
||||
assert dim < x.ndimension()
|
||||
assert x.shape[dim] == 1
|
||||
expand_dims = [-1] * x.ndimension()
|
||||
expand_dims[dim] = size
|
||||
return x.expand(*expand_dims)
|
||||
|
||||
|
||||
def expand_at(x, size, dim):
|
||||
"""
|
||||
Expand all tensors in nested dictionary or list or tuple at a single
|
||||
dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
|
||||
|
||||
|
||||
def unsqueeze_expand_at(x, size, dim):
|
||||
"""
|
||||
Unsqueeze and expand a tensor at a dimension @dim by @size.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size to expand
|
||||
dim (int): dimension to unsqueeze and expand
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze(x, dim)
|
||||
return expand_at(x, size, dim)
|
||||
|
||||
|
||||
def repeat_by_expand_at(x, repeats, dim):
|
||||
"""
|
||||
Repeat a dimension by combining expand and reshape operations.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
repeats (int): number of times to repeat the target dimension
|
||||
dim (int): dimension to repeat on
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
x = unsqueeze_expand_at(x, repeats, dim + 1)
|
||||
return join_dimensions(x, dim, dim + 1)
|
||||
|
||||
|
||||
def named_reduce_single(x, reduction, dim):
|
||||
"""
|
||||
Reduce tensor at a dimension by named reduction functions.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to be reduced
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): reduced tensor
|
||||
"""
|
||||
assert x.ndimension() > dim
|
||||
assert reduction in ["sum", "max", "mean", "flatten"]
|
||||
if reduction == "flatten":
|
||||
x = flatten(x, begin_axis=dim)
|
||||
elif reduction == "max":
|
||||
x = torch.max(x, dim=dim)[0] # [B, D]
|
||||
elif reduction == "sum":
|
||||
x = torch.sum(x, dim=dim)
|
||||
else:
|
||||
x = torch.mean(x, dim=dim)
|
||||
return x
|
||||
|
||||
|
||||
def named_reduce(x, reduction, dim):
|
||||
"""
|
||||
Reduces all tensors in nested dictionary or list or tuple at a dimension
|
||||
using a named reduction function.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
reduction (str): one of ["sum", "max", "mean", "flatten"]
|
||||
dim (int): dimension to be reduced (or begin axis for flatten)
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(
|
||||
x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
|
||||
|
||||
|
||||
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
This function indexes out a target dimension of a tensor in a structured way,
|
||||
by allowing a different value to be selected for each member of a flat index
|
||||
tensor (@indices) corresponding to a source dimension. This can be interpreted
|
||||
as moving along the source dimension, using the corresponding index value
|
||||
in @indices to select values for all other dimensions outside of the
|
||||
source and target dimensions. A common use case is to gather values
|
||||
in target dimension 1 for each batch member (target dimension 0).
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): tensor to gather values for
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
|
||||
"""
|
||||
assert len(indices.shape) == 1
|
||||
assert x.shape[source_dim] == indices.shape[0]
|
||||
|
||||
# unsqueeze in all dimensions except the source dimension
|
||||
new_shape = [1] * x.ndimension()
|
||||
new_shape[source_dim] = -1
|
||||
indices = indices.reshape(*new_shape)
|
||||
|
||||
# repeat in all dimensions - but preserve shape of source dimension,
|
||||
# and make sure target_dimension has singleton dimension
|
||||
expand_shape = list(x.shape)
|
||||
expand_shape[source_dim] = -1
|
||||
expand_shape[target_dim] = 1
|
||||
indices = indices.expand(*expand_shape)
|
||||
|
||||
out = x.gather(dim=target_dim, index=indices)
|
||||
return out.squeeze(target_dim)
|
||||
|
||||
|
||||
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
|
||||
"""
|
||||
Apply @gather_along_dim_with_dim_single to all tensors in a nested
|
||||
dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
target_dim (int): dimension to gather values along
|
||||
source_dim (int): dimension to hold constant and use for gathering values
|
||||
from the other dimensions
|
||||
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
|
||||
@source_dim
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple
|
||||
"""
|
||||
return map_tensor(x,
|
||||
lambda y, t=target_dim, s=source_dim, i=indices:
|
||||
gather_along_dim_with_dim_single(y, t, s, i))
|
||||
|
||||
|
||||
def gather_sequence_single(seq, indices):
|
||||
"""
|
||||
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
|
||||
the batch given an index for each sequence.
|
||||
|
||||
Args:
|
||||
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Return:
|
||||
y (torch.Tensor): indexed tensor of shape [B, ....]
|
||||
"""
|
||||
return gather_along_dim_with_dim_single(seq,
|
||||
target_dim=1,
|
||||
source_dim=0,
|
||||
indices=indices)
|
||||
|
||||
|
||||
def gather_sequence(seq, indices):
|
||||
"""
|
||||
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
|
||||
for tensors with leading dimensions [B, T, ...].
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
indices (torch.Tensor): tensor indices of shape [B]
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
|
||||
"""
|
||||
return gather_along_dim_with_dim(seq,
|
||||
target_dim=1,
|
||||
source_dim=0,
|
||||
indices=indices)
|
||||
|
||||
|
||||
def pad_sequence_single(seq,
|
||||
padding,
|
||||
batched=False,
|
||||
pad_same=True,
|
||||
pad_values=None):
|
||||
"""
|
||||
Pad input tensor or array @seq in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (np.ndarray or torch.Tensor): sequence to be padded
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (np.ndarray or torch.Tensor)
|
||||
"""
|
||||
assert isinstance(seq, (np.ndarray, torch.Tensor))
|
||||
assert pad_same or pad_values is not None
|
||||
if pad_values is not None:
|
||||
assert isinstance(pad_values, float)
|
||||
repeat_func = np.repeat if isinstance(
|
||||
seq, np.ndarray) else torch.repeat_interleave
|
||||
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
|
||||
ones_like_func = np.ones_like if isinstance(
|
||||
seq, np.ndarray) else torch.ones_like
|
||||
seq_dim = 1 if batched else 0
|
||||
|
||||
begin_pad = []
|
||||
end_pad = []
|
||||
|
||||
if padding[0] > 0:
|
||||
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
|
||||
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
|
||||
if padding[1] > 0:
|
||||
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
|
||||
end_pad.append(repeat_func(pad, padding[1], seq_dim))
|
||||
|
||||
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
|
||||
|
||||
|
||||
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
|
||||
"""
|
||||
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
|
||||
|
||||
Args:
|
||||
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
|
||||
batched (bool): if sequence has the batch dimension
|
||||
pad_same (bool): if pad by duplicating
|
||||
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
|
||||
|
||||
Returns:
|
||||
padded sequence (dict or list or tuple)
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
seq, {
|
||||
torch.Tensor:
|
||||
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
|
||||
pad_sequence_single(x, p, b, ps, pv),
|
||||
np.ndarray:
|
||||
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
|
||||
pad_sequence_single(x, p, b, ps, pv),
|
||||
type(None):
|
||||
lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def assert_size_at_dim_single(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that array or tensor @x has size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (np.ndarray or torch.Tensor): input array or tensor
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
msg (str): text to display if assertion fails
|
||||
"""
|
||||
assert x.shape[dim] == size, msg
|
||||
|
||||
|
||||
def assert_size_at_dim(x, size, dim, msg):
|
||||
"""
|
||||
Ensure that arrays and tensors in nested dictionary or list or tuple have
|
||||
size @size in dim @dim.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
size (int): size that tensors should have at @dim
|
||||
dim (int): dimension to check
|
||||
"""
|
||||
map_tensor(
|
||||
x,
|
||||
lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
|
||||
|
||||
|
||||
def get_shape(x):
|
||||
"""
|
||||
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
|
||||
|
||||
Args:
|
||||
x (dict or list or tuple): a possibly nested dictionary or list or tuple
|
||||
|
||||
Returns:
|
||||
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
|
||||
tensor's shape
|
||||
"""
|
||||
return recursive_dict_list_tuple_apply(
|
||||
x, {
|
||||
torch.Tensor: lambda x: x.shape,
|
||||
np.ndarray: lambda x: x.shape,
|
||||
type(None): lambda x: x,
|
||||
})
|
||||
|
||||
|
||||
def list_of_flat_dict_to_dict_of_list(list_of_dict):
|
||||
"""
|
||||
Helper function to go from a list of flat dictionaries to a dictionary of lists.
|
||||
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
|
||||
floats, etc.
|
||||
|
||||
Args:
|
||||
list_of_dict (list): list of flat dictionaries
|
||||
|
||||
Returns:
|
||||
dict_of_list (dict): dictionary of lists
|
||||
"""
|
||||
assert isinstance(list_of_dict, list)
|
||||
dic = collections.OrderedDict()
|
||||
for i in range(len(list_of_dict)):
|
||||
for k in list_of_dict[i]:
|
||||
if k not in dic:
|
||||
dic[k] = []
|
||||
dic[k].append(list_of_dict[i][k])
|
||||
return dic
|
||||
|
||||
|
||||
def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):
|
||||
"""
|
||||
Flatten a nested dict or list to a list.
|
||||
|
||||
For example, given a dict
|
||||
{
|
||||
a: 1
|
||||
b: {
|
||||
c: 2
|
||||
}
|
||||
c: 3
|
||||
}
|
||||
|
||||
the function would return [(a, 1), (b_c, 2), (c, 3)]
|
||||
|
||||
Args:
|
||||
d (dict, list): a nested dict or list to be flattened
|
||||
parent_key (str): recursion helper
|
||||
sep (str): separator for nesting keys
|
||||
item_key (str): recursion helper
|
||||
Returns:
|
||||
list: a list of (key, value) tuples
|
||||
"""
|
||||
items = []
|
||||
if isinstance(d, (tuple, list)):
|
||||
new_key = parent_key + sep + item_key if len(
|
||||
parent_key) > 0 else item_key
|
||||
for i, v in enumerate(d):
|
||||
items.extend(
|
||||
flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
|
||||
return items
|
||||
elif isinstance(d, dict):
|
||||
new_key = parent_key + sep + item_key if len(
|
||||
parent_key) > 0 else item_key
|
||||
for k, v in d.items():
|
||||
assert isinstance(k, str)
|
||||
items.extend(
|
||||
flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
|
||||
return items
|
||||
else:
|
||||
new_key = parent_key + sep + item_key if len(
|
||||
parent_key) > 0 else item_key
|
||||
return [(new_key, d)]
|
||||
|
||||
|
||||
def time_distributed(inputs,
|
||||
op,
|
||||
activation=None,
|
||||
inputs_as_kwargs=False,
|
||||
inputs_as_args=False,
|
||||
**kwargs):
|
||||
"""
|
||||
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
|
||||
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
|
||||
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
|
||||
outputs to [B, T, ...].
|
||||
|
||||
Args:
|
||||
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
|
||||
of leading dimensions [B, T, ...]
|
||||
op: a layer op that accepts inputs
|
||||
activation: activation to apply at the output
|
||||
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
|
||||
inputs_as_args (bool) whether to feed input as a args list to the op
|
||||
kwargs (dict): other kwargs to supply to the op
|
||||
|
||||
Returns:
|
||||
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
|
||||
"""
|
||||
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
|
||||
inputs = join_dimensions(inputs, 0, 1)
|
||||
if inputs_as_kwargs:
|
||||
outputs = op(**inputs, **kwargs)
|
||||
elif inputs_as_args:
|
||||
outputs = op(*inputs, **kwargs)
|
||||
else:
|
||||
outputs = op(inputs, **kwargs)
|
||||
|
||||
if activation is not None:
|
||||
outputs = map_tensor(outputs, activation)
|
||||
outputs = reshape_dimensions(outputs,
|
||||
begin_axis=0,
|
||||
end_axis=0,
|
||||
target_dims=(batch_size, seq_len))
|
||||
return outputs
|
||||
701
src/unifolm_wma/models/diffusion_head/conditional_unet1d.py
Normal file
701
src/unifolm_wma/models/diffusion_head/conditional_unet1d.py
Normal file
@@ -0,0 +1,701 @@
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import einops
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from typing import Union
|
||||
|
||||
from unifolm_wma.models.diffusion_head.conv1d_components import (
|
||||
Downsample1d, Upsample1d, Conv1dBlock)
|
||||
from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb
|
||||
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
|
||||
|
||||
from unifolm_wma.utils.basics import zero_module
|
||||
from unifolm_wma.utils.common import (
|
||||
checkpoint,
|
||||
exists,
|
||||
default,
|
||||
)
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
relative_position=False):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout))
|
||||
|
||||
def efficient_forward(self, x, context=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
|
||||
q = self.to_q(x)
|
||||
if spatial_self_attn:
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out = (out.unsqueeze(0).reshape(
|
||||
b, self.heads, out.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
attention_cls=None):
|
||||
super().__init__()
|
||||
attn_cls = CrossAttention if attention_cls is None else attention_cls
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None)
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout)
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None, **kwargs):
|
||||
## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
|
||||
input_tuple = (
|
||||
x,
|
||||
) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
|
||||
if context is not None:
|
||||
input_tuple = (x, context)
|
||||
return checkpoint(self._forward, input_tuple, self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, mask=None):
|
||||
x = self.attn1(self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
mask=mask) + x
|
||||
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class ActionLatentImageCrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
in_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
use_checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
use_linear=True):
|
||||
super().__init__()
|
||||
"""
|
||||
in_channels: action input dim
|
||||
|
||||
"""
|
||||
self.in_channels = in_channels
|
||||
self.in_dim = in_dim
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=8,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
|
||||
self.proj_in_action = nn.Linear(in_dim, inner_dim)
|
||||
self.proj_in_cond = nn.Linear(context_dim, inner_dim)
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_dim))
|
||||
self.use_linear = use_linear
|
||||
|
||||
attention_cls = None
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
checkpoint=use_checkpoint,
|
||||
attention_cls=attention_cls)
|
||||
for d in range(depth)
|
||||
])
|
||||
|
||||
def forward(self, x, context=None, **kwargs):
|
||||
ba, ca, da = x.shape
|
||||
b, t, c, h, w = context.shape
|
||||
context = rearrange(context, 'b t c h w -> b (t h w) c').contiguous()
|
||||
|
||||
x_in = x
|
||||
x = self.norm(x) # ba x ja x d_in
|
||||
if self.use_linear:
|
||||
x = self.proj_in_action(x)
|
||||
context = self.proj_in_cond(context)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context, **kwargs)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class ConditionalResidualBlock1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
cond_dim,
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=True,
|
||||
use_linear_act_proj=False):
|
||||
super().__init__()
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
Conv1dBlock(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
n_groups=n_groups),
|
||||
Conv1dBlock(out_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
n_groups=n_groups),
|
||||
])
|
||||
|
||||
self.cond_predict_scale = cond_predict_scale
|
||||
self.use_linear_act_proj = use_linear_act_proj
|
||||
self.out_channels = out_channels
|
||||
# FiLM modulation https://arxiv.org/abs/1709.07871
|
||||
# predicts per-channel scale and bias
|
||||
cond_channels = out_channels
|
||||
if cond_predict_scale and use_linear_act_proj:
|
||||
cond_channels = out_channels * 2
|
||||
self.cond_encoder = nn.Sequential(
|
||||
nn.Mish(),
|
||||
nn.Linear(cond_dim, cond_channels),
|
||||
)
|
||||
# make sure dimensions compatible
|
||||
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
|
||||
if in_channels != out_channels else nn.Identity()
|
||||
|
||||
def forward(self, x, cond=None):
|
||||
'''
|
||||
x : [ batch_size x in_channels x horizon ]
|
||||
cond : [ batch_size x cond_dim]
|
||||
|
||||
returns:
|
||||
out : [ batch_size x out_channels x horizon ]
|
||||
'''
|
||||
B, T, _ = cond.shape
|
||||
|
||||
out = self.blocks[0](x)
|
||||
if self.cond_predict_scale:
|
||||
embed = self.cond_encoder(cond)
|
||||
if self.use_linear_act_proj:
|
||||
embed = embed.reshape(B * T, -1)
|
||||
embed = embed.reshape(-1, 2, self.out_channels, 1)
|
||||
else:
|
||||
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
|
||||
scale = embed[:, 0, ...]
|
||||
bias = embed[:, 1, ...]
|
||||
out = scale * out + bias
|
||||
# else:
|
||||
# out = out + embed
|
||||
out = self.blocks[1](out)
|
||||
out = out + self.residual_conv(x)
|
||||
return out
|
||||
|
||||
|
||||
class ConditionalUnet1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dim,
|
||||
n_obs_steps=1,
|
||||
local_cond_dim=None,
|
||||
global_cond_dim=None,
|
||||
diffusion_step_embed_dim=256,
|
||||
down_dims=[256, 512, 1024],
|
||||
kernel_size=3,
|
||||
n_groups=8,
|
||||
cond_predict_scale=False,
|
||||
horizon=16,
|
||||
num_head_channels=64,
|
||||
use_linear_attn=True,
|
||||
use_linear_act_proj=True,
|
||||
act_proj_dim=32,
|
||||
cond_cross_attention=False,
|
||||
context_dims=None,
|
||||
image_size=None,
|
||||
imagen_cond_gradient=False,
|
||||
last_frame_only=False,
|
||||
use_imagen_mid_only=False,
|
||||
use_z_only=False,
|
||||
spatial_num_kp=32,
|
||||
obs_encoder_config=None):
|
||||
super().__init__()
|
||||
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.obs_encoder = instantiate_from_config(obs_encoder_config)
|
||||
|
||||
all_dims = [input_dim] + list(down_dims)
|
||||
start_dim = down_dims[0]
|
||||
|
||||
dsed = diffusion_step_embed_dim
|
||||
diffusion_step_encoder = nn.Sequential(
|
||||
SinusoidalPosEmb(dsed),
|
||||
nn.Linear(dsed, dsed * 4),
|
||||
nn.Mish(),
|
||||
nn.Linear(dsed * 4, dsed),
|
||||
)
|
||||
cond_dim = dsed + self.obs_encoder.output_shape()[-1] * self.n_obs_steps
|
||||
in_out = list(zip(all_dims[:-1], all_dims[1:]))
|
||||
local_cond_encoder = None
|
||||
down_modules = nn.ModuleList([])
|
||||
|
||||
dim_a_list = []
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
if ind == 0:
|
||||
dim_a = horizon
|
||||
else:
|
||||
dim_a = horizon // 2 * ind
|
||||
dim_a_list.append(dim_a)
|
||||
|
||||
# for attention
|
||||
num_heads = dim_out // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if use_linear_act_proj:
|
||||
if use_imagen_mid_only:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[-1]
|
||||
elif use_z_only:
|
||||
cur_cond_dim = cond_dim + 2 * spatial_num_kp
|
||||
else:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[ind]
|
||||
else:
|
||||
cur_cond_dim = cond_dim + horizon * context_dims[ind]
|
||||
|
||||
down_modules.append(
|
||||
nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_out,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out,
|
||||
dim_out,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ActionLatentImageCrossAttention(
|
||||
dim_out,
|
||||
dim_a,
|
||||
num_heads,
|
||||
dim_head,
|
||||
context_dim=context_dims[ind],
|
||||
use_linear=use_linear_attn)
|
||||
if cond_cross_attention else nn.Identity(),
|
||||
Downsample1d(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
mid_dim = all_dims[-1]
|
||||
self.mid_modules = nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ConditionalResidualBlock1D(
|
||||
mid_dim,
|
||||
mid_dim,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ActionLatentImageCrossAttention(mid_dim,
|
||||
dim_a_list[-1],
|
||||
num_heads,
|
||||
dim_head,
|
||||
context_dim=context_dims[-1],
|
||||
use_linear=use_linear_attn)
|
||||
if cond_cross_attention else nn.Identity(),
|
||||
])
|
||||
|
||||
up_modules = nn.ModuleList([])
|
||||
context_dims = context_dims[::-1]
|
||||
for ind, (dim_in, dim_out) in enumerate(
|
||||
reversed(in_out[1:] + [(down_dims[-1], down_dims[-1])])):
|
||||
is_last = ind >= (len(in_out) - 1)
|
||||
if use_linear_act_proj:
|
||||
if use_imagen_mid_only:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[0]
|
||||
elif use_z_only:
|
||||
cur_cond_dim = cond_dim + 2 * spatial_num_kp
|
||||
else:
|
||||
cur_cond_dim = cond_dim + 2 * context_dims[ind]
|
||||
else:
|
||||
cur_cond_dim = cond_dim + horizon * context_dims[ind]
|
||||
up_modules.append(
|
||||
nn.ModuleList([
|
||||
ConditionalResidualBlock1D(
|
||||
dim_out + dim_in,
|
||||
dim_in,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ConditionalResidualBlock1D(
|
||||
dim_in,
|
||||
dim_in,
|
||||
cond_dim=cur_cond_dim,
|
||||
kernel_size=kernel_size,
|
||||
n_groups=n_groups,
|
||||
cond_predict_scale=cond_predict_scale,
|
||||
use_linear_act_proj=use_linear_act_proj),
|
||||
ActionLatentImageCrossAttention(
|
||||
dim_in,
|
||||
dim_a_list.pop(),
|
||||
num_heads,
|
||||
dim_head,
|
||||
context_dim=context_dims[ind],
|
||||
use_linear=use_linear_attn)
|
||||
if cond_cross_attention else nn.Identity(),
|
||||
Upsample1d(dim_in) if not is_last else nn.Identity()
|
||||
]))
|
||||
|
||||
final_conv = nn.Sequential(
|
||||
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
|
||||
nn.Conv1d(start_dim, input_dim, 1),
|
||||
)
|
||||
|
||||
if use_z_only:
|
||||
h, w = image_size
|
||||
self.spatial_softmax_blocks = nn.ModuleList(
|
||||
[SpatialSoftmax((4, h, w), spatial_num_kp)])
|
||||
else:
|
||||
self.spatial_softmax_blocks = nn.ModuleList([])
|
||||
context_dims = context_dims[::-1]
|
||||
for ind, context_dim in enumerate(context_dims):
|
||||
h, w = image_size
|
||||
if ind != 0:
|
||||
h //= 2**ind
|
||||
w //= 2**ind
|
||||
net = SpatialSoftmax((context_dim, h, w), context_dim)
|
||||
self.spatial_softmax_blocks.append(net)
|
||||
self.spatial_softmax_blocks.append(net)
|
||||
self.spatial_softmax_blocks += self.spatial_softmax_blocks[
|
||||
0:4][::-1]
|
||||
|
||||
self.diffusion_step_encoder = diffusion_step_encoder
|
||||
self.local_cond_encoder = local_cond_encoder
|
||||
self.up_modules = up_modules
|
||||
self.down_modules = down_modules
|
||||
self.final_conv = final_conv
|
||||
|
||||
self.cond_cross_attention = cond_cross_attention
|
||||
self.use_linear_act_proj = use_linear_act_proj
|
||||
|
||||
self.proj_in_action = nn.Sequential(nn.Linear(1, act_proj_dim),
|
||||
nn.LayerNorm(act_proj_dim))
|
||||
self.proj_in_horizon = nn.Sequential(nn.Linear(horizon, act_proj_dim),
|
||||
nn.LayerNorm(act_proj_dim))
|
||||
self.proj_out_action = nn.Sequential(nn.LayerNorm(act_proj_dim),
|
||||
nn.Linear(act_proj_dim, 1))
|
||||
self.proj_out_horizon = nn.Sequential(nn.LayerNorm(act_proj_dim),
|
||||
nn.Linear(act_proj_dim, horizon))
|
||||
logger.info("number of parameters: %e",
|
||||
sum(p.numel() for p in self.parameters()))
|
||||
|
||||
self.imagen_cond_gradient = imagen_cond_gradient
|
||||
self.use_imagen_mid_only = use_imagen_mid_only
|
||||
self.use_z_only = use_z_only
|
||||
self.spatial_num_kp = spatial_num_kp
|
||||
self.last_frame_only = last_frame_only
|
||||
self.horizon = horizon
|
||||
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
imagen_cond=None,
|
||||
cond=None,
|
||||
**kwargs):
|
||||
"""
|
||||
sample: (B,T,input_dim)
|
||||
timestep: (B,) or int, diffusion step
|
||||
imagen_cond: a list of hidden info from video gen unet
|
||||
cond: dict:
|
||||
image: (B, 3, To, h, w)
|
||||
agent_pos: (B, Ta, d)
|
||||
output: (B,T,input_dim)
|
||||
"""
|
||||
|
||||
if not self.imagen_cond_gradient:
|
||||
imagen_cond = [c.detach() for c in imagen_cond]
|
||||
|
||||
cond = {'image': cond[0], 'agent_pos': cond[1]}
|
||||
|
||||
cond['image'] = cond['image'].permute(0, 2, 1, 3,
|
||||
4)
|
||||
cond['image'] = rearrange(cond['image'], 'b t c h w -> (b t) c h w')
|
||||
cond['agent_pos'] = rearrange(cond['agent_pos'], 'b t d -> (b t) d')
|
||||
|
||||
B, T, D = sample.shape
|
||||
if self.use_linear_act_proj:
|
||||
sample = self.proj_in_action(sample.unsqueeze(-1))
|
||||
global_cond = self.obs_encoder(cond)
|
||||
global_cond = rearrange(global_cond,
|
||||
'(b t) d -> b 1 (t d)',
|
||||
b=B,
|
||||
t=self.n_obs_steps)
|
||||
global_cond = repeat(global_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=T)
|
||||
else:
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
sample = self.proj_in_horizon(sample)
|
||||
robo_state_cond = rearrange(robo_state_cond, 'b t d -> b 1 (t d)')
|
||||
robo_state_cond = repeat(robo_state_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=2)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps],
|
||||
dtype=torch.long,
|
||||
device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
global_feature = self.diffusion_step_encoder(timesteps)
|
||||
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
||||
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
||||
|
||||
x = sample if not self.use_linear_act_proj else sample.reshape(
|
||||
B * T, D, -1)
|
||||
h = []
|
||||
for idx, modules in enumerate(self.down_modules):
|
||||
if self.cond_cross_attention:
|
||||
(resnet, resnet2, crossatten, downsample) = modules
|
||||
else:
|
||||
(resnet, resnet2, _, downsample) = modules
|
||||
|
||||
# Access the cond from the unet embeds from video unet
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = imagen_cond_mid
|
||||
elif self.use_z_only:
|
||||
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
imagen_cond = imagen_cond_down[idx]
|
||||
if self.last_frame_only:
|
||||
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
||||
imagen_cond = repeat(imagen_cond,
|
||||
'b t c h w -> b (repeat t) c h w',
|
||||
repeat=self.horizon)
|
||||
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[len(
|
||||
self.spatial_softmax_blocks) // 2](imagen_cond)
|
||||
elif self.use_z_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
||||
else:
|
||||
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=2, dim=1)
|
||||
cur_global_feature = torch.cat(
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
#>>> mide blocks
|
||||
resnet, resnet2, _ = self.mid_modules
|
||||
# Access the cond from the unet embeds from video unet
|
||||
if self.use_z_only:
|
||||
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
imagen_cond = imagen_cond_mid
|
||||
if self.last_frame_only:
|
||||
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
||||
imagen_cond = repeat(imagen_cond,
|
||||
'b t c h w -> b (repeat t) c h w',
|
||||
repeat=self.horizon)
|
||||
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
||||
idx += 1
|
||||
if self.use_z_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
||||
else:
|
||||
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
||||
repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
||||
repeats=2, dim=1)
|
||||
cur_global_feature = torch.cat(
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
|
||||
#>>> up blocks
|
||||
idx += 1
|
||||
for jdx, modules in enumerate(self.up_modules):
|
||||
if self.cond_cross_attention:
|
||||
(resnet, resnet2, crossatten, upsample) = modules
|
||||
else:
|
||||
(resnet, resnet2, _, upsample) = modules
|
||||
|
||||
# Access the cond from the unet embeds from video unet
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = imagen_cond_mid
|
||||
elif self.use_z_only:
|
||||
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
imagen_cond = imagen_cond_up[jdx]
|
||||
if self.last_frame_only:
|
||||
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
|
||||
imagen_cond = repeat(imagen_cond,
|
||||
'b t c h w -> b (repeat t) c h w',
|
||||
repeat=self.horizon)
|
||||
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
|
||||
if self.use_imagen_mid_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[len(
|
||||
self.spatial_softmax_blocks) // 2](imagen_cond)
|
||||
elif self.use_z_only:
|
||||
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
|
||||
else:
|
||||
imagen_cond = self.spatial_softmax_blocks[jdx +
|
||||
idx](imagen_cond)
|
||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=T, dim=1)
|
||||
else:
|
||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||
cur_global_feature = global_feature.unsqueeze(
|
||||
1).repeat_interleave(repeats=2, dim=1)
|
||||
|
||||
cur_global_feature = torch.cat(
|
||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
||||
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = resnet(x, cur_global_feature)
|
||||
x = resnet2(x, cur_global_feature)
|
||||
x = upsample(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
if self.use_linear_act_proj:
|
||||
x = x.reshape(B, T, D, -1)
|
||||
x = self.proj_out_action(x)
|
||||
x = x.reshape(B, T, D)
|
||||
else:
|
||||
x = self.proj_out_horizon(x)
|
||||
x = einops.rearrange(x, 'b t h -> b h t')
|
||||
return x
|
||||
52
src/unifolm_wma/models/diffusion_head/conv1d_components.py
Normal file
52
src/unifolm_wma/models/diffusion_head/conv1d_components.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Downsample1d(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Upsample1d(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class Conv1dBlock(nn.Module):
|
||||
'''
|
||||
Conv1d --> GroupNorm --> Mish
|
||||
'''
|
||||
|
||||
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
|
||||
super().__init__()
|
||||
|
||||
self.block = nn.Sequential(
|
||||
nn.Conv1d(inp_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2),
|
||||
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
|
||||
nn.GroupNorm(n_groups, out_channels),
|
||||
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.block(x)
|
||||
|
||||
|
||||
def test():
|
||||
cb = Conv1dBlock(256, 128, kernel_size=3)
|
||||
x = torch.zeros((1, 256, 16))
|
||||
o = cb(x)
|
||||
80
src/unifolm_wma/models/diffusion_head/ema_model.py
Normal file
80
src/unifolm_wma/models/diffusion_head/ema_model.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
|
||||
class EMAModel:
|
||||
"""
|
||||
Exponential Moving Average of models weights
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
update_after_step=0,
|
||||
inv_gamma=1.0,
|
||||
power=2 / 3,
|
||||
min_value=0.0,
|
||||
max_value=0.9999):
|
||||
"""
|
||||
@crowsonkb's notes on EMA Warmup:
|
||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||
at 215.4k steps).
|
||||
Args:
|
||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
||||
"""
|
||||
|
||||
self.averaged_model = model
|
||||
self.averaged_model.eval()
|
||||
self.averaged_model.requires_grad_(False)
|
||||
|
||||
self.update_after_step = update_after_step
|
||||
self.inv_gamma = inv_gamma
|
||||
self.power = power
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
|
||||
self.decay = 0.0
|
||||
self.optimization_step = 0
|
||||
|
||||
def get_decay(self, optimization_step):
|
||||
"""
|
||||
Compute the decay factor for the exponential moving average.
|
||||
"""
|
||||
step = max(0, optimization_step - self.update_after_step - 1)
|
||||
value = 1 - (1 + step / self.inv_gamma)**-self.power
|
||||
|
||||
if step <= 0:
|
||||
return 0.0
|
||||
|
||||
return max(self.min_value, min(value, self.max_value))
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, new_model):
|
||||
self.decay = self.get_decay(self.optimization_step)
|
||||
|
||||
all_dataptrs = set()
|
||||
for module, ema_module in zip(new_model.modules(),
|
||||
self.averaged_model.modules()):
|
||||
for param, ema_param in zip(module.parameters(recurse=False),
|
||||
ema_module.parameters(recurse=False)):
|
||||
# iterative over immediate parameters only.
|
||||
if isinstance(param, dict):
|
||||
raise RuntimeError('Dict parameter not supported')
|
||||
|
||||
if isinstance(module, _BatchNorm):
|
||||
# skip batchnorms
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
elif not param.requires_grad:
|
||||
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
|
||||
else:
|
||||
ema_param.mul_(self.decay)
|
||||
ema_param.add_(param.data.to(dtype=ema_param.dtype),
|
||||
alpha=1 - self.decay)
|
||||
|
||||
# verify that iterating over module and then parameters is identical to parameters recursively.
|
||||
# assert old_all_dataptrs == all_dataptrs
|
||||
self.optimization_step += 1
|
||||
@@ -0,0 +1,19 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
322
src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py
Normal file
322
src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as ttf
|
||||
import unifolm_wma.models.diffusion_head.common.tensor_util as tu
|
||||
|
||||
|
||||
class CropRandomizer(nn.Module):
|
||||
"""
|
||||
Randomly sample crops at input, and then average across crop features at output.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_shape,
|
||||
crop_height,
|
||||
crop_width,
|
||||
num_crops=1,
|
||||
pos_enc=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_shape (tuple, list): shape of input (not including batch dimension)
|
||||
crop_height (int): crop height
|
||||
crop_width (int): crop width
|
||||
num_crops (int): number of random crops to take
|
||||
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
|
||||
location of the cropped pixels in the source image
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert len(input_shape) == 3 # (C, H, W)
|
||||
assert crop_height < input_shape[1]
|
||||
assert crop_width < input_shape[2]
|
||||
|
||||
self.input_shape = input_shape
|
||||
self.crop_height = crop_height
|
||||
self.crop_width = crop_width
|
||||
self.num_crops = num_crops
|
||||
self.pos_enc = pos_enc
|
||||
|
||||
def output_shape_in(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_in operation, where raw inputs (usually observation modalities)
|
||||
are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
|
||||
# the number of crops are reshaped into the batch dimension, increasing the batch
|
||||
# size from B to B * N
|
||||
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
|
||||
return [out_c, self.crop_height, self.crop_width]
|
||||
|
||||
def output_shape_out(self, input_shape=None):
|
||||
"""
|
||||
Function to compute output shape from inputs to this module. Corresponds to
|
||||
the @forward_out operation, where processed inputs (usually encoded observation
|
||||
modalities) are passed in.
|
||||
|
||||
Args:
|
||||
input_shape (iterable of int): shape of input. Does not include batch dimension.
|
||||
Some modules may not need this argument, if their output does not depend
|
||||
on the size of the input, or if they assume fixed size input.
|
||||
|
||||
Returns:
|
||||
out_shape ([int]): list of integers corresponding to output shape
|
||||
"""
|
||||
|
||||
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
|
||||
# and then pools to result in [B, ...], only the batch dimension changes,
|
||||
# and so the other dimensions retain their shape.
|
||||
return list(input_shape)
|
||||
|
||||
def forward_in(self, inputs):
|
||||
"""
|
||||
Samples N random crops for each input in the batch, and then reshapes
|
||||
inputs to [B * N, ...].
|
||||
"""
|
||||
assert len(
|
||||
inputs.shape) >= 3 # must have at least (C, H, W) dimensions
|
||||
if self.training:
|
||||
# generate random crops
|
||||
out, _ = sample_random_image_crops(
|
||||
images=inputs,
|
||||
crop_height=self.crop_height,
|
||||
crop_width=self.crop_width,
|
||||
num_crops=self.num_crops,
|
||||
pos_enc=self.pos_enc,
|
||||
)
|
||||
# [B, N, ...] -> [B * N, ...]
|
||||
return tu.join_dimensions(out, 0, 1)
|
||||
else:
|
||||
# take center crop during eval
|
||||
out = ttf.center_crop(img=inputs,
|
||||
output_size=(self.crop_height,
|
||||
self.crop_width))
|
||||
if self.num_crops > 1:
|
||||
B, C, H, W = out.shape
|
||||
out = out.unsqueeze(1).expand(B, self.num_crops, C, H,
|
||||
W).reshape(-1, C, H, W)
|
||||
# [B * N, ...]
|
||||
return out
|
||||
|
||||
def forward_out(self, inputs):
|
||||
"""
|
||||
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
|
||||
to result in shape [B, ...] to make sure the network output is consistent with
|
||||
what would have happened if there were no randomization.
|
||||
"""
|
||||
if self.num_crops <= 1:
|
||||
return inputs
|
||||
else:
|
||||
batch_size = (inputs.shape[0] // self.num_crops)
|
||||
out = tu.reshape_dimensions(inputs,
|
||||
begin_axis=0,
|
||||
end_axis=0,
|
||||
target_dims=(batch_size,
|
||||
self.num_crops))
|
||||
return out.mean(dim=1)
|
||||
|
||||
def forward(self, inputs):
|
||||
return self.forward_in(inputs)
|
||||
|
||||
def __repr__(self):
|
||||
"""Pretty print network."""
|
||||
header = '{}'.format(str(self.__class__.__name__))
|
||||
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
|
||||
self.input_shape, self.crop_height, self.crop_width,
|
||||
self.num_crops)
|
||||
return msg
|
||||
|
||||
|
||||
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
|
||||
"""
|
||||
Crops images at the locations specified by @crop_indices. Crops will be
|
||||
taken across all channels.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
|
||||
N is the number of crops to take per image and each entry corresponds
|
||||
to the pixel height and width of where to take the crop. Note that
|
||||
the indices can also be of shape [..., 2] if only 1 crop should
|
||||
be taken per image. Leading dimensions must be consistent with
|
||||
@images argument. Each index specifies the top left of the crop.
|
||||
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
|
||||
H and W are the height and width of @images and CH and CW are
|
||||
@crop_height and @crop_width.
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
Returns:
|
||||
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
|
||||
"""
|
||||
|
||||
# make sure length of input shapes is consistent
|
||||
assert crop_indices.shape[-1] == 2
|
||||
ndim_im_shape = len(images.shape)
|
||||
ndim_indices_shape = len(crop_indices.shape)
|
||||
assert (ndim_im_shape == ndim_indices_shape +
|
||||
1) or (ndim_im_shape == ndim_indices_shape + 2)
|
||||
|
||||
# maybe pad so that @crop_indices is shape [..., N, 2]
|
||||
is_padded = False
|
||||
if ndim_im_shape == ndim_indices_shape + 2:
|
||||
crop_indices = crop_indices.unsqueeze(-2)
|
||||
is_padded = True
|
||||
|
||||
# make sure leading dimensions between images and indices are consistent
|
||||
assert images.shape[:-3] == crop_indices.shape[:-2]
|
||||
|
||||
device = images.device
|
||||
image_c, image_h, image_w = images.shape[-3:]
|
||||
num_crops = crop_indices.shape[-2]
|
||||
|
||||
# make sure @crop_indices are in valid range
|
||||
assert (crop_indices[..., 0] >= 0).all().item()
|
||||
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
|
||||
assert (crop_indices[..., 1] >= 0).all().item()
|
||||
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
|
||||
|
||||
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
|
||||
|
||||
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
|
||||
crop_ind_grid_h = torch.arange(crop_height).to(device)
|
||||
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h,
|
||||
size=crop_width,
|
||||
dim=-1)
|
||||
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
|
||||
crop_ind_grid_w = torch.arange(crop_width).to(device)
|
||||
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w,
|
||||
size=crop_height,
|
||||
dim=0)
|
||||
# combine into shape [CH, CW, 2]
|
||||
crop_in_grid = torch.cat(
|
||||
(crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
|
||||
|
||||
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
|
||||
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
|
||||
# shape array that tells us which pixels from the corresponding source image to grab.
|
||||
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [
|
||||
crop_height, crop_width, 2
|
||||
]
|
||||
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(
|
||||
-2) + crop_in_grid.reshape(grid_reshape)
|
||||
|
||||
# For using @torch.gather, convert to flat indices from 2D indices, and also
|
||||
# repeat across the channel dimension. To get flat index of each pixel to grab for
|
||||
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
|
||||
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[
|
||||
..., 1] # shape [..., N, CH, CW]
|
||||
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c,
|
||||
dim=-3) # shape [..., N, C, CH, CW]
|
||||
all_crop_inds = tu.flatten(all_crop_inds,
|
||||
begin_axis=-2) # shape [..., N, C, CH * CW]
|
||||
|
||||
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
|
||||
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
|
||||
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
|
||||
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
|
||||
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
|
||||
reshape_axis = len(crops.shape) - 1
|
||||
crops = tu.reshape_dimensions(crops,
|
||||
begin_axis=reshape_axis,
|
||||
end_axis=reshape_axis,
|
||||
target_dims=(crop_height, crop_width))
|
||||
|
||||
if is_padded:
|
||||
# undo padding -> [..., C, CH, CW]
|
||||
crops = crops.squeeze(-4)
|
||||
return crops
|
||||
|
||||
|
||||
def sample_random_image_crops(images,
|
||||
crop_height,
|
||||
crop_width,
|
||||
num_crops,
|
||||
pos_enc=False):
|
||||
"""
|
||||
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
|
||||
@images.
|
||||
|
||||
Args:
|
||||
images (torch.Tensor): batch of images of shape [..., C, H, W]
|
||||
|
||||
crop_height (int): height of crop to take
|
||||
|
||||
crop_width (int): width of crop to take
|
||||
|
||||
num_crops (n): number of crops to sample
|
||||
|
||||
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
|
||||
encoding of the original source pixel locations. This means that the
|
||||
output crops will contain information about where in the source image
|
||||
it was sampled from.
|
||||
|
||||
Returns:
|
||||
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
|
||||
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
|
||||
|
||||
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
|
||||
"""
|
||||
device = images.device
|
||||
|
||||
# maybe add 2 channels of spatial encoding to the source image
|
||||
source_im = images
|
||||
if pos_enc:
|
||||
# spatial encoding [y, x] in [0, 1]
|
||||
h, w = source_im.shape[-2:]
|
||||
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
||||
pos_y = pos_y.float().to(device) / float(h)
|
||||
pos_x = pos_x.float().to(device) / float(w)
|
||||
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
|
||||
|
||||
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
|
||||
leading_shape = source_im.shape[:-3]
|
||||
position_enc = position_enc[(None, ) * len(leading_shape)]
|
||||
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
|
||||
|
||||
# concat across channel dimension with input
|
||||
source_im = torch.cat((source_im, position_enc), dim=-3)
|
||||
|
||||
# make sure sample boundaries ensure crops are fully within the images
|
||||
image_c, image_h, image_w = source_im.shape[-3:]
|
||||
max_sample_h = image_h - crop_height
|
||||
max_sample_w = image_w - crop_width
|
||||
|
||||
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
|
||||
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
|
||||
# we will sample [B, N] indices, but this supports having more than one leading dimension,
|
||||
# or possibly no leading dimension.
|
||||
#
|
||||
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
|
||||
crop_inds_h = (
|
||||
max_sample_h *
|
||||
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds_w = (
|
||||
max_sample_w *
|
||||
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
|
||||
crop_inds = torch.cat(
|
||||
(crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)),
|
||||
dim=-1) # shape [..., N, 2]
|
||||
|
||||
crops = crop_image_from_indices(
|
||||
images=source_im,
|
||||
crop_indices=crop_inds,
|
||||
crop_height=crop_height,
|
||||
crop_width=crop_width,
|
||||
)
|
||||
|
||||
return crops, crop_inds
|
||||
30
src/unifolm_wma/models/diffusion_head/vision/model_getter.py
Normal file
30
src/unifolm_wma/models/diffusion_head/vision/model_getter.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
|
||||
def get_resnet(name, weights=None, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
weights: "IMAGENET1K_V1", "r3m"
|
||||
"""
|
||||
# load r3m weights
|
||||
if (weights == "r3m") or (weights == "R3M"):
|
||||
return get_r3m(name=name, **kwargs)
|
||||
|
||||
func = getattr(torchvision.models, name)
|
||||
resnet = func(weights=weights, **kwargs)
|
||||
resnet.fc = torch.nn.Identity()
|
||||
return resnet
|
||||
|
||||
|
||||
def get_r3m(name, **kwargs):
|
||||
"""
|
||||
name: resnet18, resnet34, resnet50
|
||||
"""
|
||||
import r3m
|
||||
r3m.device = 'cpu'
|
||||
model = r3m.load_r3m(name)
|
||||
r3m_model = model.module
|
||||
resnet_model = r3m_model.convnet
|
||||
resnet_model = resnet_model.to('cpu')
|
||||
return resnet_model
|
||||
@@ -0,0 +1,247 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision
|
||||
import json
|
||||
import os
|
||||
|
||||
from unifolm_wma.models.diffusion_head.vision.crop_randomizer import CropRandomizer
|
||||
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
|
||||
from unifolm_wma.models.diffusion_head.common.module_attr_mixin import ModuleAttrMixin
|
||||
from unifolm_wma.models.diffusion_head.common.pytorch_util import dict_apply, replace_submodules
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from einops import rearrange, repeat
|
||||
from typing import Dict, Tuple, Union
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class MultiImageObsEncoder(ModuleAttrMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rgb_model_config: Dict,
|
||||
shape_meta_path: str | None = None,
|
||||
resize_shape: Union[Tuple[int, int], Dict[str, tuple],
|
||||
None] = None,
|
||||
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
|
||||
random_crop: bool = True,
|
||||
# replace BatchNorm with GroupNorm
|
||||
use_group_norm: bool = False,
|
||||
# use single rgb model for all rgb inputs
|
||||
share_rgb_model: bool = False,
|
||||
# renormalize rgb input with imagenet normalization
|
||||
# assuming input in [0,1]
|
||||
imagenet_norm: bool = False,
|
||||
use_spatial_softmax=False,
|
||||
spatial_softmax_kp=32,
|
||||
use_dinoSiglip=False):
|
||||
"""
|
||||
Assumes rgb input: B,C,H,W
|
||||
Assumes low_dim input: B,D
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
if not shape_meta_path:
|
||||
shape_meta_path = str(Path(os.getcwd()) / "configs/train/meta.json")
|
||||
|
||||
with open(shape_meta_path, 'r') as file:
|
||||
shape_meta = json.load(file)
|
||||
|
||||
rgb_model = instantiate_from_config(rgb_model_config)
|
||||
|
||||
rgb_keys = list()
|
||||
low_dim_keys = list()
|
||||
key_model_map = nn.ModuleDict()
|
||||
key_transform_map = nn.ModuleDict()
|
||||
key_shape_map = dict()
|
||||
|
||||
# handle sharing vision backbone
|
||||
if share_rgb_model:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
key_model_map['rgb'] = rgb_model
|
||||
|
||||
obs_shape_meta = shape_meta['obs']
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr['shape'])
|
||||
type = attr.get('type', 'low_dim')
|
||||
key_shape_map[key] = shape
|
||||
if type == 'rgb':
|
||||
rgb_keys.append(key)
|
||||
if not use_dinoSiglip:
|
||||
# configure model for this key
|
||||
this_model = None
|
||||
if not share_rgb_model:
|
||||
if isinstance(rgb_model, dict):
|
||||
# have provided model for each key
|
||||
this_model = rgb_model[key]
|
||||
else:
|
||||
assert isinstance(rgb_model, nn.Module)
|
||||
# have a copy of the rgb model
|
||||
this_model = copy.deepcopy(rgb_model)
|
||||
|
||||
if this_model is not None:
|
||||
if use_group_norm:
|
||||
this_model = replace_submodules(
|
||||
root_module=this_model,
|
||||
predicate=lambda x: isinstance(
|
||||
x, nn.BatchNorm2d),
|
||||
func=lambda x: nn.GroupNorm(
|
||||
num_groups=x.num_features // 16,
|
||||
num_channels=x.num_features))
|
||||
key_model_map[key] = this_model
|
||||
|
||||
# configure resize
|
||||
input_shape = shape
|
||||
this_resizer = nn.Identity()
|
||||
if resize_shape is not None:
|
||||
if isinstance(resize_shape, dict):
|
||||
h, w = resize_shape[key]
|
||||
else:
|
||||
h, w = resize_shape
|
||||
this_resizer = torchvision.transforms.Resize(size=(h,
|
||||
w))
|
||||
input_shape = (shape[0], h, w)
|
||||
|
||||
# configure randomizer
|
||||
this_randomizer = nn.Identity()
|
||||
if crop_shape is not None:
|
||||
if isinstance(crop_shape, dict):
|
||||
h, w = crop_shape[key]
|
||||
else:
|
||||
h, w = crop_shape
|
||||
if random_crop:
|
||||
this_randomizer = CropRandomizer(
|
||||
input_shape=input_shape,
|
||||
crop_height=h,
|
||||
crop_width=w,
|
||||
num_crops=1,
|
||||
pos_enc=False)
|
||||
else:
|
||||
this_normalizer = torchvision.transforms.CenterCrop(
|
||||
size=(h, w))
|
||||
# configure normalizer
|
||||
this_normalizer = nn.Identity()
|
||||
if imagenet_norm:
|
||||
this_normalizer = torchvision.transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
this_transform = nn.Sequential(this_resizer,
|
||||
this_randomizer,
|
||||
this_normalizer)
|
||||
key_transform_map[key] = this_transform
|
||||
else:
|
||||
key_model_map[key] = rgb_model
|
||||
elif type == 'low_dim':
|
||||
low_dim_keys.append(key)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported obs type: {type}")
|
||||
|
||||
rgb_keys = sorted(rgb_keys)
|
||||
low_dim_keys = sorted(low_dim_keys)
|
||||
|
||||
self.shape_meta = shape_meta
|
||||
self.key_model_map = key_model_map
|
||||
self.key_transform_map = key_transform_map
|
||||
self.share_rgb_model = share_rgb_model
|
||||
self.rgb_keys = rgb_keys
|
||||
self.low_dim_keys = low_dim_keys
|
||||
self.key_shape_map = key_shape_map
|
||||
self.use_dinoSiglip = use_dinoSiglip
|
||||
|
||||
##NOTE add spatial softmax
|
||||
self.use_spatial_softmax = use_spatial_softmax
|
||||
if use_spatial_softmax and not use_dinoSiglip:
|
||||
model = nn.Sequential(
|
||||
key_model_map['image'].conv1,
|
||||
key_model_map['image'].bn1,
|
||||
key_model_map['image'].relu,
|
||||
key_model_map['image'].maxpool,
|
||||
key_model_map['image'].layer1,
|
||||
key_model_map['image'].layer2,
|
||||
key_model_map['image'].layer3,
|
||||
key_model_map['image'].layer4,
|
||||
)
|
||||
key_model_map['image'] = model
|
||||
input_shape = self.output_shape(resnet_output_shape=True)
|
||||
self.spatial_softmax = SpatialSoftmax(input_shape,
|
||||
num_kp=spatial_softmax_kp)
|
||||
|
||||
def forward(self, obs_dict, resnet_output_shape=False):
|
||||
batch_size = None
|
||||
features = list()
|
||||
# process rgb input
|
||||
if self.share_rgb_model:
|
||||
# pass all rgb obs to rgb model
|
||||
imgs = list()
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
img = self.key_transform_map[key](img)
|
||||
imgs.append(img)
|
||||
# (N*B,C,H,W)
|
||||
imgs = torch.cat(imgs, dim=0)
|
||||
# (N*B,D)
|
||||
feature = self.key_model_map['rgb'](imgs)
|
||||
# (N,B,D)
|
||||
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
|
||||
# (B,N,D)
|
||||
feature = torch.moveaxis(feature, 0, 1)
|
||||
# (B,N*D)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
features.append(feature)
|
||||
else:
|
||||
# run each rgb obs to independent models
|
||||
for key in self.rgb_keys:
|
||||
img = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = img.shape[0]
|
||||
else:
|
||||
assert batch_size == img.shape[0]
|
||||
assert img.shape[1:] == self.key_shape_map[key]
|
||||
if not self.use_dinoSiglip:
|
||||
img = self.key_transform_map[key](img)
|
||||
feature = self.key_model_map[key](img)
|
||||
else:
|
||||
feature = self.key_model_map[key](img)[:, :1, :]
|
||||
|
||||
if resnet_output_shape:
|
||||
return feature
|
||||
if not self.use_dinoSiglip and self.use_spatial_softmax:
|
||||
feature = self.spatial_softmax(feature)
|
||||
feature = feature.reshape(batch_size, -1)
|
||||
features.append(feature)
|
||||
|
||||
# process lowdim input
|
||||
for key in self.low_dim_keys:
|
||||
data = obs_dict[key]
|
||||
if batch_size is None:
|
||||
batch_size = data.shape[0]
|
||||
else:
|
||||
assert batch_size == data.shape[0]
|
||||
assert data.shape[1:] == self.key_shape_map[key]
|
||||
features.append(data)
|
||||
|
||||
# concatenate all features
|
||||
result = torch.cat(features, dim=-1)
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def output_shape(self, resnet_output_shape=False):
|
||||
example_obs_dict = dict()
|
||||
obs_shape_meta = self.shape_meta['obs']
|
||||
batch_size = 1
|
||||
for key, attr in obs_shape_meta.items():
|
||||
shape = tuple(attr['shape'])
|
||||
this_obs = torch.zeros((batch_size, ) + shape,
|
||||
dtype=self.dtype,
|
||||
device=self.device)
|
||||
example_obs_dict[key] = this_obs
|
||||
example_output = self.forward(example_obs_dict,
|
||||
resnet_output_shape=resnet_output_shape)
|
||||
output_shape = example_output.shape[1:]
|
||||
return output_shape
|
||||
473
src/unifolm_wma/models/samplers/ddim.py
Normal file
473
src/unifolm_wma/models/samplers/ddim.py
Normal file
@@ -0,0 +1,473 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import copy
|
||||
|
||||
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
|
||||
from unifolm_wma.utils.common import noise_like
|
||||
from unifolm_wma.utils.common import extract_into_tensor
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
|
||||
def __init__(self, model, schedule="linear", **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.counter = 0
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize="uniform",
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
|
||||
.device)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
|
||||
self.ddim_scale_arr_prev = torch.cat(
|
||||
[self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# Calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# DDIM sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
schedule_verbose=False,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
precision=None,
|
||||
fs=None,
|
||||
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
|
||||
# Check condition bs
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
try:
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
except:
|
||||
cbs = conditioning[list(
|
||||
conditioning.keys())[0]][0].shape[0]
|
||||
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S,
|
||||
ddim_discretize=timestep_spacing,
|
||||
ddim_eta=eta,
|
||||
verbose=schedule_verbose)
|
||||
|
||||
# Make shape
|
||||
if len(shape) == 3:
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
elif len(shape) == 4:
|
||||
C, T, H, W = shape
|
||||
size = (batch_size, C, T, H, W)
|
||||
|
||||
samples, actions, states, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
verbose=verbose,
|
||||
precision=precision,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
return samples, actions, states, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
verbose=True,
|
||||
precision=None,
|
||||
fs=None,
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
device = self.model.betas.device
|
||||
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
|
||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
else:
|
||||
img = x_T
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
img = img.to(dtype=torch.float16)
|
||||
action = action.to(dtype=torch.float16)
|
||||
state = state.to(dtype=torch.float16)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1) *
|
||||
self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {
|
||||
'x_inter': [img],
|
||||
'pred_x0': [img],
|
||||
'x_inter_action': [action],
|
||||
'pred_x0_action': [action],
|
||||
'x_inter_state': [state],
|
||||
'pred_x0_state': [state],
|
||||
}
|
||||
time_range = reversed(range(
|
||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
0]
|
||||
if verbose:
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
else:
|
||||
iterator = time_range
|
||||
|
||||
clean_cond = kwargs.pop("clean_cond", False)
|
||||
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
model_output_state,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
|
||||
return img, action, state, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self,
|
||||
x,
|
||||
x_action,
|
||||
x_state,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
uc_type=None,
|
||||
conditional_guidance_scale_temporal=None,
|
||||
mask=None,
|
||||
x0=None,
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
b, *_, device = *x.shape, x.device
|
||||
if x.dim() == 5:
|
||||
is_video = True
|
||||
else:
|
||||
is_video = False
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output, model_output_action, model_output_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, c, **kwargs) # unet denoiser
|
||||
else:
|
||||
# do_classifier_free_guidance
|
||||
if isinstance(c, torch.Tensor) or isinstance(c, dict):
|
||||
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, c, **kwargs)
|
||||
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
|
||||
x, x_action, x_state, t, unconditional_conditioning,
|
||||
**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t_cond - e_t_uncond)
|
||||
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
||||
e_t_cond_action - e_t_uncond_action)
|
||||
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
||||
e_t_cond_state - e_t_uncond_state)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
model_output = rescale_noise_cfg(
|
||||
model_output, e_t_cond, guidance_rescale=guidance_rescale)
|
||||
model_output_action = rescale_noise_cfg(
|
||||
model_output_action,
|
||||
e_t_cond_action,
|
||||
guidance_rescale=guidance_rescale)
|
||||
model_output_state = rescale_noise_cfg(
|
||||
model_output_state,
|
||||
e_t_cond_state,
|
||||
guidance_rescale=guidance_rescale)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||
**corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if is_video:
|
||||
size = (b, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (b, 1, 1, 1)
|
||||
|
||||
a_t = torch.full(size, alphas[index], device=device)
|
||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(size,
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
scale_t = torch.full(size,
|
||||
self.ddim_scale_arr[index],
|
||||
device=device)
|
||||
prev_scale_t = torch.full(size,
|
||||
self.ddim_scale_arr_prev[index],
|
||||
device=device)
|
||||
rescale = (prev_scale_t / scale_t)
|
||||
pred_x0 *= rescale
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
|
||||
return x_prev, pred_x0, model_output_action, model_output_state
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps
|
||||
) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f"Running DDIM Sampling with {total_steps} timesteps")
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0], ),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback: callback(i)
|
||||
return x_dec
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
|
||||
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) *
|
||||
noise)
|
||||
0
src/unifolm_wma/modules/__init__.py
Normal file
0
src/unifolm_wma/modules/__init__.py
Normal file
806
src/unifolm_wma/modules/attention.py
Normal file
806
src/unifolm_wma/modules/attention.py
Normal file
@@ -0,0 +1,806 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
from functools import partial
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
from unifolm_wma.utils.common import (
|
||||
checkpoint,
|
||||
exists,
|
||||
default,
|
||||
)
|
||||
from unifolm_wma.utils.basics import zero_module
|
||||
|
||||
|
||||
class RelativePosition(nn.Module):
|
||||
""" https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
|
||||
|
||||
def __init__(self, num_units, max_relative_position):
|
||||
super().__init__()
|
||||
self.num_units = num_units
|
||||
self.max_relative_position = max_relative_position
|
||||
self.embeddings_table = nn.Parameter(
|
||||
torch.Tensor(max_relative_position * 2 + 1, num_units))
|
||||
nn.init.xavier_uniform_(self.embeddings_table)
|
||||
|
||||
def forward(self, length_q, length_k):
|
||||
device = self.embeddings_table.device
|
||||
range_vec_q = torch.arange(length_q, device=device)
|
||||
range_vec_k = torch.arange(length_k, device=device)
|
||||
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
|
||||
distance_mat_clipped = torch.clamp(distance_mat,
|
||||
-self.max_relative_position,
|
||||
self.max_relative_position)
|
||||
final_mat = distance_mat_clipped + self.max_relative_position
|
||||
final_mat = final_mat.long()
|
||||
embeddings = self.embeddings_table[final_mat]
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.,
|
||||
relative_position=False,
|
||||
temporal_length=None,
|
||||
video_length=None,
|
||||
agent_state_context_len=2,
|
||||
agent_action_context_len=16,
|
||||
image_cross_attention=False,
|
||||
image_cross_attention_scale=1.0,
|
||||
agent_state_cross_attention_scale=1.0,
|
||||
agent_action_cross_attention_scale=1.0,
|
||||
cross_attention_scale_learnable=False,
|
||||
text_context_len=77):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout))
|
||||
|
||||
self.relative_position = relative_position
|
||||
if self.relative_position:
|
||||
assert (temporal_length is not None)
|
||||
self.relative_position_k = RelativePosition(
|
||||
num_units=dim_head, max_relative_position=temporal_length)
|
||||
self.relative_position_v = RelativePosition(
|
||||
num_units=dim_head, max_relative_position=temporal_length)
|
||||
else:
|
||||
## only used for spatial attention, while NOT for temporal attention
|
||||
if XFORMERS_IS_AVAILBLE and temporal_length is None:
|
||||
self.forward = self.efficient_forward
|
||||
|
||||
self.video_length = video_length
|
||||
self.image_cross_attention = image_cross_attention
|
||||
self.image_cross_attention_scale = image_cross_attention_scale
|
||||
self.agent_state_cross_attention_scale = agent_state_cross_attention_scale
|
||||
self.agent_action_cross_attention_scale = agent_action_cross_attention_scale
|
||||
self.text_context_len = text_context_len
|
||||
self.agent_state_context_len = agent_state_context_len
|
||||
self.agent_action_context_len = agent_action_context_len
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
if self.image_cross_attention:
|
||||
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_k_as = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_as = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_k_aa = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v_aa = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
if cross_attention_scale_learnable:
|
||||
self.register_parameter('alpha_ctx',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_cas',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
self.register_parameter('alpha_caa',
|
||||
nn.Parameter(torch.tensor(0.)))
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
k_as, v_as, out_as = None, None, None
|
||||
k_aa, v_aa, out_aa = None, None, None
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:,
|
||||
self.agent_state_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
if self.relative_position:
|
||||
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
|
||||
k2 = self.relative_position_k(len_q, len_k)
|
||||
sim2 = einsum('b t d, t s d -> b t s', q,
|
||||
k2) * self.scale # TODO check
|
||||
sim += sim2
|
||||
del k
|
||||
|
||||
if exists(mask):
|
||||
## feasible for causal attention mask only
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
|
||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
||||
if self.relative_position:
|
||||
v2 = self.relative_position_v(len_q, len_v)
|
||||
out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
|
||||
out += out2
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if k_ip is not None and k_as is not None and k_aa is not None:
|
||||
## for image cross-attention
|
||||
k_ip, v_ip = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_ip, v_ip))
|
||||
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_ip) * self.scale
|
||||
del k_ip
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## for agent state cross-attention
|
||||
k_as, v_as = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_as, v_as))
|
||||
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_as) * self.scale
|
||||
del k_as
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## for agent action cross-attention
|
||||
k_aa, v_aa = map(
|
||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_aa, v_aa))
|
||||
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_aa) * self.scale
|
||||
del k_aa
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if out_ip is not None and out_as is not None and out_aa is not None:
|
||||
if self.cross_attention_scale_learnable:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
||||
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
||||
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
||||
else:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip + \
|
||||
self.agent_state_cross_attention_scale * out_as + \
|
||||
self.agent_action_cross_attention_scale * out_aa
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def efficient_forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k, v, out = None, None, None
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
k_as, v_as, out_as = None, None, None
|
||||
k_aa, v_aa, out_aa = None, None, None
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
if context.shape[1] == self.text_context_len + self.video_length:
|
||||
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
else:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
|
||||
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
||||
q.shape[1],
|
||||
k_aa.shape[1],
|
||||
block_size=16).to(k_aa.device)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
|
||||
if k is not None:
|
||||
k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||
(k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(q,
|
||||
k,
|
||||
v,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out = (out.unsqueeze(0).reshape(
|
||||
b, self.heads, out.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
|
||||
if k_ip is not None:
|
||||
# For image cross-attention
|
||||
k_ip, v_ip = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_ip, v_ip),
|
||||
)
|
||||
out_ip = xformers.ops.memory_efficient_attention(q,
|
||||
k_ip,
|
||||
v_ip,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out_ip = (out_ip.unsqueeze(0).reshape(
|
||||
b, self.heads, out_ip.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_ip.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
|
||||
if k_as is not None:
|
||||
# For agent state cross-attention
|
||||
k_as, v_as = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_as, v_as),
|
||||
)
|
||||
out_as = xformers.ops.memory_efficient_attention(q,
|
||||
k_as,
|
||||
v_as,
|
||||
attn_bias=None,
|
||||
op=None)
|
||||
out_as = (out_as.unsqueeze(0).reshape(
|
||||
b, self.heads, out_as.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_as.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
if k_aa is not None:
|
||||
# For agent action cross-attention
|
||||
k_aa, v_aa = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||
),
|
||||
(k_aa, v_aa),
|
||||
)
|
||||
|
||||
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
|
||||
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
|
||||
attn_mask_aa = attn_mask_aa.to(q.dtype)
|
||||
|
||||
out_aa = xformers.ops.memory_efficient_attention(
|
||||
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
|
||||
|
||||
out_aa = (out_aa.unsqueeze(0).reshape(
|
||||
b, self.heads, out_aa.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out_aa.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
|
||||
out = 0.0 if out is None else out
|
||||
out_ip = 0.0 if out_ip is None else out_ip
|
||||
out_as = 0.0 if out_as is None else out_as
|
||||
out_aa = 0.0 if out_aa is None else out_aa
|
||||
|
||||
if self.cross_attention_scale_learnable:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
||||
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
||||
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
||||
|
||||
else:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip + \
|
||||
self.agent_state_cross_attention_scale * out_as + \
|
||||
self.agent_action_cross_attention_scale * out_aa
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
|
||||
num_token = l2 // block_size
|
||||
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2)
|
||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||
attn_mask = torch.zeros_like(mask, dtype=torch.float)
|
||||
attn_mask[mask] = float('-inf')
|
||||
return attn_mask
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
attention_cls=None,
|
||||
video_length=None,
|
||||
agent_state_context_len=2,
|
||||
agent_action_context_len=16,
|
||||
image_cross_attention=False,
|
||||
image_cross_attention_scale=1.0,
|
||||
cross_attention_scale_learnable=False,
|
||||
text_context_len=77):
|
||||
super().__init__()
|
||||
attn_cls = CrossAttention if attention_cls is None else attention_cls
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None)
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
video_length=video_length,
|
||||
agent_state_context_len=agent_state_context_len,
|
||||
agent_action_context_len=agent_action_context_len,
|
||||
image_cross_attention=image_cross_attention,
|
||||
image_cross_attention_scale=image_cross_attention_scale,
|
||||
cross_attention_scale_learnable=cross_attention_scale_learnable,
|
||||
text_context_len=text_context_len)
|
||||
self.image_cross_attention = image_cross_attention
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None, mask=None, **kwargs):
|
||||
# implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
|
||||
input_tuple = (
|
||||
x,
|
||||
) # should not be (x), otherwise *input_tuple will decouple x into multiple arguments
|
||||
if context is not None:
|
||||
input_tuple = (x, context)
|
||||
if mask is not None:
|
||||
forward_mask = partial(self._forward, mask=mask)
|
||||
return checkpoint(forward_mask, (x, ), self.parameters(),
|
||||
self.checkpoint)
|
||||
return checkpoint(self._forward, input_tuple, self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, mask=None):
|
||||
x = self.attn1(self.norm1(x),
|
||||
context=context if self.disable_self_attn else None,
|
||||
mask=mask) + x
|
||||
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data in spatial axis.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
use_checkpoint=True,
|
||||
disable_self_attn=False,
|
||||
use_linear=False,
|
||||
video_length=None,
|
||||
agent_state_context_len=2,
|
||||
agent_action_context_len=16,
|
||||
image_cross_attention=False,
|
||||
cross_attention_scale_learnable=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
attention_cls = None
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
checkpoint=use_checkpoint,
|
||||
attention_cls=attention_cls,
|
||||
video_length=video_length,
|
||||
agent_state_context_len=agent_state_context_len,
|
||||
agent_action_context_len=agent_action_context_len,
|
||||
image_cross_attention=image_cross_attention,
|
||||
cross_attention_scale_learnable=cross_attention_scale_learnable,
|
||||
) for d in range(depth)
|
||||
])
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None, **kwargs):
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context, **kwargs)
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class TemporalTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data in temporal axis.
|
||||
First, reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
use_checkpoint=True,
|
||||
use_linear=False,
|
||||
only_self_att=True,
|
||||
causal_attention=False,
|
||||
causal_block_size=1,
|
||||
relative_position=False,
|
||||
temporal_length=None):
|
||||
super().__init__()
|
||||
self.only_self_att = only_self_att
|
||||
self.relative_position = relative_position
|
||||
self.causal_attention = causal_attention
|
||||
self.causal_block_size = causal_block_size
|
||||
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.proj_in = nn.Conv1d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv1d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
if relative_position:
|
||||
assert (temporal_length is not None)
|
||||
attention_cls = partial(CrossAttention,
|
||||
relative_position=True,
|
||||
temporal_length=temporal_length)
|
||||
else:
|
||||
attention_cls = partial(CrossAttention,
|
||||
temporal_length=temporal_length)
|
||||
if self.causal_attention:
|
||||
assert (temporal_length is not None)
|
||||
self.mask = torch.tril(
|
||||
torch.ones([1, temporal_length, temporal_length]))
|
||||
|
||||
if self.only_self_att:
|
||||
context_dim = None
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
attention_cls=attention_cls,
|
||||
checkpoint=use_checkpoint)
|
||||
for d in range(depth)
|
||||
])
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv1d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None):
|
||||
b, c, t, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
|
||||
temp_mask = None
|
||||
if self.causal_attention:
|
||||
# Slice the from mask map
|
||||
temp_mask = self.mask[:, :t, :t].to(x.device)
|
||||
|
||||
if temp_mask is not None:
|
||||
mask = temp_mask.to(x.device)
|
||||
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b * h * w)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
if self.only_self_att:
|
||||
# NOTE: if no context is given, cross-attention defaults to self-attention
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, mask=mask)
|
||||
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
||||
else:
|
||||
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
|
||||
context = rearrange(context, '(b t) l con -> b t l con',
|
||||
t=t).contiguous()
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
# Calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
|
||||
for j in range(b):
|
||||
context_j = repeat(context[j],
|
||||
't l con -> (t r) l con',
|
||||
r=(h * w) // t,
|
||||
t=t).contiguous()
|
||||
# Note: causal mask will not applied in cross-attention case
|
||||
x[j] = block(x[j], context=context_j)
|
||||
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h,
|
||||
w=w).contiguous()
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv,
|
||||
'b (qkv heads c) h w -> qkv b heads c (h w)',
|
||||
heads=self.heads,
|
||||
qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out,
|
||||
'b heads c (h w) -> b (heads c) h w',
|
||||
heads=self.heads,
|
||||
h=h,
|
||||
w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=32,
|
||||
num_channels=in_channels,
|
||||
eps=1e-6,
|
||||
affine=True)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# Compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# Attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
630
src/unifolm_wma/modules/encoders/condition.py
Normal file
630
src/unifolm_wma/modules/encoders/condition.py
Normal file
@@ -0,0 +1,630 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import kornia
|
||||
import open_clip
|
||||
import math
|
||||
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
from unifolm_wma.utils.common import autocast
|
||||
from unifolm_wma.utils.utils import count_params
|
||||
from unifolm_wma.modules.encoders.resampler import reshape_tensor
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
self.n_classes = n_classes
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def forward(self, batch, key=None, disable_dropout=False):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes -
|
||||
1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
def get_unconditional_conditioning(self, bs, device="cuda"):
|
||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc = torch.ones((bs, ), device=device) * uc_class
|
||||
uc = {self.key: uc}
|
||||
return uc
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
|
||||
def __init__(self,
|
||||
version="google/t5-v1_1-xxl",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = ["last", "pooled", "hidden"]
|
||||
|
||||
def __init__(self,
|
||||
version="openai/clip-vit-large-patch14",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="last",
|
||||
layer_idx=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = layer_idx
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
assert 0 <= abs(layer_idx) <= 12
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding="max_length",
|
||||
return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens,
|
||||
output_hidden_states=self.layer == "hidden")
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class ClipImageEmbedder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
jit=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
antialias=True,
|
||||
ucg_rate=0.):
|
||||
super().__init__()
|
||||
from clip import load as load_clip
|
||||
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer('std',
|
||||
torch.Tensor([0.26862954, 0.26130258,
|
||||
0.27577711]),
|
||||
persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# re-normalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x, no_dropout=False):
|
||||
# x is assumed to be in range [-1,1]
|
||||
out = self.model.encode_image(self.preprocess(x))
|
||||
out = out.to(x.dtype)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
out = torch.bernoulli(
|
||||
(1. - self.ucg_rate) *
|
||||
torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||
return out
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
# "pooled",
|
||||
"last",
|
||||
"penultimate"
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
arch="ViT-H-14",
|
||||
version="laion2b_s32b_b79k",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="last"):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=version)
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "last":
|
||||
self.layer_idx = 0
|
||||
elif self.layer == "penultimate":
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(
|
||||
text) ## all clip models use 77 as context length
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch="ViT-H-14",
|
||||
version="laion2b_s32b_b79k",
|
||||
device="cuda",
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer="pooled",
|
||||
antialias=True,
|
||||
ucg_rate=0.):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch,
|
||||
device=torch.device('cpu'),
|
||||
pretrained=version,
|
||||
)
|
||||
del model.transformer
|
||||
self.model = model
|
||||
# self.mapper = torch.nn.Linear(1280, 1024)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "penultimate":
|
||||
raise NotImplementedError()
|
||||
self.layer_idx = 1
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer('std',
|
||||
torch.Tensor([0.26862954, 0.26130258,
|
||||
0.27577711]),
|
||||
persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@autocast
|
||||
def forward(self, image, no_dropout=False):
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
z = torch.bernoulli(
|
||||
(1. - self.ucg_rate) *
|
||||
torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, img):
|
||||
img = self.preprocess(img)
|
||||
x = self.model.visual(img)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch="ViT-H-14",
|
||||
version="laion2b_s32b_b79k",
|
||||
device="cuda",
|
||||
freeze=True,
|
||||
layer="pooled",
|
||||
antialias=True):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch,
|
||||
device=torch.device('cpu'),
|
||||
pretrained=version,
|
||||
)
|
||||
del model.transformer
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "penultimate":
|
||||
raise NotImplementedError()
|
||||
self.layer_idx = 1
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean',
|
||||
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
|
||||
persistent=False)
|
||||
self.register_buffer('std',
|
||||
torch.Tensor([0.26862954, 0.26130258,
|
||||
0.27577711]),
|
||||
persistent=False)
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',
|
||||
align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, image, no_dropout=False):
|
||||
## image: b c h w
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, x):
|
||||
x = self.preprocess(x)
|
||||
|
||||
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
|
||||
if self.model.visual.input_patchnorm:
|
||||
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
self.model.visual.grid_size[0],
|
||||
self.model.visual.patch_size[0],
|
||||
self.model.visual.grid_size[1],
|
||||
self.model.visual.patch_size[1])
|
||||
x = x.permute(0, 2, 4, 1, 3, 5)
|
||||
x = x.reshape(
|
||||
x.shape[0], self.model.visual.grid_size[0] *
|
||||
self.model.visual.grid_size[1], -1)
|
||||
x = self.model.visual.patchnorm_pre_ln(x)
|
||||
x = self.model.visual.conv1(x)
|
||||
else:
|
||||
x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
# class embeddings and positional embeddings
|
||||
x = torch.cat([
|
||||
self.model.visual.class_embedding.to(x.dtype) + torch.zeros(
|
||||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
|
||||
],
|
||||
dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.model.visual.positional_embedding.to(x.dtype)
|
||||
|
||||
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
|
||||
x = self.model.visual.patch_dropout(x)
|
||||
x = self.model.visual.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.model.visual.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
|
||||
def __init__(self,
|
||||
clip_version="openai/clip-vit-large-patch14",
|
||||
t5_version="google/t5-v1_1-xl",
|
||||
device="cuda",
|
||||
clip_max_length=77,
|
||||
t5_max_length=77):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version,
|
||||
device,
|
||||
max_length=clip_max_length)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version,
|
||||
device,
|
||||
max_length=t5_max_length)
|
||||
print(
|
||||
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
|
||||
)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
def forward(self, text):
|
||||
clip_z = self.clip_encoder.encode(text)
|
||||
t5_z = self.t5_encoder.encode(text)
|
||||
return [clip_z, t5_z]
|
||||
|
||||
|
||||
class LinearProjector(nn.Module):
|
||||
|
||||
def __init__(self, input_dim: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.projector = nn.Linear(input_dim, output_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class MLPProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
mlp_type: str = "gelu-mlp") -> None:
|
||||
super().__init__()
|
||||
if mlp_type == "gelu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(input_dim, output_dim, bias=True),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(output_dim, output_dim, bias=True),
|
||||
)
|
||||
elif mlp_type == "silu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(input_dim, output_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(output_dim, output_dim, bias=True),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(
|
||||
-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
|
||||
inner_dim = int(dim * mult)
|
||||
if ffd_type == "gelu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
elif ffd_type == "silu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
|
||||
class SATokenProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim=1024,
|
||||
depth=1,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=16,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
chunk_size=None):
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
if chunk_size is not None:
|
||||
num_queries = num_queries * chunk_size
|
||||
|
||||
self.latents = nn.Parameter(
|
||||
torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head,
|
||||
heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
return latents
|
||||
153
src/unifolm_wma/modules/encoders/resampler.py
Normal file
153
src/unifolm_wma/modules/encoders/resampler.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
|
||||
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
|
||||
# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ImageProjModel(nn.Module):
|
||||
"""Projection Model"""
|
||||
|
||||
def __init__(self,
|
||||
cross_attention_dim=1024,
|
||||
clip_embeddings_dim=1024,
|
||||
clip_extra_context_tokens=4):
|
||||
super().__init__()
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||
self.proj = nn.Linear(
|
||||
clip_embeddings_dim,
|
||||
self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds):
|
||||
#embeds = image_embeds
|
||||
embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
||||
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
# FFN
|
||||
def FeedForward(dim, mult=4):
|
||||
inner_dim = int(dim * mult)
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
|
||||
|
||||
def reshape_tensor(x, heads):
|
||||
bs, length, width = x.shape
|
||||
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||
x = x.view(bs, length, heads, -1)
|
||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||
x = x.transpose(1, 2)
|
||||
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
|
||||
x = x.reshape(bs, heads, length, -1)
|
||||
return x
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
# attention
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(
|
||||
-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Resampler(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=8,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=8,
|
||||
embedding_dim=768,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
video_length=None, # using frame-wise version or not
|
||||
):
|
||||
super().__init__()
|
||||
## queries for a single frame / image
|
||||
self.num_queries = num_queries
|
||||
self.video_length = video_length
|
||||
|
||||
## <num_queries> queries for each frame
|
||||
if video_length is not None:
|
||||
num_queries = num_queries * video_length
|
||||
|
||||
self.latents = nn.Parameter(
|
||||
torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_in = nn.Linear(embedding_dim, dim)
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(output_dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList([
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head,
|
||||
heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]))
|
||||
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
x = self.proj_in(x)
|
||||
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
|
||||
return latents
|
||||
1005
src/unifolm_wma/modules/networks/ae_modules.py
Normal file
1005
src/unifolm_wma/modules/networks/ae_modules.py
Normal file
File diff suppressed because it is too large
Load Diff
848
src/unifolm_wma/modules/networks/wma_model.py
Normal file
848
src/unifolm_wma/modules/networks/wma_model.py
Normal file
@@ -0,0 +1,848 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from functools import partial
|
||||
from abc import abstractmethod
|
||||
from einops import rearrange
|
||||
from omegaconf import OmegaConf
|
||||
from typing import Optional, Sequence, Any, Tuple, Union, List, Dict
|
||||
from collections.abc import Mapping, Iterable, Callable
|
||||
|
||||
from unifolm_wma.utils.diffusion import timestep_embedding
|
||||
from unifolm_wma.utils.common import checkpoint
|
||||
from unifolm_wma.utils.basics import (zero_module, conv_nd, linear,
|
||||
avg_pool_nd, normalization)
|
||||
from unifolm_wma.modules.attention import SpatialTransformer, TemporalTransformer
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None, batch_size=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb, batch_size=batch_size)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
elif isinstance(layer, TemporalTransformer):
|
||||
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
|
||||
x = layer(x, context)
|
||||
x = rearrange(x, 'b c f h w -> (b f) c h w')
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||
mode='nearest')
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
:param use_temporal_conv: if True, use the temporal convolution.
|
||||
:param use_image_dataset: if True, the temporal parameters will not be optimized.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_conv=False,
|
||||
up=False,
|
||||
down=False,
|
||||
use_temporal_conv=False,
|
||||
tempspatial_aware=False):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.use_temporal_conv = use_temporal_conv
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels
|
||||
if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(dims,
|
||||
channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
padding=1)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels,
|
||||
1)
|
||||
|
||||
if self.use_temporal_conv:
|
||||
self.temopral_conv = TemporalConvBlock(
|
||||
self.out_channels,
|
||||
self.out_channels,
|
||||
dropout=0.1,
|
||||
spatial_aware=tempspatial_aware)
|
||||
|
||||
def forward(self, x, emb, batch_size=None):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
input_tuple = (x, emb)
|
||||
if batch_size:
|
||||
forward_batchsize = partial(self._forward, batch_size=batch_size)
|
||||
return checkpoint(forward_batchsize, input_tuple,
|
||||
self.parameters(), self.use_checkpoint)
|
||||
return checkpoint(self._forward, input_tuple, self.parameters(),
|
||||
self.use_checkpoint)
|
||||
|
||||
def _forward(self, x, emb, batch_size=None):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
h = self.skip_connection(x) + h
|
||||
|
||||
if self.use_temporal_conv and batch_size:
|
||||
h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
|
||||
h = self.temopral_conv(h)
|
||||
h = rearrange(h, 'b c t h w -> (b t) c h w')
|
||||
return h
|
||||
|
||||
|
||||
class TemporalConvBlock(nn.Module):
|
||||
"""
|
||||
Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
dropout=0.0,
|
||||
spatial_aware=False):
|
||||
super(TemporalConvBlock, self).__init__()
|
||||
if out_channels is None:
|
||||
out_channels = in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
|
||||
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
|
||||
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
|
||||
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
|
||||
|
||||
# conv layers
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.GroupNorm(32, in_channels), nn.SiLU(),
|
||||
nn.Conv3d(in_channels,
|
||||
out_channels,
|
||||
th_kernel_shape,
|
||||
padding=th_padding_shape))
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels,
|
||||
in_channels,
|
||||
tw_kernel_shape,
|
||||
padding=tw_padding_shape))
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels,
|
||||
in_channels,
|
||||
th_kernel_shape,
|
||||
padding=th_padding_shape))
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
|
||||
nn.Conv3d(out_channels,
|
||||
in_channels,
|
||||
tw_kernel_shape,
|
||||
padding=tw_padding_shape))
|
||||
|
||||
# Zero out the last layer params,so the conv block is identity
|
||||
nn.init.zeros_(self.conv4[-1].weight)
|
||||
nn.init.zeros_(self.conv4[-1].bias)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
|
||||
return identity + x
|
||||
|
||||
|
||||
class WMAModel(nn.Module):
|
||||
"""
|
||||
The full World-Model-Action model.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
out_channels: int,
|
||||
num_res_blocks: int,
|
||||
attention_resolutions: Sequence[int],
|
||||
dropout: float = 0.0,
|
||||
channel_mult: Sequence[int] = (1, 2, 4, 8),
|
||||
conv_resample: bool = True,
|
||||
dims: int = 2,
|
||||
context_dim: int | None = None,
|
||||
use_scale_shift_norm: bool = False,
|
||||
resblock_updown: bool = False,
|
||||
num_heads: int = -1,
|
||||
num_head_channels: int = -1,
|
||||
transformer_depth: int = 1,
|
||||
use_linear: bool = False,
|
||||
use_checkpoint: bool = False,
|
||||
temporal_conv: bool = False,
|
||||
tempspatial_aware: bool = False,
|
||||
temporal_attention: bool = True,
|
||||
use_relative_position: bool = True,
|
||||
use_causal_attention: bool = False,
|
||||
temporal_length: int | None = None,
|
||||
use_fp16: bool = False,
|
||||
addition_attention: bool = False,
|
||||
temporal_selfatt_only: bool = True,
|
||||
image_cross_attention: bool = False,
|
||||
cross_attention_scale_learnable: bool = False,
|
||||
default_fs: int = 4,
|
||||
fs_condition: bool = False,
|
||||
n_obs_steps: int = 1,
|
||||
num_stem_token: int = 1,
|
||||
unet_head_config: OmegaConf | None = None,
|
||||
stem_process_config: OmegaConf | None = None,
|
||||
base_model_gen_only: bool = False):
|
||||
"""
|
||||
Initialize the World-Model-Action network.
|
||||
|
||||
Args:
|
||||
in_channels: Number of input channels to the backbone.
|
||||
model_channels: Base channel width for the UNet/backbone.
|
||||
out_channels: Number of output channels.
|
||||
num_res_blocks: Number of residual blocks per resolution stage.
|
||||
attention_resolutions: Resolutions at which to enable attention.
|
||||
dropout: Dropout probability used inside residual/attention blocks.
|
||||
channel_mult: Multipliers for channels at each resolution level.
|
||||
conv_resample: If True, use convolutional resampling for up/down sampling.
|
||||
dims: Spatial dimensionality of the backbone (1/2/3).
|
||||
context_dim: Optional context embedding dimension (for cross-attention).
|
||||
use_scale_shift_norm: Enable scale-shift (FiLM-style) normalization in blocks.
|
||||
resblock_updown: Use residual blocks for up/down sampling (instead of plain conv).
|
||||
num_heads: Number of attention heads (if >= 0). If -1, derive from num_head_channels.
|
||||
num_head_channels: Channels per attention head (if >= 0). If -1, derive from num_heads.
|
||||
transformer_depth: Number of transformer/attention blocks per stage.
|
||||
use_linear: Use linear attention variants where applicable.
|
||||
use_checkpoint: Enable gradient checkpointing in blocks to save memory.
|
||||
temporal_conv: Include temporal convolution along the time dimension.
|
||||
tempspatial_aware: If True, use time–space aware blocks.
|
||||
temporal_attention: Enable temporal self-attention.
|
||||
use_relative_position: Use relative position encodings in attention.
|
||||
use_causal_attention: Use causal (uni-directional) attention along time.
|
||||
temporal_length: Optional maximum temporal length expected by the model.
|
||||
use_fp16: Enable half-precision layers/normalization where supported.
|
||||
addition_attention: Add auxiliary attention modules.
|
||||
temporal_selfatt_only: Restrict attention to temporal-only (no spatial) if True.
|
||||
image_cross_attention: Enable cross-attention with image embeddings.
|
||||
cross_attention_scale_learnable: Make cross-attention scaling a learnable parameter.
|
||||
default_fs: Default frame-stride / fps.
|
||||
fs_condition: If True, condition on frame-stride/fps features.
|
||||
n_obs_steps: Number of observed steps used in conditioning heads.
|
||||
num_stem_token: Number of stem tokens for action tokenization.
|
||||
unet_head_config: OmegaConf for UNet heads (e.g., action/state heads).
|
||||
stem_process_config: OmegaConf for stem/preprocessor module.
|
||||
base_model_gen_only: Perform the generation using the base model with out action and state outputs.
|
||||
"""
|
||||
|
||||
super(WMAModel, self).__init__()
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.temporal_attention = temporal_attention
|
||||
time_embed_dim = model_channels * 4
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
temporal_self_att_only = True
|
||||
self.addition_attention = addition_attention
|
||||
self.temporal_length = temporal_length
|
||||
self.image_cross_attention = image_cross_attention
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
self.default_fs = default_fs
|
||||
self.fs_condition = fs_condition
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self.num_stem_token = num_stem_token
|
||||
self.base_model_gen_only = base_model_gen_only
|
||||
|
||||
# Time embedding blocks
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
if fs_condition:
|
||||
self.fps_embedding = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
nn.init.zeros_(self.fps_embedding[-1].weight)
|
||||
nn.init.zeros_(self.fps_embedding[-1].bias)
|
||||
# Input Block
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
||||
])
|
||||
if self.addition_attention:
|
||||
self.init_attn = TimestepEmbedSequential(
|
||||
TemporalTransformer(model_channels,
|
||||
n_heads=8,
|
||||
d_head=num_head_channels,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_selfatt_only,
|
||||
causal_attention=False,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disable_self_attn=False,
|
||||
video_length=temporal_length,
|
||||
agent_state_context_len=self.n_obs_steps,
|
||||
agent_action_context_len=self.temporal_length *
|
||||
num_stem_token,
|
||||
image_cross_attention=self.image_cross_attention,
|
||||
cross_attention_scale_learnable=self.
|
||||
cross_attention_scale_learnable,
|
||||
))
|
||||
if self.temporal_attention:
|
||||
layers.append(
|
||||
TemporalTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_self_att_only,
|
||||
causal_attention=use_causal_attention,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True)
|
||||
if resblock_updown else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
layers = [
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv),
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disable_self_attn=False,
|
||||
video_length=temporal_length,
|
||||
agent_state_context_len=self.n_obs_steps,
|
||||
agent_action_context_len=self.temporal_length * num_stem_token,
|
||||
image_cross_attention=self.image_cross_attention,
|
||||
cross_attention_scale_learnable=self.
|
||||
cross_attention_scale_learnable)
|
||||
]
|
||||
if self.temporal_attention:
|
||||
layers.append(
|
||||
TemporalTransformer(ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_self_att_only,
|
||||
causal_attention=use_causal_attention,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
layers.append(
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv))
|
||||
|
||||
# Middle Block
|
||||
self.middle_block = TimestepEmbedSequential(*layers)
|
||||
|
||||
# Output Block
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
tempspatial_aware=tempspatial_aware,
|
||||
use_temporal_conv=temporal_conv)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
layers.append(
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
disable_self_attn=False,
|
||||
video_length=temporal_length,
|
||||
agent_state_context_len=self.n_obs_steps,
|
||||
image_cross_attention=self.image_cross_attention,
|
||||
cross_attention_scale_learnable=self.
|
||||
cross_attention_scale_learnable))
|
||||
if self.temporal_attention:
|
||||
layers.append(
|
||||
TemporalTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
use_checkpoint=use_checkpoint,
|
||||
only_self_att=temporal_self_att_only,
|
||||
causal_attention=use_causal_attention,
|
||||
relative_position=use_relative_position,
|
||||
temporal_length=temporal_length))
|
||||
if level and i == num_res_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True)
|
||||
if resblock_updown else Upsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(
|
||||
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
# Action and state prediction unet
|
||||
unet_head_config['params']['context_dims'] = [
|
||||
mult * model_channels for mult in channel_mult
|
||||
]
|
||||
self.action_unet = instantiate_from_config(unet_head_config)
|
||||
self.state_unet = instantiate_from_config(unet_head_config)
|
||||
|
||||
# Initialize action token_projector
|
||||
self.action_token_projector = instantiate_from_config(
|
||||
stem_process_config)
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
x_action: Tensor,
|
||||
x_state: Tensor,
|
||||
timesteps: Tensor,
|
||||
context: Tensor | None = None,
|
||||
context_action: Tensor | None = None,
|
||||
features_adapter: Any = None,
|
||||
fs: Tensor | None = None,
|
||||
**kwargs) -> Tensor | tuple[Tensor, ...]:
|
||||
|
||||
"""
|
||||
Forward pass of the World-Model-Action backbone.
|
||||
|
||||
Args:
|
||||
x: Input tensor (latent video), shape (B, C,...).
|
||||
x_action: action stream input.
|
||||
x_state: state stream input.
|
||||
timesteps: Diffusion timesteps, shape (B,) or scalar Tensor.
|
||||
context: conditioning context for cross-attention.
|
||||
context_action: conditioning context specific to action/state (implementation-specific).
|
||||
features_adapter: module or dict to adapt intermediate features.
|
||||
fs: frame-stride / fps conditioning.
|
||||
|
||||
Returns:
|
||||
Tuple of Tensors for predictions:
|
||||
|
||||
"""
|
||||
|
||||
b, _, t, _, _ = x.shape
|
||||
t_emb = timestep_embedding(timesteps,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
|
||||
emb = emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||
|
||||
# Combine emb
|
||||
if self.fs_condition:
|
||||
if fs is None:
|
||||
fs = torch.tensor([self.default_fs] * b,
|
||||
dtype=torch.long,
|
||||
device=x.device)
|
||||
fs_emb = timestep_embedding(fs,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
|
||||
fs_embed = self.fps_embedding(fs_emb)
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
emb = emb + fs_embed
|
||||
|
||||
h = x.type(self.dtype)
|
||||
adapter_idx = 0
|
||||
hs = []
|
||||
hs_a = []
|
||||
for id, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if id == 0 and self.addition_attention:
|
||||
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||
# plug-in adapter features
|
||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||
h = h + features_adapter[adapter_idx]
|
||||
adapter_idx += 1
|
||||
if id != 0:
|
||||
if isinstance(module[0], Downsample):
|
||||
hs_a.append(
|
||||
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs.append(h)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
if features_adapter is not None:
|
||||
assert len(
|
||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
hs_out = []
|
||||
for module in self.output_blocks:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context=context, batch_size=b)
|
||||
if isinstance(module[-1], Upsample):
|
||||
hs_a.append(
|
||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
hs_out.append(h)
|
||||
h = h.type(x.dtype)
|
||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||
|
||||
y = self.out(h)
|
||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||
|
||||
if not self.base_model_gen_only:
|
||||
ba, _, _ = x_action.shape
|
||||
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
# Predict state
|
||||
if b > 1:
|
||||
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
s_y = self.state_unet(x_state, timesteps, hs_a,
|
||||
context_action[:2], **kwargs)
|
||||
else:
|
||||
a_y = torch.zeros_like(x_action)
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
return y, a_y, s_y
|
||||
0
src/unifolm_wma/modules/vision/__init__.py
Normal file
0
src/unifolm_wma/modules/vision/__init__.py
Normal file
244
src/unifolm_wma/modules/vision/base_vision.py
Normal file
244
src/unifolm_wma/modules/vision/base_vision.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
base_vision.py
|
||||
|
||||
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
|
||||
functions, and initialization logic.
|
||||
|
||||
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
|
||||
Transformer model for feature extraction.
|
||||
"""
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms.functional as TVF
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from timm.models.vision_transformer import Block, VisionTransformer
|
||||
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
||||
from torchvision.transforms import Compose, Resize
|
||||
|
||||
|
||||
# === Utility Functions for Monkey-Patching ===
|
||||
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
|
||||
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
result = fn(*args, **kwargs)
|
||||
return result[0] if isinstance(result, tuple) else result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# === Interface for an Image Transform ===
|
||||
class ImageTransform(Protocol):
|
||||
|
||||
def __call__(
|
||||
self, img: Image,
|
||||
**kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
...
|
||||
|
||||
|
||||
# === Custom Torchvision Image Transforms ===
|
||||
@dataclass
|
||||
class LetterboxPad:
|
||||
padding_fill_value: Tuple[int, int, int]
|
||||
|
||||
def __call__(self, image: Image) -> Image:
|
||||
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
|
||||
(w, h), max_wh = image.size, max(image.size)
|
||||
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int(
|
||||
(max_wh - h) / 2)
|
||||
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
|
||||
return TVF.pad(image,
|
||||
padding,
|
||||
fill=self.padding_fill_value,
|
||||
padding_mode="constant")
|
||||
|
||||
|
||||
# === Abstract Base Class for arbitrary Vision Backbones ===
|
||||
class VisionBackbone(nn.Module, ABC):
|
||||
|
||||
def __init__(self,
|
||||
vision_backbone_id: str,
|
||||
image_resize_strategy: str,
|
||||
default_image_size: int = 224) -> None:
|
||||
super().__init__()
|
||||
self.identifier: str = vision_backbone_id
|
||||
self.image_resize_strategy: str = image_resize_strategy
|
||||
self.default_image_size: int = default_image_size
|
||||
|
||||
# Instance attributes for a Vision Backbone
|
||||
self.featurizer: nn.Module = None
|
||||
self.image_transform: ImageTransform = None
|
||||
|
||||
def get_image_transform(self) -> ImageTransform:
|
||||
return self.image_transform
|
||||
|
||||
@abstractmethod
|
||||
def get_fsdp_wrapping_policy(self) -> Callable:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def default_image_resolution(self) -> Tuple[int, int, int]:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def embed_dim(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_patches(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def half_precision_dtype(self) -> torch.dtype:
|
||||
...
|
||||
|
||||
|
||||
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
|
||||
class TimmViTBackbone(VisionBackbone, ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_backbone_id: str,
|
||||
timm_path_or_url: str,
|
||||
image_resize_strategy: str,
|
||||
default_image_size: int = 224,
|
||||
override_act_layer: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(vision_backbone_id,
|
||||
image_resize_strategy,
|
||||
default_image_size=default_image_size)
|
||||
self.timm_path_or_url = timm_path_or_url
|
||||
self.override_act_layer = override_act_layer
|
||||
self.dtype = torch.bfloat16
|
||||
|
||||
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
|
||||
if self.override_act_layer is None:
|
||||
self.featurizer: VisionTransformer = timm.create_model(
|
||||
self.timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size)
|
||||
else:
|
||||
self.featurizer: VisionTransformer = timm.create_model(
|
||||
self.timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size,
|
||||
act_layer=self.override_act_layer,
|
||||
)
|
||||
self.featurizer.eval()
|
||||
|
||||
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
|
||||
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
||||
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
||||
self.featurizer.forward = unpack_tuple(
|
||||
partial(self.featurizer.get_intermediate_layers,
|
||||
n={len(self.featurizer.blocks) - 2}))
|
||||
|
||||
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
|
||||
assert isinstance(self.featurizer, VisionTransformer), (
|
||||
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
|
||||
"file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
|
||||
)
|
||||
|
||||
# Get Config =>> Note :: Override default image size to ensure correct image transform
|
||||
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
|
||||
self.data_cfg["input_size"] = (3, self.default_image_size,
|
||||
self.default_image_size)
|
||||
|
||||
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
|
||||
default_image_transform = timm.data.create_transform(**self.data_cfg,
|
||||
is_training=False)
|
||||
|
||||
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
|
||||
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
|
||||
assert isinstance(default_image_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert isinstance(default_image_transform.transforms[0], Resize)
|
||||
default_image_transform = Compose([
|
||||
Resize(self.default_image_size,
|
||||
interpolation=default_image_transform.transforms[0].
|
||||
interpolation),
|
||||
*default_image_transform.transforms[1:],
|
||||
])
|
||||
|
||||
# Switch on `image_resize_strategy`
|
||||
if self.image_resize_strategy == "resize-naive":
|
||||
assert isinstance(default_image_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert isinstance(default_image_transform.transforms[0], Resize)
|
||||
|
||||
target_size = (self.default_image_size, self.default_image_size)
|
||||
self.image_transform = Compose([
|
||||
Resize(target_size,
|
||||
interpolation=default_image_transform.transforms[0].
|
||||
interpolation),
|
||||
*default_image_transform.transforms[1:],
|
||||
])
|
||||
|
||||
elif self.image_resize_strategy == "resize-crop":
|
||||
self.image_transform = default_image_transform
|
||||
|
||||
elif self.image_resize_strategy == "letterbox":
|
||||
assert isinstance(default_image_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
|
||||
|
||||
# Compute Padding Fill Value (rescaled normalization mean if applicable)
|
||||
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
|
||||
|
||||
# Build New Transform
|
||||
self.image_transform = Compose(
|
||||
[LetterboxPad(fill), *default_image_transform.transforms])
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
|
||||
)
|
||||
|
||||
def get_fsdp_wrapping_policy(self) -> Callable:
|
||||
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
|
||||
vit_wrap_policy = partial(_module_wrap_policy,
|
||||
module_classes={VisionTransformer})
|
||||
transformer_block_policy = partial(transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={Block})
|
||||
return partial(_or_policy,
|
||||
policies=[vit_wrap_policy, transformer_block_policy])
|
||||
|
||||
def forward(
|
||||
self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
|
||||
return self.featurizer(pixel_values)
|
||||
|
||||
@property
|
||||
def default_image_resolution(self) -> Tuple[int, int, int]:
|
||||
return self.data_cfg["input_size"]
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self.featurizer.embed_dim
|
||||
|
||||
@property
|
||||
def num_patches(self) -> int:
|
||||
return self.featurizer.patch_embed.num_patches
|
||||
|
||||
@property
|
||||
def half_precision_dtype(self) -> torch.dtype:
|
||||
return self.dtype
|
||||
273
src/unifolm_wma/modules/vision/dinosiglip_vit.py
Normal file
273
src/unifolm_wma/modules/vision/dinosiglip_vit.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
dinosiglip_vit.py
|
||||
|
||||
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
|
||||
"""
|
||||
|
||||
import timm
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Tuple
|
||||
from PIL import Image
|
||||
from timm.models.vision_transformer import Block, VisionTransformer
|
||||
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
|
||||
from torchvision.transforms import Compose, Resize, Normalize
|
||||
|
||||
from unifolm_wma.modules.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
|
||||
from unifolm_wma.utils.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
|
||||
|
||||
# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
|
||||
DINOSigLIP_VISION_BACKBONES = {
|
||||
"dinosiglip-vit-so-224px": {
|
||||
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
|
||||
"siglip": "vit_so400m_patch14_siglip_224",
|
||||
},
|
||||
"dinosiglip-vit-so-384px": {
|
||||
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
|
||||
"siglip": "vit_so400m_patch14_siglip_384",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DinoSigLIPImageTransform:
|
||||
dino_image_transform: ImageTransform
|
||||
siglip_image_transform: ImageTransform
|
||||
is_prismatic: bool = True
|
||||
|
||||
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
|
||||
return {
|
||||
"dino": self.dino_image_transform(img, **kwargs),
|
||||
"siglip": self.siglip_image_transform(img, **kwargs)
|
||||
}
|
||||
|
||||
|
||||
class DinoSigLIPViTBackbone(VisionBackbone):
|
||||
|
||||
def __init__(self,
|
||||
vision_backbone_id: str,
|
||||
image_resize_strategy: str,
|
||||
arch_specifier: str,
|
||||
output_dim: int,
|
||||
pretrained_checkpoint=None,
|
||||
freeze=True,
|
||||
default_image_size: int = 224) -> None:
|
||||
super().__init__(vision_backbone_id,
|
||||
image_resize_strategy,
|
||||
default_image_size=default_image_size)
|
||||
self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
|
||||
vision_backbone_id]["dino"]
|
||||
self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
|
||||
vision_backbone_id]["siglip"]
|
||||
|
||||
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
|
||||
self.dino_featurizer: VisionTransformer = timm.create_model(
|
||||
self.dino_timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size)
|
||||
if pretrained_checkpoint:
|
||||
ckpt = pretrained_checkpoint + '/openvla_dino.pt'
|
||||
self.dino_featurizer.load_state_dict(
|
||||
torch.load(ckpt, weights_only=True))
|
||||
print('>>> load dino weights')
|
||||
if freeze:
|
||||
self.dino_featurizer.eval()
|
||||
for param in self.dino_featurizer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.siglip_featurizer: VisionTransformer = timm.create_model(
|
||||
self.siglip_timm_path_or_url,
|
||||
pretrained=True,
|
||||
num_classes=0,
|
||||
img_size=self.default_image_size)
|
||||
if pretrained_checkpoint:
|
||||
ckpt = pretrained_checkpoint + '/openvla_siglip.pt'
|
||||
self.siglip_featurizer.load_state_dict(
|
||||
torch.load(ckpt, weights_only=True))
|
||||
print('>>> load siglip weights')
|
||||
if freeze:
|
||||
self.siglip_featurizer.eval()
|
||||
for param in self.siglip_featurizer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
|
||||
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
|
||||
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
|
||||
self.dino_featurizer.forward = unpack_tuple(
|
||||
partial(self.dino_featurizer.get_intermediate_layers,
|
||||
n={len(self.dino_featurizer.blocks) - 2}))
|
||||
self.siglip_featurizer.forward = unpack_tuple(
|
||||
partial(self.siglip_featurizer.get_intermediate_layers,
|
||||
n={len(self.siglip_featurizer.blocks) - 2}))
|
||||
|
||||
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
|
||||
self.dino_data_cfg = timm.data.resolve_model_data_config(
|
||||
self.dino_featurizer)
|
||||
self.dino_data_cfg["input_size"] = (3, self.default_image_size,
|
||||
self.default_image_size)
|
||||
|
||||
self.siglip_data_cfg = timm.data.resolve_model_data_config(
|
||||
self.siglip_featurizer)
|
||||
self.siglip_data_cfg["input_size"] = (3, self.default_image_size,
|
||||
self.default_image_size)
|
||||
|
||||
# Initialize *both* Transforms
|
||||
self.default_dino_transform = timm.data.create_transform(
|
||||
**self.dino_data_cfg, is_training=False)
|
||||
self.default_siglip_transform = timm.data.create_transform(
|
||||
**self.siglip_data_cfg, is_training=False)
|
||||
|
||||
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
|
||||
assert isinstance(self.default_siglip_transform,
|
||||
Compose), "Unexpected `default_image_transform`!"
|
||||
assert isinstance(self.default_siglip_transform.transforms[0], Resize)
|
||||
self.default_siglip_transform = Compose([
|
||||
Resize(self.default_image_size,
|
||||
interpolation=self.default_siglip_transform.transforms[0].
|
||||
interpolation),
|
||||
*self.default_siglip_transform.transforms[1:],
|
||||
])
|
||||
|
||||
if self.image_resize_strategy == "resize-naive":
|
||||
assert isinstance(
|
||||
self.default_dino_transform,
|
||||
Compose), "Unexpected `default_dino_image_transform`!"
|
||||
assert isinstance(
|
||||
self.default_siglip_transform,
|
||||
Compose), "Unexpected `default_siglip_image_transform`!"
|
||||
assert isinstance(self.default_dino_transform.transforms[0],
|
||||
Resize)
|
||||
assert isinstance(self.default_siglip_transform.transforms[0],
|
||||
Resize)
|
||||
|
||||
self.target_size = (self.default_image_size,
|
||||
self.default_image_size)
|
||||
dino_transform = Compose([
|
||||
Resize(self.target_size,
|
||||
interpolation=self.default_dino_transform.transforms[0].
|
||||
interpolation),
|
||||
*self.default_dino_transform.transforms[1:],
|
||||
])
|
||||
siglip_transform = Compose([
|
||||
Resize(self.target_size,
|
||||
interpolation=self.default_siglip_transform.
|
||||
transforms[0].interpolation),
|
||||
*self.default_siglip_transform.transforms[1:],
|
||||
])
|
||||
|
||||
self.image_transform = DinoSigLIPImageTransform(
|
||||
dino_transform, siglip_transform)
|
||||
|
||||
elif self.image_resize_strategy == "resize-crop":
|
||||
self.image_transform = DinoSigLIPImageTransform(
|
||||
self.default_dino_transform, self.default_siglip_transform)
|
||||
|
||||
elif self.image_resize_strategy == "letterbox":
|
||||
assert isinstance(self.default_dino_transform,
|
||||
Compose), "Unexpected `default_dino_transform`!"
|
||||
assert isinstance(
|
||||
self.default_siglip_transform,
|
||||
Compose), "Unexpected `default_siglip_transform`!"
|
||||
assert ("mean" in self.dino_data_cfg
|
||||
and "mean" in self.siglip_data_cfg
|
||||
), "DinoSigLIP `data_cfg` missing `mean`!"
|
||||
|
||||
# Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
|
||||
dino_fill = tuple(
|
||||
[int(x * 255) for x in self.dino_data_cfg["mean"]])
|
||||
siglip_fill = tuple(
|
||||
[int(x * 255) for x in self.siglip_data_cfg["mean"]])
|
||||
|
||||
# Build New Transform
|
||||
self.image_transform = DinoSigLIPImageTransform(
|
||||
Compose([
|
||||
LetterboxPad(dino_fill),
|
||||
*self.default_dino_transform.transforms
|
||||
]),
|
||||
Compose([
|
||||
LetterboxPad(siglip_fill),
|
||||
*self.default_siglip_transform.transforms
|
||||
]),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
|
||||
)
|
||||
|
||||
self.arch_specifier = arch_specifier
|
||||
if arch_specifier == "linear":
|
||||
self.projector = LinearProjector(self.embed_dim, output_dim)
|
||||
elif arch_specifier.endswith("fused-gelu-mlp"):
|
||||
self.projector = FusedMLPProjector(self.embed_dim, output_dim)
|
||||
elif arch_specifier.endswith("gelu-mlp"):
|
||||
self.projector = MLPProjector(self.embed_dim, output_dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"PrismaticVLM with `{arch_specifier = }` is not supported!")
|
||||
|
||||
self.on_gpu = False
|
||||
|
||||
def get_fsdp_wrapping_policy(self) -> Callable:
|
||||
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
|
||||
vit_wrap_policy = partial(_module_wrap_policy,
|
||||
module_classes={VisionTransformer})
|
||||
transformer_block_policy = partial(transformer_auto_wrap_policy,
|
||||
transformer_layer_cls={Block})
|
||||
return partial(_or_policy,
|
||||
policies=[vit_wrap_policy, transformer_block_policy])
|
||||
|
||||
def forward(self, img) -> torch.Tensor:
|
||||
img = torch.clamp(img.float(), -1., 1.)
|
||||
img = (img + 1.0) / 2.0
|
||||
img = img * 255
|
||||
|
||||
resize = transforms.Resize(min(self.target_size),
|
||||
interpolation=self.default_dino_transform.
|
||||
transforms[0].interpolation,
|
||||
max_size=None,
|
||||
antialias=True)
|
||||
center_crop = transforms.CenterCrop(self.target_size)
|
||||
img = center_crop(resize(img))
|
||||
|
||||
dino_normalizer = Normalize(mean=torch.tensor([0.4850, 0.4560,
|
||||
0.4060]),
|
||||
std=torch.tensor([0.2290, 0.2240, 0.2250]))
|
||||
siglip_normalizer = Normalize(
|
||||
mean=torch.tensor([0.5000, 0.5000, 0.5000]),
|
||||
std=torch.tensor([0.5000, 0.5000, 0.5000]))
|
||||
pixel_values = {
|
||||
'dino': dino_normalizer(img),
|
||||
'siglip': siglip_normalizer(img)
|
||||
}
|
||||
|
||||
if self.on_gpu:
|
||||
pixel_values = {k: v.cuda() for k, v in pixel_values.items()}
|
||||
elif next(self.dino_featurizer.parameters()).device.type != 'cpu':
|
||||
self.on_gpu = True
|
||||
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
|
||||
dino_patches = self.dino_featurizer(pixel_values["dino"])
|
||||
siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
|
||||
|
||||
return self.projector(torch.cat([dino_patches, siglip_patches], dim=2))
|
||||
|
||||
@property
|
||||
def default_image_resolution(self) -> Tuple[int, int, int]:
|
||||
return self.dino_data_cfg["input_size"]
|
||||
|
||||
@property
|
||||
def embed_dim(self) -> int:
|
||||
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
|
||||
|
||||
@property
|
||||
def num_patches(self) -> int:
|
||||
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
|
||||
return self.dino_featurizer.patch_embed.num_patches
|
||||
|
||||
@property
|
||||
def half_precision_dtype(self) -> torch.dtype:
|
||||
return torch.bfloat16
|
||||
104
src/unifolm_wma/utils/basics.py
Normal file
104
src/unifolm_wma/utils/basics.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
import torch.nn as nn
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def nonlinearity(type='silu'):
|
||||
if type == 'silu':
|
||||
return nn.SiLU()
|
||||
elif type == 'leaky_relu':
|
||||
return nn.LeakyReLU()
|
||||
|
||||
|
||||
class GroupNormSpecific(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def normalization(channels, num_groups=32):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNormSpecific(num_groups, channels)
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(
|
||||
c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
226
src/unifolm_wma/utils/callbacks.py
Normal file
226
src/unifolm_wma/utils/callbacks.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
|
||||
mainlogger = logging.getLogger('mainlogger')
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import pytorch_lightning as pl
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from unifolm_wma.utils.save_video import log_local, prepare_to_log
|
||||
|
||||
STAT_DIR = '~/'
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
|
||||
def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=True, save_dir=None, \
|
||||
to_local=False, log_images_kwargs=None):
|
||||
super().__init__()
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.to_local = to_local
|
||||
self.clamp = clamp
|
||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||
self.save_stat_dir = os.path.join(save_dir, "stat")
|
||||
os.makedirs(self.save_stat_dir, exist_ok=True)
|
||||
self.fps_stat = {}
|
||||
self.fs_stat = {}
|
||||
if self.to_local:
|
||||
self.save_dir = os.path.join(save_dir, "images")
|
||||
os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True)
|
||||
os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True)
|
||||
self.count_data = 0
|
||||
|
||||
def log_to_tensorboard(self,
|
||||
pl_module,
|
||||
batch_logs,
|
||||
filename,
|
||||
split,
|
||||
save_fps=8):
|
||||
""" log images and videos to tensorboard """
|
||||
global_step = pl_module.global_step
|
||||
for key in batch_logs:
|
||||
value = batch_logs[key]
|
||||
tag = "gs%d-%s/%s||%s||%s||%s" % (
|
||||
global_step, split, key,
|
||||
batch_logs['condition'][0].split('_')[0],
|
||||
batch_logs['condition'][0].split('_')[1],
|
||||
batch_logs['video_idx'])
|
||||
if isinstance(value, list) and isinstance(value[0], str):
|
||||
captions = ' |------| '.join(value)
|
||||
pl_module.logger.experiment.add_text(tag,
|
||||
captions,
|
||||
global_step=global_step)
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 5:
|
||||
video = value
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet,
|
||||
nrow=int(n),
|
||||
padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(
|
||||
frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
pl_module.logger.experiment.add_video(tag,
|
||||
grid,
|
||||
fps=save_fps,
|
||||
global_step=global_step)
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 4:
|
||||
img = value
|
||||
grid = torchvision.utils.make_grid(img, nrow=int(n), padding=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
pl_module.logger.experiment.add_image(tag,
|
||||
grid,
|
||||
global_step=global_step)
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 3:
|
||||
b, _, _ = value.shape
|
||||
value1 = value[:b // 2, ...]
|
||||
value2 = value[b // 2:, ...]
|
||||
_, num_points, d = value1.shape
|
||||
for i in range(d):
|
||||
data1 = value1[0, :, i].cpu().detach().numpy()
|
||||
data2 = value2[0, :, i].cpu().detach().numpy()
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(data1, label='Target 1')
|
||||
ax.plot(data2, label='Sample 1')
|
||||
ax.set_title(f'Comparison at dimension {i} for {key}')
|
||||
ax.legend()
|
||||
pl_module.logger.experiment.add_figure(
|
||||
tag + f"| {key}_dim_{i}", fig, global_step=global_step)
|
||||
plt.close(fig)
|
||||
else:
|
||||
pass
|
||||
|
||||
@rank_zero_only
|
||||
def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"):
|
||||
""" generate images, then save and log to tensorboard """
|
||||
# Update fps and fs statistics
|
||||
batch_fps = batch['fps'].tolist()
|
||||
batch_fs = batch['frame_stride'].tolist()
|
||||
for num in batch_fps:
|
||||
self.fps_stat[num] = self.fps_stat.get(num, 0) + 1
|
||||
for num in batch_fs:
|
||||
self.fs_stat[num] = self.fs_stat.get(num, 0) + 1
|
||||
skip_freq = self.batch_freq if split == "train" else 5
|
||||
## NOTE HAND CODE
|
||||
self.count_data += 12.5 * 2
|
||||
if self.count_data >= skip_freq:
|
||||
self.count_data = 0
|
||||
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
torch.cuda.empty_cache()
|
||||
with torch.no_grad():
|
||||
log_func = pl_module.log_images
|
||||
batch_logs = log_func(batch,
|
||||
split=split,
|
||||
**self.log_images_kwargs)
|
||||
# Log fps and fs statistics
|
||||
with open(self.save_stat_dir + '/fps_fs_stat.json',
|
||||
'w') as file:
|
||||
json.dump({
|
||||
'fps': self.fps_stat,
|
||||
'fs': self.fs_stat
|
||||
},
|
||||
file,
|
||||
indent=4)
|
||||
|
||||
batch_logs = prepare_to_log(batch_logs, self.max_images,
|
||||
self.clamp)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
filename = "ep{}_idx{}_rank{}".format(pl_module.current_epoch,
|
||||
batch_idx,
|
||||
pl_module.global_rank)
|
||||
if self.to_local:
|
||||
mainlogger.info("Log [%s] batch <%s> to local ..." %
|
||||
(split, filename))
|
||||
filename = "gs{}_".format(pl_module.global_step) + filename
|
||||
log_local(batch_logs,
|
||||
os.path.join(self.save_dir, split),
|
||||
filename,
|
||||
save_fps=10)
|
||||
else:
|
||||
mainlogger.info("Log [%s] batch <%s> to tensorboard ..." %
|
||||
(split, filename))
|
||||
self.log_to_tensorboard(pl_module,
|
||||
batch_logs,
|
||||
filename,
|
||||
split,
|
||||
save_fps=10)
|
||||
mainlogger.info('Finish!')
|
||||
|
||||
if is_train:
|
||||
pl_module.train()
|
||||
|
||||
def on_train_batch_end(self,
|
||||
trainer,
|
||||
pl_module,
|
||||
outputs,
|
||||
batch,
|
||||
batch_idx,
|
||||
dataloader_idx=None):
|
||||
if self.batch_freq != -1 and pl_module.logdir:
|
||||
self.log_batch_imgs(pl_module, batch, batch_idx, split="train")
|
||||
|
||||
def on_validation_batch_end(self,
|
||||
trainer,
|
||||
pl_module,
|
||||
outputs,
|
||||
batch,
|
||||
batch_idx,
|
||||
dataloader_idx=None):
|
||||
#Different with validation_step() that saving the whole validation set and only keep the latest,
|
||||
#It records the performance of every validation (without overwritten) by only keep a subset
|
||||
if self.batch_freq != -1 and pl_module.logdir:
|
||||
self.log_batch_imgs(pl_module, batch, batch_idx, split="val")
|
||||
if hasattr(pl_module, 'calibrate_grad_norm'):
|
||||
if (pl_module.calibrate_grad_norm
|
||||
and batch_idx % 25 == 0) and batch_idx > 0:
|
||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
||||
|
||||
|
||||
class CUDACallback(Callback):
|
||||
# See https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
# Reset the memory use counter
|
||||
# Lightning update
|
||||
if int((pl.__version__).split('.')[1]) >= 7:
|
||||
gpu_index = trainer.strategy.root_device.index
|
||||
else:
|
||||
gpu_index = trainer.root_gpu
|
||||
torch.cuda.reset_peak_memory_stats(gpu_index)
|
||||
torch.cuda.synchronize(gpu_index)
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
if int((pl.__version__).split('.')[1]) >= 7:
|
||||
gpu_index = trainer.strategy.root_device.index
|
||||
else:
|
||||
gpu_index = trainer.root_gpu
|
||||
torch.cuda.synchronize(gpu_index)
|
||||
max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2**20
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
||||
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
||||
except AttributeError:
|
||||
pass
|
||||
111
src/unifolm_wma/utils/common.py
Normal file
111
src/unifolm_wma/utils/common.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import math
|
||||
from inspect import isfunction
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def gather_data(data, return_np=True):
|
||||
''' gather data from multiple processes to one list '''
|
||||
data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(data_list, data) # gather not supported with NCCL
|
||||
if return_np:
|
||||
data_list = [data.cpu().numpy() for data in data_list]
|
||||
return data_list
|
||||
|
||||
|
||||
def autocast(f):
|
||||
|
||||
def do_autocast(*args, **kwargs):
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True,
|
||||
dtype=torch.get_autocast_gpu_dtype(),
|
||||
cache_enabled=torch.is_autocast_cache_enabled()):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return do_autocast
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
||||
shape[0], *((1, ) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def identity(*args, **kwargs):
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def shape_to_str(x):
|
||||
shape_str = "x".join([str(x) for x in x.shape])
|
||||
return shape_str
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
ckpt = torch.utils.checkpoint.checkpoint
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
return ckpt(func, *inputs, use_reentrant=False)
|
||||
else:
|
||||
return func(*inputs)
|
||||
242
src/unifolm_wma/utils/data.py
Normal file
242
src/unifolm_wma/utils/data.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import os, sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from functools import partial
|
||||
from torch.utils.data import (DataLoader, Dataset, ConcatDataset,
|
||||
WeightedRandomSampler)
|
||||
from unifolm_wma.data.base import Txt2ImgIterableBaseDataset
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
def worker_init_fn(_):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
|
||||
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
||||
split_size = dataset.num_records // worker_info.num_workers
|
||||
# Reset num_records to the true number to retain reliable length information
|
||||
dataset.sample_ids = dataset.valid_ids[worker_id *
|
||||
split_size:(worker_id + 1) *
|
||||
split_size]
|
||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||
else:
|
||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
|
||||
class WrappedDataset(Dataset):
|
||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.data = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
|
||||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
train=None,
|
||||
validation=None,
|
||||
test=None,
|
||||
predict=None,
|
||||
wrap=False,
|
||||
num_workers=None,
|
||||
shuffle_test_loader=False,
|
||||
use_worker_init_fn=False,
|
||||
shuffle_val_dataloader=True,
|
||||
train_img=None,
|
||||
dataset_and_weights=None):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.dataset_configs = dict()
|
||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||
self.use_worker_init_fn = use_worker_init_fn
|
||||
if train is not None:
|
||||
self.dataset_configs["train"] = train
|
||||
self.train_dataloader = self._train_dataloader
|
||||
if validation is not None:
|
||||
self.dataset_configs["validation"] = validation
|
||||
self.val_dataloader = partial(self._val_dataloader,
|
||||
shuffle=shuffle_val_dataloader)
|
||||
if test is not None:
|
||||
self.dataset_configs["test"] = test
|
||||
self.test_dataloader = partial(self._test_dataloader,
|
||||
shuffle=shuffle_test_loader)
|
||||
if predict is not None:
|
||||
self.dataset_configs["predict"] = predict
|
||||
self.predict_dataloader = self._predict_dataloader
|
||||
|
||||
self.img_loader = None
|
||||
self.wrap = wrap
|
||||
self.collate_fn = None
|
||||
self.dataset_weights = dataset_and_weights
|
||||
assert round(sum(self.dataset_weights.values()),
|
||||
2) == 1.0, "The sum of dataset weights != 1.0"
|
||||
|
||||
def prepare_data(self):
|
||||
pass
|
||||
|
||||
def setup(self, stage=None):
|
||||
if 'train' in self.dataset_configs:
|
||||
self.train_datasets = dict()
|
||||
for dataname in self.dataset_weights:
|
||||
data_dir = self.dataset_configs['train']['params']['data_dir']
|
||||
transition_dir = '/'.join([data_dir, 'transitions'])
|
||||
csv_file = f'{dataname}.csv'
|
||||
meta_path = '/'.join([data_dir, csv_file])
|
||||
self.dataset_configs['train']['params'][
|
||||
'meta_path'] = meta_path
|
||||
self.dataset_configs['train']['params'][
|
||||
'transition_dir'] = transition_dir
|
||||
self.dataset_configs['train']['params'][
|
||||
'dataset_name'] = dataname
|
||||
self.train_datasets[dataname] = instantiate_from_config(
|
||||
self.dataset_configs['train'])
|
||||
|
||||
# Setup validation dataset
|
||||
if 'validation' in self.dataset_configs:
|
||||
self.val_datasets = dict()
|
||||
for dataname in self.dataset_weights:
|
||||
data_dir = self.dataset_configs['validation']['params'][
|
||||
'data_dir']
|
||||
transition_dir = '/'.join([data_dir, 'transitions'])
|
||||
csv_file = f'{dataname}.csv'
|
||||
meta_path = '/'.join([data_dir, csv_file])
|
||||
self.dataset_configs['validation']['params'][
|
||||
'meta_path'] = meta_path
|
||||
self.dataset_configs['validation']['params'][
|
||||
'transition_dir'] = transition_dir
|
||||
self.dataset_configs['validation']['params'][
|
||||
'dataset_name'] = dataname
|
||||
self.val_datasets[dataname] = instantiate_from_config(
|
||||
self.dataset_configs['validation'])
|
||||
|
||||
# Setup test dataset
|
||||
if 'test' in self.dataset_configs:
|
||||
self.test_datasets = dict()
|
||||
for dataname in self.dataset_weights:
|
||||
data_dir = self.dataset_configs['test']['params']['data_dir']
|
||||
transition_dir = '/'.join([data_dir, 'transitions'])
|
||||
csv_file = f'{dataname}.csv'
|
||||
meta_path = '/'.join([data_dir, csv_file])
|
||||
self.dataset_configs['test']['params']['meta_path'] = meta_path
|
||||
self.dataset_configs['test']['params'][
|
||||
'transition_dir'] = transition_dir
|
||||
self.dataset_configs['test']['params'][
|
||||
'dataset_name'] = dataname
|
||||
self.test_datasets[dataname] = instantiate_from_config(
|
||||
self.dataset_configs['test'])
|
||||
|
||||
if self.wrap:
|
||||
for k in self.datasets:
|
||||
self.datasets[k] = WrappedDataset(self.datasets[k])
|
||||
|
||||
def _train_dataloader(self):
|
||||
is_iterable_dataset = False # NOTE Hand Code
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
combined_dataset = []
|
||||
sample_weights = []
|
||||
for dataname, dataset in self.train_datasets.items():
|
||||
combined_dataset.append(dataset)
|
||||
sample_weights.append(
|
||||
torch.full((len(dataset), ),
|
||||
self.dataset_weights[dataname] / len(dataset)))
|
||||
combined_dataset = ConcatDataset(combined_dataset)
|
||||
sample_weights = torch.cat(sample_weights)
|
||||
sampler = WeightedRandomSampler(sample_weights,
|
||||
num_samples=len(combined_dataset),
|
||||
replacement=True)
|
||||
loader = DataLoader(combined_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
collate_fn=self.collate_fn,
|
||||
drop_last=True
|
||||
)
|
||||
return loader
|
||||
|
||||
def _val_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = False # NOTE Hand Code
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
combined_dataset = []
|
||||
sample_weights = []
|
||||
for dataname, dataset in self.val_datasets.items():
|
||||
combined_dataset.append(dataset)
|
||||
sample_weights.append(
|
||||
torch.full((len(dataset), ),
|
||||
self.dataset_weights[dataname] / len(dataset)))
|
||||
combined_dataset = ConcatDataset(combined_dataset)
|
||||
sample_weights = torch.cat(sample_weights)
|
||||
sampler = WeightedRandomSampler(sample_weights,
|
||||
num_samples=len(combined_dataset),
|
||||
replacement=True)
|
||||
loader = DataLoader(combined_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
collate_fn=self.collate_fn)
|
||||
return loader
|
||||
|
||||
def _test_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = False # NOTE Hand Code
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
combined_dataset = []
|
||||
sample_weights = []
|
||||
for dataname, dataset in self.test_datasets.items():
|
||||
combined_dataset.append(dataset)
|
||||
sample_weights.append(
|
||||
torch.full((len(dataset), ),
|
||||
self.dataset_weights[dataname] / len(dataset)))
|
||||
combined_dataset = ConcatDataset(combined_dataset)
|
||||
sample_weights = torch.cat(sample_weights)
|
||||
sampler = WeightedRandomSampler(sample_weights,
|
||||
num_samples=len(combined_dataset),
|
||||
replacement=True)
|
||||
loader = DataLoader(combined_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
collate_fn=self.collate_fn)
|
||||
return loader
|
||||
|
||||
def _predict_dataloader(self, shuffle=False):
|
||||
if isinstance(self.datasets['predict'],
|
||||
Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoader(
|
||||
self.datasets["predict"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
collate_fn=self.collate_fn,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
count = 0
|
||||
for _, values in self.train_datasets.items():
|
||||
count += len(values)
|
||||
return count
|
||||
191
src/unifolm_wma/utils/diffusion.py
Normal file
191
src/unifolm_wma/utils/diffusion.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) *
|
||||
torch.arange(start=0, end=half, dtype=torch.float32) /
|
||||
half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def make_beta_schedule(schedule,
|
||||
n_timestep,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (torch.linspace(linear_start**0.5,
|
||||
linear_end**0.5,
|
||||
n_timestep,
|
||||
dtype=torch.float64)**2)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep +
|
||||
cosine_s)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start,
|
||||
linear_end,
|
||||
n_timestep,
|
||||
dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start,
|
||||
linear_end,
|
||||
n_timestep,
|
||||
dtype=torch.float64)**0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method,
|
||||
num_ddim_timesteps,
|
||||
num_ddpm_timesteps,
|
||||
verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
steps_out = ddim_timesteps + 1
|
||||
elif ddim_discr_method == 'uniform_trailing':
|
||||
c = num_ddpm_timesteps / num_ddim_timesteps
|
||||
ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0,
|
||||
-c))).astype(np.int64)
|
||||
steps_out = ddim_timesteps - 1
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
|
||||
num_ddim_timesteps))**2).astype(int)
|
||||
steps_out = ddim_timesteps + 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
||||
)
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
# steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums,
|
||||
ddim_timesteps,
|
||||
eta,
|
||||
verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
# print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}')
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] +
|
||||
alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt(
|
||||
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(
|
||||
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
|
||||
)
|
||||
print(
|
||||
f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
|
||||
)
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
|
||||
Args:
|
||||
betas (`numpy.ndarray`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Returns:
|
||||
`numpy.ndarray`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_bar_sqrt = np.sqrt(alphas_cumprod)
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 -
|
||||
alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = np.concatenate([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)),
|
||||
keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# Rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# Mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (
|
||||
1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
94
src/unifolm_wma/utils/distributions.py
Normal file
94
src/unifolm_wma/utils/distributions.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn(self.mean.shape)
|
||||
|
||||
x = self.mean + self.std * noise.to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var +
|
||||
self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar +
|
||||
torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
|
||||
((mean1 - mean2)**2) * torch.exp(-logvar2))
|
||||
84
src/unifolm_wma/utils/ema.py
Normal file
84
src/unifolm_wma/utils/ema.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'num_updates',
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates else torch.tensor(-1, dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
#Remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay,
|
||||
(1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(
|
||||
m_param[key])
|
||||
shadow_params[sname].sub_(
|
||||
one_minus_decay *
|
||||
(shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(
|
||||
shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
66
src/unifolm_wma/utils/nn_utils.py
Normal file
66
src/unifolm_wma/utils/nn_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
nn_utils.py
|
||||
|
||||
Utility functions and PyTorch submodule definitions.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
|
||||
class LinearProjector(nn.Module):
|
||||
|
||||
def __init__(self, vision_dim: int, llm_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
|
||||
|
||||
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(img_patches)
|
||||
|
||||
|
||||
class MLPProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vision_dim: int,
|
||||
llm_dim: int,
|
||||
mlp_type: str = "gelu-mlp") -> None:
|
||||
super().__init__()
|
||||
if mlp_type == "gelu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(vision_dim, llm_dim, bias=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_dim, llm_dim, bias=True),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(img_patches)
|
||||
|
||||
|
||||
class FusedMLPProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fused_vision_dim: int,
|
||||
llm_dim: int,
|
||||
mlp_type: str = "fused-gelu-mlp") -> None:
|
||||
super().__init__()
|
||||
self.initial_projection_dim = fused_vision_dim * 4
|
||||
if mlp_type == "fused-gelu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(fused_vision_dim,
|
||||
self.initial_projection_dim,
|
||||
bias=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_dim, llm_dim, bias=True),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Fused Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(fused_img_patches)
|
||||
147
src/unifolm_wma/utils/projector.py
Normal file
147
src/unifolm_wma/utils/projector.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LinearProjector(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.projector = nn.Linear(input_dim, output_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class MLPProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
mlp_type: str = "gelu-mlp") -> None:
|
||||
super().__init__()
|
||||
if mlp_type == "gelu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(vision_dim, llm_dim, bias=True),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(llm_dim, llm_dim, bias=True),
|
||||
)
|
||||
elif mlp_type == "silu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(vision_dim, llm_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(llm_dim, llm_dim, bias=True),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
|
||||
inner_dim = int(dim * mult)
|
||||
if ffd_type = "gelu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
elif ffd_type = "silu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
|
||||
class TokenProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=1,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=16,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
chunck_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.chunck_size = chunck_size
|
||||
if chunck_size is not None:
|
||||
num_queries = num_queries * chunck_size
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x)
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
258
src/unifolm_wma/utils/save_video.py
Normal file
258
src/unifolm_wma/utils/save_video.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchvision
|
||||
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from torchvision.utils import make_grid
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
|
||||
|
||||
def frames_to_mp4(frame_dir, output_path, fps):
|
||||
|
||||
def read_first_n_frames(d: os.PathLike, num_frames: int):
|
||||
if num_frames:
|
||||
images = [
|
||||
Image.open(os.path.join(d, f))
|
||||
for f in sorted(os.listdir(d))[:num_frames]
|
||||
]
|
||||
else:
|
||||
images = [
|
||||
Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))
|
||||
]
|
||||
images = [to_tensor(x) for x in images]
|
||||
return torch.stack(images)
|
||||
|
||||
videos = read_first_n_frames(frame_dir, num_frames=None)
|
||||
videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(output_path,
|
||||
videos,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
|
||||
"""
|
||||
video: torch.Tensor, b,c,t,h,w, 0-1
|
||||
if -1~1, enable rescale=True
|
||||
"""
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
nrow = int(np.sqrt(n)) if nrow is None else nrow
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=nrow, padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids,
|
||||
dim=0)
|
||||
grid = torch.clamp(grid.float(), -1., 1.)
|
||||
if rescale:
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(
|
||||
0, 2, 3, 1)
|
||||
torchvision.io.write_video(savepath,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True):
|
||||
assert (video.dim() == 5)
|
||||
assert (isinstance(video, torch.Tensor))
|
||||
|
||||
video = video.detach().cpu()
|
||||
if clamp:
|
||||
video = torch.clamp(video, -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n)))
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids,
|
||||
dim=0)
|
||||
if rescale:
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(
|
||||
0, 2, 3, 1)
|
||||
path = os.path.join(root, filename)
|
||||
torchvision.io.write_video(path,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
|
||||
if batch_logs is None:
|
||||
return None
|
||||
""" save images and videos from images dict """
|
||||
|
||||
def save_img_grid(grid, path, rescale):
|
||||
if rescale:
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||
Image.fromarray(grid).save(path)
|
||||
|
||||
for key in batch_logs:
|
||||
value = batch_logs[key]
|
||||
if isinstance(value, list) and isinstance(value[0], str):
|
||||
# A batch of captions
|
||||
path = os.path.join(save_dir, "%s-%s.txt" % (key, filename))
|
||||
with open(path, 'w') as f:
|
||||
for i, txt in enumerate(value):
|
||||
f.write(f'idx={i}, txt={txt}\n')
|
||||
f.close()
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 5:
|
||||
# Save video grids
|
||||
video = value
|
||||
# Only save grayscale or rgb mode
|
||||
if video.shape[1] != 1 and video.shape[1] != 3:
|
||||
continue
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(1), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids,
|
||||
dim=0)
|
||||
if rescale:
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename))
|
||||
torchvision.io.write_video(path,
|
||||
grid,
|
||||
fps=save_fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
# Save frame sheet
|
||||
img = value
|
||||
video_frames = rearrange(img, 'b c t h w -> (b t) c h w')
|
||||
t = img.shape[2]
|
||||
grid = torchvision.utils.make_grid(video_frames, nrow=t, padding=0)
|
||||
path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
|
||||
# Save_img_grid(grid, path, rescale)
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 4:
|
||||
# Save image grids
|
||||
img = value
|
||||
# Only save grayscale or rgb mode
|
||||
if img.shape[1] != 1 and img.shape[1] != 3:
|
||||
continue
|
||||
n = img.shape[0]
|
||||
grid = torchvision.utils.make_grid(img, nrow=1, padding=0)
|
||||
path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
|
||||
save_img_grid(grid, path, rescale)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def prepare_to_log(batch_logs, max_images=100000, clamp=True):
|
||||
if batch_logs is None:
|
||||
return None
|
||||
for key in batch_logs:
|
||||
N = batch_logs[key].shape[0] if hasattr(
|
||||
batch_logs[key], 'shape') else len(batch_logs[key])
|
||||
N = min(N, max_images)
|
||||
batch_logs[key] = batch_logs[key][:N]
|
||||
# In batch_logs: images <batched tensor> & instruction <text list>
|
||||
if isinstance(batch_logs[key], torch.Tensor):
|
||||
batch_logs[key] = batch_logs[key].detach().cpu()
|
||||
if clamp:
|
||||
try:
|
||||
batch_logs[key] = torch.clamp(batch_logs[key].float(), -1.,
|
||||
1.)
|
||||
except RuntimeError:
|
||||
print("clamp_scalar_cpu not implemented for Half")
|
||||
return batch_logs
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def fill_with_black_squares(video, desired_len: int) -> Tensor:
|
||||
if len(video) >= desired_len:
|
||||
return video
|
||||
|
||||
return torch.cat([
|
||||
video,
|
||||
torch.zeros_like(video[0]).unsqueeze(0).repeat(
|
||||
desired_len - len(video), 1, 1, 1),
|
||||
],
|
||||
dim=0)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
def load_num_videos(data_path, num_videos):
|
||||
# First argument can be either data_path of np array
|
||||
if isinstance(data_path, str):
|
||||
videos = np.load(data_path)['arr_0'] # NTHWC
|
||||
elif isinstance(data_path, np.ndarray):
|
||||
videos = data_path
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
if num_videos is not None:
|
||||
videos = videos[:num_videos, :, :, :, :]
|
||||
return videos
|
||||
|
||||
|
||||
def npz_to_video_grid(data_path,
|
||||
out_path,
|
||||
num_frames,
|
||||
fps,
|
||||
num_videos=None,
|
||||
nrow=None,
|
||||
verbose=True):
|
||||
if isinstance(data_path, str):
|
||||
videos = load_num_videos(data_path, num_videos)
|
||||
elif isinstance(data_path, np.ndarray):
|
||||
videos = data_path
|
||||
else:
|
||||
raise Exception
|
||||
n, t, h, w, c = videos.shape
|
||||
videos_th = []
|
||||
for i in range(n):
|
||||
video = videos[i, :, :, :, :]
|
||||
images = [video[j, :, :, :] for j in range(t)]
|
||||
images = [to_tensor(img) for img in images]
|
||||
video = torch.stack(images)
|
||||
videos_th.append(video)
|
||||
if verbose:
|
||||
videos = [
|
||||
fill_with_black_squares(v, num_frames)
|
||||
for v in tqdm(videos_th, desc='Adding empty frames')
|
||||
]
|
||||
else:
|
||||
videos = [fill_with_black_squares(v, num_frames)
|
||||
for v in videos_th] # NTCHW
|
||||
|
||||
frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4)
|
||||
if nrow is None:
|
||||
nrow = int(np.ceil(np.sqrt(n)))
|
||||
if verbose:
|
||||
frame_grids = [
|
||||
make_grid(fs, nrow=nrow)
|
||||
for fs in tqdm(frame_grids, desc='Making grids')
|
||||
]
|
||||
else:
|
||||
frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
|
||||
|
||||
if os.path.dirname(out_path) != "":
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(
|
||||
0, 2, 3, 1)
|
||||
torchvision.io.write_video(out_path,
|
||||
frame_grids,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
231
src/unifolm_wma/utils/train.py
Normal file
231
src/unifolm_wma/utils/train.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
mainlogger = logging.getLogger('mainlogger')
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def init_workspace(name, logdir, model_config, lightning_config, rank=0):
|
||||
workdir = os.path.join(logdir, name)
|
||||
ckptdir = os.path.join(workdir, "checkpoints")
|
||||
cfgdir = os.path.join(workdir, "configs")
|
||||
loginfo = os.path.join(workdir, "loginfo")
|
||||
|
||||
# Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower)
|
||||
os.makedirs(workdir, exist_ok=True)
|
||||
os.makedirs(ckptdir, exist_ok=True)
|
||||
os.makedirs(cfgdir, exist_ok=True)
|
||||
os.makedirs(loginfo, exist_ok=True)
|
||||
|
||||
if rank == 0:
|
||||
if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
|
||||
os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
exist_ok=True)
|
||||
OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml"))
|
||||
OmegaConf.save(OmegaConf.create({"lightning": lightning_config}),
|
||||
os.path.join(cfgdir, "lightning.yaml"))
|
||||
return workdir, ckptdir, cfgdir, loginfo
|
||||
|
||||
|
||||
def check_config_attribute(config, name):
|
||||
if name in config:
|
||||
value = getattr(config, name)
|
||||
return value
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger):
|
||||
default_callbacks_cfg = {
|
||||
"model_checkpoint": {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{epoch}",
|
||||
"verbose": True,
|
||||
"save_last": False,
|
||||
}
|
||||
},
|
||||
"batch_logger": {
|
||||
"target": "unifolm_wma.utils.callbacks.ImageLogger",
|
||||
"params": {
|
||||
"save_dir": logdir,
|
||||
"batch_frequency": 1000,
|
||||
"max_images": 4,
|
||||
"clamp": True,
|
||||
}
|
||||
},
|
||||
"learning_rate_logger": {
|
||||
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step",
|
||||
"log_momentum": False
|
||||
}
|
||||
},
|
||||
"cuda_callback": {
|
||||
"target": "unifolm_wma.utils.callbacks.CUDACallback",
|
||||
},
|
||||
}
|
||||
|
||||
# Optional setting for saving checkpoints
|
||||
monitor_metric = check_config_attribute(config.model.params, "monitor")
|
||||
if monitor_metric is not None:
|
||||
mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.")
|
||||
default_callbacks_cfg["model_checkpoint"]["params"][
|
||||
"monitor"] = monitor_metric
|
||||
default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3
|
||||
default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min"
|
||||
|
||||
if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
|
||||
mainlogger.info(
|
||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
|
||||
)
|
||||
default_metrics_over_trainsteps_ckpt_dict = {
|
||||
'metrics_over_trainsteps_checkpoint': {
|
||||
"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch}-{step}",
|
||||
"verbose": True,
|
||||
'save_top_k': -1,
|
||||
'every_n_train_steps': 10000,
|
||||
'save_weights_only': True
|
||||
}
|
||||
}
|
||||
}
|
||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
|
||||
if "callbacks" in lightning_config:
|
||||
callbacks_cfg = lightning_config.callbacks
|
||||
else:
|
||||
callbacks_cfg = OmegaConf.create()
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
|
||||
return callbacks_cfg
|
||||
|
||||
|
||||
def get_trainer_logger(lightning_config, logdir, on_debug):
|
||||
default_logger_cfgs = {
|
||||
"tensorboard": {
|
||||
"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
||||
"params": {
|
||||
"save_dir": logdir,
|
||||
"name": "tensorboard",
|
||||
}
|
||||
},
|
||||
"testtube": {
|
||||
"target": "pytorch_lightning.loggers.CSVLogger",
|
||||
"params": {
|
||||
"name": "testtube",
|
||||
"save_dir": logdir,
|
||||
}
|
||||
},
|
||||
}
|
||||
os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True)
|
||||
default_logger_cfg = default_logger_cfgs["tensorboard"]
|
||||
if "logger" in lightning_config:
|
||||
logger_cfg = lightning_config.logger
|
||||
else:
|
||||
logger_cfg = OmegaConf.create()
|
||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
||||
return logger_cfg
|
||||
|
||||
|
||||
def get_trainer_strategy(lightning_config):
|
||||
default_strategy_dict = {
|
||||
"target": "pytorch_lightning.strategies.DDPShardedStrategy"
|
||||
}
|
||||
if "strategy" in lightning_config:
|
||||
strategy_cfg = lightning_config.strategy
|
||||
return strategy_cfg
|
||||
else:
|
||||
strategy_cfg = OmegaConf.create()
|
||||
|
||||
strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg)
|
||||
return strategy_cfg
|
||||
|
||||
|
||||
def load_checkpoints(model, model_cfg):
|
||||
if check_config_attribute(model_cfg, "pretrained_checkpoint"):
|
||||
pretrained_ckpt = model_cfg.pretrained_checkpoint
|
||||
assert os.path.exists(
|
||||
pretrained_ckpt
|
||||
), "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt
|
||||
mainlogger.info(">>> Load weights from pretrained checkpoint")
|
||||
|
||||
pl_sd = torch.load(pretrained_ckpt, map_location="cpu")
|
||||
try:
|
||||
if 'state_dict' in pl_sd.keys():
|
||||
model.load_state_dict(pl_sd["state_dict"], strict=False)
|
||||
mainlogger.info(
|
||||
">>> Loaded weights from pretrained checkpoint: %s" %
|
||||
pretrained_ckpt)
|
||||
else:
|
||||
# deepspeed
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in pl_sd['module'].keys():
|
||||
new_pl_sd[key[16:]] = pl_sd['module'][key]
|
||||
model.load_state_dict(new_pl_sd, strict=False)
|
||||
except:
|
||||
model.load_state_dict(pl_sd)
|
||||
else:
|
||||
mainlogger.info(">>> Start training from scratch")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def set_logger(logfile, name='mainlogger'):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
fh = logging.FileHandler(logfile, mode='w')
|
||||
fh.setLevel(logging.INFO)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(
|
||||
logging.Formatter("%(asctime)s-%(levelname)s: %(message)s"))
|
||||
ch.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(fh)
|
||||
logger.addHandler(ch)
|
||||
return logger
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def get_num_parameters(model):
|
||||
models = [('World Model', model.model.diffusion_model),
|
||||
('Action Head', model.model.diffusion_model.action_unet),
|
||||
('State Head', model.model.diffusion_model.state_unet),
|
||||
('Total Trainable', model),
|
||||
('Total', model)]
|
||||
|
||||
data = []
|
||||
for index, (name, model) in enumerate(models):
|
||||
if name == "Total Trainable":
|
||||
total_params = count_trainable_parameters(model)
|
||||
else:
|
||||
total_params = count_parameters(model)
|
||||
if total_params < 0.1e9:
|
||||
total_params_value = round(total_params / 1e6, 2)
|
||||
unit = 'M'
|
||||
else:
|
||||
total_params_value = round(total_params / 1e9, 2)
|
||||
unit = 'B'
|
||||
|
||||
data.append({
|
||||
'Model Name': name,
|
||||
'Params': f"{total_params_value} {unit}"
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
print(df)
|
||||
81
src/unifolm_wma/utils/utils.py
Normal file
81
src/unifolm_wma/utils/utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import importlib
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params."
|
||||
)
|
||||
return total_params
|
||||
|
||||
|
||||
def check_istarget(name, para_list):
|
||||
"""
|
||||
name: full name of source para
|
||||
para_list: partial name of target para
|
||||
"""
|
||||
istarget = False
|
||||
for para in para_list:
|
||||
if para in name:
|
||||
return True
|
||||
return istarget
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def load_npz_from_dir(data_dir):
|
||||
data = [
|
||||
np.load(os.path.join(data_dir, data_name))['arr_0']
|
||||
for data_name in os.listdir(data_dir)
|
||||
]
|
||||
data = np.concatenate(data, axis=0)
|
||||
return data
|
||||
|
||||
|
||||
def load_npz_from_paths(data_paths):
|
||||
data = [np.load(data_path)['arr_0'] for data_path in data_paths]
|
||||
data = np.concatenate(data, axis=0)
|
||||
return data
|
||||
|
||||
|
||||
def resize_numpy_image(image,
|
||||
max_resolution=512 * 512,
|
||||
resize_short_edge=None):
|
||||
h, w = image.shape[:2]
|
||||
if resize_short_edge is not None:
|
||||
k = resize_short_edge / min(h, w)
|
||||
else:
|
||||
k = max_resolution / (h * w)
|
||||
k = k**0.5
|
||||
h = int(np.round(h * k / 64)) * 64
|
||||
w = int(np.round(w * k / 64)) * 64
|
||||
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
|
||||
return image
|
||||
|
||||
|
||||
def setup_dist(args):
|
||||
if dist.is_initialized():
|
||||
return
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group('nccl', init_method='env://')
|
||||
Reference in New Issue
Block a user