init commit
This commit is contained in:
104
src/unifolm_wma/utils/basics.py
Normal file
104
src/unifolm_wma/utils/basics.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
import torch.nn as nn
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def nonlinearity(type='silu'):
|
||||
if type == 'silu':
|
||||
return nn.SiLU()
|
||||
elif type == 'leaky_relu':
|
||||
return nn.LeakyReLU()
|
||||
|
||||
|
||||
class GroupNormSpecific(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def normalization(channels, num_groups=32):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNormSpecific(num_groups, channels)
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(
|
||||
c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
226
src/unifolm_wma/utils/callbacks.py
Normal file
226
src/unifolm_wma/utils/callbacks.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import json
|
||||
|
||||
mainlogger = logging.getLogger('mainlogger')
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import pytorch_lightning as pl
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
|
||||
from unifolm_wma.utils.save_video import log_local, prepare_to_log
|
||||
|
||||
STAT_DIR = '~/'
|
||||
|
||||
|
||||
class ImageLogger(Callback):
|
||||
|
||||
def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=True, save_dir=None, \
|
||||
to_local=False, log_images_kwargs=None):
|
||||
super().__init__()
|
||||
self.rescale = rescale
|
||||
self.batch_freq = batch_frequency
|
||||
self.max_images = max_images
|
||||
self.to_local = to_local
|
||||
self.clamp = clamp
|
||||
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
||||
self.save_stat_dir = os.path.join(save_dir, "stat")
|
||||
os.makedirs(self.save_stat_dir, exist_ok=True)
|
||||
self.fps_stat = {}
|
||||
self.fs_stat = {}
|
||||
if self.to_local:
|
||||
self.save_dir = os.path.join(save_dir, "images")
|
||||
os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True)
|
||||
os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True)
|
||||
self.count_data = 0
|
||||
|
||||
def log_to_tensorboard(self,
|
||||
pl_module,
|
||||
batch_logs,
|
||||
filename,
|
||||
split,
|
||||
save_fps=8):
|
||||
""" log images and videos to tensorboard """
|
||||
global_step = pl_module.global_step
|
||||
for key in batch_logs:
|
||||
value = batch_logs[key]
|
||||
tag = "gs%d-%s/%s||%s||%s||%s" % (
|
||||
global_step, split, key,
|
||||
batch_logs['condition'][0].split('_')[0],
|
||||
batch_logs['condition'][0].split('_')[1],
|
||||
batch_logs['video_idx'])
|
||||
if isinstance(value, list) and isinstance(value[0], str):
|
||||
captions = ' |------| '.join(value)
|
||||
pl_module.logger.experiment.add_text(tag,
|
||||
captions,
|
||||
global_step=global_step)
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 5:
|
||||
video = value
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet,
|
||||
nrow=int(n),
|
||||
padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(
|
||||
frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
pl_module.logger.experiment.add_video(tag,
|
||||
grid,
|
||||
fps=save_fps,
|
||||
global_step=global_step)
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 4:
|
||||
img = value
|
||||
grid = torchvision.utils.make_grid(img, nrow=int(n), padding=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
pl_module.logger.experiment.add_image(tag,
|
||||
grid,
|
||||
global_step=global_step)
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 3:
|
||||
b, _, _ = value.shape
|
||||
value1 = value[:b // 2, ...]
|
||||
value2 = value[b // 2:, ...]
|
||||
_, num_points, d = value1.shape
|
||||
for i in range(d):
|
||||
data1 = value1[0, :, i].cpu().detach().numpy()
|
||||
data2 = value2[0, :, i].cpu().detach().numpy()
|
||||
fig, ax = plt.subplots()
|
||||
ax.plot(data1, label='Target 1')
|
||||
ax.plot(data2, label='Sample 1')
|
||||
ax.set_title(f'Comparison at dimension {i} for {key}')
|
||||
ax.legend()
|
||||
pl_module.logger.experiment.add_figure(
|
||||
tag + f"| {key}_dim_{i}", fig, global_step=global_step)
|
||||
plt.close(fig)
|
||||
else:
|
||||
pass
|
||||
|
||||
@rank_zero_only
|
||||
def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"):
|
||||
""" generate images, then save and log to tensorboard """
|
||||
# Update fps and fs statistics
|
||||
batch_fps = batch['fps'].tolist()
|
||||
batch_fs = batch['frame_stride'].tolist()
|
||||
for num in batch_fps:
|
||||
self.fps_stat[num] = self.fps_stat.get(num, 0) + 1
|
||||
for num in batch_fs:
|
||||
self.fs_stat[num] = self.fs_stat.get(num, 0) + 1
|
||||
skip_freq = self.batch_freq if split == "train" else 5
|
||||
## NOTE HAND CODE
|
||||
self.count_data += 12.5 * 2
|
||||
if self.count_data >= skip_freq:
|
||||
self.count_data = 0
|
||||
|
||||
is_train = pl_module.training
|
||||
if is_train:
|
||||
pl_module.eval()
|
||||
torch.cuda.empty_cache()
|
||||
with torch.no_grad():
|
||||
log_func = pl_module.log_images
|
||||
batch_logs = log_func(batch,
|
||||
split=split,
|
||||
**self.log_images_kwargs)
|
||||
# Log fps and fs statistics
|
||||
with open(self.save_stat_dir + '/fps_fs_stat.json',
|
||||
'w') as file:
|
||||
json.dump({
|
||||
'fps': self.fps_stat,
|
||||
'fs': self.fs_stat
|
||||
},
|
||||
file,
|
||||
indent=4)
|
||||
|
||||
batch_logs = prepare_to_log(batch_logs, self.max_images,
|
||||
self.clamp)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
filename = "ep{}_idx{}_rank{}".format(pl_module.current_epoch,
|
||||
batch_idx,
|
||||
pl_module.global_rank)
|
||||
if self.to_local:
|
||||
mainlogger.info("Log [%s] batch <%s> to local ..." %
|
||||
(split, filename))
|
||||
filename = "gs{}_".format(pl_module.global_step) + filename
|
||||
log_local(batch_logs,
|
||||
os.path.join(self.save_dir, split),
|
||||
filename,
|
||||
save_fps=10)
|
||||
else:
|
||||
mainlogger.info("Log [%s] batch <%s> to tensorboard ..." %
|
||||
(split, filename))
|
||||
self.log_to_tensorboard(pl_module,
|
||||
batch_logs,
|
||||
filename,
|
||||
split,
|
||||
save_fps=10)
|
||||
mainlogger.info('Finish!')
|
||||
|
||||
if is_train:
|
||||
pl_module.train()
|
||||
|
||||
def on_train_batch_end(self,
|
||||
trainer,
|
||||
pl_module,
|
||||
outputs,
|
||||
batch,
|
||||
batch_idx,
|
||||
dataloader_idx=None):
|
||||
if self.batch_freq != -1 and pl_module.logdir:
|
||||
self.log_batch_imgs(pl_module, batch, batch_idx, split="train")
|
||||
|
||||
def on_validation_batch_end(self,
|
||||
trainer,
|
||||
pl_module,
|
||||
outputs,
|
||||
batch,
|
||||
batch_idx,
|
||||
dataloader_idx=None):
|
||||
#Different with validation_step() that saving the whole validation set and only keep the latest,
|
||||
#It records the performance of every validation (without overwritten) by only keep a subset
|
||||
if self.batch_freq != -1 and pl_module.logdir:
|
||||
self.log_batch_imgs(pl_module, batch, batch_idx, split="val")
|
||||
if hasattr(pl_module, 'calibrate_grad_norm'):
|
||||
if (pl_module.calibrate_grad_norm
|
||||
and batch_idx % 25 == 0) and batch_idx > 0:
|
||||
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
|
||||
|
||||
|
||||
class CUDACallback(Callback):
|
||||
# See https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
||||
def on_train_epoch_start(self, trainer, pl_module):
|
||||
# Reset the memory use counter
|
||||
# Lightning update
|
||||
if int((pl.__version__).split('.')[1]) >= 7:
|
||||
gpu_index = trainer.strategy.root_device.index
|
||||
else:
|
||||
gpu_index = trainer.root_gpu
|
||||
torch.cuda.reset_peak_memory_stats(gpu_index)
|
||||
torch.cuda.synchronize(gpu_index)
|
||||
self.start_time = time.time()
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
if int((pl.__version__).split('.')[1]) >= 7:
|
||||
gpu_index = trainer.strategy.root_device.index
|
||||
else:
|
||||
gpu_index = trainer.root_gpu
|
||||
torch.cuda.synchronize(gpu_index)
|
||||
max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2**20
|
||||
epoch_time = time.time() - self.start_time
|
||||
|
||||
try:
|
||||
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
||||
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
||||
|
||||
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
|
||||
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
||||
except AttributeError:
|
||||
pass
|
||||
111
src/unifolm_wma/utils/common.py
Normal file
111
src/unifolm_wma/utils/common.py
Normal file
@@ -0,0 +1,111 @@
|
||||
import math
|
||||
from inspect import isfunction
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def gather_data(data, return_np=True):
|
||||
''' gather data from multiple processes to one list '''
|
||||
data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
|
||||
dist.all_gather(data_list, data) # gather not supported with NCCL
|
||||
if return_np:
|
||||
data_list = [data.cpu().numpy() for data in data_list]
|
||||
return data_list
|
||||
|
||||
|
||||
def autocast(f):
|
||||
|
||||
def do_autocast(*args, **kwargs):
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=True,
|
||||
dtype=torch.get_autocast_gpu_dtype(),
|
||||
cache_enabled=torch.is_autocast_cache_enabled()):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return do_autocast
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
||||
shape[0], *((1, ) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def identity(*args, **kwargs):
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def shape_to_str(x):
|
||||
shape_str = "x".join([str(x) for x in x.shape])
|
||||
return shape_str
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
ckpt = torch.utils.checkpoint.checkpoint
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
return ckpt(func, *inputs, use_reentrant=False)
|
||||
else:
|
||||
return func(*inputs)
|
||||
242
src/unifolm_wma/utils/data.py
Normal file
242
src/unifolm_wma/utils/data.py
Normal file
@@ -0,0 +1,242 @@
|
||||
import os, sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from functools import partial
|
||||
from torch.utils.data import (DataLoader, Dataset, ConcatDataset,
|
||||
WeightedRandomSampler)
|
||||
from unifolm_wma.data.base import Txt2ImgIterableBaseDataset
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
def worker_init_fn(_):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
|
||||
dataset = worker_info.dataset
|
||||
worker_id = worker_info.id
|
||||
|
||||
if isinstance(dataset, Txt2ImgIterableBaseDataset):
|
||||
split_size = dataset.num_records // worker_info.num_workers
|
||||
# Reset num_records to the true number to retain reliable length information
|
||||
dataset.sample_ids = dataset.valid_ids[worker_id *
|
||||
split_size:(worker_id + 1) *
|
||||
split_size]
|
||||
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
|
||||
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
|
||||
else:
|
||||
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
||||
|
||||
|
||||
class WrappedDataset(Dataset):
|
||||
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
||||
|
||||
def __init__(self, dataset):
|
||||
self.data = dataset
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.data[idx]
|
||||
|
||||
|
||||
class DataModuleFromConfig(pl.LightningDataModule):
|
||||
|
||||
def __init__(self,
|
||||
batch_size,
|
||||
train=None,
|
||||
validation=None,
|
||||
test=None,
|
||||
predict=None,
|
||||
wrap=False,
|
||||
num_workers=None,
|
||||
shuffle_test_loader=False,
|
||||
use_worker_init_fn=False,
|
||||
shuffle_val_dataloader=True,
|
||||
train_img=None,
|
||||
dataset_and_weights=None):
|
||||
super().__init__()
|
||||
self.batch_size = batch_size
|
||||
self.dataset_configs = dict()
|
||||
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
||||
self.use_worker_init_fn = use_worker_init_fn
|
||||
if train is not None:
|
||||
self.dataset_configs["train"] = train
|
||||
self.train_dataloader = self._train_dataloader
|
||||
if validation is not None:
|
||||
self.dataset_configs["validation"] = validation
|
||||
self.val_dataloader = partial(self._val_dataloader,
|
||||
shuffle=shuffle_val_dataloader)
|
||||
if test is not None:
|
||||
self.dataset_configs["test"] = test
|
||||
self.test_dataloader = partial(self._test_dataloader,
|
||||
shuffle=shuffle_test_loader)
|
||||
if predict is not None:
|
||||
self.dataset_configs["predict"] = predict
|
||||
self.predict_dataloader = self._predict_dataloader
|
||||
|
||||
self.img_loader = None
|
||||
self.wrap = wrap
|
||||
self.collate_fn = None
|
||||
self.dataset_weights = dataset_and_weights
|
||||
assert round(sum(self.dataset_weights.values()),
|
||||
2) == 1.0, "The sum of dataset weights != 1.0"
|
||||
|
||||
def prepare_data(self):
|
||||
pass
|
||||
|
||||
def setup(self, stage=None):
|
||||
if 'train' in self.dataset_configs:
|
||||
self.train_datasets = dict()
|
||||
for dataname in self.dataset_weights:
|
||||
data_dir = self.dataset_configs['train']['params']['data_dir']
|
||||
transition_dir = '/'.join([data_dir, 'transitions'])
|
||||
csv_file = f'{dataname}.csv'
|
||||
meta_path = '/'.join([data_dir, csv_file])
|
||||
self.dataset_configs['train']['params'][
|
||||
'meta_path'] = meta_path
|
||||
self.dataset_configs['train']['params'][
|
||||
'transition_dir'] = transition_dir
|
||||
self.dataset_configs['train']['params'][
|
||||
'dataset_name'] = dataname
|
||||
self.train_datasets[dataname] = instantiate_from_config(
|
||||
self.dataset_configs['train'])
|
||||
|
||||
# Setup validation dataset
|
||||
if 'validation' in self.dataset_configs:
|
||||
self.val_datasets = dict()
|
||||
for dataname in self.dataset_weights:
|
||||
data_dir = self.dataset_configs['validation']['params'][
|
||||
'data_dir']
|
||||
transition_dir = '/'.join([data_dir, 'transitions'])
|
||||
csv_file = f'{dataname}.csv'
|
||||
meta_path = '/'.join([data_dir, csv_file])
|
||||
self.dataset_configs['validation']['params'][
|
||||
'meta_path'] = meta_path
|
||||
self.dataset_configs['validation']['params'][
|
||||
'transition_dir'] = transition_dir
|
||||
self.dataset_configs['validation']['params'][
|
||||
'dataset_name'] = dataname
|
||||
self.val_datasets[dataname] = instantiate_from_config(
|
||||
self.dataset_configs['validation'])
|
||||
|
||||
# Setup test dataset
|
||||
if 'test' in self.dataset_configs:
|
||||
self.test_datasets = dict()
|
||||
for dataname in self.dataset_weights:
|
||||
data_dir = self.dataset_configs['test']['params']['data_dir']
|
||||
transition_dir = '/'.join([data_dir, 'transitions'])
|
||||
csv_file = f'{dataname}.csv'
|
||||
meta_path = '/'.join([data_dir, csv_file])
|
||||
self.dataset_configs['test']['params']['meta_path'] = meta_path
|
||||
self.dataset_configs['test']['params'][
|
||||
'transition_dir'] = transition_dir
|
||||
self.dataset_configs['test']['params'][
|
||||
'dataset_name'] = dataname
|
||||
self.test_datasets[dataname] = instantiate_from_config(
|
||||
self.dataset_configs['test'])
|
||||
|
||||
if self.wrap:
|
||||
for k in self.datasets:
|
||||
self.datasets[k] = WrappedDataset(self.datasets[k])
|
||||
|
||||
def _train_dataloader(self):
|
||||
is_iterable_dataset = False # NOTE Hand Code
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
combined_dataset = []
|
||||
sample_weights = []
|
||||
for dataname, dataset in self.train_datasets.items():
|
||||
combined_dataset.append(dataset)
|
||||
sample_weights.append(
|
||||
torch.full((len(dataset), ),
|
||||
self.dataset_weights[dataname] / len(dataset)))
|
||||
combined_dataset = ConcatDataset(combined_dataset)
|
||||
sample_weights = torch.cat(sample_weights)
|
||||
sampler = WeightedRandomSampler(sample_weights,
|
||||
num_samples=len(combined_dataset),
|
||||
replacement=True)
|
||||
loader = DataLoader(combined_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
collate_fn=self.collate_fn,
|
||||
drop_last=True
|
||||
)
|
||||
return loader
|
||||
|
||||
def _val_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = False # NOTE Hand Code
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
combined_dataset = []
|
||||
sample_weights = []
|
||||
for dataname, dataset in self.val_datasets.items():
|
||||
combined_dataset.append(dataset)
|
||||
sample_weights.append(
|
||||
torch.full((len(dataset), ),
|
||||
self.dataset_weights[dataname] / len(dataset)))
|
||||
combined_dataset = ConcatDataset(combined_dataset)
|
||||
sample_weights = torch.cat(sample_weights)
|
||||
sampler = WeightedRandomSampler(sample_weights,
|
||||
num_samples=len(combined_dataset),
|
||||
replacement=True)
|
||||
loader = DataLoader(combined_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
collate_fn=self.collate_fn)
|
||||
return loader
|
||||
|
||||
def _test_dataloader(self, shuffle=False):
|
||||
is_iterable_dataset = False # NOTE Hand Code
|
||||
if is_iterable_dataset or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
combined_dataset = []
|
||||
sample_weights = []
|
||||
for dataname, dataset in self.test_datasets.items():
|
||||
combined_dataset.append(dataset)
|
||||
sample_weights.append(
|
||||
torch.full((len(dataset), ),
|
||||
self.dataset_weights[dataname] / len(dataset)))
|
||||
combined_dataset = ConcatDataset(combined_dataset)
|
||||
sample_weights = torch.cat(sample_weights)
|
||||
sampler = WeightedRandomSampler(sample_weights,
|
||||
num_samples=len(combined_dataset),
|
||||
replacement=True)
|
||||
loader = DataLoader(combined_dataset,
|
||||
sampler=sampler,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
collate_fn=self.collate_fn)
|
||||
return loader
|
||||
|
||||
def _predict_dataloader(self, shuffle=False):
|
||||
if isinstance(self.datasets['predict'],
|
||||
Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
|
||||
init_fn = worker_init_fn
|
||||
else:
|
||||
init_fn = None
|
||||
return DataLoader(
|
||||
self.datasets["predict"],
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
worker_init_fn=init_fn,
|
||||
collate_fn=self.collate_fn,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
count = 0
|
||||
for _, values in self.train_datasets.items():
|
||||
count += len(values)
|
||||
return count
|
||||
191
src/unifolm_wma/utils/diffusion.py
Normal file
191
src/unifolm_wma/utils/diffusion.py
Normal file
@@ -0,0 +1,191 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) *
|
||||
torch.arange(start=0, end=half, dtype=torch.float32) /
|
||||
half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def make_beta_schedule(schedule,
|
||||
n_timestep,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (torch.linspace(linear_start**0.5,
|
||||
linear_end**0.5,
|
||||
n_timestep,
|
||||
dtype=torch.float64)**2)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep +
|
||||
cosine_s)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start,
|
||||
linear_end,
|
||||
n_timestep,
|
||||
dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start,
|
||||
linear_end,
|
||||
n_timestep,
|
||||
dtype=torch.float64)**0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method,
|
||||
num_ddim_timesteps,
|
||||
num_ddpm_timesteps,
|
||||
verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
steps_out = ddim_timesteps + 1
|
||||
elif ddim_discr_method == 'uniform_trailing':
|
||||
c = num_ddpm_timesteps / num_ddim_timesteps
|
||||
ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0,
|
||||
-c))).astype(np.int64)
|
||||
steps_out = ddim_timesteps - 1
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
|
||||
num_ddim_timesteps))**2).astype(int)
|
||||
steps_out = ddim_timesteps + 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
||||
)
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
# steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums,
|
||||
ddim_timesteps,
|
||||
eta,
|
||||
verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
# print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}')
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] +
|
||||
alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt(
|
||||
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(
|
||||
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
|
||||
)
|
||||
print(
|
||||
f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
|
||||
)
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def rescale_zero_terminal_snr(betas):
|
||||
"""
|
||||
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
|
||||
|
||||
Args:
|
||||
betas (`numpy.ndarray`):
|
||||
the betas that the scheduler is being initialized with.
|
||||
|
||||
Returns:
|
||||
`numpy.ndarray`: rescaled betas with zero terminal SNR
|
||||
"""
|
||||
# Convert betas to alphas_bar_sqrt
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_bar_sqrt = np.sqrt(alphas_cumprod)
|
||||
|
||||
# Store old values.
|
||||
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
|
||||
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
|
||||
|
||||
# Shift so the last timestep is zero.
|
||||
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
||||
|
||||
# Scale so the first timestep is back to the old value.
|
||||
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 -
|
||||
alphas_bar_sqrt_T)
|
||||
|
||||
# Convert alphas_bar_sqrt to betas
|
||||
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
|
||||
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
|
||||
alphas = np.concatenate([alphas_bar[0:1], alphas])
|
||||
betas = 1 - alphas
|
||||
|
||||
return betas
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
"""
|
||||
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
||||
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)),
|
||||
keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# Rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# Mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (
|
||||
1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
94
src/unifolm_wma/utils/distributions.py
Normal file
94
src/unifolm_wma/utils/distributions.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self, noise=None):
|
||||
if noise is None:
|
||||
noise = torch.randn(self.mean.shape)
|
||||
|
||||
x = self.mean + self.std * noise.to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var +
|
||||
self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(logtwopi + self.logvar +
|
||||
torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
|
||||
((mean1 - mean2)**2) * torch.exp(-logvar2))
|
||||
84
src/unifolm_wma/utils/ema.py
Normal file
84
src/unifolm_wma/utils/ema.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'num_updates',
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates else torch.tensor(-1, dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
#Remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
decay = min(self.decay,
|
||||
(1 + self.num_updates) / (10 + self.num_updates))
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(
|
||||
m_param[key])
|
||||
shadow_params[sname].sub_(
|
||||
one_minus_decay *
|
||||
(shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(
|
||||
shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
66
src/unifolm_wma/utils/nn_utils.py
Normal file
66
src/unifolm_wma/utils/nn_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
nn_utils.py
|
||||
|
||||
Utility functions and PyTorch submodule definitions.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
|
||||
class LinearProjector(nn.Module):
|
||||
|
||||
def __init__(self, vision_dim: int, llm_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
|
||||
|
||||
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(img_patches)
|
||||
|
||||
|
||||
class MLPProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
vision_dim: int,
|
||||
llm_dim: int,
|
||||
mlp_type: str = "gelu-mlp") -> None:
|
||||
super().__init__()
|
||||
if mlp_type == "gelu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(vision_dim, llm_dim, bias=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_dim, llm_dim, bias=True),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(img_patches)
|
||||
|
||||
|
||||
class FusedMLPProjector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fused_vision_dim: int,
|
||||
llm_dim: int,
|
||||
mlp_type: str = "fused-gelu-mlp") -> None:
|
||||
super().__init__()
|
||||
self.initial_projection_dim = fused_vision_dim * 4
|
||||
if mlp_type == "fused-gelu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(fused_vision_dim,
|
||||
self.initial_projection_dim,
|
||||
bias=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_dim, llm_dim, bias=True),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Fused Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(fused_img_patches)
|
||||
147
src/unifolm_wma/utils/projector.py
Normal file
147
src/unifolm_wma/utils/projector.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LinearProjector(nn.Module):
|
||||
def __init__(self, input_dim: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.projector = nn.Linear(input_dim, output_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class MLPProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
mlp_type: str = "gelu-mlp") -> None:
|
||||
super().__init__()
|
||||
if mlp_type == "gelu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(vision_dim, llm_dim, bias=True),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(llm_dim, llm_dim, bias=True),
|
||||
)
|
||||
elif mlp_type == "silu-mlp":
|
||||
self.projector = nn.Sequential(
|
||||
nn.Linear(vision_dim, llm_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(llm_dim, llm_dim, bias=True),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.projector(x)
|
||||
|
||||
|
||||
class PerceiverAttention(nn.Module):
|
||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.dim_head = dim_head
|
||||
self.heads = heads
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||
|
||||
|
||||
def forward(self, x, latents):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): image features
|
||||
shape (b, n1, D)
|
||||
latent (torch.Tensor): latent features
|
||||
shape (b, n2, D)
|
||||
"""
|
||||
|
||||
x = self.norm1(x)
|
||||
latents = self.norm2(latents)
|
||||
|
||||
b, l, _ = latents.shape
|
||||
|
||||
q = self.to_q(latents)
|
||||
kv_input = torch.cat((x, latents), dim=-2)
|
||||
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
|
||||
|
||||
q = reshape_tensor(q, self.heads)
|
||||
k = reshape_tensor(k, self.heads)
|
||||
v = reshape_tensor(v, self.heads)
|
||||
|
||||
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
|
||||
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
|
||||
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
out = weight @ v
|
||||
|
||||
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
|
||||
inner_dim = int(dim * mult)
|
||||
if ffd_type = "gelu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.GELU(approximate='tanh'),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
elif ffd_type = "silu-ffd":
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(inner_dim, dim, bias=False),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
|
||||
|
||||
|
||||
class TokenProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim=1024,
|
||||
depth=1,
|
||||
dim_head=64,
|
||||
heads=16,
|
||||
num_queries=16,
|
||||
output_dim=1024,
|
||||
ff_mult=4,
|
||||
chunck_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_queries = num_queries
|
||||
self.chunck_size = chunck_size
|
||||
if chunck_size is not None:
|
||||
num_queries = num_queries * chunck_size
|
||||
|
||||
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
|
||||
self.proj_out = nn.Linear(dim, output_dim)
|
||||
self.norm_out = nn.LayerNorm(dim)
|
||||
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
|
||||
FeedForward(dim=dim, mult=ff_mult),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x)
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
for attn, ff in self.layers:
|
||||
latents = attn(x, latents) + latents
|
||||
latents = ff(latents) + latents
|
||||
latents = self.proj_out(latents)
|
||||
latents = self.norm_out(latents)
|
||||
258
src/unifolm_wma/utils/save_video.py
Normal file
258
src/unifolm_wma/utils/save_video.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import torchvision
|
||||
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
from torchvision.utils import make_grid
|
||||
from torchvision.transforms.functional import to_tensor
|
||||
|
||||
|
||||
def frames_to_mp4(frame_dir, output_path, fps):
|
||||
|
||||
def read_first_n_frames(d: os.PathLike, num_frames: int):
|
||||
if num_frames:
|
||||
images = [
|
||||
Image.open(os.path.join(d, f))
|
||||
for f in sorted(os.listdir(d))[:num_frames]
|
||||
]
|
||||
else:
|
||||
images = [
|
||||
Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))
|
||||
]
|
||||
images = [to_tensor(x) for x in images]
|
||||
return torch.stack(images)
|
||||
|
||||
videos = read_first_n_frames(frame_dir, num_frames=None)
|
||||
videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(output_path,
|
||||
videos,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
|
||||
"""
|
||||
video: torch.Tensor, b,c,t,h,w, 0-1
|
||||
if -1~1, enable rescale=True
|
||||
"""
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
nrow = int(np.sqrt(n)) if nrow is None else nrow
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=nrow, padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids,
|
||||
dim=0)
|
||||
grid = torch.clamp(grid.float(), -1., 1.)
|
||||
if rescale:
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(
|
||||
0, 2, 3, 1)
|
||||
torchvision.io.write_video(savepath,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True):
|
||||
assert (video.dim() == 5)
|
||||
assert (isinstance(video, torch.Tensor))
|
||||
|
||||
video = video.detach().cpu()
|
||||
if clamp:
|
||||
video = torch.clamp(video, -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n)))
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids,
|
||||
dim=0)
|
||||
if rescale:
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(
|
||||
0, 2, 3, 1)
|
||||
path = os.path.join(root, filename)
|
||||
torchvision.io.write_video(path,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
|
||||
if batch_logs is None:
|
||||
return None
|
||||
""" save images and videos from images dict """
|
||||
|
||||
def save_img_grid(grid, path, rescale):
|
||||
if rescale:
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
grid = grid.numpy()
|
||||
grid = (grid * 255).astype(np.uint8)
|
||||
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
||||
Image.fromarray(grid).save(path)
|
||||
|
||||
for key in batch_logs:
|
||||
value = batch_logs[key]
|
||||
if isinstance(value, list) and isinstance(value[0], str):
|
||||
# A batch of captions
|
||||
path = os.path.join(save_dir, "%s-%s.txt" % (key, filename))
|
||||
with open(path, 'w') as f:
|
||||
for i, txt in enumerate(value):
|
||||
f.write(f'idx={i}, txt={txt}\n')
|
||||
f.close()
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 5:
|
||||
# Save video grids
|
||||
video = value
|
||||
# Only save grayscale or rgb mode
|
||||
if video.shape[1] != 1 and video.shape[1] != 3:
|
||||
continue
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(1), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids,
|
||||
dim=0)
|
||||
if rescale:
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename))
|
||||
torchvision.io.write_video(path,
|
||||
grid,
|
||||
fps=save_fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
# Save frame sheet
|
||||
img = value
|
||||
video_frames = rearrange(img, 'b c t h w -> (b t) c h w')
|
||||
t = img.shape[2]
|
||||
grid = torchvision.utils.make_grid(video_frames, nrow=t, padding=0)
|
||||
path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
|
||||
# Save_img_grid(grid, path, rescale)
|
||||
elif isinstance(value, torch.Tensor) and value.dim() == 4:
|
||||
# Save image grids
|
||||
img = value
|
||||
# Only save grayscale or rgb mode
|
||||
if img.shape[1] != 1 and img.shape[1] != 3:
|
||||
continue
|
||||
n = img.shape[0]
|
||||
grid = torchvision.utils.make_grid(img, nrow=1, padding=0)
|
||||
path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
|
||||
save_img_grid(grid, path, rescale)
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def prepare_to_log(batch_logs, max_images=100000, clamp=True):
|
||||
if batch_logs is None:
|
||||
return None
|
||||
for key in batch_logs:
|
||||
N = batch_logs[key].shape[0] if hasattr(
|
||||
batch_logs[key], 'shape') else len(batch_logs[key])
|
||||
N = min(N, max_images)
|
||||
batch_logs[key] = batch_logs[key][:N]
|
||||
# In batch_logs: images <batched tensor> & instruction <text list>
|
||||
if isinstance(batch_logs[key], torch.Tensor):
|
||||
batch_logs[key] = batch_logs[key].detach().cpu()
|
||||
if clamp:
|
||||
try:
|
||||
batch_logs[key] = torch.clamp(batch_logs[key].float(), -1.,
|
||||
1.)
|
||||
except RuntimeError:
|
||||
print("clamp_scalar_cpu not implemented for Half")
|
||||
return batch_logs
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def fill_with_black_squares(video, desired_len: int) -> Tensor:
|
||||
if len(video) >= desired_len:
|
||||
return video
|
||||
|
||||
return torch.cat([
|
||||
video,
|
||||
torch.zeros_like(video[0]).unsqueeze(0).repeat(
|
||||
desired_len - len(video), 1, 1, 1),
|
||||
],
|
||||
dim=0)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------
|
||||
def load_num_videos(data_path, num_videos):
|
||||
# First argument can be either data_path of np array
|
||||
if isinstance(data_path, str):
|
||||
videos = np.load(data_path)['arr_0'] # NTHWC
|
||||
elif isinstance(data_path, np.ndarray):
|
||||
videos = data_path
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
if num_videos is not None:
|
||||
videos = videos[:num_videos, :, :, :, :]
|
||||
return videos
|
||||
|
||||
|
||||
def npz_to_video_grid(data_path,
|
||||
out_path,
|
||||
num_frames,
|
||||
fps,
|
||||
num_videos=None,
|
||||
nrow=None,
|
||||
verbose=True):
|
||||
if isinstance(data_path, str):
|
||||
videos = load_num_videos(data_path, num_videos)
|
||||
elif isinstance(data_path, np.ndarray):
|
||||
videos = data_path
|
||||
else:
|
||||
raise Exception
|
||||
n, t, h, w, c = videos.shape
|
||||
videos_th = []
|
||||
for i in range(n):
|
||||
video = videos[i, :, :, :, :]
|
||||
images = [video[j, :, :, :] for j in range(t)]
|
||||
images = [to_tensor(img) for img in images]
|
||||
video = torch.stack(images)
|
||||
videos_th.append(video)
|
||||
if verbose:
|
||||
videos = [
|
||||
fill_with_black_squares(v, num_frames)
|
||||
for v in tqdm(videos_th, desc='Adding empty frames')
|
||||
]
|
||||
else:
|
||||
videos = [fill_with_black_squares(v, num_frames)
|
||||
for v in videos_th] # NTCHW
|
||||
|
||||
frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4)
|
||||
if nrow is None:
|
||||
nrow = int(np.ceil(np.sqrt(n)))
|
||||
if verbose:
|
||||
frame_grids = [
|
||||
make_grid(fs, nrow=nrow)
|
||||
for fs in tqdm(frame_grids, desc='Making grids')
|
||||
]
|
||||
else:
|
||||
frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
|
||||
|
||||
if os.path.dirname(out_path) != "":
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(
|
||||
0, 2, 3, 1)
|
||||
torchvision.io.write_video(out_path,
|
||||
frame_grids,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
231
src/unifolm_wma/utils/train.py
Normal file
231
src/unifolm_wma/utils/train.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
mainlogger = logging.getLogger('mainlogger')
|
||||
|
||||
import torch
|
||||
import pandas as pd
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def init_workspace(name, logdir, model_config, lightning_config, rank=0):
|
||||
workdir = os.path.join(logdir, name)
|
||||
ckptdir = os.path.join(workdir, "checkpoints")
|
||||
cfgdir = os.path.join(workdir, "configs")
|
||||
loginfo = os.path.join(workdir, "loginfo")
|
||||
|
||||
# Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower)
|
||||
os.makedirs(workdir, exist_ok=True)
|
||||
os.makedirs(ckptdir, exist_ok=True)
|
||||
os.makedirs(cfgdir, exist_ok=True)
|
||||
os.makedirs(loginfo, exist_ok=True)
|
||||
|
||||
if rank == 0:
|
||||
if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
|
||||
os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
exist_ok=True)
|
||||
OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml"))
|
||||
OmegaConf.save(OmegaConf.create({"lightning": lightning_config}),
|
||||
os.path.join(cfgdir, "lightning.yaml"))
|
||||
return workdir, ckptdir, cfgdir, loginfo
|
||||
|
||||
|
||||
def check_config_attribute(config, name):
|
||||
if name in config:
|
||||
value = getattr(config, name)
|
||||
return value
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger):
|
||||
default_callbacks_cfg = {
|
||||
"model_checkpoint": {
|
||||
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
|
||||
"params": {
|
||||
"dirpath": ckptdir,
|
||||
"filename": "{epoch}",
|
||||
"verbose": True,
|
||||
"save_last": False,
|
||||
}
|
||||
},
|
||||
"batch_logger": {
|
||||
"target": "unifolm_wma.utils.callbacks.ImageLogger",
|
||||
"params": {
|
||||
"save_dir": logdir,
|
||||
"batch_frequency": 1000,
|
||||
"max_images": 4,
|
||||
"clamp": True,
|
||||
}
|
||||
},
|
||||
"learning_rate_logger": {
|
||||
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
|
||||
"params": {
|
||||
"logging_interval": "step",
|
||||
"log_momentum": False
|
||||
}
|
||||
},
|
||||
"cuda_callback": {
|
||||
"target": "unifolm_wma.utils.callbacks.CUDACallback",
|
||||
},
|
||||
}
|
||||
|
||||
# Optional setting for saving checkpoints
|
||||
monitor_metric = check_config_attribute(config.model.params, "monitor")
|
||||
if monitor_metric is not None:
|
||||
mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.")
|
||||
default_callbacks_cfg["model_checkpoint"]["params"][
|
||||
"monitor"] = monitor_metric
|
||||
default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3
|
||||
default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min"
|
||||
|
||||
if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
|
||||
mainlogger.info(
|
||||
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
|
||||
)
|
||||
default_metrics_over_trainsteps_ckpt_dict = {
|
||||
'metrics_over_trainsteps_checkpoint': {
|
||||
"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
|
||||
'params': {
|
||||
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
|
||||
"filename": "{epoch}-{step}",
|
||||
"verbose": True,
|
||||
'save_top_k': -1,
|
||||
'every_n_train_steps': 10000,
|
||||
'save_weights_only': True
|
||||
}
|
||||
}
|
||||
}
|
||||
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
|
||||
|
||||
if "callbacks" in lightning_config:
|
||||
callbacks_cfg = lightning_config.callbacks
|
||||
else:
|
||||
callbacks_cfg = OmegaConf.create()
|
||||
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
|
||||
|
||||
return callbacks_cfg
|
||||
|
||||
|
||||
def get_trainer_logger(lightning_config, logdir, on_debug):
|
||||
default_logger_cfgs = {
|
||||
"tensorboard": {
|
||||
"target": "pytorch_lightning.loggers.TensorBoardLogger",
|
||||
"params": {
|
||||
"save_dir": logdir,
|
||||
"name": "tensorboard",
|
||||
}
|
||||
},
|
||||
"testtube": {
|
||||
"target": "pytorch_lightning.loggers.CSVLogger",
|
||||
"params": {
|
||||
"name": "testtube",
|
||||
"save_dir": logdir,
|
||||
}
|
||||
},
|
||||
}
|
||||
os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True)
|
||||
default_logger_cfg = default_logger_cfgs["tensorboard"]
|
||||
if "logger" in lightning_config:
|
||||
logger_cfg = lightning_config.logger
|
||||
else:
|
||||
logger_cfg = OmegaConf.create()
|
||||
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
|
||||
return logger_cfg
|
||||
|
||||
|
||||
def get_trainer_strategy(lightning_config):
|
||||
default_strategy_dict = {
|
||||
"target": "pytorch_lightning.strategies.DDPShardedStrategy"
|
||||
}
|
||||
if "strategy" in lightning_config:
|
||||
strategy_cfg = lightning_config.strategy
|
||||
return strategy_cfg
|
||||
else:
|
||||
strategy_cfg = OmegaConf.create()
|
||||
|
||||
strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg)
|
||||
return strategy_cfg
|
||||
|
||||
|
||||
def load_checkpoints(model, model_cfg):
|
||||
if check_config_attribute(model_cfg, "pretrained_checkpoint"):
|
||||
pretrained_ckpt = model_cfg.pretrained_checkpoint
|
||||
assert os.path.exists(
|
||||
pretrained_ckpt
|
||||
), "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt
|
||||
mainlogger.info(">>> Load weights from pretrained checkpoint")
|
||||
|
||||
pl_sd = torch.load(pretrained_ckpt, map_location="cpu")
|
||||
try:
|
||||
if 'state_dict' in pl_sd.keys():
|
||||
model.load_state_dict(pl_sd["state_dict"], strict=False)
|
||||
mainlogger.info(
|
||||
">>> Loaded weights from pretrained checkpoint: %s" %
|
||||
pretrained_ckpt)
|
||||
else:
|
||||
# deepspeed
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in pl_sd['module'].keys():
|
||||
new_pl_sd[key[16:]] = pl_sd['module'][key]
|
||||
model.load_state_dict(new_pl_sd, strict=False)
|
||||
except:
|
||||
model.load_state_dict(pl_sd)
|
||||
else:
|
||||
mainlogger.info(">>> Start training from scratch")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def set_logger(logfile, name='mainlogger'):
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
fh = logging.FileHandler(logfile, mode='w')
|
||||
fh.setLevel(logging.INFO)
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(
|
||||
logging.Formatter("%(asctime)s-%(levelname)s: %(message)s"))
|
||||
ch.setFormatter(logging.Formatter("%(message)s"))
|
||||
logger.addHandler(fh)
|
||||
logger.addHandler(ch)
|
||||
return logger
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
def count_trainable_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
def get_num_parameters(model):
|
||||
models = [('World Model', model.model.diffusion_model),
|
||||
('Action Head', model.model.diffusion_model.action_unet),
|
||||
('State Head', model.model.diffusion_model.state_unet),
|
||||
('Total Trainable', model),
|
||||
('Total', model)]
|
||||
|
||||
data = []
|
||||
for index, (name, model) in enumerate(models):
|
||||
if name == "Total Trainable":
|
||||
total_params = count_trainable_parameters(model)
|
||||
else:
|
||||
total_params = count_parameters(model)
|
||||
if total_params < 0.1e9:
|
||||
total_params_value = round(total_params / 1e6, 2)
|
||||
unit = 'M'
|
||||
else:
|
||||
total_params_value = round(total_params / 1e9, 2)
|
||||
unit = 'B'
|
||||
|
||||
data.append({
|
||||
'Model Name': name,
|
||||
'Params': f"{total_params_value} {unit}"
|
||||
})
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
print(df)
|
||||
81
src/unifolm_wma/utils/utils.py
Normal file
81
src/unifolm_wma/utils/utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import importlib
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params."
|
||||
)
|
||||
return total_params
|
||||
|
||||
|
||||
def check_istarget(name, para_list):
|
||||
"""
|
||||
name: full name of source para
|
||||
para_list: partial name of target para
|
||||
"""
|
||||
istarget = False
|
||||
for para in para_list:
|
||||
if para in name:
|
||||
return True
|
||||
return istarget
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def load_npz_from_dir(data_dir):
|
||||
data = [
|
||||
np.load(os.path.join(data_dir, data_name))['arr_0']
|
||||
for data_name in os.listdir(data_dir)
|
||||
]
|
||||
data = np.concatenate(data, axis=0)
|
||||
return data
|
||||
|
||||
|
||||
def load_npz_from_paths(data_paths):
|
||||
data = [np.load(data_path)['arr_0'] for data_path in data_paths]
|
||||
data = np.concatenate(data, axis=0)
|
||||
return data
|
||||
|
||||
|
||||
def resize_numpy_image(image,
|
||||
max_resolution=512 * 512,
|
||||
resize_short_edge=None):
|
||||
h, w = image.shape[:2]
|
||||
if resize_short_edge is not None:
|
||||
k = resize_short_edge / min(h, w)
|
||||
else:
|
||||
k = max_resolution / (h * w)
|
||||
k = k**0.5
|
||||
h = int(np.round(h * k / 64)) * 64
|
||||
w = int(np.round(w * k / 64)) * 64
|
||||
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
|
||||
return image
|
||||
|
||||
|
||||
def setup_dist(args):
|
||||
if dist.is_initialized():
|
||||
return
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group('nccl', init_method='env://')
|
||||
Reference in New Issue
Block a user