第一次完整测例跑完

This commit is contained in:
2026-01-18 00:30:10 +08:00
parent ca15cc593b
commit 25c6fc04db
180 changed files with 29305 additions and 0 deletions

View File

@@ -0,0 +1,541 @@
import argparse, os, glob
import datetime, time
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import random
from pytorch_lightning import seed_everything
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
"""
Get list of files in `data_dir` with extensions in `postfixes`.
Args:
data_dir (str): Directory path.
postfixes (list[str]): List of file extensions (e.g., ['csv', 'jpg']).
Returns:
list[str]: Sorted list of matched file paths.
"""
patterns = [
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
]
file_list = []
for pattern in patterns:
file_list.extend(glob.glob(pattern))
file_list.sort()
return file_list
def load_model_checkpoint(model: torch.nn.Module,
ckpt: str) -> torch.nn.Module:
"""
Load model weights from checkpoint file.
Args:
model (torch.nn.Module): The model to load weights into.
ckpt (str): Path to the checkpoint file.
Returns:
torch.nn.Module: Model with weights loaded.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
loaded = model.load_state_dict(state_dict, strict=False)
print("Missing keys:")
for k in loaded.missing_keys:
print(f" {k}")
print("Unexpected keys:")
for k in loaded.unexpected_keys:
print(f" {k}")
except:
# Rename the keys for 256x256 model
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=False)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def load_prompts(prompt_file: str) -> list[str]:
"""
Load prompts from a text file, one per line.
Args:
prompt_file (str): Path to the prompt file.
Returns:
list[str]: List of prompt strings.
"""
f = open(prompt_file, 'r')
prompt_list = []
for idx, line in enumerate(f.readlines()):
l = line.strip()
if len(l) != 0:
prompt_list.append(l)
f.close()
return prompt_list
def load_data_prompts(
data_dir: str,
savedir: str,
video_size: tuple[int, int] = (256, 256),
video_frames: int = 16
) -> tuple[list[str], list[torch.Tensor], list[str], list[float], list[float],
list[int]]:
"""
Load image prompts, process them into video format, and retrieve metadata.
Args:
data_dir (str): Directory containing images and CSV file.
savedir (str): Output directory to check if inference was already done.
video_size (tuple[int, int], optional): Target size of video frames.
video_frames (int, optional): Number of frames in each video.
Returns:
tuple: (filenames, video tensors, prompts, fps values, fs values, num_generations)
"""
transform = transforms.Compose([
transforms.Resize(min(video_size)),
transforms.CenterCrop(video_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# Load prompt csv
prompt_file = get_filelist(data_dir, ['csv'])
assert len(prompt_file) > 0, "Error: found NO image prompt file!"
# Load image prompts
file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
data_list = []
filename_list = []
prompt_list = []
fps_list = []
fs_list = []
num_gen_list = []
prompt_csv = pd.read_csv(prompt_file[0])
n_samples = len(file_list)
for idx in range(n_samples):
image = Image.open(file_list[idx]).convert('RGB')
image_tensor = transform(image).unsqueeze(1)
frame_tensor = repeat(image_tensor,
'c t h w -> c (repeat t) h w',
repeat=video_frames)
_, filename = os.path.split(file_list[idx])
if not is_inferenced(savedir, filename):
video_id = filename[:-4]
prompt_csv['videoid'] = prompt_csv['videoid'].map(str)
if not (prompt_csv['videoid'] == video_id).any():
continue
data_list.append(frame_tensor)
filename_list.append(filename)
ins = prompt_csv[prompt_csv['videoid'] ==
video_id]['instruction'].values[0]
prompt_list.append(ins)
fps = prompt_csv[prompt_csv['videoid'] ==
video_id]['fps'].values[0]
fps_list.append(fps)
fs = prompt_csv[prompt_csv['videoid'] == video_id]['fs'].values[0]
fs_list.append(fs)
num_gen = prompt_csv[prompt_csv['videoid'] ==
video_id]['num_gen'].values[0]
num_gen_list.append(int(num_gen))
return filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list
def is_inferenced(save_dir: str, filename: str) -> bool:
"""
Check if a result video already exists.
Args:
save_dir (str): Directory where results are saved.
filename (str): Base filename to check.
Returns:
bool: True if file exists, else False.
"""
video_file = os.path.join(save_dir, f"{filename[:-4]}.mp4")
return os.path.exists(video_file)
def save_results_seperate(prompt: str | list[str],
samples: torch.Tensor,
filename: str,
fakedir: str,
fps: int = 8) -> None:
"""
Save generated video samples as .mp4 files.
Args:
prompt (str | list[str]): The prompt text.
samples (torch.Tensor): Generated video tensor of shape [B, C, T, H, W].
filename (str): Output filename.
fakedir (str): Directory to save output videos.
fps (int, optional): Frames per second.
Returns:
None
"""
prompt = prompt[0] if isinstance(prompt, list) else prompt
# Save video
videos = [samples]
savedirs = [fakedir]
for idx, video in enumerate(videos):
if video is None:
continue
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
for i in range(n):
grid = video[i, ...]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0)
path = os.path.join(savedirs[idx], f'{filename.split(".")[0]}.mp4')
torchvision.io.write_video(path,
grid,
fps=fps,
video_codec='h264',
options={'crf': '0'})
def get_latent_z(model: torch.nn.Module, videos: torch.Tensor) -> torch.Tensor:
"""
Encode videos to latent space.
Args:
model (torch.nn.Module): Model with encode_first_stage function.
videos (torch.Tensor): Video tensor of shape [B, C, T, H, W].
Returns:
torch.Tensor: Latent representation of shape [B, C, T, H, W].
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def image_guided_synthesis(model: torch.nn.Module,
prompts: list[str],
videos: torch.Tensor,
noise_shape: list[int],
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = False,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
**kwargs) -> torch.Tensor:
"""
Run DDIM-based image-to-video synthesis with hybrid/text+image guidance.
Args:
model (torch.nn.Module): Diffusion model.
prompts (list[str]): Text prompts.
videos (torch.Tensor): Input images/videos of shape [B, C, T, H, W].
noise_shape (list[int]): Latent noise shape [B, C, T, H, W].
ddim_steps (int, optional): Number of DDIM steps.
ddim_eta (float, optional): Eta value for DDIM.
unconditional_guidance_scale (float, optional): Guidance scale.
fs (int | None, optional): FPS input for sampler.
text_input (bool, optional): If True, use text guidance.
timestep_spacing (str, optional): Timestep schedule spacing.
guidance_rescale (float, optional): Rescale guidance effect.
**kwargs: Additional sampler args.
Returns:
torch.Tensor: Synthesized videos of shape [B, 1, C, T, H, W].
"""
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
if not text_input:
prompts = [""] * batch_size
b, c, t, h, w = videos.shape
img = videos[:, :, 0]
img_emb = model.embedder(img)
img_emb = model.image_proj_model(img_emb)
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
cond_emb = model.get_learned_conditioning(prompts)
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, videos)
img_cat_cond = z[:, :, :1, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=z.shape[2])
cond["c_concat"] = [img_cat_cond]
uc = None
cond_mask = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
batch_variants = []
if ddim_sampler is not None:
samples, _, _, _ = ddim_sampler.sample(
S=ddim_steps,
batch_size=batch_size,
shape=noise_shape[1:],
conditioning=cond,
eta=ddim_eta,
mask=cond_mask,
x0=None,
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants.append(batch_images)
batch_variants = torch.stack(batch_variants)
return batch_variants.permute(1, 0, 2, 3, 4, 5)
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Load config
config = OmegaConf.load(args.config)
# Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model = model.cuda(gpu_no)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
# Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
fakedir = os.path.join(args.savedir, "samples")
os.makedirs(fakedir, exist_ok=True)
# Prompt file setting
assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list = load_data_prompts(
args.prompt_dir,
args.savedir,
video_size=(args.height, args.width),
video_frames=n_frames)
num_samples = len(prompt_list)
samples_split = num_samples // gpu_num
print('>>> Prompts testing [rank:%d] %d/%d samples loaded.' %
(gpu_no, samples_split, num_samples))
indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1)))
fps_list_rank = [fps_list[i] for i in indices]
fs_list_rank = [fs_list[i] for i in indices]
prompt_list_rank = [prompt_list[i] for i in indices]
data_list_rank = [data_list[i] for i in indices]
filename_list_rank = [filename_list[i] for i in indices]
with torch.no_grad(), torch.cuda.amp.autocast():
# Create a new result csv
for idx, indice in enumerate(
tqdm(range(0, len(prompt_list_rank), args.bs),
desc=f'Sample batch')):
fps = fps_list_rank[indice:indice + args.bs]
fs = fs_list_rank[indice:indice + args.bs]
prompts = prompt_list_rank[indice:indice + args.bs]
num_gen = num_gen_list[indice:indice + args.bs]
videos = data_list_rank[indice:indice + args.bs]
filenames = filename_list_rank[indice:indice + args.bs]
if isinstance(videos, list):
videos = torch.stack(videos, dim=0).to("cuda")
else:
videos = videos.unsqueeze(0).to("cuda")
results = []
print(
f">>> {prompts[0]}, frame_stride:{fs[0]}, and {num_gen[0]} generation ..."
)
for _ in range(num_gen[0]):
batch_samples = image_guided_synthesis(
model, prompts, videos, noise_shape, args.ddim_steps,
args.ddim_eta, args.unconditional_guidance_scale,
fps[0] // fs[0], args.text_input, args.timestep_spacing,
args.guidance_rescale)
results.extend(batch_samples)
videos = repeat(batch_samples[0][:, :, -1, :, :].unsqueeze(2),
'b c t h w -> b c (repeat t) h w',
repeat=batch_samples[0].shape[2])
batch_samples = [torch.concat(results, axis=2)]
# Save each example individually
for nn, samples in enumerate(batch_samples):
prompt = prompts[nn]
filename = filenames[nn]
save_results_seperate(prompt,
samples,
filename,
fakedir,
fps=8)
def get_parser() -> argparse.ArgumentParser:
"""
Create and return the argument parser.
Returns:
argparse.ArgumentParser: Parser for command-line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config",
type=str,
help="Path to the YAML configuration file.")
parser.add_argument(
"--prompt_dir",
type=str,
default=None,
help="Directory containing videos and corresponding prompts.")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument(
"--text_input",
action='store_true',
default=False,
help=
"Whether to provide a text prompt as input to the image-to-video model."
)
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
if seed < 0:
seed = random.randint(0, 2**31)
seed_everything(seed)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)

View File

@@ -0,0 +1,77 @@
import torch
import warnings
import torchvision
import sys
import pyarrow as pa
import logging
from dataclasses import dataclass, field
from typing import Dict, Any, ClassVar, Deque, Mapping, Union
from datasets.features.features import register_feature
from torch.utils.tensorboard.writer import SummaryWriter
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@dataclass
class VideoFrame:
"""
Provides a type for a dataset containing video frames.
Example:
```python
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
features = {"image": VideoFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""
pa_type: ClassVar[Any] = pa.struct({
"path": pa.string(),
"timestamp": pa.float32()
})
_type: str = field(default="VideoFrame", init=False, repr=False)
def __call__(self):
return self.pa_type
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"'register_feature' is experimental and might be subject to breaking changes in the future.",
category=UserWarning,
)
register_feature(VideoFrame, "VideoFrame")
def populate_queues(
queues: Dict[str, Deque[Any]],
batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]:
for key in batch:
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen:
while len(queues[key]) != queues[key].maxlen:
queues[key].append(batch[key])
else:
queues[key].append(batch[key])
return queues
def log_to_tensorboard(
writer: SummaryWriter,
data: Union[torch.Tensor, Any],
tag: str,
fps: int = 10) -> None:
if isinstance(data, torch.Tensor) and data.dim() == 5:
video = data
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)

View File

@@ -0,0 +1,463 @@
import argparse, os, sys
import torch
import torchvision
import warnings
import imageio
import logging
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import traceback
import uvicorn
from omegaconf import OmegaConf
from einops import rearrange, repeat
from collections import OrderedDict
from pytorch_lightning import seed_everything
from torch import nn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from typing import Any, Dict, Optional, Tuple, List
from datetime import datetime
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.models.samplers.ddim import DDIMSampler
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Args:
module (nn.Module): PyTorch module.
Returns:
torch.device: The device where the module's parameters are stored.
"""
return next(iter(module.parameters())).device
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model to load weights into.
ckpt (str): Path to checkpoint file.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=False)
except:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=False)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def write_video(video_path: str, stacked_frames: List[Any], fps: int) -> None:
"""Write a video to disk using imageio.
Args:
video_path (str): Path to save the video.
stacked_frames (List[Any]): Frames to write.
fps (int): Frames per second.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def save_results(video: torch.Tensor, filename: str, fps: int = 8) -> None:
"""Save a video tensor as an MP4 file.
Args:
video (torch.Tensor): Video tensor of shape (B, C, T, H, W).
filename (str): Path to save video.
fps (int, optional): Frame rate. Defaults to 8.
"""
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def get_latent_z(model: nn.Module, videos: torch.Tensor) -> torch.Tensor:
"""Encode videos into latent space.
Args:
model (nn.Module): Model with `encode_first_stage` method.
videos (torch.Tensor): Input videos (B, C, T, H, W).
Returns:
torch.Tensor: Latent representation (B, C, T, H, W).
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def image_guided_synthesis(
model: torch.nn.Module,
prompts: list[str],
observation: Dict[str, torch.Tensor],
noise_shape: tuple[int, int, int, int, int],
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference with DDIM sampling.
Args:
model (nn.Module): Diffusion model.
prompts (Any): Conditioning text prompts.
observation (Dict[str, torch.Tensor]): Observation dictionary.
noise_shape (List[int]): Shape of noise tensor.
ddim_steps (int, optional): Number of DDIM steps. Defaults to 50.
ddim_eta (float, optional): Sampling eta. Defaults to 1.0.
unconditional_guidance_scale (float, optional): Guidance scale. Defaults to 1.0.
fs (Optional[int], optional): Frame stride or FPS. Defaults to None.
timestep_spacing (str, optional): Spacing strategy. Defaults to "uniform".
guidance_rescale (float, optional): Guidance rescale. Defaults to 0.0.
**kwargs (Any): Additional arguments.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
img = observation['observation.images.top']
cond_img = img[:, -1, ...]
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state = model.state_projector(observation['observation.state'])
cond_state_emb = model.agent_state_pos_emb + cond_state
cond_action = model.action_projector(observation['action'])
cond_action_emb = model.agent_action_pos_emb + cond_action
cond_action_emb = torch.zeros_like(cond_action_emb)
cond["c_crossattn"] = [
torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'].permute(
0, 2, 1, 3, 4)[:, :, -model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:]
]
uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
def run_inference(args: argparse.Namespace, gpu_num: int,
gpu_no: int) -> Tuple[nn.Module, List[int], Any]:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Load config
config = OmegaConf.load(args.config)
# Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model = model.cuda(gpu_no)
model.eval()
print(">>> Model is successfully loaded ...")
# Build unnomalizer
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
## Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
## Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
return model, noise_shape, data
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config", type=str, help="Path to the config file.")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--frame_stride",
type=int,
default=3,
help=
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
)
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
return parser
class Server:
def __init__(self, args: argparse.Namespace) -> None:
self.model_, self.noise_shape_, self.data_ = run_inference(args, 1, 0)
self.args_ = args
self.dataset_name = self.data_.dataset_configs['test']['params'][
'dataset_name']
self.device_ = get_device_from_parameters(self.model_)
def normalize_image(self, image: torch.Tensor) -> torch.Tensor:
return (image / 255 - 0.5) * 2
def predict_action(self, payload: Dict[str, Any]) -> Any:
try:
images = payload['observation.images.top']
states = payload['observation.state']
actions = payload['action'] # Should be all zeros
language_instruction = payload['language_instruction']
images = torch.tensor(images).cuda()
images = self.data_.test_datasets[
self.dataset_name].spatial_transform(images).unsqueeze(0)
images = self.normalize_image(images)
print(f"images shape: {images.shape} ...")
states = torch.tensor(states)
states = self.data_.test_datasets[self.dataset_name].normalizer(
{'observation.state': states})['observation.state']
states, _ = self.data_.test_datasets[
self.dataset_name]._map_to_uni_state(states, "joint position")
print(f"states shape: {states.shape} ...")
actions = torch.tensor(actions)
actions, action_mask = self.data_.test_datasets[
self.dataset_name]._map_to_uni_action(actions,
"joint position")
print(f"actions shape: {actions.shape} ...")
print("=" * 20)
states = states.unsqueeze(0).cuda()
actions = actions.unsqueeze(0).cuda()
observation = {
'observation.images.top': images,
'observation.state': states,
'action': actions
}
observation = {
key: observation[key].to(self.device_, non_blocking=True)
for key in observation
}
args = self.args_
pred_videos, pred_action, _ = image_guided_synthesis(
self.model_,
language_instruction,
observation,
self.noise_shape_,
ddim_steps=args.ddim_steps,
ddim_ets=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=30 / args.frame_stride,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale)
pred_action = pred_action[..., action_mask[0] == 1.0][0].cpu()
pred_action = self.data_.test_datasets[
self.dataset_name].unnormalizer({'action':
pred_action})['action']
os.makedirs(args.savedir, exist_ok=True)
current_time = datetime.now().strftime("%H:%M:%S")
video_file = f'{args.savedir}/{current_time}.mp4'
save_results(pred_videos.cpu(), video_file)
response = {
'result': 'ok',
'action': pred_action.tolist(),
'desc': 'success'
}
return JSONResponse(response)
except:
logging.error(traceback.format_exc())
logging.warning(
"Your request threw an error; make sure your request complies with the expected format:\n"
"{'image': np.ndarray, 'instruction': str}\n"
"You can optionally an `unnorm_key: str` to specific the dataset statistics you want to use for "
"de-normalizing the output actions.")
return {'result': 'error', 'desc': traceback.format_exc()}
def run(self, host: str = "127.0.0.1", port: int = 8000) -> None:
self.app = FastAPI()
self.app.post("/predict_action")(self.predict_action)
print(">>> Inference server is ready ... ")
uvicorn.run(self.app, host=host, port=port)
print(">>> Inference server stops ... ")
return
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
seed_everything(seed)
rank, gpu_num = 0, 1
print(">>> Launch inference server ... ")
server = Server(args)
server.run()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
#!/bin/bash
model_name=base_model
ckpt=/path/to/base/model
config=configs/inference/base_model_inference.yaml
res_dir="/path/to/result/directory"
seed=123
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/base_model_inference.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir "${res_dir}/videos" \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 1.0 \
--ddim_steps 16 \
--ddim_eta 1.0 \
--prompt_dir "/path/to/examples/base_model_prompts" \
--text_input \
--video_length 16 \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae

View File

@@ -0,0 +1,26 @@
model_name=testing
ckpt=/path/to/model/checkpoint
config=configs/inference/world_model_decision_making.yaml
seed=123
res_dir="path/to/results/directory"
datasets=(
"unitree_g1_pack_camera"
)
for dataset in "${datasets[@]}"; do
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/real_eval_server.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir "${res_dir}/${dataset}/${model_name}/videos" \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 1.0 \
--ddim_steps 16 \
--ddim_eta 1.0 \
--video_length 16 \
--frame_stride 2 \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae
done

View File

@@ -0,0 +1,42 @@
model_name=testing
ckpt=/path/to/model/checkpoint
config=configs/inference/world_model_interaction.yaml
seed=123
res_dir="/path/to/result/directory"
datasets=(
"unitree_z1_stackbox"
"unitree_z1_dual_arm_stackbox"
"unitree_z1_dual_arm_stackbox_v2"
"unitree_z1_dual_arm_cleanup_pencils"
"unitree_g1_pack_camera"
)
n_iters=(12 7 11 8 11)
fses=(4 4 4 4 6)
for i in "${!datasets[@]}"; do
dataset=${datasets[$i]}
n_iter=${n_iters[$i]}
fs=${fses[$i]}
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir "${res_dir}/${model_name}/${dataset}" \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 1.0 \
--ddim_steps 50 \
--ddim_eta 1.0 \
--prompt_dir "/path/to/unifolm-world-model-action/examples/world_model_interaction_prompts" \
--dataset ${dataset} \
--video_length 16 \
--frame_stride ${fs} \
--n_action_steps 16 \
--exe_steps 16 \
--n_iter ${n_iter} \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae
done

32
scripts/train.sh Normal file
View File

@@ -0,0 +1,32 @@
# NCCL configuration
# export NCCL_DEBUG=debug
# export NCCL_IB_DISABLE=0
# export NCCL_IB_GID_INDEX=3
# export NCCL_NET_GDR_LEVEL=3
# export CUDA_LAUNCH_BLOCKING=1
# export NCCL_TOPO_FILE=/tmp/topo.txt
# export MASTER_ADDR="master.ip."
# export MASTER_PROT=12366
# args
name="experiment_name"
config_file=configs/train/config.yaml
# save root dir for logs, checkpoints, tensorboard record, etc.
save_root="/path/to/savedir"
mkdir -p $save_root/$name
## run
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=12366 --node_rank=0 \
./scripts/trainer.py \
--base $config_file \
--train \
--name $name \
--logdir $save_root \
--devices 8 \
--total_gpus=8 \
lightning.trainer.num_nodes=1

214
scripts/trainer.py Normal file
View File

@@ -0,0 +1,214 @@
import argparse, os, datetime
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from transformers import logging as transf_logging
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
def get_parser(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument("--seed",
"-s",
type=int,
default=20250912,
help="seed for seed_everything")
parser.add_argument("--name",
"-n",
type=str,
default="",
help="experiment name, as saving folder")
parser.add_argument(
"--base",
"-b",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right.",
default=list())
parser.add_argument("--train",
"-t",
action='store_true',
default=False,
help='train')
parser.add_argument("--val",
"-v",
action='store_true',
default=False,
help='val')
parser.add_argument("--test",
action='store_true',
default=False,
help='test')
parser.add_argument("--logdir",
"-l",
type=str,
default="logs",
help="directory for logging dat shit")
parser.add_argument("--auto_resume",
action='store_true',
default=False,
help="resume from full-info checkpoint")
parser.add_argument("--auto_resume_weight_only",
action='store_true',
default=False,
help="resume from weight-only checkpoint")
parser.add_argument("--debug",
"-d",
action='store_true',
default=False,
help="enable post-mortem debugging")
return parser
def get_nondefault_trainer_args(args):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
default_trainer_args = parser.parse_args([])
return sorted(k for k in vars(default_trainer_args)
if getattr(args, k) != getattr(default_trainer_args, k))
if __name__ == "__main__":
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
local_rank = int(os.environ.get('LOCAL_RANK'))
global_rank = int(os.environ.get('RANK'))
num_rank = int(os.environ.get('WORLD_SIZE'))
parser = get_parser()
# Extends existing argparse by default Trainer attributes
parser = Trainer.add_argparse_args(parser)
args, unknown = parser.parse_known_args()
transf_logging.set_verbosity_error()
seed_everything(args.seed)
configs = [OmegaConf.load(cfg) for cfg in args.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# Setup workspace directories
workdir, ckptdir, cfgdir, loginfo = init_workspace(args.name, args.logdir,
config,
lightning_config,
global_rank)
logger = set_logger(
logfile=os.path.join(loginfo, 'log_%d:%s.txt' % (global_rank, now)))
logger.info("@lightning version: %s [>=1.8 required]" % (pl.__version__))
logger.info("***** Configing Model *****")
config.model.params.logdir = workdir
model = instantiate_from_config(config.model)
# Load checkpoints
model = load_checkpoints(model, config.model)
# Register_schedule again to make ZTSNR work
if model.rescale_betas_zero_snr:
model.register_schedule(given_betas=model.given_betas,
beta_schedule=model.beta_schedule,
timesteps=model.timesteps,
linear_start=model.linear_start,
linear_end=model.linear_end,
cosine_s=model.cosine_s)
# Update trainer config
for k in get_nondefault_trainer_args(args):
trainer_config[k] = getattr(args, k)
num_nodes = trainer_config.num_nodes
ngpu_per_node = trainer_config.devices
logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs")
# Setup learning rate
base_lr = config.model.base_learning_rate
bs = config.data.params.batch_size
if getattr(config.model, 'scale_lr', True):
model.learning_rate = num_rank * bs * base_lr
else:
model.learning_rate = base_lr
logger.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
for k in data.train_datasets:
logger.info(
f"{k}, {data.train_datasets[k].__class__.__name__}, {len(data.train_datasets[k])}"
)
if hasattr(data, 'val_datasets'):
for k in data.val_datasets:
logger.info(
f"{k}, {data.val_datasets[k].__class__.__name__}, {len(data.val_datasets[k])}"
)
for item in unknown:
if item.startswith('--total_gpus'):
num_gpus = int(item.split('=')[-1])
break
model.datasets_len = len(data)
logger.info("***** Configing Trainer *****")
if "accelerator" not in trainer_config:
trainer_config["accelerator"] = "gpu"
# Setup trainer args: pl-logger and callbacks
trainer_kwargs = dict()
trainer_kwargs["num_sanity_val_steps"] = 0
logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# Setup callbacks
callbacks_cfg = get_trainer_callbacks(lightning_config, config, workdir,
ckptdir, logger)
trainer_kwargs["callbacks"] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
strategy_cfg = get_trainer_strategy(lightning_config)
trainer_kwargs["strategy"] = strategy_cfg if type(
strategy_cfg) == str else instantiate_from_config(strategy_cfg)
trainer_kwargs['precision'] = lightning_config.get('precision', 32)
trainer_kwargs["sync_batchnorm"] = False
# Trainer config: others
trainer_args = argparse.Namespace(**trainer_config)
trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs)
# Allow checkpointing via USR1
def melk(*args, **kwargs):
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
signal.signal(signal.SIGUSR2, divein)
# List the key model sizes
total_params = get_num_parameters(model)
logger.info("***** Running the Loop *****")
if args.train:
try:
if "strategy" in lightning_config and lightning_config[
'strategy'].startswith('deepspeed'):
logger.info("<Training in DeepSpeed Mode>")
if trainer_kwargs['precision'] == 16:
with torch.cuda.amp.autocast():
trainer.fit(model, data)
else:
trainer.fit(model, data)
else:
logger.info("<Training in DDPSharded Mode>")
trainer.fit(model, data)
except Exception:
raise