Files
unifolm-world-model-action/scripts/evaluation/world_model_interaction.py
2026-02-10 12:46:12 +08:00

1233 lines
48 KiB
Python

import argparse, os, glob
import pandas as pd
import random
import torch
import torchvision
import h5py
import numpy as np
import logging
import einops
import warnings
import imageio
import atexit
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from typing import Optional, Dict, List, Any, Mapping
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
from eval_utils import populate_queues
from collections import deque
from torch import Tensor
from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
# ========== Async I/O ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
def _get_io_executor() -> ThreadPoolExecutor:
global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try:
fut.result()
except Exception as e:
print(f">>> [async I/O] error: {e}")
_io_futures.clear()
atexit.register(_flush_io)
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Synchronous save on CPU tensor (runs in background thread)."""
video = torch.clamp(video_cpu.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
# ========== Original Functions ==========
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Args:
module (nn.Module): The model whose device is to be inferred.
Returns:
torch.device: The device of the model's parameters.
"""
return next(iter(module.parameters())).device
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file.
Args:
video_path (str): Output path for the video.
stacked_frames (list): List of image frames.
fps (int): Frames per second for the video.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
"""Return sorted list of files in a directory matching specified postfixes.
Args:
data_dir (str): Directory path to search in.
postfixes (list[str]): List of file extensions to match.
Returns:
list[str]: Sorted list of file paths.
"""
patterns = [
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
]
file_list = []
for pattern in patterns:
file_list.extend(glob.glob(pattern))
file_list.sort()
return file_list
def _load_state_dict(model: nn.Module,
state_dict: Mapping[str, torch.Tensor],
strict: bool = True,
assign: bool = False) -> None:
if assign:
try:
model.load_state_dict(state_dict, strict=strict, assign=True)
return
except TypeError:
warnings.warn(
"load_state_dict(assign=True) not supported; "
"falling back to copy load.")
model.load_state_dict(state_dict, strict=strict)
def load_model_checkpoint(model: nn.Module,
ckpt: str,
assign: bool | None = None,
device: str | torch.device = "cpu") -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
assign (bool | None): Whether to preserve checkpoint tensor dtypes
via load_state_dict(assign=True). If None, auto-enable when a
casted checkpoint metadata is detected.
device (str | torch.device): Target device for loaded tensors.
Returns:
nn.Module: Model with loaded weights.
"""
ckpt_data = torch.load(ckpt, map_location=device, mmap=True)
use_assign = False
if assign is not None:
use_assign = assign
elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data:
use_assign = True
if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data:
state_dict = ckpt_data["state_dict"]
try:
_load_state_dict(model, state_dict, strict=True, assign=use_assign)
except Exception:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
_load_state_dict(model,
new_pl_sd,
strict=True,
assign=use_assign)
elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data:
new_pl_sd = OrderedDict()
for key in ckpt_data['module'].keys():
new_pl_sd[key[16:]] = ckpt_data['module'][key]
_load_state_dict(model, new_pl_sd, strict=True, assign=use_assign)
else:
_load_state_dict(model,
ckpt_data,
strict=True,
assign=use_assign)
print('>>> model checkpoint loaded.')
return model
def maybe_cast_module(module: nn.Module | None,
dtype: torch.dtype,
label: str) -> None:
if module is None:
return
try:
param = next(module.parameters())
except StopIteration:
print(f">>> {label} has no parameters; skip cast")
return
if param.dtype == dtype:
print(f">>> {label} already {dtype}; skip cast")
return
module.to(dtype=dtype)
print(f">>> {label} cast to {dtype}")
def save_casted_checkpoint(model: nn.Module,
save_path: str,
metadata: Optional[Dict[str, Any]] = None) -> None:
if not save_path:
return
save_dir = os.path.dirname(save_path)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
cpu_state = {}
for key, value in model.state_dict().items():
if isinstance(value, torch.Tensor):
cpu_state[key] = value.detach().to("cpu")
else:
cpu_state[key] = value
payload: Dict[str, Any] = {"state_dict": cpu_state}
if metadata:
payload["precision_metadata"] = metadata
torch.save(payload, save_path)
print(f">>> Saved casted checkpoint to {save_path}")
def _module_param_dtype(module: nn.Module | None) -> str:
if module is None:
return "None"
dtype_counts: Dict[str, int] = {}
for param in module.parameters():
dtype_key = str(param.dtype)
dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel()
if not dtype_counts:
return "no_params"
if len(dtype_counts) == 1:
return next(iter(dtype_counts))
total = sum(dtype_counts.values())
parts = []
for dtype_key in sorted(dtype_counts.keys()):
ratio = dtype_counts[dtype_key] / total
parts.append(f"{dtype_key}={ratio:.1%}")
return f"mixed({', '.join(parts)})"
def log_inference_precision(model: nn.Module) -> None:
device = "unknown"
for param in model.parameters():
device = str(param.device)
break
model_dtype = _module_param_dtype(model)
print(f">>> inference precision: model={model_dtype}, device={device}")
for attr in [
"model", "first_stage_model", "cond_stage_model", "embedder",
"image_proj_model"
]:
if hasattr(model, attr):
submodule = getattr(model, attr)
print(f">>> {attr} param dtype: {_module_param_dtype(submodule)}")
print(
">>> autocast gpu dtype default: "
f"{torch.get_autocast_gpu_dtype()} "
f"(enabled={torch.is_autocast_enabled()})")
def is_inferenced(save_dir: str, filename: str) -> bool:
"""Check if a given filename has already been processed and saved.
Args:
save_dir (str): Directory where results are saved.
filename (str): Name of the file to check.
Returns:
bool: True if processed file exists, False otherwise.
"""
video_file = os.path.join(save_dir, "samples_separate",
f"{filename[:-4]}_sample0.mp4")
return os.path.exists(video_file)
def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
"""Save video tensor to file using torchvision.
Args:
video (Tensor): Tensor of shape (B, C, T, H, W).
filename (str): Output file path.
fps (int, optional): Frames per second. Defaults to 8.
"""
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def get_init_frame_path(data_dir: str, sample: dict) -> str:
"""Construct the init_frame path from directory and sample metadata.
Args:
data_dir (str): Base directory containing videos.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the video file.
"""
rel_video_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.png')
full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
return full_image_fp
def get_transition_path(data_dir: str, sample: dict) -> str:
"""Construct the full transition file path from directory and sample metadata.
Args:
data_dir (str): Base directory containing transition files.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the HDF5 transition file.
"""
rel_transition_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.h5')
full_transition_fp = os.path.join(data_dir, 'transitions',
rel_transition_fp)
return full_transition_fp
def prepare_init_input(start_idx: int,
init_frame_path: str,
transition_dict: dict[str, torch.Tensor],
frame_stride: int,
wma_data,
video_length: int = 16,
n_obs_steps: int = 2) -> dict[str, Tensor]:
"""
Extracts a structured sample from a video sequence including frames, states, and actions,
along with properly padded observations and pre-processed tensors for model input.
Args:
start_idx (int): Starting frame index for the current clip.
video: decord video instance.
transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
'observation.state', 'action_type', 'state_type'.
frame_stride (int): Temporal stride between sampled frames.
wma_data: Object that holds configuration and utility functions like normalization,
transformation, and resolution info.
video_length (int, optional): Number of frames to sample from the video. Default is 16.
n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
"""
indices = [start_idx + frame_stride * i for i in range(video_length)]
init_frame = Image.open(init_frame_path).convert('RGB')
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
3, 0, 1, 2).float()
if start_idx < n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = n_obs_steps - 1 - start_idx
first_slice = states[0:1, :] # (t, d)
padding = first_slice.repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
actions = transition_dict['action'][indices, :]
ori_state_dim = states.shape[-1]
ori_action_dim = actions.shape[-1]
frames_action_state_dict = {
'action': actions,
'observation.state': states,
}
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
frames_action_state_dict = wma_data.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
if wma_data.spatial_transform is not None:
init_frame = wma_data.spatial_transform(init_frame)
init_frame = (init_frame / 255 - 0.5) * 2
data = {
'observation.image': init_frame,
}
data.update(frames_action_state_dict)
return data, ori_state_dim, ori_action_dim
def get_latent_z(model, videos: Tensor) -> Tensor:
"""
Extracts latent features from a video batch using the model's first-stage encoder.
Args:
model: the world model.
videos (Tensor): Input videos of shape [B, C, T, H, W].
Returns:
Tensor: Latent video tensor of shape [B, C, T, H, W].
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_ctx = nullcontext()
if getattr(model, "vae_bf16", False) and model.device.type == "cuda":
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with vae_ctx:
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# Map to expected inputs for the policy
return_observations = {}
if isinstance(observations["pixels"], dict):
imgs = {
f"observation.images.{key}": img
for key, img in observations["pixels"].items()
}
else:
imgs = {"observation.images.top": observations["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# Sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# Sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# Convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
return_observations[imgkey] = img
return_observations["observation.state"] = torch.from_numpy(
observations["agent_pos"]).float()
return_observations['observation.state'] = model.normalize_inputs({
'observation.state':
return_observations['observation.state'].to(model.device)
})['observation.state']
return return_observations
def _move_to_device(batch: Mapping[str, Any],
device: torch.device) -> dict[str, Any]:
moved = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor) and value.device != device:
moved[key] = value.to(device, non_blocking=True)
else:
moved[key] = value
return moved
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16,
n_samples: int = 1,
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
diffusion_autocast_dtype: Optional[torch.dtype] = None,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
Args:
model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
prompts (list[str]): A list of textual prompts to guide the synthesis process.
observation (dict): A dictionary containing observed inputs including:
- 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
- 'observation.state': Tensor of shape [B, O, D] (state vector)
- 'action': Tensor of shape [B, T, D] (action sequence)
noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
typically (B, C, T, H, W).
action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
ddim_steps (int): Number of DDIM sampling steps. Default is 50.
ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16).
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
b, _, t, _, _ = noise_shape
ddim_sampler = getattr(model, "_ddim_sampler", None)
if ddim_sampler is None:
ddim_sampler = DDIMSampler(model)
model._ddim_sampler = ddim_sampler
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
if getattr(model, "encoder_mode", "autocast") == "autocast":
preprocess_ctx = torch.autocast("cuda", enabled=False)
with preprocess_ctx:
cond_img_fp32 = cond_img.float()
if hasattr(model.embedder, "preprocess"):
preprocessed = model.embedder.preprocess(cond_img_fp32)
else:
preprocessed = cond_img_fp32
if hasattr(model.embedder,
"encode_with_vision_transformer") and hasattr(
model.embedder, "preprocess"):
original_preprocess = model.embedder.preprocess
try:
model.embedder.preprocess = lambda x: x
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_img_emb = model.embedder.encode_with_vision_transformer(
preprocessed)
finally:
model.embedder.preprocess = original_preprocess
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_img_emb = model.embedder(preprocessed)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_img_emb = model.embedder(cond_img)
else:
cond_img_emb = model.embedder(cond_img)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
if not text_input:
prompts = [""] * batch_size
encoder_ctx = nullcontext()
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
encoder_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with encoder_ctx:
cond_ins_emb = model.get_learned_conditioning(prompts)
target_dtype = cond_ins_emb.dtype
cond_img_emb = model._projector_forward(model.image_proj_model,
cond_img_emb, target_dtype)
cond_state_emb = model._projector_forward(
model.state_projector, observation['observation.state'],
target_dtype)
cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to(
dtype=target_dtype)
cond_action_emb = model._projector_forward(
model.action_projector, observation['action'], target_dtype)
cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to(
dtype=target_dtype)
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
cond["c_crossattn"] = [
torch.cat(
[cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :,
-model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:],
sim_mode,
False,
]
uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
if ddim_sampler is not None:
autocast_ctx = nullcontext()
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
with autocast_ctx:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
if getattr(model, "vae_bf16", False):
if samples.dtype != torch.bfloat16:
samples = samples.to(dtype=torch.bfloat16)
vae_ctx = nullcontext()
if model.device.type == "cuda":
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with vae_ctx:
batch_images = model.decode_first_stage(samples)
else:
if samples.dtype != torch.float32:
samples = samples.float()
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Create inference dir
os.makedirs(args.savedir + '/inference', exist_ok=True)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
# Load config (always needed for data setup)
config = OmegaConf.load(args.config)
prepared_path = args.ckpt_path + ".prepared.pt"
if os.path.exists(prepared_path):
# ---- Fast path: load the fully-prepared model ----
print(f">>> Loading prepared model from {prepared_path} ...")
model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}",
weights_only=False,
mmap=True)
model.eval()
diffusion_autocast_dtype = (torch.bfloat16
if args.diffusion_dtype == "bf16"
else None)
print(f">>> Prepared model loaded.")
else:
# ---- Normal path: construct + checkpoint + casting ----
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path,
device=f"cuda:{gpu_no}")
model.eval()
model = model.cuda(gpu_no) # move residual buffers not in state_dict
print(f'>>> Load pre-trained model ...')
diffusion_autocast_dtype = None
if args.diffusion_dtype == "bf16":
maybe_cast_module(
model.model,
torch.bfloat16,
"diffusion backbone",
)
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
maybe_cast_module(
model.first_stage_model,
vae_weight_dtype,
"VAE",
)
model.vae_bf16 = args.vae_dtype == "bf16"
print(f">>> VAE dtype set to {args.vae_dtype}")
# --- VAE performance optimizations ---
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
vae = model.first_stage_model
# torch.compile: fuses GroupNorm+SiLU, conv chains, etc.
if args.vae_compile:
vae.decoder = torch.compile(vae.decoder, mode="reduce-overhead")
vae.encoder = torch.compile(vae.encoder, mode="reduce-overhead")
print(">>> VAE encoder/decoder compiled with torch.compile (reduce-overhead)")
# Batch decode size
vae_decode_bs = args.vae_decode_bs if args.vae_decode_bs > 0 else 9999
model.vae_decode_bs = vae_decode_bs
model.vae_encode_bs = vae_decode_bs
if args.vae_decode_bs > 0:
print(f">>> VAE encode/decode batch size set to {args.vae_decode_bs}")
else:
print(">>> VAE encode/decode batch size: all frames at once")
encoder_mode = args.encoder_mode
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
maybe_cast_module(
model.cond_stage_model,
encoder_weight_dtype,
"cond_stage_model",
)
if hasattr(model, "embedder") and model.embedder is not None:
maybe_cast_module(
model.embedder,
encoder_weight_dtype,
"embedder",
)
model.encoder_bf16 = encoder_bf16
model.encoder_mode = encoder_mode
print(
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
)
projector_mode = args.projector_mode
projector_bf16 = projector_mode in ("autocast", "bf16_full")
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
maybe_cast_module(
model.image_proj_model,
projector_weight_dtype,
"image_proj_model",
)
if hasattr(model, "state_projector") and model.state_projector is not None:
maybe_cast_module(
model.state_projector,
projector_weight_dtype,
"state_projector",
)
if hasattr(model, "action_projector") and model.action_projector is not None:
maybe_cast_module(
model.action_projector,
projector_weight_dtype,
"action_projector",
)
if hasattr(model, "projector_bf16"):
model.projector_bf16 = projector_bf16
model.projector_mode = projector_mode
print(
f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})"
)
log_inference_precision(model)
if args.export_casted_ckpt:
metadata = {
"diffusion_dtype": args.diffusion_dtype,
"vae_dtype": args.vae_dtype,
"encoder_mode": args.encoder_mode,
"projector_mode": args.projector_mode,
"perframe_ae": args.perframe_ae,
}
save_casted_checkpoint(model, args.export_casted_ckpt, metadata)
if args.export_only:
print(">>> export_only set; skipping inference.")
return
# Save prepared model for fast loading next time
if prepared_path:
print(f">>> Saving prepared model to {prepared_path} ...")
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
device = get_device_from_parameters(model)
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
# Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
# Start inference
for idx in range(0, len(df)):
sample = df.iloc[idx]
# Got initial frame path
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
os.makedirs(video_save_dir, exist_ok=True)
os.makedirs(video_save_dir + '/dm', exist_ok=True)
os.makedirs(video_save_dir + '/wm', exist_ok=True)
# Load transitions to get the initial state later
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# If many, test various frequence control and world-model generation
for fs in args.frame_stride:
# For saving imagens in policy
sample_save_dir = f'{video_save_dir}/dm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For saving environmental changes in world-model
sample_save_dir = f'{video_save_dir}/wm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For collecting interaction videos
wm_video = []
# Initialize observation queues
cond_obs_queues = {
"observation.images.top":
deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
# Obtain initial frame and state
start_idx = 0
model_input_fs = ori_fps // fs
batch, ori_state_dim, ori_action_dim = prepare_init_input(
start_idx,
init_frame_path,
transition_dict,
fs,
data.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2,
3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
}
observation = _move_to_device(observation, device)
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Multi-round interaction with the world-model
for itr in tqdm(range(args.n_iter)):
log_every = max(1, args.step_log_every)
log_step = (itr % log_every == 0)
# Get observation
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = _move_to_device(observation, device)
# Use world-model in policy to generate action
if log_step:
print(f'>>> Step {itr}: generating actions ...')
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
diffusion_autocast_dtype=diffusion_autocast_dtype)
# Update future actions in the observation queues
for act_idx in range(len(pred_actions[0])):
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
obs_update['action'][:, ori_action_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
obs_update)
# Collect data for interacting the world-model using the predicted actions
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = _move_to_device(observation, device)
# Interaction with the world-model
if log_step:
print(f'>>> Step {itr}: interacting with world model ...')
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
diffusion_autocast_dtype=diffusion_autocast_dtype)
for step_idx in range(args.exe_steps):
obs_update = {
'observation.images.top':
pred_videos_1[0][:, step_idx:step_idx + 1].permute(1, 0, 2, 3),
'observation.state':
torch.zeros_like(pred_states[0][step_idx:step_idx + 1]) if
args.zero_pred_state else pred_states[0][step_idx:step_idx + 1],
'action':
torch.zeros_like(pred_actions[0][-1:])
}
obs_update['observation.state'][:, ori_state_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
obs_update)
# Save the imagen videos for decision-making (async)
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results_async(pred_videos_0,
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results_async(pred_videos_1,
sample_video_file,
fps=args.save_fps)
print('>' * 24)
# Collect the result of world-model interactions
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
full_video = torch.cat(wm_video, dim=2)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
# Wait for all async I/O to complete
_flush_io()
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config",
type=str,
help="Path to the model checkpoint.")
parser.add_argument(
"--prompt_dir",
type=str,
default=None,
help="Directory containing videos and corresponding prompts.")
parser.add_argument("--dataset",
type=str,
default=None,
help="the name of dataset to test")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--frame_stride",
type=int,
nargs='+',
required=True,
help=
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
)
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument("--num_generation",
type=int,
default=1,
help="seed for seed_everything")
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
parser.add_argument(
"--diffusion_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for diffusion backbone weights and sampling autocast."
)
parser.add_argument(
"--projector_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="fp32",
help=
"Projector precision mode for image/state/action projectors: "
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
"bf16_full=bf16 weights + bf16 forward."
)
parser.add_argument(
"--encoder_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="fp32",
help=
"Encoder precision mode for cond_stage_model/embedder: "
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
"bf16_full=bf16 weights + bf16 forward."
)
parser.add_argument(
"--vae_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast."
)
parser.add_argument(
"--vae_compile",
action='store_true',
default=False,
help="Apply torch.compile to VAE decoder for kernel fusion."
)
parser.add_argument(
"--vae_decode_bs",
type=int,
default=0,
help="VAE decode batch size (0=all frames at once). Reduces kernel launch overhead."
)
parser.add_argument(
"--export_casted_ckpt",
type=str,
default=None,
help=
"Save a checkpoint after applying precision settings (mixed dtypes preserved)."
)
parser.add_argument(
"--export_only",
action='store_true',
default=False,
help="Exit after exporting the casted checkpoint."
)
parser.add_argument(
"--step_log_every",
type=int,
default=1,
help="Print per-iteration step logs every N iterations."
)
parser.add_argument(
"--n_action_steps",
type=int,
default=16,
help="num of samples per prompt",
)
parser.add_argument(
"--exe_steps",
type=int,
default=16,
help="num of samples to execute",
)
parser.add_argument(
"--n_iter",
type=int,
default=40,
help="num of iteration to interact with the world model",
)
parser.add_argument("--zero_pred_state",
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument("--save_fps",
type=int,
default=8,
help="fps for the saving video")
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
if seed < 0:
seed = random.randint(0, 2**31)
seed_everything(seed)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)