init commit

This commit is contained in:
yuchen-x
2025-09-12 21:53:41 +08:00
parent 275a568149
commit d7be60f9fe
105 changed files with 16119 additions and 1 deletions

View File

@@ -0,0 +1,104 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import torch.nn as nn
from unifolm_wma.utils.utils import instantiate_from_config
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def nonlinearity(type='silu'):
if type == 'silu':
return nn.SiLU()
elif type == 'leaky_relu':
return nn.LeakyReLU()
class GroupNormSpecific(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def normalization(channels, num_groups=32):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNormSpecific(num_groups, channels)
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}

View File

@@ -0,0 +1,226 @@
import os
import time
import logging
import json
mainlogger = logging.getLogger('mainlogger')
import torch
import torchvision
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities import rank_zero_info
from unifolm_wma.utils.save_video import log_local, prepare_to_log
STAT_DIR = '~/'
class ImageLogger(Callback):
def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=True, save_dir=None, \
to_local=False, log_images_kwargs=None):
super().__init__()
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.to_local = to_local
self.clamp = clamp
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
self.save_stat_dir = os.path.join(save_dir, "stat")
os.makedirs(self.save_stat_dir, exist_ok=True)
self.fps_stat = {}
self.fs_stat = {}
if self.to_local:
self.save_dir = os.path.join(save_dir, "images")
os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True)
os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True)
self.count_data = 0
def log_to_tensorboard(self,
pl_module,
batch_logs,
filename,
split,
save_fps=8):
""" log images and videos to tensorboard """
global_step = pl_module.global_step
for key in batch_logs:
value = batch_logs[key]
tag = "gs%d-%s/%s||%s||%s||%s" % (
global_step, split, key,
batch_logs['condition'][0].split('_')[0],
batch_logs['condition'][0].split('_')[1],
batch_logs['video_idx'])
if isinstance(value, list) and isinstance(value[0], str):
captions = ' |------| '.join(value)
pl_module.logger.experiment.add_text(tag,
captions,
global_step=global_step)
elif isinstance(value, torch.Tensor) and value.dim() == 5:
video = value
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet,
nrow=int(n),
padding=0)
for framesheet in video
]
grid = torch.stack(
frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
pl_module.logger.experiment.add_video(tag,
grid,
fps=save_fps,
global_step=global_step)
elif isinstance(value, torch.Tensor) and value.dim() == 4:
img = value
grid = torchvision.utils.make_grid(img, nrow=int(n), padding=0)
grid = (grid + 1.0) / 2.0
pl_module.logger.experiment.add_image(tag,
grid,
global_step=global_step)
elif isinstance(value, torch.Tensor) and value.dim() == 3:
b, _, _ = value.shape
value1 = value[:b // 2, ...]
value2 = value[b // 2:, ...]
_, num_points, d = value1.shape
for i in range(d):
data1 = value1[0, :, i].cpu().detach().numpy()
data2 = value2[0, :, i].cpu().detach().numpy()
fig, ax = plt.subplots()
ax.plot(data1, label='Target 1')
ax.plot(data2, label='Sample 1')
ax.set_title(f'Comparison at dimension {i} for {key}')
ax.legend()
pl_module.logger.experiment.add_figure(
tag + f"| {key}_dim_{i}", fig, global_step=global_step)
plt.close(fig)
else:
pass
@rank_zero_only
def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"):
""" generate images, then save and log to tensorboard """
# Update fps and fs statistics
batch_fps = batch['fps'].tolist()
batch_fs = batch['frame_stride'].tolist()
for num in batch_fps:
self.fps_stat[num] = self.fps_stat.get(num, 0) + 1
for num in batch_fs:
self.fs_stat[num] = self.fs_stat.get(num, 0) + 1
skip_freq = self.batch_freq if split == "train" else 5
## NOTE HAND CODE
self.count_data += 12.5 * 2
if self.count_data >= skip_freq:
self.count_data = 0
is_train = pl_module.training
if is_train:
pl_module.eval()
torch.cuda.empty_cache()
with torch.no_grad():
log_func = pl_module.log_images
batch_logs = log_func(batch,
split=split,
**self.log_images_kwargs)
# Log fps and fs statistics
with open(self.save_stat_dir + '/fps_fs_stat.json',
'w') as file:
json.dump({
'fps': self.fps_stat,
'fs': self.fs_stat
},
file,
indent=4)
batch_logs = prepare_to_log(batch_logs, self.max_images,
self.clamp)
torch.cuda.empty_cache()
filename = "ep{}_idx{}_rank{}".format(pl_module.current_epoch,
batch_idx,
pl_module.global_rank)
if self.to_local:
mainlogger.info("Log [%s] batch <%s> to local ..." %
(split, filename))
filename = "gs{}_".format(pl_module.global_step) + filename
log_local(batch_logs,
os.path.join(self.save_dir, split),
filename,
save_fps=10)
else:
mainlogger.info("Log [%s] batch <%s> to tensorboard ..." %
(split, filename))
self.log_to_tensorboard(pl_module,
batch_logs,
filename,
split,
save_fps=10)
mainlogger.info('Finish!')
if is_train:
pl_module.train()
def on_train_batch_end(self,
trainer,
pl_module,
outputs,
batch,
batch_idx,
dataloader_idx=None):
if self.batch_freq != -1 and pl_module.logdir:
self.log_batch_imgs(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(self,
trainer,
pl_module,
outputs,
batch,
batch_idx,
dataloader_idx=None):
#Different with validation_step() that saving the whole validation set and only keep the latest,
#It records the performance of every validation (without overwritten) by only keep a subset
if self.batch_freq != -1 and pl_module.logdir:
self.log_batch_imgs(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, 'calibrate_grad_norm'):
if (pl_module.calibrate_grad_norm
and batch_idx % 25 == 0) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
class CUDACallback(Callback):
# See https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
def on_train_epoch_start(self, trainer, pl_module):
# Reset the memory use counter
# Lightning update
if int((pl.__version__).split('.')[1]) >= 7:
gpu_index = trainer.strategy.root_device.index
else:
gpu_index = trainer.root_gpu
torch.cuda.reset_peak_memory_stats(gpu_index)
torch.cuda.synchronize(gpu_index)
self.start_time = time.time()
def on_train_epoch_end(self, trainer, pl_module):
if int((pl.__version__).split('.')[1]) >= 7:
gpu_index = trainer.strategy.root_device.index
else:
gpu_index = trainer.root_gpu
torch.cuda.synchronize(gpu_index)
max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2**20
epoch_time = time.time() - self.start_time
try:
max_memory = trainer.training_type_plugin.reduce(max_memory)
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass

View File

@@ -0,0 +1,111 @@
import math
from inspect import isfunction
import torch
from torch import Tensor, nn
import torch.distributed as dist
def gather_data(data, return_np=True):
''' gather data from multiple processes to one list '''
data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
dist.all_gather(data_list, data) # gather not supported with NCCL
if return_np:
data_list = [data.cpu().numpy() for data in data_list]
return data_list
def autocast(f):
def do_autocast(*args, **kwargs):
with torch.cuda.amp.autocast(
enabled=True,
dtype=torch.get_autocast_gpu_dtype(),
cache_enabled=torch.is_autocast_cache_enabled()):
return f(*args, **kwargs)
return do_autocast
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1, ) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def exists(val):
return val is not None
def identity(*args, **kwargs):
return nn.Identity()
def uniq(arr):
return {el: True for el in arr}.keys()
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def shape_to_str(x):
shape_str = "x".join([str(x) for x in x.shape])
return shape_str
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
ckpt = torch.utils.checkpoint.checkpoint
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
return ckpt(func, *inputs, use_reentrant=False)
else:
return func(*inputs)

View File

@@ -0,0 +1,242 @@
import os, sys
import numpy as np
import torch
import pytorch_lightning as pl
from functools import partial
from torch.utils.data import (DataLoader, Dataset, ConcatDataset,
WeightedRandomSampler)
from unifolm_wma.data.base import Txt2ImgIterableBaseDataset
from unifolm_wma.utils.utils import instantiate_from_config
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, Txt2ImgIterableBaseDataset):
split_size = dataset.num_records // worker_info.num_workers
# Reset num_records to the true number to retain reliable length information
dataset.sample_ids = dataset.valid_ids[worker_id *
split_size:(worker_id + 1) *
split_size]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
else:
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=True,
train_img=None,
dataset_and_weights=None):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(self._val_dataloader,
shuffle=shuffle_val_dataloader)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(self._test_dataloader,
shuffle=shuffle_test_loader)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.img_loader = None
self.wrap = wrap
self.collate_fn = None
self.dataset_weights = dataset_and_weights
assert round(sum(self.dataset_weights.values()),
2) == 1.0, "The sum of dataset weights != 1.0"
def prepare_data(self):
pass
def setup(self, stage=None):
if 'train' in self.dataset_configs:
self.train_datasets = dict()
for dataname in self.dataset_weights:
data_dir = self.dataset_configs['train']['params']['data_dir']
transition_dir = '/'.join([data_dir, 'transitions'])
csv_file = f'{dataname}.csv'
meta_path = '/'.join([data_dir, csv_file])
self.dataset_configs['train']['params'][
'meta_path'] = meta_path
self.dataset_configs['train']['params'][
'transition_dir'] = transition_dir
self.dataset_configs['train']['params'][
'dataset_name'] = dataname
self.train_datasets[dataname] = instantiate_from_config(
self.dataset_configs['train'])
# Setup validation dataset
if 'validation' in self.dataset_configs:
self.val_datasets = dict()
for dataname in self.dataset_weights:
data_dir = self.dataset_configs['validation']['params'][
'data_dir']
transition_dir = '/'.join([data_dir, 'transitions'])
csv_file = f'{dataname}.csv'
meta_path = '/'.join([data_dir, csv_file])
self.dataset_configs['validation']['params'][
'meta_path'] = meta_path
self.dataset_configs['validation']['params'][
'transition_dir'] = transition_dir
self.dataset_configs['validation']['params'][
'dataset_name'] = dataname
self.val_datasets[dataname] = instantiate_from_config(
self.dataset_configs['validation'])
# Setup test dataset
if 'test' in self.dataset_configs:
self.test_datasets = dict()
for dataname in self.dataset_weights:
data_dir = self.dataset_configs['test']['params']['data_dir']
transition_dir = '/'.join([data_dir, 'transitions'])
csv_file = f'{dataname}.csv'
meta_path = '/'.join([data_dir, csv_file])
self.dataset_configs['test']['params']['meta_path'] = meta_path
self.dataset_configs['test']['params'][
'transition_dir'] = transition_dir
self.dataset_configs['test']['params'][
'dataset_name'] = dataname
self.test_datasets[dataname] = instantiate_from_config(
self.dataset_configs['test'])
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
is_iterable_dataset = False # NOTE Hand Code
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
combined_dataset = []
sample_weights = []
for dataname, dataset in self.train_datasets.items():
combined_dataset.append(dataset)
sample_weights.append(
torch.full((len(dataset), ),
self.dataset_weights[dataname] / len(dataset)))
combined_dataset = ConcatDataset(combined_dataset)
sample_weights = torch.cat(sample_weights)
sampler = WeightedRandomSampler(sample_weights,
num_samples=len(combined_dataset),
replacement=True)
loader = DataLoader(combined_dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
drop_last=True
)
return loader
def _val_dataloader(self, shuffle=False):
is_iterable_dataset = False # NOTE Hand Code
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
combined_dataset = []
sample_weights = []
for dataname, dataset in self.val_datasets.items():
combined_dataset.append(dataset)
sample_weights.append(
torch.full((len(dataset), ),
self.dataset_weights[dataname] / len(dataset)))
combined_dataset = ConcatDataset(combined_dataset)
sample_weights = torch.cat(sample_weights)
sampler = WeightedRandomSampler(sample_weights,
num_samples=len(combined_dataset),
replacement=True)
loader = DataLoader(combined_dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn)
return loader
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = False # NOTE Hand Code
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
combined_dataset = []
sample_weights = []
for dataname, dataset in self.test_datasets.items():
combined_dataset.append(dataset)
sample_weights.append(
torch.full((len(dataset), ),
self.dataset_weights[dataname] / len(dataset)))
combined_dataset = ConcatDataset(combined_dataset)
sample_weights = torch.cat(sample_weights)
sampler = WeightedRandomSampler(sample_weights,
num_samples=len(combined_dataset),
replacement=True)
loader = DataLoader(combined_dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn)
return loader
def _predict_dataloader(self, shuffle=False):
if isinstance(self.datasets['predict'],
Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
)
def __len__(self):
count = 0
for _, values in self.train_datasets.items():
count += len(values)
return count

View File

@@ -0,0 +1,191 @@
import math
import numpy as np
import torch
import torch.nn.functional as F
from einops import repeat
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) *
torch.arange(start=0, end=half, dtype=torch.float32) /
half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
def make_beta_schedule(schedule,
n_timestep,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
if schedule == "linear":
betas = (torch.linspace(linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64)**2)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep +
cosine_s)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start,
linear_end,
n_timestep,
dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start,
linear_end,
n_timestep,
dtype=torch.float64)**0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method,
num_ddim_timesteps,
num_ddpm_timesteps,
verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
steps_out = ddim_timesteps + 1
elif ddim_discr_method == 'uniform_trailing':
c = num_ddpm_timesteps / num_ddim_timesteps
ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0,
-c))).astype(np.int64)
steps_out = ddim_timesteps - 1
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
num_ddim_timesteps))**2).astype(int)
steps_out = ddim_timesteps + 1
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
# steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
def make_ddim_sampling_parameters(alphacums,
ddim_timesteps,
eta,
verbose=True):
# select alphas for computing the variance schedule
# print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}')
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] +
alphacums[ddim_timesteps[:-1]].tolist())
# according the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
)
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`numpy.ndarray`):
the betas that the scheduler is being initialized with.
Returns:
`numpy.ndarray`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_bar_sqrt = np.sqrt(alphas_cumprod)
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 -
alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = np.concatenate([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)),
keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# Rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# Mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (
1 - guidance_rescale) * noise_cfg
return noise_cfg

View File

@@ -0,0 +1,94 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self, noise=None):
if noise is None:
noise = torch.randn(self.mean.shape)
x = self.mean + self.std * noise.to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var +
self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(logtwopi + self.logvar +
torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
((mean1 - mean2)**2) * torch.exp(-logvar2))

View File

@@ -0,0 +1,84 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
#Remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
decay = min(self.decay,
(1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(
m_param[key])
shadow_params[sname].sub_(
one_minus_decay *
(shadow_params[sname] - m_param[key]))
else:
assert not key in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data)
else:
assert not key in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

View File

@@ -0,0 +1,66 @@
"""
nn_utils.py
Utility functions and PyTorch submodule definitions.
"""
import torch
import torch.nn as nn
# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
class LinearProjector(nn.Module):
def __init__(self, vision_dim: int, llm_dim: int) -> None:
super().__init__()
self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
return self.projector(img_patches)
class MLPProjector(nn.Module):
def __init__(self,
vision_dim: int,
llm_dim: int,
mlp_type: str = "gelu-mlp") -> None:
super().__init__()
if mlp_type == "gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(vision_dim, llm_dim, bias=True),
nn.GELU(),
nn.Linear(llm_dim, llm_dim, bias=True),
)
else:
raise ValueError(
f"Projector with `{mlp_type = }` is not supported!")
def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
return self.projector(img_patches)
class FusedMLPProjector(nn.Module):
def __init__(self,
fused_vision_dim: int,
llm_dim: int,
mlp_type: str = "fused-gelu-mlp") -> None:
super().__init__()
self.initial_projection_dim = fused_vision_dim * 4
if mlp_type == "fused-gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(fused_vision_dim,
self.initial_projection_dim,
bias=True),
nn.GELU(),
nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
nn.GELU(),
nn.Linear(llm_dim, llm_dim, bias=True),
)
else:
raise ValueError(
f"Fused Projector with `{mlp_type = }` is not supported!")
def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
return self.projector(fused_img_patches)

View File

@@ -0,0 +1,147 @@
import torch
import torch.nn as nn
class LinearProjector(nn.Module):
def __init__(self, input_dim: int, output_dim: int) -> None:
super().__init__()
self.projector = nn.Linear(input_dim, output_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class MLPProjector(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
mlp_type: str = "gelu-mlp") -> None:
super().__init__()
if mlp_type == "gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(vision_dim, llm_dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(llm_dim, llm_dim, bias=True),
)
elif mlp_type == "silu-mlp":
self.projector = nn.Sequential(
nn.Linear(vision_dim, llm_dim, bias=True),
nn.SiLU(),
nn.Linear(llm_dim, llm_dim, bias=True),
)
else:
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
inner_dim = int(dim * mult)
if ffd_type = "gelu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(approximate='tanh'),
nn.Linear(inner_dim, dim, bias=False),
)
elif ffd_type = "silu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.SiLU(),
nn.Linear(inner_dim, dim, bias=False),
)
else:
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
class TokenProjector(nn.Module):
def __init__(
self,
dim=1024,
depth=1,
dim_head=64,
heads=16,
num_queries=16,
output_dim=1024,
ff_mult=4,
chunck_size=None,
):
super().__init__()
self.num_queries = num_queries
self.chunck_size = chunck_size
if chunck_size is not None:
num_queries = num_queries * chunck_size
self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList(
[
PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]
)
)
def forward(self, x)
latents = self.latents.repeat(x.size(0), 1, 1)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)

View File

@@ -0,0 +1,258 @@
import os
import torch
import numpy as np
import torchvision
from tqdm import tqdm
from PIL import Image
from einops import rearrange
from torch import Tensor
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_tensor
def frames_to_mp4(frame_dir, output_path, fps):
def read_first_n_frames(d: os.PathLike, num_frames: int):
if num_frames:
images = [
Image.open(os.path.join(d, f))
for f in sorted(os.listdir(d))[:num_frames]
]
else:
images = [
Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))
]
images = [to_tensor(x) for x in images]
return torch.stack(images)
videos = read_first_n_frames(frame_dir, num_frames=None)
videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(output_path,
videos,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
"""
video: torch.Tensor, b,c,t,h,w, 0-1
if -1~1, enable rescale=True
"""
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
nrow = int(np.sqrt(n)) if nrow is None else nrow
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=nrow, padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids,
dim=0)
grid = torch.clamp(grid.float(), -1., 1.)
if rescale:
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(
0, 2, 3, 1)
torchvision.io.write_video(savepath,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True):
assert (video.dim() == 5)
assert (isinstance(video, torch.Tensor))
video = video.detach().cpu()
if clamp:
video = torch.clamp(video, -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n)))
for framesheet in video
]
grid = torch.stack(frame_grids,
dim=0)
if rescale:
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(
0, 2, 3, 1)
path = os.path.join(root, filename)
torchvision.io.write_video(path,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
if batch_logs is None:
return None
""" save images and videos from images dict """
def save_img_grid(grid, path, rescale):
if rescale:
grid = (grid + 1.0) / 2.0
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
for key in batch_logs:
value = batch_logs[key]
if isinstance(value, list) and isinstance(value[0], str):
# A batch of captions
path = os.path.join(save_dir, "%s-%s.txt" % (key, filename))
with open(path, 'w') as f:
for i, txt in enumerate(value):
f.write(f'idx={i}, txt={txt}\n')
f.close()
elif isinstance(value, torch.Tensor) and value.dim() == 5:
# Save video grids
video = value
# Only save grayscale or rgb mode
if video.shape[1] != 1 and video.shape[1] != 3:
continue
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(1), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids,
dim=0)
if rescale:
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename))
torchvision.io.write_video(path,
grid,
fps=save_fps,
video_codec='h264',
options={'crf': '10'})
# Save frame sheet
img = value
video_frames = rearrange(img, 'b c t h w -> (b t) c h w')
t = img.shape[2]
grid = torchvision.utils.make_grid(video_frames, nrow=t, padding=0)
path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
# Save_img_grid(grid, path, rescale)
elif isinstance(value, torch.Tensor) and value.dim() == 4:
# Save image grids
img = value
# Only save grayscale or rgb mode
if img.shape[1] != 1 and img.shape[1] != 3:
continue
n = img.shape[0]
grid = torchvision.utils.make_grid(img, nrow=1, padding=0)
path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
save_img_grid(grid, path, rescale)
else:
pass
def prepare_to_log(batch_logs, max_images=100000, clamp=True):
if batch_logs is None:
return None
for key in batch_logs:
N = batch_logs[key].shape[0] if hasattr(
batch_logs[key], 'shape') else len(batch_logs[key])
N = min(N, max_images)
batch_logs[key] = batch_logs[key][:N]
# In batch_logs: images <batched tensor> & instruction <text list>
if isinstance(batch_logs[key], torch.Tensor):
batch_logs[key] = batch_logs[key].detach().cpu()
if clamp:
try:
batch_logs[key] = torch.clamp(batch_logs[key].float(), -1.,
1.)
except RuntimeError:
print("clamp_scalar_cpu not implemented for Half")
return batch_logs
# ----------------------------------------------------------------------------------------------
def fill_with_black_squares(video, desired_len: int) -> Tensor:
if len(video) >= desired_len:
return video
return torch.cat([
video,
torch.zeros_like(video[0]).unsqueeze(0).repeat(
desired_len - len(video), 1, 1, 1),
],
dim=0)
# ----------------------------------------------------------------------------------------------
def load_num_videos(data_path, num_videos):
# First argument can be either data_path of np array
if isinstance(data_path, str):
videos = np.load(data_path)['arr_0'] # NTHWC
elif isinstance(data_path, np.ndarray):
videos = data_path
else:
raise Exception
if num_videos is not None:
videos = videos[:num_videos, :, :, :, :]
return videos
def npz_to_video_grid(data_path,
out_path,
num_frames,
fps,
num_videos=None,
nrow=None,
verbose=True):
if isinstance(data_path, str):
videos = load_num_videos(data_path, num_videos)
elif isinstance(data_path, np.ndarray):
videos = data_path
else:
raise Exception
n, t, h, w, c = videos.shape
videos_th = []
for i in range(n):
video = videos[i, :, :, :, :]
images = [video[j, :, :, :] for j in range(t)]
images = [to_tensor(img) for img in images]
video = torch.stack(images)
videos_th.append(video)
if verbose:
videos = [
fill_with_black_squares(v, num_frames)
for v in tqdm(videos_th, desc='Adding empty frames')
]
else:
videos = [fill_with_black_squares(v, num_frames)
for v in videos_th] # NTCHW
frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4)
if nrow is None:
nrow = int(np.ceil(np.sqrt(n)))
if verbose:
frame_grids = [
make_grid(fs, nrow=nrow)
for fs in tqdm(frame_grids, desc='Making grids')
]
else:
frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
if os.path.dirname(out_path) != "":
os.makedirs(os.path.dirname(out_path), exist_ok=True)
frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(
0, 2, 3, 1)
torchvision.io.write_video(out_path,
frame_grids,
fps=fps,
video_codec='h264',
options={'crf': '10'})

View File

@@ -0,0 +1,231 @@
import os
import logging
mainlogger = logging.getLogger('mainlogger')
import torch
import pandas as pd
from omegaconf import OmegaConf
from collections import OrderedDict
def init_workspace(name, logdir, model_config, lightning_config, rank=0):
workdir = os.path.join(logdir, name)
ckptdir = os.path.join(workdir, "checkpoints")
cfgdir = os.path.join(workdir, "configs")
loginfo = os.path.join(workdir, "loginfo")
# Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower)
os.makedirs(workdir, exist_ok=True)
os.makedirs(ckptdir, exist_ok=True)
os.makedirs(cfgdir, exist_ok=True)
os.makedirs(loginfo, exist_ok=True)
if rank == 0:
if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'),
exist_ok=True)
OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml"))
OmegaConf.save(OmegaConf.create({"lightning": lightning_config}),
os.path.join(cfgdir, "lightning.yaml"))
return workdir, ckptdir, cfgdir, loginfo
def check_config_attribute(config, name):
if name in config:
value = getattr(config, name)
return value
else:
return None
def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger):
default_callbacks_cfg = {
"model_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch}",
"verbose": True,
"save_last": False,
}
},
"batch_logger": {
"target": "unifolm_wma.utils.callbacks.ImageLogger",
"params": {
"save_dir": logdir,
"batch_frequency": 1000,
"max_images": 4,
"clamp": True,
}
},
"learning_rate_logger": {
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
"params": {
"logging_interval": "step",
"log_momentum": False
}
},
"cuda_callback": {
"target": "unifolm_wma.utils.callbacks.CUDACallback",
},
}
# Optional setting for saving checkpoints
monitor_metric = check_config_attribute(config.model.params, "monitor")
if monitor_metric is not None:
mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.")
default_callbacks_cfg["model_checkpoint"]["params"][
"monitor"] = monitor_metric
default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3
default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min"
if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
mainlogger.info(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
)
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {
"target": 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
"dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
"filename": "{epoch}-{step}",
"verbose": True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True
}
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
return callbacks_cfg
def get_trainer_logger(lightning_config, logdir, on_debug):
default_logger_cfgs = {
"tensorboard": {
"target": "pytorch_lightning.loggers.TensorBoardLogger",
"params": {
"save_dir": logdir,
"name": "tensorboard",
}
},
"testtube": {
"target": "pytorch_lightning.loggers.CSVLogger",
"params": {
"name": "testtube",
"save_dir": logdir,
}
},
}
os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True)
default_logger_cfg = default_logger_cfgs["tensorboard"]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
return logger_cfg
def get_trainer_strategy(lightning_config):
default_strategy_dict = {
"target": "pytorch_lightning.strategies.DDPShardedStrategy"
}
if "strategy" in lightning_config:
strategy_cfg = lightning_config.strategy
return strategy_cfg
else:
strategy_cfg = OmegaConf.create()
strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg)
return strategy_cfg
def load_checkpoints(model, model_cfg):
if check_config_attribute(model_cfg, "pretrained_checkpoint"):
pretrained_ckpt = model_cfg.pretrained_checkpoint
assert os.path.exists(
pretrained_ckpt
), "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt
mainlogger.info(">>> Load weights from pretrained checkpoint")
pl_sd = torch.load(pretrained_ckpt, map_location="cpu")
try:
if 'state_dict' in pl_sd.keys():
model.load_state_dict(pl_sd["state_dict"], strict=False)
mainlogger.info(
">>> Loaded weights from pretrained checkpoint: %s" %
pretrained_ckpt)
else:
# deepspeed
new_pl_sd = OrderedDict()
for key in pl_sd['module'].keys():
new_pl_sd[key[16:]] = pl_sd['module'][key]
model.load_state_dict(new_pl_sd, strict=False)
except:
model.load_state_dict(pl_sd)
else:
mainlogger.info(">>> Start training from scratch")
return model
def set_logger(logfile, name='mainlogger'):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logfile, mode='w')
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
fh.setFormatter(
logging.Formatter("%(asctime)s-%(levelname)s: %(message)s"))
ch.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(fh)
logger.addHandler(ch)
return logger
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
def count_trainable_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def get_num_parameters(model):
models = [('World Model', model.model.diffusion_model),
('Action Head', model.model.diffusion_model.action_unet),
('State Head', model.model.diffusion_model.state_unet),
('Total Trainable', model),
('Total', model)]
data = []
for index, (name, model) in enumerate(models):
if name == "Total Trainable":
total_params = count_trainable_parameters(model)
else:
total_params = count_parameters(model)
if total_params < 0.1e9:
total_params_value = round(total_params / 1e6, 2)
unit = 'M'
else:
total_params_value = round(total_params / 1e9, 2)
unit = 'B'
data.append({
'Model Name': name,
'Params': f"{total_params_value} {unit}"
})
df = pd.DataFrame(data)
print(df)

View File

@@ -0,0 +1,81 @@
import importlib
import numpy as np
import cv2
import torch
import torch.distributed as dist
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params."
)
return total_params
def check_istarget(name, para_list):
"""
name: full name of source para
para_list: partial name of target para
"""
istarget = False
for para in para_list:
if para in name:
return True
return istarget
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def load_npz_from_dir(data_dir):
data = [
np.load(os.path.join(data_dir, data_name))['arr_0']
for data_name in os.listdir(data_dir)
]
data = np.concatenate(data, axis=0)
return data
def load_npz_from_paths(data_paths):
data = [np.load(data_path)['arr_0'] for data_path in data_paths]
data = np.concatenate(data, axis=0)
return data
def resize_numpy_image(image,
max_resolution=512 * 512,
resize_short_edge=None):
h, w = image.shape[:2]
if resize_short_edge is not None:
k = resize_short_edge / min(h, w)
else:
k = max_resolution / (h * w)
k = k**0.5
h = int(np.round(h * k / 64)) * 64
w = int(np.round(w * k / 64)) * 64
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
return image
def setup_dist(args):
if dist.is_initialized():
return
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group('nccl', init_method='env://')