第一次完整测例跑完
This commit is contained in:
541
scripts/evaluation/base_model_inference.py
Normal file
541
scripts/evaluation/base_model_inference.py
Normal file
@@ -0,0 +1,541 @@
|
||||
import argparse, os, glob
|
||||
import datetime, time
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import random
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from PIL import Image
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||
"""
|
||||
Get list of files in `data_dir` with extensions in `postfixes`.
|
||||
|
||||
Args:
|
||||
data_dir (str): Directory path.
|
||||
postfixes (list[str]): List of file extensions (e.g., ['csv', 'jpg']).
|
||||
|
||||
Returns:
|
||||
list[str]: Sorted list of matched file paths.
|
||||
"""
|
||||
patterns = [
|
||||
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
|
||||
]
|
||||
file_list = []
|
||||
for pattern in patterns:
|
||||
file_list.extend(glob.glob(pattern))
|
||||
file_list.sort()
|
||||
return file_list
|
||||
|
||||
|
||||
def load_model_checkpoint(model: torch.nn.Module,
|
||||
ckpt: str) -> torch.nn.Module:
|
||||
"""
|
||||
Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model to load weights into.
|
||||
ckpt (str): Path to the checkpoint file.
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: Model with weights loaded.
|
||||
"""
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
if "state_dict" in list(state_dict.keys()):
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
loaded = model.load_state_dict(state_dict, strict=False)
|
||||
print("Missing keys:")
|
||||
for k in loaded.missing_keys:
|
||||
print(f" {k}")
|
||||
print("Unexpected keys:")
|
||||
for k in loaded.unexpected_keys:
|
||||
print(f" {k}")
|
||||
|
||||
except:
|
||||
# Rename the keys for 256x256 model
|
||||
new_pl_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_pl_sd[k] = v
|
||||
|
||||
for k in list(new_pl_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_key = k.replace("framestride_embed", "fps_embedding")
|
||||
new_pl_sd[new_key] = new_pl_sd[k]
|
||||
del new_pl_sd[k]
|
||||
model.load_state_dict(new_pl_sd, strict=False)
|
||||
else:
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in state_dict['module'].keys():
|
||||
new_pl_sd[key[16:]] = state_dict['module'][key]
|
||||
model.load_state_dict(new_pl_sd)
|
||||
print('>>> model checkpoint loaded.')
|
||||
return model
|
||||
|
||||
|
||||
def load_prompts(prompt_file: str) -> list[str]:
|
||||
"""
|
||||
Load prompts from a text file, one per line.
|
||||
|
||||
Args:
|
||||
prompt_file (str): Path to the prompt file.
|
||||
|
||||
Returns:
|
||||
list[str]: List of prompt strings.
|
||||
"""
|
||||
f = open(prompt_file, 'r')
|
||||
prompt_list = []
|
||||
for idx, line in enumerate(f.readlines()):
|
||||
l = line.strip()
|
||||
if len(l) != 0:
|
||||
prompt_list.append(l)
|
||||
f.close()
|
||||
return prompt_list
|
||||
|
||||
|
||||
def load_data_prompts(
|
||||
data_dir: str,
|
||||
savedir: str,
|
||||
video_size: tuple[int, int] = (256, 256),
|
||||
video_frames: int = 16
|
||||
) -> tuple[list[str], list[torch.Tensor], list[str], list[float], list[float],
|
||||
list[int]]:
|
||||
"""
|
||||
Load image prompts, process them into video format, and retrieve metadata.
|
||||
|
||||
Args:
|
||||
data_dir (str): Directory containing images and CSV file.
|
||||
savedir (str): Output directory to check if inference was already done.
|
||||
video_size (tuple[int, int], optional): Target size of video frames.
|
||||
video_frames (int, optional): Number of frames in each video.
|
||||
|
||||
Returns:
|
||||
tuple: (filenames, video tensors, prompts, fps values, fs values, num_generations)
|
||||
"""
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize(min(video_size)),
|
||||
transforms.CenterCrop(video_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
])
|
||||
|
||||
# Load prompt csv
|
||||
prompt_file = get_filelist(data_dir, ['csv'])
|
||||
assert len(prompt_file) > 0, "Error: found NO image prompt file!"
|
||||
|
||||
# Load image prompts
|
||||
file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
|
||||
data_list = []
|
||||
filename_list = []
|
||||
prompt_list = []
|
||||
fps_list = []
|
||||
fs_list = []
|
||||
num_gen_list = []
|
||||
prompt_csv = pd.read_csv(prompt_file[0])
|
||||
n_samples = len(file_list)
|
||||
|
||||
for idx in range(n_samples):
|
||||
image = Image.open(file_list[idx]).convert('RGB')
|
||||
image_tensor = transform(image).unsqueeze(1)
|
||||
frame_tensor = repeat(image_tensor,
|
||||
'c t h w -> c (repeat t) h w',
|
||||
repeat=video_frames)
|
||||
_, filename = os.path.split(file_list[idx])
|
||||
|
||||
if not is_inferenced(savedir, filename):
|
||||
video_id = filename[:-4]
|
||||
prompt_csv['videoid'] = prompt_csv['videoid'].map(str)
|
||||
if not (prompt_csv['videoid'] == video_id).any():
|
||||
continue
|
||||
data_list.append(frame_tensor)
|
||||
filename_list.append(filename)
|
||||
ins = prompt_csv[prompt_csv['videoid'] ==
|
||||
video_id]['instruction'].values[0]
|
||||
prompt_list.append(ins)
|
||||
fps = prompt_csv[prompt_csv['videoid'] ==
|
||||
video_id]['fps'].values[0]
|
||||
fps_list.append(fps)
|
||||
fs = prompt_csv[prompt_csv['videoid'] == video_id]['fs'].values[0]
|
||||
fs_list.append(fs)
|
||||
num_gen = prompt_csv[prompt_csv['videoid'] ==
|
||||
video_id]['num_gen'].values[0]
|
||||
num_gen_list.append(int(num_gen))
|
||||
|
||||
return filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list
|
||||
|
||||
|
||||
def is_inferenced(save_dir: str, filename: str) -> bool:
|
||||
"""
|
||||
Check if a result video already exists.
|
||||
|
||||
Args:
|
||||
save_dir (str): Directory where results are saved.
|
||||
filename (str): Base filename to check.
|
||||
|
||||
Returns:
|
||||
bool: True if file exists, else False.
|
||||
"""
|
||||
video_file = os.path.join(save_dir, f"{filename[:-4]}.mp4")
|
||||
return os.path.exists(video_file)
|
||||
|
||||
|
||||
def save_results_seperate(prompt: str | list[str],
|
||||
samples: torch.Tensor,
|
||||
filename: str,
|
||||
fakedir: str,
|
||||
fps: int = 8) -> None:
|
||||
"""
|
||||
Save generated video samples as .mp4 files.
|
||||
|
||||
Args:
|
||||
prompt (str | list[str]): The prompt text.
|
||||
samples (torch.Tensor): Generated video tensor of shape [B, C, T, H, W].
|
||||
filename (str): Output filename.
|
||||
fakedir (str): Directory to save output videos.
|
||||
fps (int, optional): Frames per second.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
|
||||
# Save video
|
||||
videos = [samples]
|
||||
savedirs = [fakedir]
|
||||
for idx, video in enumerate(videos):
|
||||
if video is None:
|
||||
continue
|
||||
video = video.detach().cpu()
|
||||
video = torch.clamp(video.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
for i in range(n):
|
||||
grid = video[i, ...]
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0)
|
||||
path = os.path.join(savedirs[idx], f'{filename.split(".")[0]}.mp4')
|
||||
torchvision.io.write_video(path,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '0'})
|
||||
|
||||
|
||||
def get_latent_z(model: torch.nn.Module, videos: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Encode videos to latent space.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Model with encode_first_stage function.
|
||||
videos (torch.Tensor): Video tensor of shape [B, C, T, H, W].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Latent representation of shape [B, C, T, H, W].
|
||||
"""
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
|
||||
|
||||
def image_guided_synthesis(model: torch.nn.Module,
|
||||
prompts: list[str],
|
||||
videos: torch.Tensor,
|
||||
noise_shape: list[int],
|
||||
ddim_steps: int = 50,
|
||||
ddim_eta: float = 1.0,
|
||||
unconditional_guidance_scale: float = 1.0,
|
||||
fs: int | None = None,
|
||||
text_input: bool = False,
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
**kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Run DDIM-based image-to-video synthesis with hybrid/text+image guidance.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): Diffusion model.
|
||||
prompts (list[str]): Text prompts.
|
||||
videos (torch.Tensor): Input images/videos of shape [B, C, T, H, W].
|
||||
noise_shape (list[int]): Latent noise shape [B, C, T, H, W].
|
||||
ddim_steps (int, optional): Number of DDIM steps.
|
||||
ddim_eta (float, optional): Eta value for DDIM.
|
||||
unconditional_guidance_scale (float, optional): Guidance scale.
|
||||
fs (int | None, optional): FPS input for sampler.
|
||||
text_input (bool, optional): If True, use text guidance.
|
||||
timestep_spacing (str, optional): Timestep schedule spacing.
|
||||
guidance_rescale (float, optional): Rescale guidance effect.
|
||||
**kwargs: Additional sampler args.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Synthesized videos of shape [B, 1, C, T, H, W].
|
||||
"""
|
||||
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
batch_size = noise_shape[0]
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
if not text_input:
|
||||
prompts = [""] * batch_size
|
||||
|
||||
b, c, t, h, w = videos.shape
|
||||
img = videos[:, :, 0]
|
||||
img_emb = model.embedder(img)
|
||||
img_emb = model.image_proj_model(img_emb)
|
||||
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
|
||||
cond_emb = model.get_learned_conditioning(prompts)
|
||||
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, videos)
|
||||
img_cat_cond = z[:, :, :1, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=z.shape[2])
|
||||
cond["c_concat"] = [img_cat_cond]
|
||||
|
||||
uc = None
|
||||
cond_mask = None
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
|
||||
batch_variants = []
|
||||
if ddim_sampler is not None:
|
||||
samples, _, _, _ = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
conditioning=cond,
|
||||
eta=ddim_eta,
|
||||
mask=cond_mask,
|
||||
x0=None,
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants.append(batch_images)
|
||||
|
||||
batch_variants = torch.stack(batch_variants)
|
||||
return batch_variants.permute(1, 0, 2, 3, 4, 5)
|
||||
|
||||
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
"""
|
||||
Run inference pipeline on prompts and image inputs.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Parsed command-line arguments.
|
||||
gpu_num (int): Number of GPUs.
|
||||
gpu_no (int): Index of the current GPU.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Load config
|
||||
config = OmegaConf.load(args.config)
|
||||
# Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model = model.cuda(gpu_no)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model.eval()
|
||||
|
||||
# Run over data
|
||||
assert (args.height % 16 == 0) and (
|
||||
args.width % 16
|
||||
== 0), "Error: image size [h,w] should be multiples of 16!"
|
||||
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
||||
|
||||
# Get latent noise shape
|
||||
h, w = args.height // 8, args.width // 8
|
||||
channels = model.model.diffusion_model.out_channels
|
||||
n_frames = args.video_length
|
||||
print(f'>>> Generate {n_frames} frames under each generation ...')
|
||||
noise_shape = [args.bs, channels, n_frames, h, w]
|
||||
|
||||
fakedir = os.path.join(args.savedir, "samples")
|
||||
os.makedirs(fakedir, exist_ok=True)
|
||||
|
||||
# Prompt file setting
|
||||
assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
|
||||
filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list = load_data_prompts(
|
||||
args.prompt_dir,
|
||||
args.savedir,
|
||||
video_size=(args.height, args.width),
|
||||
video_frames=n_frames)
|
||||
|
||||
num_samples = len(prompt_list)
|
||||
samples_split = num_samples // gpu_num
|
||||
print('>>> Prompts testing [rank:%d] %d/%d samples loaded.' %
|
||||
(gpu_no, samples_split, num_samples))
|
||||
|
||||
indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1)))
|
||||
fps_list_rank = [fps_list[i] for i in indices]
|
||||
fs_list_rank = [fs_list[i] for i in indices]
|
||||
prompt_list_rank = [prompt_list[i] for i in indices]
|
||||
data_list_rank = [data_list[i] for i in indices]
|
||||
filename_list_rank = [filename_list[i] for i in indices]
|
||||
|
||||
with torch.no_grad(), torch.cuda.amp.autocast():
|
||||
# Create a new result csv
|
||||
for idx, indice in enumerate(
|
||||
tqdm(range(0, len(prompt_list_rank), args.bs),
|
||||
desc=f'Sample batch')):
|
||||
fps = fps_list_rank[indice:indice + args.bs]
|
||||
fs = fs_list_rank[indice:indice + args.bs]
|
||||
prompts = prompt_list_rank[indice:indice + args.bs]
|
||||
num_gen = num_gen_list[indice:indice + args.bs]
|
||||
videos = data_list_rank[indice:indice + args.bs]
|
||||
filenames = filename_list_rank[indice:indice + args.bs]
|
||||
if isinstance(videos, list):
|
||||
videos = torch.stack(videos, dim=0).to("cuda")
|
||||
else:
|
||||
videos = videos.unsqueeze(0).to("cuda")
|
||||
|
||||
results = []
|
||||
print(
|
||||
f">>> {prompts[0]}, frame_stride:{fs[0]}, and {num_gen[0]} generation ..."
|
||||
)
|
||||
for _ in range(num_gen[0]):
|
||||
batch_samples = image_guided_synthesis(
|
||||
model, prompts, videos, noise_shape, args.ddim_steps,
|
||||
args.ddim_eta, args.unconditional_guidance_scale,
|
||||
fps[0] // fs[0], args.text_input, args.timestep_spacing,
|
||||
args.guidance_rescale)
|
||||
results.extend(batch_samples)
|
||||
videos = repeat(batch_samples[0][:, :, -1, :, :].unsqueeze(2),
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=batch_samples[0].shape[2])
|
||||
batch_samples = [torch.concat(results, axis=2)]
|
||||
|
||||
# Save each example individually
|
||||
for nn, samples in enumerate(batch_samples):
|
||||
prompt = prompts[nn]
|
||||
filename = filenames[nn]
|
||||
save_results_seperate(prompt,
|
||||
samples,
|
||||
filename,
|
||||
fakedir,
|
||||
fps=8)
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
"""
|
||||
Create and return the argument parser.
|
||||
|
||||
Returns:
|
||||
argparse.ArgumentParser: Parser for command-line arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--savedir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the results.")
|
||||
parser.add_argument("--ckpt_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the model checkpoint.")
|
||||
parser.add_argument("--config",
|
||||
type=str,
|
||||
help="Path to the YAML configuration file.")
|
||||
parser.add_argument(
|
||||
"--prompt_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory containing videos and corresponding prompts.")
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
|
||||
parser.add_argument("--bs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size for inference. Must be 1.")
|
||||
parser.add_argument("--height",
|
||||
type=int,
|
||||
default=320,
|
||||
help="Height of the generated images in pixels.")
|
||||
parser.add_argument("--width",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Width of the generated images in pixels.")
|
||||
parser.add_argument(
|
||||
"--unconditional_guidance_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale for classifier-free guidance during sampling.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=123,
|
||||
help="Random seed for reproducibility.")
|
||||
parser.add_argument("--video_length",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of frames in the generated video.")
|
||||
parser.add_argument(
|
||||
"--text_input",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
"Whether to provide a text prompt as input to the image-to-video model."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timestep_spacing",
|
||||
type=str,
|
||||
default="uniform",
|
||||
help=
|
||||
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_rescale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help=
|
||||
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perframe_ae",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
seed = args.seed
|
||||
if seed < 0:
|
||||
seed = random.randint(0, 2**31)
|
||||
seed_everything(seed)
|
||||
rank, gpu_num = 0, 1
|
||||
run_inference(args, gpu_num, rank)
|
||||
77
scripts/evaluation/eval_utils.py
Normal file
77
scripts/evaluation/eval_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
import warnings
|
||||
import torchvision
|
||||
import sys
|
||||
import pyarrow as pa
|
||||
import logging
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, ClassVar, Deque, Mapping, Union
|
||||
from datasets.features.features import register_feature
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
"""
|
||||
Provides a type for a dataset containing video frames.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
|
||||
features = {"image": VideoFrame()}
|
||||
Dataset.from_dict(data_dict, features=Features(features))
|
||||
```
|
||||
"""
|
||||
|
||||
pa_type: ClassVar[Any] = pa.struct({
|
||||
"path": pa.string(),
|
||||
"timestamp": pa.float32()
|
||||
})
|
||||
_type: str = field(default="VideoFrame", init=False, repr=False)
|
||||
|
||||
def __call__(self):
|
||||
return self.pa_type
|
||||
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
"'register_feature' is experimental and might be subject to breaking changes in the future.",
|
||||
category=UserWarning,
|
||||
)
|
||||
register_feature(VideoFrame, "VideoFrame")
|
||||
|
||||
|
||||
def populate_queues(
|
||||
queues: Dict[str, Deque[Any]],
|
||||
batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]:
|
||||
|
||||
for key in batch:
|
||||
if key not in queues:
|
||||
continue
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
while len(queues[key]) != queues[key].maxlen:
|
||||
queues[key].append(batch[key])
|
||||
else:
|
||||
queues[key].append(batch[key])
|
||||
return queues
|
||||
|
||||
|
||||
def log_to_tensorboard(
|
||||
writer: SummaryWriter,
|
||||
data: Union[torch.Tensor, Any],
|
||||
tag: str,
|
||||
fps: int = 10) -> None:
|
||||
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
||||
video = data
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
writer.add_video(tag, grid, fps=fps)
|
||||
463
scripts/evaluation/real_eval_server.py
Normal file
463
scripts/evaluation/real_eval_server.py
Normal file
@@ -0,0 +1,463 @@
|
||||
import argparse, os, sys
|
||||
import torch
|
||||
import torchvision
|
||||
import warnings
|
||||
import imageio
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
plt.switch_backend('agg')
|
||||
import traceback
|
||||
import uvicorn
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import nn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import JSONResponse
|
||||
from typing import Any, Dict, Optional, Tuple, List
|
||||
from datetime import datetime
|
||||
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
|
||||
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
"""Get a module's device by checking one of its parameters.
|
||||
|
||||
Args:
|
||||
module (nn.Module): PyTorch module.
|
||||
|
||||
Returns:
|
||||
torch.device: The device where the module's parameters are stored.
|
||||
"""
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model to load weights into.
|
||||
ckpt (str): Path to checkpoint file.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with loaded weights.
|
||||
"""
|
||||
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
if "state_dict" in list(state_dict.keys()):
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
except:
|
||||
new_pl_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_pl_sd[k] = v
|
||||
|
||||
for k in list(new_pl_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_key = k.replace("framestride_embed", "fps_embedding")
|
||||
new_pl_sd[new_key] = new_pl_sd[k]
|
||||
del new_pl_sd[k]
|
||||
model.load_state_dict(new_pl_sd, strict=False)
|
||||
else:
|
||||
new_pl_sd = OrderedDict()
|
||||
for key in state_dict['module'].keys():
|
||||
new_pl_sd[key[16:]] = state_dict['module'][key]
|
||||
model.load_state_dict(new_pl_sd)
|
||||
print('>>> model checkpoint loaded.')
|
||||
return model
|
||||
|
||||
|
||||
def write_video(video_path: str, stacked_frames: List[Any], fps: int) -> None:
|
||||
"""Write a video to disk using imageio.
|
||||
|
||||
Args:
|
||||
video_path (str): Path to save the video.
|
||||
stacked_frames (List[Any]): Frames to write.
|
||||
fps (int): Frames per second.
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore",
|
||||
"pkg_resources is deprecated as an API",
|
||||
category=DeprecationWarning)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
|
||||
|
||||
def save_results(video: torch.Tensor, filename: str, fps: int = 8) -> None:
|
||||
"""Save a video tensor as an MP4 file.
|
||||
|
||||
Args:
|
||||
video (torch.Tensor): Video tensor of shape (B, C, T, H, W).
|
||||
filename (str): Path to save video.
|
||||
fps (int, optional): Frame rate. Defaults to 8.
|
||||
|
||||
"""
|
||||
video = video.detach().cpu()
|
||||
video = torch.clamp(video.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(filename,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def get_latent_z(model: nn.Module, videos: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode videos into latent space.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model with `encode_first_stage` method.
|
||||
videos (torch.Tensor): Input videos (B, C, T, H, W).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Latent representation (B, C, T, H, W).
|
||||
|
||||
"""
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
|
||||
|
||||
def image_guided_synthesis(
|
||||
model: torch.nn.Module,
|
||||
prompts: list[str],
|
||||
observation: Dict[str, torch.Tensor],
|
||||
noise_shape: tuple[int, int, int, int, int],
|
||||
ddim_steps: int = 50,
|
||||
ddim_eta: float = 1.0,
|
||||
unconditional_guidance_scale: float = 1.0,
|
||||
fs: int | None = None,
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
**kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Run inference with DDIM sampling.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Diffusion model.
|
||||
prompts (Any): Conditioning text prompts.
|
||||
observation (Dict[str, torch.Tensor]): Observation dictionary.
|
||||
noise_shape (List[int]): Shape of noise tensor.
|
||||
ddim_steps (int, optional): Number of DDIM steps. Defaults to 50.
|
||||
ddim_eta (float, optional): Sampling eta. Defaults to 1.0.
|
||||
unconditional_guidance_scale (float, optional): Guidance scale. Defaults to 1.0.
|
||||
fs (Optional[int], optional): Frame stride or FPS. Defaults to None.
|
||||
timestep_spacing (str, optional): Spacing strategy. Defaults to "uniform".
|
||||
guidance_rescale (float, optional): Guidance rescale. Defaults to 0.0.
|
||||
**kwargs (Any): Additional arguments.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
||||
b, _, t, _, _ = noise_shape
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
batch_size = noise_shape[0]
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
img = observation['observation.images.top']
|
||||
cond_img = img[:, -1, ...]
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
img_cat_cond = z[:, :, -1:, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=noise_shape[2])
|
||||
cond = {"c_concat": [img_cat_cond]}
|
||||
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
cond_state = model.state_projector(observation['observation.state'])
|
||||
cond_state_emb = model.agent_state_pos_emb + cond_state
|
||||
|
||||
cond_action = model.action_projector(observation['action'])
|
||||
cond_action_emb = model.agent_action_pos_emb + cond_action
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
|
||||
cond["c_crossattn"] = [
|
||||
torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1)
|
||||
]
|
||||
cond["c_crossattn_action"] = [
|
||||
observation['observation.images.top'].permute(
|
||||
0, 2, 1, 3, 4)[:, :, -model.n_obs_steps_acting:],
|
||||
observation['observation.state'][:, -model.n_obs_steps_acting:]
|
||||
]
|
||||
|
||||
uc = None
|
||||
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
|
||||
if ddim_sampler is not None:
|
||||
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=ddim_eta,
|
||||
cfg_img=None,
|
||||
mask=cond_mask,
|
||||
x0=cond_z0,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
|
||||
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int,
|
||||
gpu_no: int) -> Tuple[nn.Module, List[int], Any]:
|
||||
"""
|
||||
Run inference pipeline on prompts and image inputs.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): Parsed command-line arguments.
|
||||
gpu_num (int): Number of GPUs.
|
||||
gpu_no (int): Index of the current GPU.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Load config
|
||||
config = OmegaConf.load(args.config)
|
||||
# Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
|
||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model = model.cuda(gpu_no)
|
||||
model.eval()
|
||||
print(">>> Model is successfully loaded ...")
|
||||
|
||||
# Build unnomalizer
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
|
||||
## Run over data
|
||||
assert (args.height % 16 == 0) and (
|
||||
args.width % 16
|
||||
== 0), "Error: image size [h,w] should be multiples of 16!"
|
||||
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
||||
|
||||
## Get latent noise shape
|
||||
h, w = args.height // 8, args.width // 8
|
||||
channels = model.model.diffusion_model.out_channels
|
||||
n_frames = args.video_length
|
||||
print(f'>>> Generate {n_frames} frames under each generation ...')
|
||||
noise_shape = [args.bs, channels, n_frames, h, w]
|
||||
|
||||
return model, noise_shape, data
|
||||
|
||||
|
||||
def get_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--savedir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to save the results.")
|
||||
parser.add_argument("--ckpt_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the model checkpoint.")
|
||||
parser.add_argument("--config", type=str, help="Path to the config file.")
|
||||
parser.add_argument(
|
||||
"--ddim_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
|
||||
parser.add_argument(
|
||||
"--ddim_eta",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
|
||||
parser.add_argument("--bs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size for inference. Must be 1.")
|
||||
parser.add_argument("--height",
|
||||
type=int,
|
||||
default=320,
|
||||
help="Height of the generated images in pixels.")
|
||||
parser.add_argument("--width",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Width of the generated images in pixels.")
|
||||
parser.add_argument(
|
||||
"--frame_stride",
|
||||
type=int,
|
||||
default=3,
|
||||
help=
|
||||
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unconditional_guidance_scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Scale for classifier-free guidance during sampling.")
|
||||
parser.add_argument("--seed",
|
||||
type=int,
|
||||
default=123,
|
||||
help="Random seed for reproducibility.")
|
||||
parser.add_argument("--video_length",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of frames in the generated video.")
|
||||
parser.add_argument(
|
||||
"--timestep_spacing",
|
||||
type=str,
|
||||
default="uniform",
|
||||
help=
|
||||
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--guidance_rescale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help=
|
||||
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--perframe_ae",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
class Server:
|
||||
|
||||
def __init__(self, args: argparse.Namespace) -> None:
|
||||
self.model_, self.noise_shape_, self.data_ = run_inference(args, 1, 0)
|
||||
self.args_ = args
|
||||
self.dataset_name = self.data_.dataset_configs['test']['params'][
|
||||
'dataset_name']
|
||||
self.device_ = get_device_from_parameters(self.model_)
|
||||
|
||||
def normalize_image(self, image: torch.Tensor) -> torch.Tensor:
|
||||
return (image / 255 - 0.5) * 2
|
||||
|
||||
def predict_action(self, payload: Dict[str, Any]) -> Any:
|
||||
try:
|
||||
images = payload['observation.images.top']
|
||||
states = payload['observation.state']
|
||||
actions = payload['action'] # Should be all zeros
|
||||
language_instruction = payload['language_instruction']
|
||||
|
||||
images = torch.tensor(images).cuda()
|
||||
images = self.data_.test_datasets[
|
||||
self.dataset_name].spatial_transform(images).unsqueeze(0)
|
||||
images = self.normalize_image(images)
|
||||
print(f"images shape: {images.shape} ...")
|
||||
states = torch.tensor(states)
|
||||
states = self.data_.test_datasets[self.dataset_name].normalizer(
|
||||
{'observation.state': states})['observation.state']
|
||||
states, _ = self.data_.test_datasets[
|
||||
self.dataset_name]._map_to_uni_state(states, "joint position")
|
||||
print(f"states shape: {states.shape} ...")
|
||||
actions = torch.tensor(actions)
|
||||
actions, action_mask = self.data_.test_datasets[
|
||||
self.dataset_name]._map_to_uni_action(actions,
|
||||
"joint position")
|
||||
print(f"actions shape: {actions.shape} ...")
|
||||
print("=" * 20)
|
||||
states = states.unsqueeze(0).cuda()
|
||||
actions = actions.unsqueeze(0).cuda()
|
||||
|
||||
observation = {
|
||||
'observation.images.top': images,
|
||||
'observation.state': states,
|
||||
'action': actions
|
||||
}
|
||||
observation = {
|
||||
key: observation[key].to(self.device_, non_blocking=True)
|
||||
for key in observation
|
||||
}
|
||||
|
||||
args = self.args_
|
||||
pred_videos, pred_action, _ = image_guided_synthesis(
|
||||
self.model_,
|
||||
language_instruction,
|
||||
observation,
|
||||
self.noise_shape_,
|
||||
ddim_steps=args.ddim_steps,
|
||||
ddim_ets=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=30 / args.frame_stride,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale)
|
||||
|
||||
pred_action = pred_action[..., action_mask[0] == 1.0][0].cpu()
|
||||
pred_action = self.data_.test_datasets[
|
||||
self.dataset_name].unnormalizer({'action':
|
||||
pred_action})['action']
|
||||
|
||||
os.makedirs(args.savedir, exist_ok=True)
|
||||
current_time = datetime.now().strftime("%H:%M:%S")
|
||||
video_file = f'{args.savedir}/{current_time}.mp4'
|
||||
save_results(pred_videos.cpu(), video_file)
|
||||
|
||||
response = {
|
||||
'result': 'ok',
|
||||
'action': pred_action.tolist(),
|
||||
'desc': 'success'
|
||||
}
|
||||
return JSONResponse(response)
|
||||
|
||||
except:
|
||||
logging.error(traceback.format_exc())
|
||||
logging.warning(
|
||||
"Your request threw an error; make sure your request complies with the expected format:\n"
|
||||
"{'image': np.ndarray, 'instruction': str}\n"
|
||||
"You can optionally an `unnorm_key: str` to specific the dataset statistics you want to use for "
|
||||
"de-normalizing the output actions.")
|
||||
return {'result': 'error', 'desc': traceback.format_exc()}
|
||||
|
||||
def run(self, host: str = "127.0.0.1", port: int = 8000) -> None:
|
||||
self.app = FastAPI()
|
||||
self.app.post("/predict_action")(self.predict_action)
|
||||
print(">>> Inference server is ready ... ")
|
||||
uvicorn.run(self.app, host=host, port=port)
|
||||
print(">>> Inference server stops ... ")
|
||||
return
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
seed = args.seed
|
||||
seed_everything(seed)
|
||||
rank, gpu_num = 0, 1
|
||||
print(">>> Launch inference server ... ")
|
||||
server = Server(args)
|
||||
server.run()
|
||||
1220
scripts/evaluation/world_model_interaction.py
Normal file
1220
scripts/evaluation/world_model_interaction.py
Normal file
File diff suppressed because it is too large
Load Diff
23
scripts/run_base_model_inference.sh
Normal file
23
scripts/run_base_model_inference.sh
Normal file
@@ -0,0 +1,23 @@
|
||||
#!/bin/bash
|
||||
|
||||
model_name=base_model
|
||||
ckpt=/path/to/base/model
|
||||
config=configs/inference/base_model_inference.yaml
|
||||
res_dir="/path/to/result/directory"
|
||||
seed=123
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/base_model_inference.py \
|
||||
--seed ${seed} \
|
||||
--ckpt_path $ckpt \
|
||||
--config $config \
|
||||
--savedir "${res_dir}/videos" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
--unconditional_guidance_scale 1.0 \
|
||||
--ddim_steps 16 \
|
||||
--ddim_eta 1.0 \
|
||||
--prompt_dir "/path/to/examples/base_model_prompts" \
|
||||
--text_input \
|
||||
--video_length 16 \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
26
scripts/run_real_eval_server.sh
Normal file
26
scripts/run_real_eval_server.sh
Normal file
@@ -0,0 +1,26 @@
|
||||
model_name=testing
|
||||
ckpt=/path/to/model/checkpoint
|
||||
config=configs/inference/world_model_decision_making.yaml
|
||||
seed=123
|
||||
res_dir="path/to/results/directory"
|
||||
datasets=(
|
||||
"unitree_g1_pack_camera"
|
||||
)
|
||||
|
||||
|
||||
for dataset in "${datasets[@]}"; do
|
||||
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/real_eval_server.py \
|
||||
--seed ${seed} \
|
||||
--ckpt_path $ckpt \
|
||||
--config $config \
|
||||
--savedir "${res_dir}/${dataset}/${model_name}/videos" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
--unconditional_guidance_scale 1.0 \
|
||||
--ddim_steps 16 \
|
||||
--ddim_eta 1.0 \
|
||||
--video_length 16 \
|
||||
--frame_stride 2 \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
done
|
||||
42
scripts/run_world_model_interaction.sh
Normal file
42
scripts/run_world_model_interaction.sh
Normal file
@@ -0,0 +1,42 @@
|
||||
model_name=testing
|
||||
ckpt=/path/to/model/checkpoint
|
||||
config=configs/inference/world_model_interaction.yaml
|
||||
seed=123
|
||||
res_dir="/path/to/result/directory"
|
||||
|
||||
datasets=(
|
||||
"unitree_z1_stackbox"
|
||||
"unitree_z1_dual_arm_stackbox"
|
||||
"unitree_z1_dual_arm_stackbox_v2"
|
||||
"unitree_z1_dual_arm_cleanup_pencils"
|
||||
"unitree_g1_pack_camera"
|
||||
)
|
||||
|
||||
n_iters=(12 7 11 8 11)
|
||||
fses=(4 4 4 4 6)
|
||||
|
||||
for i in "${!datasets[@]}"; do
|
||||
dataset=${datasets[$i]}
|
||||
n_iter=${n_iters[$i]}
|
||||
fs=${fses[$i]}
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed ${seed} \
|
||||
--ckpt_path $ckpt \
|
||||
--config $config \
|
||||
--savedir "${res_dir}/${model_name}/${dataset}" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
--unconditional_guidance_scale 1.0 \
|
||||
--ddim_steps 50 \
|
||||
--ddim_eta 1.0 \
|
||||
--prompt_dir "/path/to/unifolm-world-model-action/examples/world_model_interaction_prompts" \
|
||||
--dataset ${dataset} \
|
||||
--video_length 16 \
|
||||
--frame_stride ${fs} \
|
||||
--n_action_steps 16 \
|
||||
--exe_steps 16 \
|
||||
--n_iter ${n_iter} \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
done
|
||||
32
scripts/train.sh
Normal file
32
scripts/train.sh
Normal file
@@ -0,0 +1,32 @@
|
||||
# NCCL configuration
|
||||
# export NCCL_DEBUG=debug
|
||||
# export NCCL_IB_DISABLE=0
|
||||
# export NCCL_IB_GID_INDEX=3
|
||||
# export NCCL_NET_GDR_LEVEL=3
|
||||
# export CUDA_LAUNCH_BLOCKING=1
|
||||
|
||||
# export NCCL_TOPO_FILE=/tmp/topo.txt
|
||||
# export MASTER_ADDR="master.ip."
|
||||
# export MASTER_PROT=12366
|
||||
|
||||
|
||||
# args
|
||||
name="experiment_name"
|
||||
config_file=configs/train/config.yaml
|
||||
|
||||
# save root dir for logs, checkpoints, tensorboard record, etc.
|
||||
save_root="/path/to/savedir"
|
||||
|
||||
mkdir -p $save_root/$name
|
||||
|
||||
## run
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
|
||||
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=12366 --node_rank=0 \
|
||||
./scripts/trainer.py \
|
||||
--base $config_file \
|
||||
--train \
|
||||
--name $name \
|
||||
--logdir $save_root \
|
||||
--devices 8 \
|
||||
--total_gpus=8 \
|
||||
lightning.trainer.num_nodes=1
|
||||
214
scripts/trainer.py
Normal file
214
scripts/trainer.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import argparse, os, datetime
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from transformers import logging as transf_logging
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
||||
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
||||
|
||||
|
||||
def get_parser(**parser_kwargs):
|
||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||
parser.add_argument("--seed",
|
||||
"-s",
|
||||
type=int,
|
||||
default=20250912,
|
||||
help="seed for seed_everything")
|
||||
parser.add_argument("--name",
|
||||
"-n",
|
||||
type=str,
|
||||
default="",
|
||||
help="experiment name, as saving folder")
|
||||
parser.add_argument(
|
||||
"--base",
|
||||
"-b",
|
||||
nargs="*",
|
||||
metavar="base_config.yaml",
|
||||
help="paths to base configs. Loaded from left-to-right.",
|
||||
default=list())
|
||||
parser.add_argument("--train",
|
||||
"-t",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='train')
|
||||
parser.add_argument("--val",
|
||||
"-v",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='val')
|
||||
parser.add_argument("--test",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='test')
|
||||
parser.add_argument("--logdir",
|
||||
"-l",
|
||||
type=str,
|
||||
default="logs",
|
||||
help="directory for logging dat shit")
|
||||
parser.add_argument("--auto_resume",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="resume from full-info checkpoint")
|
||||
parser.add_argument("--auto_resume_weight_only",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="resume from weight-only checkpoint")
|
||||
parser.add_argument("--debug",
|
||||
"-d",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="enable post-mortem debugging")
|
||||
return parser
|
||||
|
||||
|
||||
def get_nondefault_trainer_args(args):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
default_trainer_args = parser.parse_args([])
|
||||
return sorted(k for k in vars(default_trainer_args)
|
||||
if getattr(args, k) != getattr(default_trainer_args, k))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
local_rank = int(os.environ.get('LOCAL_RANK'))
|
||||
global_rank = int(os.environ.get('RANK'))
|
||||
num_rank = int(os.environ.get('WORLD_SIZE'))
|
||||
|
||||
parser = get_parser()
|
||||
# Extends existing argparse by default Trainer attributes
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
args, unknown = parser.parse_known_args()
|
||||
transf_logging.set_verbosity_error()
|
||||
seed_everything(args.seed)
|
||||
|
||||
configs = [OmegaConf.load(cfg) for cfg in args.base]
|
||||
cli = OmegaConf.from_dotlist(unknown)
|
||||
config = OmegaConf.merge(*configs, cli)
|
||||
lightning_config = config.pop("lightning", OmegaConf.create())
|
||||
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
||||
|
||||
# Setup workspace directories
|
||||
workdir, ckptdir, cfgdir, loginfo = init_workspace(args.name, args.logdir,
|
||||
config,
|
||||
lightning_config,
|
||||
global_rank)
|
||||
logger = set_logger(
|
||||
logfile=os.path.join(loginfo, 'log_%d:%s.txt' % (global_rank, now)))
|
||||
logger.info("@lightning version: %s [>=1.8 required]" % (pl.__version__))
|
||||
logger.info("***** Configing Model *****")
|
||||
config.model.params.logdir = workdir
|
||||
model = instantiate_from_config(config.model)
|
||||
# Load checkpoints
|
||||
model = load_checkpoints(model, config.model)
|
||||
|
||||
# Register_schedule again to make ZTSNR work
|
||||
if model.rescale_betas_zero_snr:
|
||||
model.register_schedule(given_betas=model.given_betas,
|
||||
beta_schedule=model.beta_schedule,
|
||||
timesteps=model.timesteps,
|
||||
linear_start=model.linear_start,
|
||||
linear_end=model.linear_end,
|
||||
cosine_s=model.cosine_s)
|
||||
|
||||
# Update trainer config
|
||||
for k in get_nondefault_trainer_args(args):
|
||||
trainer_config[k] = getattr(args, k)
|
||||
|
||||
num_nodes = trainer_config.num_nodes
|
||||
ngpu_per_node = trainer_config.devices
|
||||
logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs")
|
||||
|
||||
# Setup learning rate
|
||||
base_lr = config.model.base_learning_rate
|
||||
bs = config.data.params.batch_size
|
||||
if getattr(config.model, 'scale_lr', True):
|
||||
model.learning_rate = num_rank * bs * base_lr
|
||||
else:
|
||||
model.learning_rate = base_lr
|
||||
|
||||
logger.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
for k in data.train_datasets:
|
||||
logger.info(
|
||||
f"{k}, {data.train_datasets[k].__class__.__name__}, {len(data.train_datasets[k])}"
|
||||
)
|
||||
if hasattr(data, 'val_datasets'):
|
||||
for k in data.val_datasets:
|
||||
logger.info(
|
||||
f"{k}, {data.val_datasets[k].__class__.__name__}, {len(data.val_datasets[k])}"
|
||||
)
|
||||
|
||||
for item in unknown:
|
||||
if item.startswith('--total_gpus'):
|
||||
num_gpus = int(item.split('=')[-1])
|
||||
break
|
||||
model.datasets_len = len(data)
|
||||
|
||||
logger.info("***** Configing Trainer *****")
|
||||
if "accelerator" not in trainer_config:
|
||||
trainer_config["accelerator"] = "gpu"
|
||||
|
||||
# Setup trainer args: pl-logger and callbacks
|
||||
trainer_kwargs = dict()
|
||||
trainer_kwargs["num_sanity_val_steps"] = 0
|
||||
logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug)
|
||||
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
|
||||
|
||||
# Setup callbacks
|
||||
callbacks_cfg = get_trainer_callbacks(lightning_config, config, workdir,
|
||||
ckptdir, logger)
|
||||
trainer_kwargs["callbacks"] = [
|
||||
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
|
||||
]
|
||||
strategy_cfg = get_trainer_strategy(lightning_config)
|
||||
trainer_kwargs["strategy"] = strategy_cfg if type(
|
||||
strategy_cfg) == str else instantiate_from_config(strategy_cfg)
|
||||
trainer_kwargs['precision'] = lightning_config.get('precision', 32)
|
||||
trainer_kwargs["sync_batchnorm"] = False
|
||||
|
||||
# Trainer config: others
|
||||
trainer_args = argparse.Namespace(**trainer_config)
|
||||
trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs)
|
||||
|
||||
# Allow checkpointing via USR1
|
||||
def melk(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
print("Summoning checkpoint.")
|
||||
ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt")
|
||||
trainer.save_checkpoint(ckpt_path)
|
||||
|
||||
def divein(*args, **kwargs):
|
||||
if trainer.global_rank == 0:
|
||||
import pudb
|
||||
pudb.set_trace()
|
||||
|
||||
import signal
|
||||
signal.signal(signal.SIGUSR1, melk)
|
||||
signal.signal(signal.SIGUSR2, divein)
|
||||
|
||||
# List the key model sizes
|
||||
total_params = get_num_parameters(model)
|
||||
|
||||
logger.info("***** Running the Loop *****")
|
||||
if args.train:
|
||||
try:
|
||||
if "strategy" in lightning_config and lightning_config[
|
||||
'strategy'].startswith('deepspeed'):
|
||||
logger.info("<Training in DeepSpeed Mode>")
|
||||
if trainer_kwargs['precision'] == 16:
|
||||
with torch.cuda.amp.autocast():
|
||||
trainer.fit(model, data)
|
||||
else:
|
||||
trainer.fit(model, data)
|
||||
else:
|
||||
logger.info("<Training in DDPSharded Mode>")
|
||||
trainer.fit(model, data)
|
||||
except Exception:
|
||||
raise
|
||||
Reference in New Issue
Block a user