Files
unifolm-world-model-action/scripts/evaluation/world_model_interaction.py
olivame 7e501b17fd 把混和精度模型权重导出至本地文件,减少dtype开销
--export_casted_ckpt ckpts/unifolm_wma_dual_mixbf16.ckpt \
        --export_only
2026-01-19 15:14:01 +08:00

1608 lines
64 KiB
Python

import argparse, os, glob
import pandas as pd
import random
import torch
import torchvision
import h5py
import numpy as np
import logging
import einops
import warnings
import imageio
import time
import json
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass, field, asdict
from typing import Optional, Dict, List, Any, Mapping
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
from eval_utils import populate_queues, log_to_tensorboard
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
# ========== Profiling Infrastructure ==========
@dataclass
class TimingRecord:
"""Record for a single timing measurement."""
name: str
start_time: float = 0.0
end_time: float = 0.0
cuda_time_ms: float = 0.0
count: int = 0
children: List['TimingRecord'] = field(default_factory=list)
@property
def cpu_time_ms(self) -> float:
return (self.end_time - self.start_time) * 1000
def to_dict(self) -> dict:
return {
'name': self.name,
'cpu_time_ms': self.cpu_time_ms,
'cuda_time_ms': self.cuda_time_ms,
'count': self.count,
'children': [c.to_dict() for c in self.children]
}
class ProfilerManager:
"""Manages macro and micro-level profiling."""
def __init__(
self,
enabled: bool = False,
output_dir: str = "./profile_output",
profile_detail: str = "light",
):
self.enabled = enabled
self.output_dir = output_dir
self.profile_detail = profile_detail
self.macro_timings: Dict[str, List[float]] = {}
self.cuda_events: Dict[str, List[tuple]] = {}
self.memory_snapshots: List[Dict] = []
self.pytorch_profiler = None
self.current_iteration = 0
self.operator_stats: Dict[str, Dict] = {}
self.profiler_config = self._build_profiler_config(profile_detail)
if enabled:
os.makedirs(output_dir, exist_ok=True)
def _build_profiler_config(self, profile_detail: str) -> Dict[str, Any]:
"""Return profiler settings based on the requested detail level."""
if profile_detail not in ("light", "full"):
raise ValueError(f"Unsupported profile_detail: {profile_detail}")
if profile_detail == "full":
return {
"record_shapes": True,
"profile_memory": True,
"with_stack": True,
"with_flops": True,
"with_modules": True,
"group_by_input_shape": True,
}
return {
"record_shapes": False,
"profile_memory": False,
"with_stack": False,
"with_flops": False,
"with_modules": False,
"group_by_input_shape": False,
}
@contextmanager
def profile_section(self, name: str, sync_cuda: bool = True):
"""Context manager for profiling a code section."""
if not self.enabled:
yield
return
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
start_event = None
end_event = None
if torch.cuda.is_available():
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
start_time = time.perf_counter()
try:
yield
finally:
if sync_cuda and torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.perf_counter()
cpu_time_ms = (end_time - start_time) * 1000
cuda_time_ms = 0.0
if start_event is not None and end_event is not None:
end_event.record()
torch.cuda.synchronize()
cuda_time_ms = start_event.elapsed_time(end_event)
if name not in self.macro_timings:
self.macro_timings[name] = []
self.macro_timings[name].append(cpu_time_ms)
if name not in self.cuda_events:
self.cuda_events[name] = []
self.cuda_events[name].append((cpu_time_ms, cuda_time_ms))
def record_memory(self, tag: str = ""):
"""Record current GPU memory state."""
if not self.enabled or not torch.cuda.is_available():
return
snapshot = {
'tag': tag,
'iteration': self.current_iteration,
'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
'reserved_mb': torch.cuda.memory_reserved() / 1024**2,
'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
}
self.memory_snapshots.append(snapshot)
def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3):
"""Start PyTorch profiler for operator-level analysis."""
if not self.enabled:
return nullcontext()
self.pytorch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=wait, warmup=warmup, active=active, repeat=1
),
on_trace_ready=self._trace_handler,
record_shapes=self.profiler_config["record_shapes"],
profile_memory=self.profiler_config["profile_memory"],
with_stack=self.profiler_config["with_stack"],
with_flops=self.profiler_config["with_flops"],
with_modules=self.profiler_config["with_modules"],
)
return self.pytorch_profiler
def _trace_handler(self, prof):
"""Handle profiler trace output."""
trace_path = os.path.join(
self.output_dir,
f"trace_iter_{self.current_iteration}.json"
)
prof.export_chrome_trace(trace_path)
# Extract operator statistics
key_averages = prof.key_averages(
group_by_input_shape=self.profiler_config["group_by_input_shape"]
)
for evt in key_averages:
op_name = evt.key
if op_name not in self.operator_stats:
self.operator_stats[op_name] = {
'count': 0,
'cpu_time_total_us': 0,
'cuda_time_total_us': 0,
'self_cpu_time_total_us': 0,
'self_cuda_time_total_us': 0,
'cpu_memory_usage': 0,
'cuda_memory_usage': 0,
'flops': 0,
}
stats = self.operator_stats[op_name]
stats['count'] += evt.count
stats['cpu_time_total_us'] += evt.cpu_time_total
stats['cuda_time_total_us'] += evt.cuda_time_total
stats['self_cpu_time_total_us'] += evt.self_cpu_time_total
stats['self_cuda_time_total_us'] += evt.self_cuda_time_total
if hasattr(evt, 'cpu_memory_usage'):
stats['cpu_memory_usage'] += evt.cpu_memory_usage
if hasattr(evt, 'cuda_memory_usage'):
stats['cuda_memory_usage'] += evt.cuda_memory_usage
if hasattr(evt, 'flops') and evt.flops:
stats['flops'] += evt.flops
def step_profiler(self):
"""Step the PyTorch profiler."""
if self.pytorch_profiler is not None:
self.pytorch_profiler.step()
def generate_report(self) -> str:
"""Generate comprehensive profiling report."""
if not self.enabled:
return "Profiling disabled."
report_lines = []
report_lines.append("=" * 80)
report_lines.append("PERFORMANCE PROFILING REPORT")
report_lines.append("=" * 80)
report_lines.append("")
# Macro-level timing summary
report_lines.append("-" * 40)
report_lines.append("MACRO-LEVEL TIMING SUMMARY")
report_lines.append("-" * 40)
report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}")
report_lines.append("-" * 86)
total_time = 0
timing_data = []
for name, times in sorted(self.macro_timings.items()):
cuda_times = [ct for _, ct in self.cuda_events.get(name, [])]
avg_time = np.mean(times)
avg_cuda = np.mean(cuda_times) if cuda_times else 0
total = sum(times)
total_time += total
timing_data.append({
'name': name,
'count': len(times),
'total_ms': total,
'avg_ms': avg_time,
'cuda_avg_ms': avg_cuda,
'times': times,
'cuda_times': cuda_times,
})
report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}")
report_lines.append("-" * 86)
report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}")
report_lines.append("")
# Memory summary
if self.memory_snapshots:
report_lines.append("-" * 40)
report_lines.append("GPU MEMORY SUMMARY")
report_lines.append("-" * 40)
max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots)
avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots])
report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB")
report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB")
report_lines.append("")
# Top operators by CUDA time
if self.operator_stats:
report_lines.append("-" * 40)
report_lines.append("TOP 30 OPERATORS BY CUDA TIME")
report_lines.append("-" * 40)
sorted_ops = sorted(
self.operator_stats.items(),
key=lambda x: x[1]['cuda_time_total_us'],
reverse=True
)[:30]
report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}")
report_lines.append("-" * 96)
for op_name, stats in sorted_ops:
# Truncate long operator names
display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name
report_lines.append(
f"{display_name:<50} {stats['count']:>8} "
f"{stats['cuda_time_total_us']/1000:>12.2f} "
f"{stats['cpu_time_total_us']/1000:>12.2f} "
f"{stats['self_cuda_time_total_us']/1000:>14.2f}"
)
report_lines.append("")
# Compute category breakdown
report_lines.append("-" * 40)
report_lines.append("OPERATOR CATEGORY BREAKDOWN")
report_lines.append("-" * 40)
categories = {
'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'],
'Convolution': ['conv', 'cudnn'],
'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'],
'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'],
'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'],
'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'],
'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'],
}
category_times = {cat: 0.0 for cat in categories}
category_times['Other'] = 0.0
for op_name, stats in self.operator_stats.items():
op_lower = op_name.lower()
categorized = False
for cat, keywords in categories.items():
if any(kw in op_lower for kw in keywords):
category_times[cat] += stats['cuda_time_total_us']
categorized = True
break
if not categorized:
category_times['Other'] += stats['cuda_time_total_us']
total_op_time = sum(category_times.values())
report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}")
report_lines.append("-" * 57)
for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]):
pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0
report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%")
report_lines.append("")
report = "\n".join(report_lines)
return report
def save_results(self):
"""Save all profiling results to files."""
if not self.enabled:
return
# Save report
report = self.generate_report()
report_path = os.path.join(self.output_dir, "profiling_report.txt")
with open(report_path, 'w') as f:
f.write(report)
print(f">>> Profiling report saved to: {report_path}")
# Save detailed JSON data
data = {
'macro_timings': {
name: {
'times': times,
'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])]
}
for name, times in self.macro_timings.items()
},
'memory_snapshots': self.memory_snapshots,
'operator_stats': self.operator_stats,
}
json_path = os.path.join(self.output_dir, "profiling_data.json")
with open(json_path, 'w') as f:
json.dump(data, f, indent=2)
print(f">>> Detailed profiling data saved to: {json_path}")
# Print summary to console
print("\n" + report)
# Global profiler instance
_profiler: Optional[ProfilerManager] = None
def get_profiler() -> ProfilerManager:
"""Get the global profiler instance."""
global _profiler
if _profiler is None:
_profiler = ProfilerManager(enabled=False)
return _profiler
def init_profiler(enabled: bool, output_dir: str, profile_detail: str) -> ProfilerManager:
"""Initialize the global profiler."""
global _profiler
_profiler = ProfilerManager(
enabled=enabled,
output_dir=output_dir,
profile_detail=profile_detail,
)
return _profiler
# ========== Original Functions ==========
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Args:
module (nn.Module): The model whose device is to be inferred.
Returns:
torch.device: The device of the model's parameters.
"""
return next(iter(module.parameters())).device
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file.
Args:
video_path (str): Output path for the video.
stacked_frames (list): List of image frames.
fps (int): Frames per second for the video.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
"""Return sorted list of files in a directory matching specified postfixes.
Args:
data_dir (str): Directory path to search in.
postfixes (list[str]): List of file extensions to match.
Returns:
list[str]: Sorted list of file paths.
"""
patterns = [
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
]
file_list = []
for pattern in patterns:
file_list.extend(glob.glob(pattern))
file_list.sort()
return file_list
def _load_state_dict(model: nn.Module,
state_dict: Mapping[str, torch.Tensor],
strict: bool = True,
assign: bool = False) -> None:
if assign:
try:
model.load_state_dict(state_dict, strict=strict, assign=True)
return
except TypeError:
warnings.warn(
"load_state_dict(assign=True) not supported; "
"falling back to copy load.")
model.load_state_dict(state_dict, strict=strict)
def load_model_checkpoint(model: nn.Module,
ckpt: str,
assign: bool | None = None) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
assign (bool | None): Whether to preserve checkpoint tensor dtypes
via load_state_dict(assign=True). If None, auto-enable when a
casted checkpoint metadata is detected.
Returns:
nn.Module: Model with loaded weights.
"""
ckpt_data = torch.load(ckpt, map_location="cpu")
use_assign = False
if assign is not None:
use_assign = assign
elif isinstance(ckpt_data, Mapping) and "precision_metadata" in ckpt_data:
use_assign = True
if isinstance(ckpt_data, Mapping) and "state_dict" in ckpt_data:
state_dict = ckpt_data["state_dict"]
try:
_load_state_dict(model, state_dict, strict=True, assign=use_assign)
except Exception:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
_load_state_dict(model,
new_pl_sd,
strict=True,
assign=use_assign)
elif isinstance(ckpt_data, Mapping) and "module" in ckpt_data:
new_pl_sd = OrderedDict()
for key in ckpt_data['module'].keys():
new_pl_sd[key[16:]] = ckpt_data['module'][key]
_load_state_dict(model, new_pl_sd, strict=True, assign=use_assign)
else:
_load_state_dict(model,
ckpt_data,
strict=True,
assign=use_assign)
print('>>> model checkpoint loaded.')
return model
def maybe_cast_module(module: nn.Module | None,
dtype: torch.dtype,
label: str,
profiler: Optional[ProfilerManager] = None,
profile_name: Optional[str] = None) -> None:
if module is None:
return
try:
param = next(module.parameters())
except StopIteration:
print(f">>> {label} has no parameters; skip cast")
return
if param.dtype == dtype:
print(f">>> {label} already {dtype}; skip cast")
return
ctx = nullcontext()
if profiler is not None and profile_name:
ctx = profiler.profile_section(profile_name)
with ctx:
module.to(dtype=dtype)
print(f">>> {label} cast to {dtype}")
def save_casted_checkpoint(model: nn.Module,
save_path: str,
metadata: Optional[Dict[str, Any]] = None) -> None:
if not save_path:
return
save_dir = os.path.dirname(save_path)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
cpu_state = {}
for key, value in model.state_dict().items():
if isinstance(value, torch.Tensor):
cpu_state[key] = value.detach().to("cpu")
else:
cpu_state[key] = value
payload: Dict[str, Any] = {"state_dict": cpu_state}
if metadata:
payload["precision_metadata"] = metadata
torch.save(payload, save_path)
print(f">>> Saved casted checkpoint to {save_path}")
def _module_param_dtype(module: nn.Module | None) -> str:
if module is None:
return "None"
dtype_counts: Dict[str, int] = {}
for param in module.parameters():
dtype_key = str(param.dtype)
dtype_counts[dtype_key] = dtype_counts.get(dtype_key, 0) + param.numel()
if not dtype_counts:
return "no_params"
if len(dtype_counts) == 1:
return next(iter(dtype_counts))
total = sum(dtype_counts.values())
parts = []
for dtype_key in sorted(dtype_counts.keys()):
ratio = dtype_counts[dtype_key] / total
parts.append(f"{dtype_key}={ratio:.1%}")
return f"mixed({', '.join(parts)})"
def log_inference_precision(model: nn.Module) -> None:
device = "unknown"
for param in model.parameters():
device = str(param.device)
break
model_dtype = _module_param_dtype(model)
print(f">>> inference precision: model={model_dtype}, device={device}")
for attr in [
"model", "first_stage_model", "cond_stage_model", "embedder",
"image_proj_model"
]:
if hasattr(model, attr):
submodule = getattr(model, attr)
print(f">>> {attr} param dtype: {_module_param_dtype(submodule)}")
print(
">>> autocast gpu dtype default: "
f"{torch.get_autocast_gpu_dtype()} "
f"(enabled={torch.is_autocast_enabled()})")
def is_inferenced(save_dir: str, filename: str) -> bool:
"""Check if a given filename has already been processed and saved.
Args:
save_dir (str): Directory where results are saved.
filename (str): Name of the file to check.
Returns:
bool: True if processed file exists, False otherwise.
"""
video_file = os.path.join(save_dir, "samples_separate",
f"{filename[:-4]}_sample0.mp4")
return os.path.exists(video_file)
def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
"""Save video tensor to file using torchvision.
Args:
video (Tensor): Tensor of shape (B, C, T, H, W).
filename (str): Output file path.
fps (int, optional): Frames per second. Defaults to 8.
"""
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def get_init_frame_path(data_dir: str, sample: dict) -> str:
"""Construct the init_frame path from directory and sample metadata.
Args:
data_dir (str): Base directory containing videos.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the video file.
"""
rel_video_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.png')
full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
return full_image_fp
def get_transition_path(data_dir: str, sample: dict) -> str:
"""Construct the full transition file path from directory and sample metadata.
Args:
data_dir (str): Base directory containing transition files.
sample (dict): Dictionary containing 'data_dir' and 'videoid'.
Returns:
str: Full path to the HDF5 transition file.
"""
rel_transition_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.h5')
full_transition_fp = os.path.join(data_dir, 'transitions',
rel_transition_fp)
return full_transition_fp
def prepare_init_input(start_idx: int,
init_frame_path: str,
transition_dict: dict[str, torch.Tensor],
frame_stride: int,
wma_data,
video_length: int = 16,
n_obs_steps: int = 2) -> dict[str, Tensor]:
"""
Extracts a structured sample from a video sequence including frames, states, and actions,
along with properly padded observations and pre-processed tensors for model input.
Args:
start_idx (int): Starting frame index for the current clip.
video: decord video instance.
transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
'observation.state', 'action_type', 'state_type'.
frame_stride (int): Temporal stride between sampled frames.
wma_data: Object that holds configuration and utility functions like normalization,
transformation, and resolution info.
video_length (int, optional): Number of frames to sample from the video. Default is 16.
n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
"""
indices = [start_idx + frame_stride * i for i in range(video_length)]
init_frame = Image.open(init_frame_path).convert('RGB')
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
3, 0, 1, 2).float()
if start_idx < n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = n_obs_steps - 1 - start_idx
first_slice = states[0:1, :] # (t, d)
padding = first_slice.repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
actions = transition_dict['action'][indices, :]
ori_state_dim = states.shape[-1]
ori_action_dim = actions.shape[-1]
frames_action_state_dict = {
'action': actions,
'observation.state': states,
}
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
frames_action_state_dict = wma_data.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
if wma_data.spatial_transform is not None:
init_frame = wma_data.spatial_transform(init_frame)
init_frame = (init_frame / 255 - 0.5) * 2
data = {
'observation.image': init_frame,
}
data.update(frames_action_state_dict)
return data, ori_state_dim, ori_action_dim
def get_latent_z(model, videos: Tensor) -> Tensor:
"""
Extracts latent features from a video batch using the model's first-stage encoder.
Args:
model: the world model.
videos (Tensor): Input videos of shape [B, C, T, H, W].
Returns:
Tensor: Latent video tensor of shape [B, C, T, H, W].
"""
profiler = get_profiler()
with profiler.profile_section("get_latent_z/encode"):
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_ctx = nullcontext()
if getattr(model, "vae_bf16", False) and model.device.type == "cuda":
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with vae_ctx:
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def preprocess_observation(
model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
"""Convert environment observation to LeRobot format observation.
Args:
observation: Dictionary of observation batches from a Gym vector environment.
Returns:
Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
"""
# Map to expected inputs for the policy
return_observations = {}
if isinstance(observations["pixels"], dict):
imgs = {
f"observation.images.{key}": img
for key, img in observations["pixels"].items()
}
else:
imgs = {"observation.images.top": observations["pixels"]}
for imgkey, img in imgs.items():
img = torch.from_numpy(img)
# Sanity check that images are channel last
_, h, w, c = img.shape
assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
# Sanity check that images are uint8
assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
# Convert to channel first of type float32 in range [0,1]
img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
img = img.type(torch.float32)
return_observations[imgkey] = img
return_observations["observation.state"] = torch.from_numpy(
observations["agent_pos"]).float()
return_observations['observation.state'] = model.normalize_inputs({
'observation.state':
return_observations['observation.state'].to(model.device)
})['observation.state']
return return_observations
def _move_to_device(batch: Mapping[str, Any],
device: torch.device) -> dict[str, Any]:
moved = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor) and value.device != device:
moved[key] = value.to(device, non_blocking=True)
else:
moved[key] = value
return moved
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16,
n_samples: int = 1,
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
diffusion_autocast_dtype: Optional[torch.dtype] = None,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
Args:
model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
prompts (list[str]): A list of textual prompts to guide the synthesis process.
observation (dict): A dictionary containing observed inputs including:
- 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
- 'observation.state': Tensor of shape [B, O, D] (state vector)
- 'action': Tensor of shape [B, T, D] (action sequence)
noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
typically (B, C, T, H, W).
action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
ddim_steps (int): Number of DDIM sampling steps. Default is 50.
ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
diffusion_autocast_dtype (Optional[torch.dtype]): Autocast dtype for diffusion sampling (e.g., torch.bfloat16).
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
profiler = get_profiler()
b, _, t, _, _ = noise_shape
ddim_sampler = getattr(model, "_ddim_sampler", None)
if ddim_sampler is None:
ddim_sampler = DDIMSampler(model)
model._ddim_sampler = ddim_sampler
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
with profiler.profile_section("synthesis/conditioning_prep"):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
if getattr(model, "encoder_mode", "autocast") == "autocast":
preprocess_ctx = torch.autocast("cuda", enabled=False)
with preprocess_ctx:
cond_img_fp32 = cond_img.float()
if hasattr(model.embedder, "preprocess"):
preprocessed = model.embedder.preprocess(cond_img_fp32)
else:
preprocessed = cond_img_fp32
if hasattr(model.embedder,
"encode_with_vision_transformer") and hasattr(
model.embedder, "preprocess"):
original_preprocess = model.embedder.preprocess
try:
model.embedder.preprocess = lambda x: x
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_img_emb = model.embedder.encode_with_vision_transformer(
preprocessed)
finally:
model.embedder.preprocess = original_preprocess
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_img_emb = model.embedder(preprocessed)
else:
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_img_emb = model.embedder(cond_img)
else:
cond_img_emb = model.embedder(cond_img)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
if not text_input:
prompts = [""] * batch_size
encoder_ctx = nullcontext()
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
encoder_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with encoder_ctx:
cond_ins_emb = model.get_learned_conditioning(prompts)
target_dtype = cond_ins_emb.dtype
cond_img_emb = model._projector_forward(model.image_proj_model,
cond_img_emb, target_dtype)
cond_state_emb = model._projector_forward(
model.state_projector, observation['observation.state'],
target_dtype)
cond_state_emb = cond_state_emb + model.agent_state_pos_emb.to(
dtype=target_dtype)
cond_action_emb = model._projector_forward(
model.action_projector, observation['action'], target_dtype)
cond_action_emb = cond_action_emb + model.agent_action_pos_emb.to(
dtype=target_dtype)
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
cond["c_crossattn"] = [
torch.cat(
[cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :,
-model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:],
sim_mode,
False,
]
uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
if ddim_sampler is not None:
with profiler.profile_section("synthesis/ddim_sampling"):
autocast_ctx = nullcontext()
if diffusion_autocast_dtype is not None and model.device.type == "cuda":
autocast_ctx = torch.autocast("cuda", dtype=diffusion_autocast_dtype)
with autocast_ctx:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
with profiler.profile_section("synthesis/decode_first_stage"):
if getattr(model, "vae_bf16", False):
if samples.dtype != torch.bfloat16:
samples = samples.to(dtype=torch.bfloat16)
vae_ctx = nullcontext()
if model.device.type == "cuda":
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
with vae_ctx:
batch_images = model.decode_first_stage(samples)
else:
if samples.dtype != torch.float32:
samples = samples.float()
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
profiler = get_profiler()
# Create inference and tensorboard dirs
os.makedirs(args.savedir + '/inference', exist_ok=True)
log_dir = args.savedir + f"/tensorboard"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
# Load config
with profiler.profile_section("model_loading/config"):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
with profiler.profile_section("model_loading/checkpoint"):
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
print(f'>>> Load pre-trained model ...')
# Build unnomalizer
logging.info("***** Configing Data *****")
with profiler.profile_section("data_loading"):
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
with profiler.profile_section("model_to_cuda"):
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
diffusion_autocast_dtype = None
if args.diffusion_dtype == "bf16":
maybe_cast_module(
model.model,
torch.bfloat16,
"diffusion backbone",
profiler=profiler,
profile_name="model_loading/diffusion_bf16",
)
diffusion_autocast_dtype = torch.bfloat16
print(">>> diffusion backbone set to bfloat16")
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
vae_weight_dtype = torch.bfloat16 if args.vae_dtype == "bf16" else torch.float32
maybe_cast_module(
model.first_stage_model,
vae_weight_dtype,
"VAE",
profiler=profiler,
profile_name="model_loading/vae_cast",
)
model.vae_bf16 = args.vae_dtype == "bf16"
print(f">>> VAE dtype set to {args.vae_dtype}")
encoder_mode = args.encoder_mode
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
maybe_cast_module(
model.cond_stage_model,
encoder_weight_dtype,
"cond_stage_model",
profiler=profiler,
profile_name="model_loading/encoder_cond_cast",
)
if hasattr(model, "embedder") and model.embedder is not None:
maybe_cast_module(
model.embedder,
encoder_weight_dtype,
"embedder",
profiler=profiler,
profile_name="model_loading/encoder_embedder_cast",
)
model.encoder_bf16 = encoder_bf16
model.encoder_mode = encoder_mode
print(
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
)
projector_mode = args.projector_mode
projector_bf16 = projector_mode in ("autocast", "bf16_full")
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
maybe_cast_module(
model.image_proj_model,
projector_weight_dtype,
"image_proj_model",
profiler=profiler,
profile_name="model_loading/projector_image_cast",
)
if hasattr(model, "state_projector") and model.state_projector is not None:
maybe_cast_module(
model.state_projector,
projector_weight_dtype,
"state_projector",
profiler=profiler,
profile_name="model_loading/projector_state_cast",
)
if hasattr(model, "action_projector") and model.action_projector is not None:
maybe_cast_module(
model.action_projector,
projector_weight_dtype,
"action_projector",
profiler=profiler,
profile_name="model_loading/projector_action_cast",
)
if hasattr(model, "projector_bf16"):
model.projector_bf16 = projector_bf16
model.projector_mode = projector_mode
print(
f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})"
)
log_inference_precision(model)
if args.export_casted_ckpt:
metadata = {
"diffusion_dtype": args.diffusion_dtype,
"vae_dtype": args.vae_dtype,
"encoder_mode": args.encoder_mode,
"projector_mode": args.projector_mode,
"perframe_ae": args.perframe_ae,
}
save_casted_checkpoint(model, args.export_casted_ckpt, metadata)
if args.export_only:
print(">>> export_only set; skipping inference.")
return
profiler.record_memory("after_model_load")
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
# Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
# Determine profiler iterations
profile_active_iters = getattr(args, 'profile_iterations', 3)
use_pytorch_profiler = profiler.enabled and profile_active_iters > 0
# Start inference
for idx in range(0, len(df)):
sample = df.iloc[idx]
# Got initial frame path
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
os.makedirs(video_save_dir, exist_ok=True)
os.makedirs(video_save_dir + '/dm', exist_ok=True)
os.makedirs(video_save_dir + '/wm', exist_ok=True)
# Load transitions to get the initial state later
transition_path = get_transition_path(args.prompt_dir, sample)
with profiler.profile_section("load_transitions"):
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# If many, test various frequence control and world-model generation
for fs in args.frame_stride:
# For saving imagens in policy
sample_save_dir = f'{video_save_dir}/dm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For saving environmental changes in world-model
sample_save_dir = f'{video_save_dir}/wm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For collecting interaction videos
wm_video = []
# Initialize observation queues
cond_obs_queues = {
"observation.images.top":
deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
# Obtain initial frame and state
with profiler.profile_section("prepare_init_input"):
start_idx = 0
model_input_fs = ori_fps // fs
batch, ori_state_dim, ori_action_dim = prepare_init_input(
start_idx,
init_frame_path,
transition_dict,
fs,
data.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2,
3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0)
}
observation = _move_to_device(observation, device)
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Setup PyTorch profiler context if enabled
pytorch_prof_ctx = nullcontext()
if use_pytorch_profiler:
pytorch_prof_ctx = profiler.start_pytorch_profiler(
wait=1, warmup=1, active=profile_active_iters
)
# Multi-round interaction with the world-model
with pytorch_prof_ctx:
for itr in tqdm(range(args.n_iter)):
log_every = max(1, args.step_log_every)
log_step = (itr % log_every == 0)
profiler.current_iteration = itr
profiler.record_memory(f"iter_{itr}_start")
with profiler.profile_section("iteration_total"):
# Get observation
with profiler.profile_section("prepare_observation"):
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = _move_to_device(observation, device)
# Use world-model in policy to generate action
if log_step:
print(f'>>> Step {itr}: generating actions ...')
with profiler.profile_section("action_generation"):
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
diffusion_autocast_dtype=diffusion_autocast_dtype)
# Update future actions in the observation queues
with profiler.profile_section("update_action_queues"):
for act_idx in range(len(pred_actions[0])):
obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
obs_update['action'][:, ori_action_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
obs_update)
# Collect data for interacting the world-model using the predicted actions
with profiler.profile_section("prepare_wm_observation"):
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']),
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = _move_to_device(observation, device)
# Interaction with the world-model
if log_step:
print(f'>>> Step {itr}: interacting with world model ...')
with profiler.profile_section("world_model_interaction"):
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
diffusion_autocast_dtype=diffusion_autocast_dtype)
with profiler.profile_section("update_state_queues"):
for step_idx in range(args.exe_steps):
obs_update = {
'observation.images.top':
pred_videos_1[0][:, step_idx:step_idx + 1].permute(1, 0, 2, 3),
'observation.state':
torch.zeros_like(pred_states[0][step_idx:step_idx + 1]) if
args.zero_pred_state else pred_states[0][step_idx:step_idx + 1],
'action':
torch.zeros_like(pred_actions[0][-1:])
}
obs_update['observation.state'][:, ori_state_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
obs_update)
# Save the imagen videos for decision-making
with profiler.profile_section("save_results"):
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
# Save the imagen videos for decision-making
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(),
sample_video_file,
fps=args.save_fps)
print('>' * 24)
# Collect the result of world-model interactions
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
profiler.record_memory(f"iter_{itr}_end")
profiler.step_profiler()
full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps)
# Save profiling results
profiler.save_results()
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config",
type=str,
help="Path to the model checkpoint.")
parser.add_argument(
"--prompt_dir",
type=str,
default=None,
help="Directory containing videos and corresponding prompts.")
parser.add_argument("--dataset",
type=str,
default=None,
help="the name of dataset to test")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--frame_stride",
type=int,
nargs='+',
required=True,
help=
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
)
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument("--num_generation",
type=int,
default=1,
help="seed for seed_everything")
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
parser.add_argument(
"--diffusion_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for diffusion backbone weights and sampling autocast."
)
parser.add_argument(
"--projector_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="fp32",
help=
"Projector precision mode for image/state/action projectors: "
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
"bf16_full=bf16 weights + bf16 forward."
)
parser.add_argument(
"--encoder_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="fp32",
help=
"Encoder precision mode for cond_stage_model/embedder: "
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
"bf16_full=bf16 weights + bf16 forward."
)
parser.add_argument(
"--vae_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="Dtype for VAE/first_stage_model weights and forward autocast."
)
parser.add_argument(
"--export_casted_ckpt",
type=str,
default=None,
help=
"Save a checkpoint after applying precision settings (mixed dtypes preserved)."
)
parser.add_argument(
"--export_only",
action='store_true',
default=False,
help="Exit after exporting the casted checkpoint."
)
parser.add_argument(
"--step_log_every",
type=int,
default=1,
help="Print per-iteration step logs every N iterations."
)
parser.add_argument(
"--n_action_steps",
type=int,
default=16,
help="num of samples per prompt",
)
parser.add_argument(
"--exe_steps",
type=int,
default=16,
help="num of samples to execute",
)
parser.add_argument(
"--n_iter",
type=int,
default=40,
help="num of iteration to interact with the world model",
)
parser.add_argument("--zero_pred_state",
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument("--save_fps",
type=int,
default=8,
help="fps for the saving video")
# Profiling arguments
parser.add_argument(
"--profile",
action='store_true',
default=False,
help="Enable performance profiling (macro and operator-level analysis)."
)
parser.add_argument(
"--profile_output_dir",
type=str,
default=None,
help="Directory to save profiling results. Defaults to {savedir}/profile_output."
)
parser.add_argument(
"--profile_iterations",
type=int,
default=3,
help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
)
parser.add_argument(
"--profile_detail",
type=str,
choices=["light", "full"],
default="light",
help="Profiling detail level. Use 'full' for shapes/stacks/memory/flops."
)
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
if seed < 0:
seed = random.randint(0, 2**31)
seed_everything(seed)
# Initialize profiler
profile_output_dir = args.profile_output_dir
if profile_output_dir is None:
profile_output_dir = os.path.join(args.savedir, "profile_output")
init_profiler(
enabled=args.profile,
output_dir=profile_output_dir,
profile_detail=args.profile_detail,
)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)