init commit

This commit is contained in:
yuchen-x
2025-09-12 21:53:41 +08:00
parent 275a568149
commit d7be60f9fe
105 changed files with 16119 additions and 1 deletions

View File

@@ -0,0 +1,541 @@
import argparse, os, glob
import datetime, time
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import random
from pytorch_lightning import seed_everything
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
"""
Get list of files in `data_dir` with extensions in `postfixes`.
Args:
data_dir (str): Directory path.
postfixes (list[str]): List of file extensions (e.g., ['csv', 'jpg']).
Returns:
list[str]: Sorted list of matched file paths.
"""
patterns = [
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
]
file_list = []
for pattern in patterns:
file_list.extend(glob.glob(pattern))
file_list.sort()
return file_list
def load_model_checkpoint(model: torch.nn.Module,
ckpt: str) -> torch.nn.Module:
"""
Load model weights from checkpoint file.
Args:
model (torch.nn.Module): The model to load weights into.
ckpt (str): Path to the checkpoint file.
Returns:
torch.nn.Module: Model with weights loaded.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
loaded = model.load_state_dict(state_dict, strict=False)
print("Missing keys:")
for k in loaded.missing_keys:
print(f" {k}")
print("Unexpected keys:")
for k in loaded.unexpected_keys:
print(f" {k}")
except:
# Rename the keys for 256x256 model
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=False)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def load_prompts(prompt_file: str) -> list[str]:
"""
Load prompts from a text file, one per line.
Args:
prompt_file (str): Path to the prompt file.
Returns:
list[str]: List of prompt strings.
"""
f = open(prompt_file, 'r')
prompt_list = []
for idx, line in enumerate(f.readlines()):
l = line.strip()
if len(l) != 0:
prompt_list.append(l)
f.close()
return prompt_list
def load_data_prompts(
data_dir: str,
savedir: str,
video_size: tuple[int, int] = (256, 256),
video_frames: int = 16
) -> tuple[list[str], list[torch.Tensor], list[str], list[float], list[float],
list[int]]:
"""
Load image prompts, process them into video format, and retrieve metadata.
Args:
data_dir (str): Directory containing images and CSV file.
savedir (str): Output directory to check if inference was already done.
video_size (tuple[int, int], optional): Target size of video frames.
video_frames (int, optional): Number of frames in each video.
Returns:
tuple: (filenames, video tensors, prompts, fps values, fs values, num_generations)
"""
transform = transforms.Compose([
transforms.Resize(min(video_size)),
transforms.CenterCrop(video_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# Load prompt csv
prompt_file = get_filelist(data_dir, ['csv'])
assert len(prompt_file) > 0, "Error: found NO image prompt file!"
# Load image prompts
file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
data_list = []
filename_list = []
prompt_list = []
fps_list = []
fs_list = []
num_gen_list = []
prompt_csv = pd.read_csv(prompt_file[0])
n_samples = len(file_list)
for idx in range(n_samples):
image = Image.open(file_list[idx]).convert('RGB')
image_tensor = transform(image).unsqueeze(1)
frame_tensor = repeat(image_tensor,
'c t h w -> c (repeat t) h w',
repeat=video_frames)
_, filename = os.path.split(file_list[idx])
if not is_inferenced(savedir, filename):
video_id = filename[:-4]
prompt_csv['videoid'] = prompt_csv['videoid'].map(str)
if not (prompt_csv['videoid'] == video_id).any():
continue
data_list.append(frame_tensor)
filename_list.append(filename)
ins = prompt_csv[prompt_csv['videoid'] ==
video_id]['instruction'].values[0]
prompt_list.append(ins)
fps = prompt_csv[prompt_csv['videoid'] ==
video_id]['fps'].values[0]
fps_list.append(fps)
fs = prompt_csv[prompt_csv['videoid'] == video_id]['fs'].values[0]
fs_list.append(fs)
num_gen = prompt_csv[prompt_csv['videoid'] ==
video_id]['num_gen'].values[0]
num_gen_list.append(int(num_gen))
return filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list
def is_inferenced(save_dir: str, filename: str) -> bool:
"""
Check if a result video already exists.
Args:
save_dir (str): Directory where results are saved.
filename (str): Base filename to check.
Returns:
bool: True if file exists, else False.
"""
video_file = os.path.join(save_dir, f"{filename[:-4]}.mp4")
return os.path.exists(video_file)
def save_results_seperate(prompt: str | list[str],
samples: torch.Tensor,
filename: str,
fakedir: str,
fps: int = 8) -> None:
"""
Save generated video samples as .mp4 files.
Args:
prompt (str | list[str]): The prompt text.
samples (torch.Tensor): Generated video tensor of shape [B, C, T, H, W].
filename (str): Output filename.
fakedir (str): Directory to save output videos.
fps (int, optional): Frames per second.
Returns:
None
"""
prompt = prompt[0] if isinstance(prompt, list) else prompt
# Save video
videos = [samples]
savedirs = [fakedir]
for idx, video in enumerate(videos):
if video is None:
continue
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
for i in range(n):
grid = video[i, ...]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0)
path = os.path.join(savedirs[idx], f'{filename.split(".")[0]}.mp4')
torchvision.io.write_video(path,
grid,
fps=fps,
video_codec='h264',
options={'crf': '0'})
def get_latent_z(model: torch.nn.Module, videos: torch.Tensor) -> torch.Tensor:
"""
Encode videos to latent space.
Args:
model (torch.nn.Module): Model with encode_first_stage function.
videos (torch.Tensor): Video tensor of shape [B, C, T, H, W].
Returns:
torch.Tensor: Latent representation of shape [B, C, T, H, W].
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def image_guided_synthesis(model: torch.nn.Module,
prompts: list[str],
videos: torch.Tensor,
noise_shape: list[int],
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = False,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
**kwargs) -> torch.Tensor:
"""
Run DDIM-based image-to-video synthesis with hybrid/text+image guidance.
Args:
model (torch.nn.Module): Diffusion model.
prompts (list[str]): Text prompts.
videos (torch.Tensor): Input images/videos of shape [B, C, T, H, W].
noise_shape (list[int]): Latent noise shape [B, C, T, H, W].
ddim_steps (int, optional): Number of DDIM steps.
ddim_eta (float, optional): Eta value for DDIM.
unconditional_guidance_scale (float, optional): Guidance scale.
fs (int | None, optional): FPS input for sampler.
text_input (bool, optional): If True, use text guidance.
timestep_spacing (str, optional): Timestep schedule spacing.
guidance_rescale (float, optional): Rescale guidance effect.
**kwargs: Additional sampler args.
Returns:
torch.Tensor: Synthesized videos of shape [B, 1, C, T, H, W].
"""
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
if not text_input:
prompts = [""] * batch_size
b, c, t, h, w = videos.shape
img = videos[:, :, 0]
img_emb = model.embedder(img)
img_emb = model.image_proj_model(img_emb)
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
cond_emb = model.get_learned_conditioning(prompts)
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, videos)
img_cat_cond = z[:, :, :1, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=z.shape[2])
cond["c_concat"] = [img_cat_cond]
uc = None
cond_mask = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
batch_variants = []
if ddim_sampler is not None:
samples, _, _, _ = ddim_sampler.sample(
S=ddim_steps,
batch_size=batch_size,
shape=noise_shape[1:],
conditioning=cond,
eta=ddim_eta,
mask=cond_mask,
x0=None,
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants.append(batch_images)
batch_variants = torch.stack(batch_variants)
return batch_variants.permute(1, 0, 2, 3, 4, 5)
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Load config
config = OmegaConf.load(args.config)
# Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model = model.cuda(gpu_no)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
# Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
fakedir = os.path.join(args.savedir, "samples")
os.makedirs(fakedir, exist_ok=True)
# Prompt file setting
assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list = load_data_prompts(
args.prompt_dir,
args.savedir,
video_size=(args.height, args.width),
video_frames=n_frames)
num_samples = len(prompt_list)
samples_split = num_samples // gpu_num
print('>>> Prompts testing [rank:%d] %d/%d samples loaded.' %
(gpu_no, samples_split, num_samples))
indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1)))
fps_list_rank = [fps_list[i] for i in indices]
fs_list_rank = [fs_list[i] for i in indices]
prompt_list_rank = [prompt_list[i] for i in indices]
data_list_rank = [data_list[i] for i in indices]
filename_list_rank = [filename_list[i] for i in indices]
with torch.no_grad(), torch.cuda.amp.autocast():
# Create a new result csv
for idx, indice in enumerate(
tqdm(range(0, len(prompt_list_rank), args.bs),
desc=f'Sample batch')):
fps = fps_list_rank[indice:indice + args.bs]
fs = fs_list_rank[indice:indice + args.bs]
prompts = prompt_list_rank[indice:indice + args.bs]
num_gen = num_gen_list[indice:indice + args.bs]
videos = data_list_rank[indice:indice + args.bs]
filenames = filename_list_rank[indice:indice + args.bs]
if isinstance(videos, list):
videos = torch.stack(videos, dim=0).to("cuda")
else:
videos = videos.unsqueeze(0).to("cuda")
results = []
print(
f">>> {prompts[0]}, frame_stride:{fs[0]}, and {num_gen[0]} generation ..."
)
for _ in range(num_gen[0]):
batch_samples = image_guided_synthesis(
model, prompts, videos, noise_shape, args.ddim_steps,
args.ddim_eta, args.unconditional_guidance_scale,
fps[0] // fs[0], args.text_input, args.timestep_spacing,
args.guidance_rescale)
results.extend(batch_samples)
videos = repeat(batch_samples[0][:, :, -1, :, :].unsqueeze(2),
'b c t h w -> b c (repeat t) h w',
repeat=batch_samples[0].shape[2])
batch_samples = [torch.concat(results, axis=2)]
# Save each example individually
for nn, samples in enumerate(batch_samples):
prompt = prompts[nn]
filename = filenames[nn]
save_results_seperate(prompt,
samples,
filename,
fakedir,
fps=8)
def get_parser() -> argparse.ArgumentParser:
"""
Create and return the argument parser.
Returns:
argparse.ArgumentParser: Parser for command-line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config",
type=str,
help="Path to the YAML configuration file.")
parser.add_argument(
"--prompt_dir",
type=str,
default=None,
help="Directory containing videos and corresponding prompts.")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument(
"--text_input",
action='store_true',
default=False,
help=
"Whether to provide a text prompt as input to the image-to-video model."
)
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
if seed < 0:
seed = random.randint(0, 2**31)
seed_everything(seed)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)

View File

@@ -0,0 +1,812 @@
import argparse, os, glob
import pandas as pd
import random
import torch
import torchvision
import h5py
import numpy as np
import logging
import einops
import warnings
import imageio
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
from eval_utils import populate_queues, log_to_tensorboard
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Args:
module (nn.Module): The model whose device is to be inferred.
Returns:
torch.device: The device of the model's parameters.
"""
return next(iter(module.parameters())).device
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file.
Args:
video_path (str): Output path for the video.
stacked_frames (list): List of image frames.
fps (int): Frames per second for the video.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
"""Return sorted list of files in a directory matching specified postfixes.
Args:
data_dir (str): Directory path to search in.
postfixes (list[str]): List of file extensions to match.
Returns:
list[str]: Sorted list of file paths.
"""
patterns = [
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
]
file_list = []
for pattern in patterns:
file_list.extend(glob.glob(pattern))
file_list.sort()
return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=True)
except:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=True)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def is_inferenced(save_dir: str, filename: str) -> bool:
"""Check if a given filename has already been processed and saved.
Args:
save_dir (str): Directory where results are saved.
filename (str): Name of the file to check.
Returns:
bool: True if processed file exists, False otherwise.
"""
video_file = os.path.join(save_dir, "samples_separate",
f"{filename[:-4]}_sample0.mp4")
return os.path.exists(video_file)
def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
"""Save video tensor to file using torchvision.
Args:
video (Tensor): Tensor of shape (B, C, T, H, W).
filename (str): Output file path.
fps (int, optional): Frames per second. Defaults to 8.
"""
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def get_init_frame_path(data_dir: str, sample: dict) -> str:
"""Construct the init_frame path from directory and sample metadata.
Args:
data_dir (str): Base directory containing videos.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the video file.
"""
rel_video_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.png')
full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
return full_image_fp
def get_transition_path(data_dir: str, sample: dict) -> str:
"""Construct the full transition file path from directory and sample metadata.
Args:
data_dir (str): Base directory containing transition files.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the HDF5 transition file.
"""
rel_transition_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.h5')
full_transition_fp = os.path.join(data_dir, 'transitions',
rel_transition_fp)
return full_transition_fp
def prepare_init_input(start_idx: int,
init_frame_path: str,
transition_dict: dict[str, torch.Tensor],
frame_stride: int,
wma_data,
video_length: int = 16,
n_obs_steps: int = 2) -> dict[str, Tensor]:
"""
Extracts a structured sample from a video sequence including frames, states, and actions,
along with properly padded observations and pre-processed tensors for model input.
Args:
start_idx (int): Starting frame index for the current clip.
video: decord video instance.
transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
'observation.state', 'action_type', 'state_type'.
frame_stride (int): Temporal stride between sampled frames.
wma_data: Object that holds configuration and utility functions like normalization,
transformation, and resolution info.
video_length (int, optional): Number of frames to sample from the video. Default is 16.
n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
"""
indices = [start_idx + frame_stride * i for i in range(video_length)]
init_frame = Image.open(init_frame_path).convert('RGB')
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
3, 0, 1, 2).float()
if start_idx < n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = n_obs_steps - 1 - start_idx
first_slice = states[0:1, :] # (t, d)
padding = first_slice.repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
actions = transition_dict['action'][indices, :]
ori_state_dim = states.shape[-1]
ori_action_dim = actions.shape[-1]
frames_action_state_dict = {
'action': actions,
'observation.state': states,
}
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
frames_action_state_dict = wma_data.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
if wma_data.spatial_transform is not None:
init_frame = wma_data.spatial_transform(init_frame)
init_frame = (init_frame / 255 - 0.5) * 2
data = {
'observation.image': init_frame,
}
data.update(frames_action_state_dict)
return data, ori_state_dim, ori_action_dim
def get_latent_z(model, videos: Tensor) -> Tensor:
"""
Extracts latent features from a video batch using the model's first-stage encoder.
Args:
model: the world model.
videos (Tensor): Input videos of shape [B, C, T, H, W].
Returns:
Tensor: Latent video tensor of shape [B, C, T, H, W].
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# Map to expected inputs for the policy
return_observations = {}
if isinstance(observations["pixels"], dict):
imgs = {
f"observation.images.{key}": img
for key, img in observations["pixels"].items()
}
else:
imgs = {"observation.images.top": observations["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# Sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# Sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# Convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
return_observations[imgkey] = img
return_observations["observation.state"] = torch.from_numpy(
observations["agent_pos"]).float()
return_observations['observation.state'] = model.normalize_inputs({
'observation.state':
return_observations['observation.state'].to(model.device)
})['observation.state']
return return_observations
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16,
n_samples: int = 1,
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
Args:
model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
prompts (list[str]): A list of textual prompts to guide the synthesis process.
observation (dict): A dictionary containing observed inputs including:
- 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
- 'observation.state': Tensor of shape [B, O, D] (state vector)
- 'action': Tensor of shape [B, T, D] (action sequence)
noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
typically (B, C, T, H, W).
action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
ddim_steps (int): Number of DDIM sampling steps. Default is 50.
ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
if not text_input:
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
cond["c_crossattn"] = [
torch.cat(
[cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :,
-model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:],
sim_mode,
False,
]
uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Create inference and tensorboard dirs
os.makedirs(args.savedir + '/inference', exist_ok=True)
log_dir = args.savedir + f"/tensorboard"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
# Load config
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
print(f'>>> Load pre-trained model ...')
# Build unnomalizer
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
# Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
# Start inference
for idx in range(0, len(df)):
sample = df.iloc[idx]
# Got initial frame path
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
os.makedirs(video_save_dir, exist_ok=True)
os.makedirs(video_save_dir + '/dm', exist_ok=True)
os.makedirs(video_save_dir + '/wm', exist_ok=True)
# Load transitions to get the initial state later
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# If many, test various frequence control and world-model generation
for fs in args.frame_stride:
# For saving imagens in policy
sample_save_dir = f'{video_save_dir}/dm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For saving environmental changes in world-model
sample_save_dir = f'{video_save_dir}/wm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For collecting interaction videos
wm_video = []
# Initialize observation queues
cond_obs_queues = {
"observation.images.top":
deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
# Obtain initial frame and state
start_idx = 0
model_input_fs = ori_fps // fs
batch, ori_state_dim, ori_action_dim = prepare_init_input(
start_idx,
init_frame_path,
transition_dict,
fs,
data.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2,
3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Multi-round interaction with the world-model
for itr in tqdm(range(args.n_iter)):
# Get observation
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False)
# Update future actions in the observation queues
for idx in range(len(pred_actions[0])):
observation = {'action': pred_actions[0][idx:idx + 1]}
observation['action'][:, ori_action_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
observation)
# Collect data for interacting the world-model using the predicted actions
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale)
for idx in range(args.exe_steps):
observation = {
'observation.images.top':
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state':
torch.zeros_like(pred_states[0][idx:idx + 1]) if
args.zero_pred_state else pred_states[0][idx:idx + 1],
'action':
torch.zeros_like(pred_actions[0][-1:])
}
observation['observation.state'][:, ori_state_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
observation)
# Save the imagen videos for decision-making
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
# Save the imagen videos for decision-making
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(),
sample_video_file,
fps=args.save_fps)
print('>' * 24)
# Collect the result of world-model interactions
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps)
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config",
type=str,
help="Path to the model checkpoint.")
parser.add_argument(
"--prompt_dir",
type=str,
default=None,
help="Directory containing videos and corresponding prompts.")
parser.add_argument("--dataset",
type=str,
default=None,
help="the name of dataset to test")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--frame_stride",
type=int,
nargs='+',
required=True,
help=
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
)
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument("--num_generation",
type=int,
default=1,
help="seed for seed_everything")
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
parser.add_argument(
"--n_action_steps",
type=int,
default=16,
help="num of samples per prompt",
)
parser.add_argument(
"--exe_steps",
type=int,
default=16,
help="num of samples to execute",
)
parser.add_argument(
"--n_iter",
type=int,
default=40,
help="num of iteration to interact with the world model",
)
parser.add_argument("--zero_pred_state",
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument("--save_fps",
type=int,
default=8,
help="fps for the saving video")
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
if seed < 0:
seed = random.randint(0, 2**31)
seed_everything(seed)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)

View File

@@ -0,0 +1,23 @@
#!/bin/bash
model_name=base_model
ckpt=/path/to/base/model
config=configs/inference/base_model_inference.yaml
res_dir="/path/to/result/directory"
seed=123
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/base_model_inference.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir "${res_dir}/videos" \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 1.0 \
--ddim_steps 16 \
--ddim_eta 1.0 \
--prompt_dir "/path/to/examples/base_model_prompts" \
--text_input \
--video_length 16 \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae

View File

@@ -0,0 +1,42 @@
model_name=testing
ckpt=/path/to/model/checkpoint
config=configs/inference/world_model_interaction.yaml
seed=123
res_dir="/path/to/result/directory"
datasets=(
"unitree_z1_stackbox"
"unitree_z1_dual_arm_stackbox"
"unitree_z1_dual_arm_stackbox_v2"
"unitree_z1_dual_arm_cleanup_pencils"
"unitree_g1_pack_camera"
)
n_iters=(12 7 11 8 11)
fses=(4 4 4 4 6)
for i in "${!datasets[@]}"; do
dataset=${datasets[$i]}
n_iter=${n_iters[$i]}
fs=${fses[$i]}
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir "${res_dir}/${model_name}/${dataset}" \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 1.0 \
--ddim_steps 50 \
--ddim_eta 1.0 \
--prompt_dir "/path/to/unifolm-world-model-action/examples/world_model_interaction_prompts" \
--dataset ${dataset} \
--video_length 16 \
--frame_stride ${fs} \
--n_action_steps 16 \
--exe_steps 16 \
--n_iter ${n_iter} \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae
done

32
scripts/train.sh Normal file
View File

@@ -0,0 +1,32 @@
# NCCL configuration
# export NCCL_DEBUG=debug
# export NCCL_IB_DISABLE=0
# export NCCL_IB_GID_INDEX=3
# export NCCL_NET_GDR_LEVEL=3
# export CUDA_LAUNCH_BLOCKING=1
# export NCCL_TOPO_FILE=/tmp/topo.txt
# export MASTER_ADDR="master.ip."
# export MASTER_PROT=12366
# args
name="experiment_name"
config_file=configs/train/config.yaml
# save root dir for logs, checkpoints, tensorboard record, etc.
save_root="/path/to/savedir"
mkdir -p $save_root/$name
## run
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=12366 --node_rank=0 \
./scripts/trainer.py \
--base $config_file \
--train \
--name $name \
--logdir $save_root \
--devices 8 \
--total_gpus=8 \
lightning.trainer.num_nodes=1

214
scripts/trainer.py Normal file
View File

@@ -0,0 +1,214 @@
import argparse, os, datetime
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from transformers import logging as transf_logging
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
def get_parser(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument("--seed",
"-s",
type=int,
default=20250912,
help="seed for seed_everything")
parser.add_argument("--name",
"-n",
type=str,
default="",
help="experiment name, as saving folder")
parser.add_argument(
"--base",
"-b",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right.",
default=list())
parser.add_argument("--train",
"-t",
action='store_true',
default=False,
help='train')
parser.add_argument("--val",
"-v",
action='store_true',
default=False,
help='val')
parser.add_argument("--test",
action='store_true',
default=False,
help='test')
parser.add_argument("--logdir",
"-l",
type=str,
default="logs",
help="directory for logging dat shit")
parser.add_argument("--auto_resume",
action='store_true',
default=False,
help="resume from full-info checkpoint")
parser.add_argument("--auto_resume_weight_only",
action='store_true',
default=False,
help="resume from weight-only checkpoint")
parser.add_argument("--debug",
"-d",
action='store_true',
default=False,
help="enable post-mortem debugging")
return parser
def get_nondefault_trainer_args(args):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
default_trainer_args = parser.parse_args([])
return sorted(k for k in vars(default_trainer_args)
if getattr(args, k) != getattr(default_trainer_args, k))
if __name__ == "__main__":
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
local_rank = int(os.environ.get('LOCAL_RANK'))
global_rank = int(os.environ.get('RANK'))
num_rank = int(os.environ.get('WORLD_SIZE'))
parser = get_parser()
# Extends existing argparse by default Trainer attributes
parser = Trainer.add_argparse_args(parser)
args, unknown = parser.parse_known_args()
transf_logging.set_verbosity_error()
seed_everything(args.seed)
configs = [OmegaConf.load(cfg) for cfg in args.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# Setup workspace directories
workdir, ckptdir, cfgdir, loginfo = init_workspace(args.name, args.logdir,
config,
lightning_config,
global_rank)
logger = set_logger(
logfile=os.path.join(loginfo, 'log_%d:%s.txt' % (global_rank, now)))
logger.info("@lightning version: %s [>=1.8 required]" % (pl.__version__))
logger.info("***** Configing Model *****")
config.model.params.logdir = workdir
model = instantiate_from_config(config.model)
# Load checkpoints
model = load_checkpoints(model, config.model)
# Register_schedule again to make ZTSNR work
if model.rescale_betas_zero_snr:
model.register_schedule(given_betas=model.given_betas,
beta_schedule=model.beta_schedule,
timesteps=model.timesteps,
linear_start=model.linear_start,
linear_end=model.linear_end,
cosine_s=model.cosine_s)
# Update trainer config
for k in get_nondefault_trainer_args(args):
trainer_config[k] = getattr(args, k)
num_nodes = trainer_config.num_nodes
ngpu_per_node = trainer_config.devices
logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs")
# Setup learning rate
base_lr = config.model.base_learning_rate
bs = config.data.params.batch_size
if getattr(config.model, 'scale_lr', True):
model.learning_rate = num_rank * bs * base_lr
else:
model.learning_rate = base_lr
logger.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
for k in data.train_datasets:
logger.info(
f"{k}, {data.train_datasets[k].__class__.__name__}, {len(data.train_datasets[k])}"
)
if hasattr(data, 'val_datasets'):
for k in data.val_datasets:
logger.info(
f"{k}, {data.val_datasets[k].__class__.__name__}, {len(data.val_datasets[k])}"
)
for item in unknown:
if item.startswith('--total_gpus'):
num_gpus = int(item.split('=')[-1])
break
model.datasets_len = len(data)
logger.info("***** Configing Trainer *****")
if "accelerator" not in trainer_config:
trainer_config["accelerator"] = "gpu"
# Setup trainer args: pl-logger and callbacks
trainer_kwargs = dict()
trainer_kwargs["num_sanity_val_steps"] = 0
logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# Setup callbacks
callbacks_cfg = get_trainer_callbacks(lightning_config, config, workdir,
ckptdir, logger)
trainer_kwargs["callbacks"] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
strategy_cfg = get_trainer_strategy(lightning_config)
trainer_kwargs["strategy"] = strategy_cfg if type(
strategy_cfg) == str else instantiate_from_config(strategy_cfg)
trainer_kwargs['precision'] = lightning_config.get('precision', 32)
trainer_kwargs["sync_batchnorm"] = False
# Trainer config: others
trainer_args = argparse.Namespace(**trainer_config)
trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs)
# Allow checkpointing via USR1
def melk(*args, **kwargs):
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
signal.signal(signal.SIGUSR2, divein)
# List the key model sizes
total_params = get_num_parameters(model)
logger.info("***** Running the Loop *****")
if args.train:
try:
if "strategy" in lightning_config and lightning_config[
'strategy'].startswith('deepspeed'):
logger.info("<Training in DeepSpeed Mode>")
if trainer_kwargs['precision'] == 16:
with torch.cuda.amp.autocast():
trainer.fit(model, data)
else:
trainer.fit(model, data)
else:
logger.info("<Training in DDPSharded Mode>")
trainer.fit(model, data)
except Exception:
raise