早停特征验证,早停不通

This commit is contained in:
qhy
2026-03-15 12:41:53 +08:00
parent db9cc5766d
commit 7e45eba18b
227 changed files with 24579 additions and 163 deletions

View File

@@ -0,0 +1,611 @@
import argparse
import json
import os
from pathlib import Path
import numpy as np
import pandas as pd
try:
import matplotlib.pyplot as plt
except ModuleNotFoundError:
plt = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Analyze DDIM convergence logs for the five research directions."
)
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="Directory containing stepwise_log.csv, sample_summary.csv, and round_summary.csv.",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory to write analysis outputs. Defaults to <input_dir>/analysis.",
)
parser.add_argument(
"--mid_step_start",
type=int,
default=15,
help="Start of the step window used for cross-step latent-delta analysis.",
)
parser.add_argument(
"--mid_step_end",
type=int,
default=35,
help="End of the step window used for cross-step latent-delta analysis.",
)
parser.add_argument(
"--full50_step_cap",
type=int,
default=50,
help="Upper bound used when summarizing step-based metrics.",
)
parser.add_argument(
"--stability_threshold",
type=float,
default=None,
help="Optional threshold for offline first_stable_step analysis. If unset, stable-step metrics are not computed.",
)
parser.add_argument(
"--stability_window",
type=int,
default=3,
help="Consecutive-step window for offline first_stable_step analysis.",
)
return parser.parse_args()
def safe_float(value) -> float:
if pd.isna(value):
return float("nan")
return float(value)
def describe_series(series: pd.Series) -> dict[str, float]:
series = pd.to_numeric(series, errors="coerce").dropna()
if series.empty:
return {
"count": 0.0,
"mean": float("nan"),
"std": float("nan"),
"min": float("nan"),
"p25": float("nan"),
"median": float("nan"),
"p75": float("nan"),
"p90": float("nan"),
"p95": float("nan"),
"max": float("nan"),
}
return {
"count": float(series.count()),
"mean": float(series.mean()),
"std": float(series.std(ddof=0)),
"min": float(series.min()),
"p25": float(series.quantile(0.25)),
"median": float(series.median()),
"p75": float(series.quantile(0.75)),
"p90": float(series.quantile(0.90)),
"p95": float(series.quantile(0.95)),
"max": float(series.max()),
}
def ensure_columns(frame: pd.DataFrame, required: list[str], frame_name: str) -> None:
missing = [column for column in required if column not in frame.columns]
if missing:
raise ValueError(f"{frame_name} is missing required columns: {missing}")
def first_consecutive_below(values: pd.Series, threshold: float,
window: int) -> float:
cleaned = pd.to_numeric(values, errors="coerce").tolist()
if window <= 0 or len(cleaned) < window:
return float("nan")
for start in range(len(cleaned) - window + 1):
chunk = cleaned[start:start + window]
if all(pd.notna(value) and value < threshold for value in chunk):
return float(start + 1)
return float("nan")
def load_tables(input_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
stepwise_path = input_dir / "stepwise_log.csv"
sample_summary_path = input_dir / "sample_summary.csv"
round_summary_path = input_dir / "round_summary.csv"
if not stepwise_path.exists():
raise FileNotFoundError(f"Missing file: {stepwise_path}")
if not sample_summary_path.exists():
raise FileNotFoundError(f"Missing file: {sample_summary_path}")
if not round_summary_path.exists():
raise FileNotFoundError(f"Missing file: {round_summary_path}")
stepwise_df = pd.read_csv(stepwise_path)
sample_summary_df = pd.read_csv(sample_summary_path)
round_summary_df = pd.read_csv(round_summary_path)
ensure_columns(
stepwise_df,
[
"sample_id",
"scene",
"pass_type",
"round_id",
"step",
"latent_delta",
"action_delta",
"state_delta",
"action_cosine_vs_full50",
"state_cosine_vs_full50",
"latent_l2_vs_full50",
],
"stepwise_log.csv",
)
ensure_columns(
sample_summary_df,
[
"sample_id",
"scene",
"pass_type",
"pass_total_time_s",
"action_first_stable_step",
"state_first_stable_step",
"latent_first_stable_step",
"action_vs_full50_90pct_step",
"action_vs_full50_95pct_step",
"oracle_budget_action",
"oracle_budget_state",
],
"sample_summary.csv",
)
ensure_columns(
round_summary_df,
[
"sample_id",
"scene",
"round_id",
"policy_pass_total_time_s",
"world_model_pass_total_time_s",
"round_total_time_s",
"latent_init_dist_to_prev_round",
"action_drift_vs_prev_round",
],
"round_summary.csv",
)
return stepwise_df, sample_summary_df, round_summary_df
def aggregate_stepwise(stepwise_df: pd.DataFrame) -> pd.DataFrame:
metric_columns = [
"step_time_s",
"latent_delta",
"action_delta",
"state_delta",
"action_cosine_vs_full50",
"state_cosine_vs_full50",
"latent_l2_vs_full50",
]
stepwise_df = stepwise_df.copy()
for column in metric_columns:
stepwise_df[column] = pd.to_numeric(stepwise_df[column], errors="coerce")
grouped = stepwise_df.groupby(["pass_type", "step"])[metric_columns]
return grouped.agg(["mean", "median", "std"]).reset_index()
def aggregate_scene_summary(sample_summary_df: pd.DataFrame) -> pd.DataFrame:
metrics = [
"pass_total_time_s",
"action_first_stable_step",
"state_first_stable_step",
"latent_first_stable_step",
"action_vs_full50_90pct_step",
"action_vs_full50_95pct_step",
"oracle_budget_action",
"oracle_budget_state",
"round_total_time_s",
"policy_pass_total_time_s",
"world_model_pass_total_time_s",
"latent_init_dist_to_prev_round",
"action_drift_vs_prev_round",
]
sample_summary_df = sample_summary_df.copy()
for column in metrics:
sample_summary_df[column] = pd.to_numeric(sample_summary_df[column],
errors="coerce")
grouped = sample_summary_df.groupby(["scene", "pass_type"])[metrics]
return grouped.agg(["mean", "median", "std"]).reset_index()
def compute_stability_summary(stepwise_df: pd.DataFrame, threshold: float,
window: int) -> pd.DataFrame:
rows = []
grouped = stepwise_df.sort_values("step").groupby(
["sample_id", "scene", "pass_type"])
for (sample_id, scene, pass_type), group in grouped:
rows.append({
"sample_id": sample_id,
"scene": scene,
"pass_type": pass_type,
"action_first_stable_step": first_consecutive_below(
group["action_delta"], threshold, window),
"state_first_stable_step": first_consecutive_below(
group["state_delta"], threshold, window),
"latent_first_stable_step": first_consecutive_below(
group["latent_delta"], threshold, window),
})
return pd.DataFrame(rows)
def compute_direction_summary(stepwise_df: pd.DataFrame,
sample_summary_df: pd.DataFrame,
round_summary_df: pd.DataFrame,
mid_step_start: int,
mid_step_end: int,
stability_threshold: float | None,
stability_window: int) -> dict:
policy_summary = sample_summary_df[
sample_summary_df["pass_type"] == "policy"].copy()
world_summary = sample_summary_df[
sample_summary_df["pass_type"] == "world_model"].copy()
latent_mid = stepwise_df[
(stepwise_df["step"] >= mid_step_start)
& (stepwise_df["step"] <= mid_step_end)].copy()
action_before_latent = policy_summary[[
"action_first_stable_step",
"latent_first_stable_step",
"action_vs_full50_95pct_step",
]].dropna()
if action_before_latent.empty:
action_first_beats_latent = float("nan")
action_95_beats_latent = float("nan")
else:
action_first_beats_latent = float(
(action_before_latent["action_first_stable_step"] <
action_before_latent["latent_first_stable_step"]).mean())
action_95_beats_latent = float(
(action_before_latent["action_vs_full50_95pct_step"] <
action_before_latent["latent_first_stable_step"]).mean())
direction_summary = {
"analysis_config": {
"stability_threshold": stability_threshold,
"stability_window": int(stability_window),
"mid_step_window": [int(mid_step_start), int(mid_step_end)],
},
"dataset_overview": {
"num_step_rows": int(len(stepwise_df)),
"num_sample_rows": int(len(sample_summary_df)),
"num_round_rows": int(len(round_summary_df)),
"num_unique_samples": int(sample_summary_df["sample_id"].nunique()),
"scenes": sorted(sample_summary_df["scene"].dropna().unique().tolist()),
"pass_types": sorted(sample_summary_df["pass_type"].dropna().unique().tolist()),
},
"direction_original_early_stop": {
"latent_first_stable_step_policy":
describe_series(policy_summary["latent_first_stable_step"]),
"latent_first_stable_step_world_model":
describe_series(world_summary["latent_first_stable_step"]),
"latent_l2_vs_full50":
describe_series(stepwise_df["latent_l2_vs_full50"]),
},
"direction_c_action_converges_first": {
"action_first_stable_step":
describe_series(policy_summary["action_first_stable_step"]),
"latent_first_stable_step":
describe_series(policy_summary["latent_first_stable_step"]),
"action_vs_full50_95pct_step":
describe_series(policy_summary["action_vs_full50_95pct_step"]),
"share_action_first_stable_before_latent":
action_first_beats_latent,
"share_action_95pct_before_latent":
action_95_beats_latent,
},
"direction_d_cross_step_similarity": {
"latent_delta_mid_steps":
describe_series(latent_mid["latent_delta"]),
"action_delta_mid_steps":
describe_series(latent_mid["action_delta"]),
"state_delta_mid_steps":
describe_series(latent_mid["state_delta"]),
},
"direction_a_budget_heterogeneity": {
"oracle_budget_action":
describe_series(policy_summary["oracle_budget_action"]),
"oracle_budget_state":
describe_series(world_summary["oracle_budget_state"]),
"oracle_budget_action_by_scene": {
scene: describe_series(group["oracle_budget_action"])
for scene, group in policy_summary.groupby("scene")
},
},
"direction_b_round_reuse": {
"latent_init_dist_to_prev_round":
describe_series(round_summary_df["latent_init_dist_to_prev_round"]),
"action_drift_vs_prev_round":
describe_series(round_summary_df["action_drift_vs_prev_round"]),
"round_total_time_s":
describe_series(round_summary_df["round_total_time_s"]),
"policy_pass_total_time_s":
describe_series(round_summary_df["policy_pass_total_time_s"]),
"world_model_pass_total_time_s":
describe_series(round_summary_df["world_model_pass_total_time_s"]),
},
}
return direction_summary
def save_json(path: Path, payload: dict) -> None:
with open(path, "w", encoding="utf-8") as file:
json.dump(payload, file, indent=2, ensure_ascii=False)
def write_markdown_report(path: Path, direction_summary: dict) -> None:
lines = [
"# DDIM Analysis Report",
"",
"## Analysis Config",
"",
]
config = direction_summary["analysis_config"]
lines.extend([
f"- Stability threshold: {config['stability_threshold']}",
f"- Stability window: {config['stability_window']}",
f"- Mid-step window: {config['mid_step_window'][0]}-{config['mid_step_window'][1]}",
"",
"## Dataset Overview",
"",
])
overview = direction_summary["dataset_overview"]
lines.extend([
f"- Unique samples: {overview['num_unique_samples']}",
f"- Step rows: {overview['num_step_rows']}",
f"- Sample rows: {overview['num_sample_rows']}",
f"- Round rows: {overview['num_round_rows']}",
f"- Scenes: {', '.join(overview['scenes'])}",
f"- Pass types: {', '.join(overview['pass_types'])}",
"",
"## Five Directions",
"",
])
def append_stats(title: str, payload: dict) -> None:
lines.append(f"### {title}")
lines.append("")
for key, value in payload.items():
if isinstance(value, dict):
if {"median", "mean", "p90"} <= set(value.keys()):
lines.append(
f"- `{key}`: mean={value['mean']:.4f}, median={value['median']:.4f}, p90={value['p90']:.4f}, p95={value['p95']:.4f}"
)
else:
lines.append(f"- `{key}`:")
for sub_key, sub_value in value.items():
if isinstance(sub_value, dict) and {"median", "mean"} <= set(
sub_value.keys()):
lines.append(
f" - `{sub_key}`: mean={sub_value['mean']:.4f}, median={sub_value['median']:.4f}, std={sub_value['std']:.4f}"
)
else:
lines.append(f" - `{sub_key}`: {sub_value}")
else:
if isinstance(value, float):
lines.append(f"- `{key}`: {value:.4f}")
else:
lines.append(f"- `{key}`: {value}")
lines.append("")
append_stats("Original Early Stop",
direction_summary["direction_original_early_stop"])
append_stats("Direction C: Action Converges First",
direction_summary["direction_c_action_converges_first"])
append_stats("Direction D: Cross-Step Similarity",
direction_summary["direction_d_cross_step_similarity"])
append_stats("Direction A: Budget Heterogeneity",
direction_summary["direction_a_budget_heterogeneity"])
append_stats("Direction B: Round Reuse",
direction_summary["direction_b_round_reuse"])
with open(path, "w", encoding="utf-8") as file:
file.write("\n".join(lines) + "\n")
def save_convergence_plot(stepwise_df: pd.DataFrame, output_path: Path) -> None:
if plt is None:
return
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
pass_types = sorted(stepwise_df["pass_type"].dropna().unique().tolist())
for pass_type in pass_types:
subset = stepwise_df[stepwise_df["pass_type"] == pass_type]
grouped = subset.groupby("step").mean(numeric_only=True).reset_index()
axes[0, 0].plot(grouped["step"],
grouped["latent_delta"],
label=pass_type)
axes[0, 1].plot(grouped["step"],
grouped["action_cosine_vs_full50"],
label=pass_type)
axes[1, 0].plot(grouped["step"],
grouped["state_cosine_vs_full50"],
label=pass_type)
axes[1, 1].plot(grouped["step"],
grouped["latent_l2_vs_full50"],
label=pass_type)
axes[0, 0].set_title("Mean Latent Delta")
axes[0, 1].set_title("Mean Action Cosine vs Full50")
axes[1, 0].set_title("Mean State Cosine vs Full50")
axes[1, 1].set_title("Mean Latent L2 vs Full50")
for axis in axes.flatten():
axis.set_xlabel("Step")
axis.grid(alpha=0.3)
axis.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def save_budget_distribution_plot(sample_summary_df: pd.DataFrame,
output_path: Path) -> None:
if plt is None:
return
policy_summary = sample_summary_df[
sample_summary_df["pass_type"] == "policy"].copy()
scenes = sorted(policy_summary["scene"].dropna().unique().tolist())
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for scene in scenes:
subset = policy_summary[policy_summary["scene"] == scene]
axes[0].hist(subset["oracle_budget_action"].dropna(),
bins=15,
alpha=0.5,
label=scene)
axes[0].set_title("Oracle Budget Action Distribution")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Count")
axes[0].legend()
axes[0].grid(alpha=0.3)
boxplot_data = [
policy_summary[policy_summary["scene"] == scene]["oracle_budget_action"].dropna()
for scene in scenes
]
if scenes and any(len(values) > 0 for values in boxplot_data):
axes[1].boxplot(boxplot_data, labels=scenes, showfliers=False)
axes[1].set_title("Oracle Budget Action by Scene")
axes[1].set_ylabel("Step")
axes[1].tick_params(axis="x", rotation=20)
axes[1].grid(alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def save_midstep_delta_plot(stepwise_df: pd.DataFrame, output_path: Path,
mid_step_start: int, mid_step_end: int) -> None:
if plt is None:
return
subset = stepwise_df[(stepwise_df["step"] >= mid_step_start)
& (stepwise_df["step"] <= mid_step_end)].copy()
steps = sorted(subset["step"].dropna().unique().tolist())
data = [subset[subset["step"] == step]["latent_delta"].dropna() for step in steps]
fig, ax = plt.subplots(figsize=(14, 5))
if steps and any(len(values) > 0 for values in data):
ax.boxplot(data, labels=steps, showfliers=False)
ax.set_title(
f"Latent Delta Distribution for Steps {mid_step_start}-{mid_step_end}")
ax.set_xlabel("Step")
ax.set_ylabel("latent_delta")
ax.grid(alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def save_round_reuse_plot(round_summary_df: pd.DataFrame, output_path: Path) -> None:
if plt is None:
return
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
subset = round_summary_df.dropna(
subset=["latent_init_dist_to_prev_round", "action_drift_vs_prev_round"])
if not subset.empty:
for scene, group in subset.groupby("scene"):
axes[0].scatter(group["latent_init_dist_to_prev_round"],
group["action_drift_vs_prev_round"],
alpha=0.6,
label=scene)
axes[0].legend()
axes[0].set_title("Round Reuse: Latent Init Dist vs Action Drift")
axes[0].set_xlabel("latent_init_dist_to_prev_round")
axes[0].set_ylabel("action_drift_vs_prev_round")
axes[0].grid(alpha=0.3)
scene_groups = []
scene_labels = []
for scene, group in round_summary_df.groupby("scene"):
values = group["round_total_time_s"].dropna()
if len(values) > 0:
scene_groups.append(values)
scene_labels.append(scene)
if scene_groups:
axes[1].boxplot(scene_groups, labels=scene_labels, showfliers=False)
axes[1].set_title("Round Total Time by Scene")
axes[1].set_ylabel("Seconds")
axes[1].tick_params(axis="x", rotation=20)
axes[1].grid(alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def main() -> None:
args = parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir) if args.output_dir else input_dir / "analysis"
output_dir.mkdir(parents=True, exist_ok=True)
stepwise_df, sample_summary_df, round_summary_df = load_tables(input_dir)
stability_summary_df = None
sample_summary_for_analysis = sample_summary_df.copy()
if args.stability_threshold is not None:
stability_summary_df = compute_stability_summary(stepwise_df,
args.stability_threshold,
args.stability_window)
sample_summary_for_analysis = sample_summary_for_analysis.drop(
columns=[
"action_first_stable_step",
"state_first_stable_step",
"latent_first_stable_step",
],
errors="ignore",
).merge(stability_summary_df,
on=["sample_id", "scene", "pass_type"],
how="left")
step_aggregate_df = aggregate_stepwise(stepwise_df)
scene_summary_df = aggregate_scene_summary(sample_summary_for_analysis)
direction_summary = compute_direction_summary(
stepwise_df=stepwise_df,
sample_summary_df=sample_summary_for_analysis,
round_summary_df=round_summary_df,
mid_step_start=args.mid_step_start,
mid_step_end=args.mid_step_end,
stability_threshold=args.stability_threshold,
stability_window=args.stability_window,
)
step_aggregate_df.to_csv(output_dir / "step_aggregate.csv", index=False)
scene_summary_df.to_csv(output_dir / "scene_summary.csv", index=False)
if stability_summary_df is not None:
stability_summary_df.to_csv(output_dir / "stability_summary.csv",
index=False)
save_json(output_dir / "direction_summary.json", direction_summary)
write_markdown_report(output_dir / "analysis_report.md", direction_summary)
if plt is not None:
save_convergence_plot(stepwise_df, output_dir / "convergence_curves.png")
save_budget_distribution_plot(sample_summary_df,
output_dir / "budget_distribution.png")
save_midstep_delta_plot(stepwise_df,
output_dir / "midstep_latent_delta.png",
args.mid_step_start, args.mid_step_end)
save_round_reuse_plot(round_summary_df, output_dir / "round_reuse.png")
else:
print("matplotlib not installed; skipped PNG plots.")
print(f"Analysis written to: {output_dir}")
if __name__ == "__main__":
main()

View File

@@ -1,32 +1,35 @@
import argparse, os, glob
import pandas as pd
import random
import torch
import torchvision
import h5py
import argparse, os, glob
import json
import pandas as pd
import random
import torch
import torchvision
import h5py
import numpy as np
import logging
import einops
import warnings
import imageio
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
from eval_utils import populate_queues, log_to_tensorboard
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import warnings
import imageio
import time
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
import torch.nn.functional as F
from eval_utils import populate_queues, log_to_tensorboard
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
def get_device_from_parameters(module: nn.Module) -> torch.device:
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Args:
@@ -35,7 +38,413 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
Returns:
torch.device: The device of the model's parameters.
"""
return next(iter(module.parameters())).device
return next(iter(module.parameters())).device
def get_scene_name(sample: pd.Series, fallback: str) -> str:
"""Resolve the scene label used in analysis logs."""
if 'data_dir' in sample and pd.notna(sample['data_dir']):
return str(sample['data_dir'])
return fallback
def build_sample_id(dataset: str, sample: pd.Series, frame_stride: int) -> str:
"""Build a stable sample id while keeping the required CSV schema flat."""
return f"{dataset}-vid{sample['videoid']}-fs{frame_stride}"
def flatten_batch_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Flatten all non-batch dimensions for batched metric computation."""
return tensor.detach().float().reshape(tensor.shape[0], -1)
def batch_relative_l2(current: torch.Tensor,
previous: torch.Tensor) -> list[float]:
"""Compute ||current-previous|| / ||previous|| for each item in the batch."""
current_flat = flatten_batch_tensor(current)
previous_flat = flatten_batch_tensor(previous)
numerator = torch.linalg.vector_norm(current_flat - previous_flat, dim=1)
denominator = torch.linalg.vector_norm(previous_flat, dim=1).clamp_min(1e-8)
return (numerator / denominator).cpu().tolist()
def batch_l2_distance(current: torch.Tensor,
reference: torch.Tensor) -> list[float]:
"""Compute L2 distance against a reference tensor for each batch item."""
current_flat = flatten_batch_tensor(current)
reference_flat = flatten_batch_tensor(reference)
return torch.linalg.vector_norm(current_flat - reference_flat,
dim=1).cpu().tolist()
def batch_cosine_similarity(current: torch.Tensor,
reference: torch.Tensor) -> list[float]:
"""Compute cosine similarity against a reference tensor for each batch item."""
current_flat = flatten_batch_tensor(current)
reference_flat = flatten_batch_tensor(reference)
return F.cosine_similarity(current_flat,
reference_flat,
dim=1,
eps=1e-8).cpu().tolist()
def first_consecutive_below(values: list[float], threshold: float,
window: int) -> float:
"""Return the first 1-based index where `window` consecutive values are below threshold."""
if window <= 0 or len(values) < window:
return np.nan
for start in range(len(values) - window + 1):
window_values = values[start:start + window]
if all(pd.notna(value) and value < threshold for value in window_values):
return float(start + 1)
return np.nan
def first_at_least(values: list[float], threshold: float) -> float:
"""Return the first 1-based index where the series reaches the threshold."""
for index, value in enumerate(values, start=1):
if pd.notna(value) and value >= threshold:
return float(index)
return np.nan
def first_at_most(values: list[float], threshold: float) -> float:
"""Return the first 1-based index where the series drops below the threshold."""
for index, value in enumerate(values, start=1):
if pd.notna(value) and value <= threshold:
return float(index)
return np.nan
def safe_mean(values: list[float]) -> float:
"""Average numeric values while ignoring NaNs."""
valid_values = [value for value in values if pd.notna(value)]
if not valid_values:
return np.nan
return float(np.mean(valid_values))
def make_sampling_noise_bundle(model: nn.Module,
noise_shape: list[int]) -> dict[str, torch.Tensor]:
"""Create aligned initial noise for latent, action, and state diffusion streams."""
batch_size = noise_shape[0]
horizon = noise_shape[2]
device = model.device
return {
'img': torch.randn(noise_shape, device=device),
'action': torch.randn((batch_size, horizon, model.agent_action_dim),
device=device),
'state': torch.randn((batch_size, horizon, model.agent_state_dim),
device=device),
}
def load_psnr_lookup(psnr_path: str | None) -> dict[str, float]:
"""Load optional PSNR values keyed by sample_id or videoid."""
if not psnr_path:
return {}
if not os.path.exists(psnr_path):
logging.warning("PSNR file not found: %s", psnr_path)
return {}
suffix = os.path.splitext(psnr_path)[1].lower()
lookup: dict[str, float] = {}
if suffix == '.csv':
df = pd.read_csv(psnr_path)
key_column = 'sample_id' if 'sample_id' in df.columns else 'videoid'
value_column = 'psnr_full50' if 'psnr_full50' in df.columns else 'psnr'
for _, row in df.iterrows():
if pd.notna(row[key_column]) and pd.notna(row[value_column]):
lookup[str(row[key_column])] = float(row[value_column])
return lookup
if suffix == '.json':
with open(psnr_path, 'r', encoding='utf-8') as file:
data = json.load(file)
if isinstance(data, dict):
if 'sample_id' in data and ('psnr_full50' in data or 'psnr' in data):
lookup[str(data['sample_id'])] = float(
data.get('psnr_full50', data['psnr']))
else:
for key, value in data.items():
if isinstance(value, (int, float)):
lookup[str(key)] = float(value)
elif isinstance(data, list):
for item in data:
if not isinstance(item, dict):
continue
if 'sample_id' in item and ('psnr_full50' in item or 'psnr' in
item):
lookup[str(item['sample_id'])] = float(
item.get('psnr_full50', item['psnr']))
return lookup
logging.warning("Unsupported PSNR file format: %s", psnr_path)
return {}
class InteractionAnalysisLogger:
"""Collect stepwise metrics and aggregated per-sample summaries."""
STEP_COLUMNS = [
'sample_id',
'scene',
'pass_type',
'round_id',
'step',
'step_time_s',
'latent_delta',
'action_delta',
'state_delta',
'action_cosine_vs_full50',
'state_cosine_vs_full50',
'latent_l2_vs_full50',
]
SUMMARY_COLUMNS = [
'sample_id',
'scene',
'pass_type',
'pass_total_time_s',
'action_first_stable_step',
'state_first_stable_step',
'latent_first_stable_step',
'action_vs_full50_90pct_step',
'action_vs_full50_95pct_step',
'oracle_budget_action',
'oracle_budget_state',
'oracle_budget_latent',
'latent_init_dist_to_prev_round',
'action_drift_vs_prev_round',
'round_total_time_s',
'policy_pass_total_time_s',
'world_model_pass_total_time_s',
'psnr_full50',
]
ROUND_COLUMNS = [
'sample_id',
'scene',
'round_id',
'policy_pass_total_time_s',
'world_model_pass_total_time_s',
'round_total_time_s',
'latent_init_dist_to_prev_round',
'action_drift_vs_prev_round',
'psnr_full50',
]
def __init__(self, output_dir: str, psnr_lookup: dict[str, float]):
self.output_dir = output_dir
self.psnr_lookup = psnr_lookup
self.step_rows: list[dict] = []
self.summary_buckets: dict[tuple[str, str, str], dict] = {}
self.round_rows: list[dict] = []
self.round_buckets: dict[tuple[str, str], dict] = {}
self.prev_policy_action: dict[str, torch.Tensor] = {}
self.prev_world_latent: dict[str, torch.Tensor] = {}
def resolve_psnr(self, sample_id: str, videoid: int) -> float:
"""Resolve a PSNR value by full sample id first, then by raw video id."""
candidates = [sample_id, sample_id.rsplit('-fs', 1)[0], str(videoid)]
for candidate in candidates:
if candidate in self.psnr_lookup:
return float(self.psnr_lookup[candidate])
return np.nan
def append_summary_row(self, row: dict) -> None:
"""Store per-round summaries and aggregate them later by sample and pass type."""
key = (row['sample_id'], row['scene'], row['pass_type'])
metric_columns = self.SUMMARY_COLUMNS[3:]
if key not in self.summary_buckets:
self.summary_buckets[key] = {
'sample_id': row['sample_id'],
'scene': row['scene'],
'pass_type': row['pass_type'],
**{column: [] for column in metric_columns},
}
for column in metric_columns:
self.summary_buckets[key][column].append(row.get(column, np.nan))
def append_round_row(self, row: dict) -> None:
"""Store per-round metrics and aggregate them later by sample."""
self.round_rows.append(row)
key = (row['sample_id'], row['scene'])
metric_columns = self.ROUND_COLUMNS[3:]
if key not in self.round_buckets:
self.round_buckets[key] = {
'sample_id': row['sample_id'],
'scene': row['scene'],
**{column: [] for column in metric_columns},
}
for column in metric_columns:
self.round_buckets[key][column].append(row.get(column, np.nan))
def collect_trace_series(self, debug_info: dict, reference_action: torch.Tensor,
reference_state: torch.Tensor,
reference_latent: torch.Tensor) -> tuple[list[float], list[float], list[float]]:
"""Extract cosine/L2 curves for either the target pass or the full-50 reference pass."""
action_cosines = []
state_cosines = []
latent_l2s = []
for record in debug_info['step_records']:
action_cosines.append(
batch_cosine_similarity(record['action'], reference_action)[0])
state_cosines.append(
batch_cosine_similarity(record['state'], reference_state)[0])
latent_l2s.append(
batch_l2_distance(record['pred_x0'], reference_latent)[0])
return action_cosines, state_cosines, latent_l2s
def log_pass(self, sample_id: str, videoid: int, scene: str, pass_type: str,
round_id: int, pass_total_time_s: float, target_debug: dict,
reference_debug: dict) -> dict | None:
"""Log one pass worth of stepwise and aggregated metrics."""
if not target_debug or not target_debug.get('step_records'):
return None
if not reference_debug or not reference_debug.get('step_records'):
reference_debug = target_debug
reference_final_action = reference_debug['step_records'][-1]['action']
reference_final_state = reference_debug['step_records'][-1]['state']
reference_final_latent = reference_debug['step_records'][-1]['pred_x0']
prev_img = target_debug['analysis_init']['img']
prev_action = target_debug['analysis_init']['action']
prev_state = target_debug['analysis_init']['state']
action_deltas: list[float] = []
state_deltas: list[float] = []
latent_deltas: list[float] = []
action_cosines: list[float] = []
state_cosines: list[float] = []
latent_l2s: list[float] = []
for record in target_debug['step_records']:
latent_delta = batch_relative_l2(record['img'], prev_img)[0]
action_delta = batch_relative_l2(record['action'], prev_action)[0]
state_delta = batch_relative_l2(record['state'], prev_state)[0]
action_cosine = batch_cosine_similarity(record['action'],
reference_final_action)[0]
state_cosine = batch_cosine_similarity(record['state'],
reference_final_state)[0]
latent_l2 = batch_l2_distance(record['pred_x0'],
reference_final_latent)[0]
action_deltas.append(action_delta)
state_deltas.append(state_delta)
latent_deltas.append(latent_delta)
action_cosines.append(action_cosine)
state_cosines.append(state_cosine)
latent_l2s.append(latent_l2)
self.step_rows.append({
'sample_id': sample_id,
'scene': scene,
'pass_type': pass_type,
'round_id': round_id,
'step': record['step_index'],
'step_time_s': float(record['step_time_s']),
'latent_delta': latent_delta,
'action_delta': action_delta,
'state_delta': state_delta,
'action_cosine_vs_full50': action_cosine,
'state_cosine_vs_full50': state_cosine,
'latent_l2_vs_full50': latent_l2,
})
prev_img = record['img']
prev_action = record['action']
prev_state = record['state']
oracle_action_cosines, oracle_state_cosines, oracle_latent_l2s = self.collect_trace_series(
reference_debug, reference_final_action, reference_final_state,
reference_final_latent)
latent_init_dist_to_prev_round = np.nan
action_drift_vs_prev_round = np.nan
if pass_type == 'policy':
previous_action = self.prev_policy_action.get(sample_id)
if previous_action is not None:
action_drift_vs_prev_round = 1.0 - batch_cosine_similarity(
reference_final_action, previous_action)[0]
self.prev_policy_action[sample_id] = reference_final_action.clone()
elif pass_type == 'world_model':
previous_latent = self.prev_world_latent.get(sample_id)
if previous_latent is not None:
latent_init_dist_to_prev_round = batch_l2_distance(
reference_final_latent, previous_latent)[0]
self.prev_world_latent[sample_id] = reference_final_latent.clone()
summary_row = {
'sample_id': sample_id,
'scene': scene,
'pass_type': pass_type,
'pass_total_time_s': float(pass_total_time_s),
'action_first_stable_step': np.nan,
'state_first_stable_step': np.nan,
'latent_first_stable_step': np.nan,
'action_vs_full50_90pct_step': first_at_least(action_cosines, 0.90),
'action_vs_full50_95pct_step': first_at_least(action_cosines, 0.95),
'oracle_budget_action': first_at_least(oracle_action_cosines, 0.95),
'oracle_budget_state': first_at_least(oracle_state_cosines, 0.95),
'oracle_budget_latent': np.nan,
'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round,
'action_drift_vs_prev_round': action_drift_vs_prev_round,
'round_total_time_s': np.nan,
'policy_pass_total_time_s': np.nan,
'world_model_pass_total_time_s': np.nan,
'psnr_full50': self.resolve_psnr(sample_id, videoid),
}
self.append_summary_row(summary_row)
return summary_row
def log_round(self, sample_id: str, videoid: int, scene: str, round_id: int,
policy_pass_total_time_s: float,
world_model_pass_total_time_s: float, round_total_time_s: float,
latent_init_dist_to_prev_round: float,
action_drift_vs_prev_round: float) -> None:
"""Log one interaction round consisting of one policy pass and one world-model pass."""
self.append_round_row({
'sample_id': sample_id,
'scene': scene,
'round_id': round_id,
'policy_pass_total_time_s': float(policy_pass_total_time_s),
'world_model_pass_total_time_s': float(world_model_pass_total_time_s),
'round_total_time_s': float(round_total_time_s),
'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round,
'action_drift_vs_prev_round': action_drift_vs_prev_round,
'psnr_full50': self.resolve_psnr(sample_id, videoid),
})
def flush(self) -> None:
"""Write analysis CSVs to disk."""
os.makedirs(self.output_dir, exist_ok=True)
stepwise_path = os.path.join(self.output_dir, 'stepwise_log.csv')
summary_path = os.path.join(self.output_dir, 'sample_summary.csv')
round_path = os.path.join(self.output_dir, 'round_summary.csv')
stepwise_df = pd.DataFrame(self.step_rows, columns=self.STEP_COLUMNS)
stepwise_df.to_csv(stepwise_path, index=False)
round_df = pd.DataFrame(self.round_rows, columns=self.ROUND_COLUMNS)
round_df.to_csv(round_path, index=False)
summary_rows = []
metric_columns = self.SUMMARY_COLUMNS[3:]
for bucket in self.summary_buckets.values():
round_bucket = self.round_buckets.get((bucket['sample_id'],
bucket['scene']))
row = {
'sample_id': bucket['sample_id'],
'scene': bucket['scene'],
'pass_type': bucket['pass_type'],
}
for column in metric_columns:
if round_bucket is not None and column in round_bucket:
row[column] = safe_mean(round_bucket[column])
else:
row[column] = safe_mean(bucket[column])
summary_rows.append(row)
summary_df = pd.DataFrame(summary_rows, columns=self.SUMMARY_COLUMNS)
summary_df.to_csv(summary_path, index=False)
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
@@ -312,22 +721,25 @@ def preprocess_observation(
return return_observations
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
def image_guided_synthesis_sim_mode(
model: torch.nn.Module,
prompts: list[str],
observation: dict,
noise_shape: tuple[int, int, int, int, int],
action_cond_step: int = 16,
n_samples: int = 1,
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
text_input: bool = True,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
init_noise_bundle: dict[str, torch.Tensor] | None = None,
decode_video: bool = True,
return_debug_info: bool = False,
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor, dict | None]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
@@ -350,18 +762,24 @@ def image_guided_synthesis_sim_mode(
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
init_noise_bundle (dict[str, torch.Tensor] | None): Optional aligned noise inputs for latent/action/state.
decode_video (bool): Whether to decode the final latent into pixel space.
return_debug_info (bool): Whether to return per-step traces for analysis logging.
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W].
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
debug_info (dict | None): Optional per-step trace used for convergence analysis.
"""
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
batch_variants = None
debug_info = None
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
@@ -406,11 +824,11 @@ def image_guided_synthesis_sim_mode(
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
@@ -418,20 +836,35 @@ def image_guided_synthesis_sim_mode(
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
x_T=None if init_noise_bundle is None else init_noise_bundle['img'],
action_T=None if init_noise_bundle is None else
init_noise_bundle['action'],
state_T=None if init_noise_bundle is None else
init_noise_bundle['state'],
record_step_outputs=return_debug_info,
**kwargs)
batch_variants = None
if decode_video:
batch_variants = model.decode_first_stage(samples)
if return_debug_info:
debug_info = {
'analysis_init': intermedia.get('analysis_init'),
'step_records': intermedia.get('step_records', []),
'final_latent': samples.detach().cpu(),
'final_action': actions.detach().cpu(),
'final_state': states.detach().cpu(),
}
return batch_variants, actions, states, debug_info
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
@@ -443,11 +876,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
Returns:
None
"""
# Create inference and tensorboard dirs
os.makedirs(args.savedir + '/inference', exist_ok=True)
log_dir = args.savedir + f"/tensorboard"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# Create inference and tensorboard dirs
inference_dir = args.savedir + '/inference'
os.makedirs(inference_dir, exist_ok=True)
log_dir = args.savedir + f"/tensorboard"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
analysis_logger = None
if args.analysis_log_metrics:
analysis_logger = InteractionAnalysisLogger(
output_dir=inference_dir,
psnr_lookup=load_psnr_lookup(args.analysis_psnr_path),
)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
@@ -474,10 +914,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
device = get_device_from_parameters(model)
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
if args.analysis_log_metrics:
assert args.ddim_steps > 0, "analysis_log_metrics requires positive --ddim_steps."
assert args.analysis_reference_steps > 0, (
"analysis_log_metrics requires positive --analysis_reference_steps.")
# Get latent noise shape
h, w = args.height // 8, args.width // 8
@@ -508,12 +952,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# If many, test various frequence control and world-model generation
for fs in args.frame_stride:
# For saving imagens in policy
sample_save_dir = f'{video_save_dir}/dm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# If many, test various frequence control and world-model generation
for fs in args.frame_stride:
sample_id = build_sample_id(args.dataset, sample, fs)
scene = get_scene_name(sample, args.dataset)
# For saving imagens in policy
sample_save_dir = f'{video_save_dir}/dm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
# For saving environmental changes in world-model
sample_save_dir = f'{video_save_dir}/wm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True)
@@ -552,11 +998,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Update observation queues
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Multi-round interaction with the world-model
for itr in tqdm(range(args.n_iter)):
# Get observation
observation = {
# Multi-round interaction with the world-model
for itr in tqdm(range(args.n_iter)):
round_start_time = time.time()
# Get observation
observation = {
'observation.images.top':
torch.stack(list(
cond_obs_queues['observation.images.top']),
@@ -566,33 +1013,72 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
}
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
policy_noise_bundle = make_sampling_noise_bundle(
model, noise_shape) if args.analysis_log_metrics else None
policy_reference_debug = None
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
_, _, _, policy_reference_debug = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.analysis_reference_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
decode_video=False,
return_debug_info=True)
policy_pass_start = time.time()
pred_videos_0, pred_actions, _, policy_debug = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False)
# Update future actions in the observation queues
for idx in range(len(pred_actions[0])):
observation = {'action': pred_actions[0][idx:idx + 1]}
observation['action'][:, ori_action_dim:] = 0.0
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
return_debug_info=args.analysis_log_metrics)
policy_pass_total_time_s = time.time() - policy_pass_start
policy_summary_row = None
if analysis_logger is not None:
if policy_reference_debug is None:
policy_reference_debug = policy_debug
policy_summary_row = analysis_logger.log_pass(
sample_id=sample_id,
videoid=int(sample['videoid']),
scene=scene,
pass_type='policy',
round_id=itr,
pass_total_time_s=policy_pass_total_time_s,
target_debug=policy_debug,
reference_debug=policy_reference_debug,
)
# Update future actions in the observation queues
for idx in range(len(pred_actions[0])):
observation = {'action': pred_actions[0][idx:idx + 1]}
observation['action'][:, ori_action_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues,
observation)
@@ -611,29 +1097,84 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
observation = {
key: observation[key].to(device, non_blocking=True)
for key in observation
}
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
}
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
world_noise_bundle = make_sampling_noise_bundle(
model, noise_shape) if args.analysis_log_metrics else None
world_reference_debug = None
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
_, _, _, world_reference_debug = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.analysis_reference_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
decode_video=False,
return_debug_info=True)
world_pass_start = time.time()
pred_videos_1, _, pred_states, world_debug = image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale)
for idx in range(args.exe_steps):
observation = {
'observation.images.top':
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
return_debug_info=args.analysis_log_metrics)
world_pass_total_time_s = time.time() - world_pass_start
world_summary_row = None
if analysis_logger is not None:
if world_reference_debug is None:
world_reference_debug = world_debug
world_summary_row = analysis_logger.log_pass(
sample_id=sample_id,
videoid=int(sample['videoid']),
scene=scene,
pass_type='world_model',
round_id=itr,
pass_total_time_s=world_pass_total_time_s,
target_debug=world_debug,
reference_debug=world_reference_debug,
)
analysis_logger.log_round(
sample_id=sample_id,
videoid=int(sample['videoid']),
scene=scene,
round_id=itr,
policy_pass_total_time_s=policy_pass_total_time_s,
world_model_pass_total_time_s=
world_pass_total_time_s,
round_total_time_s=time.time() - round_start_time,
latent_init_dist_to_prev_round=np.nan
if world_summary_row is None else
world_summary_row['latent_init_dist_to_prev_round'],
action_drift_vs_prev_round=np.nan
if policy_summary_row is None else
policy_summary_row['action_drift_vs_prev_round'],
)
for idx in range(args.exe_steps):
observation = {
'observation.images.top':
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state':
torch.zeros_like(pred_states[0][idx:idx + 1]) if
args.zero_pred_state else pred_states[0][idx:idx + 1],
@@ -678,8 +1219,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps)
if analysis_logger is not None:
analysis_logger.flush()
writer.close()
def get_parser():
@@ -794,11 +1339,25 @@ def get_parser():
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument("--save_fps",
type=int,
default=8,
help="fps for the saving video")
return parser
parser.add_argument("--save_fps",
type=int,
default=8,
help="fps for the saving video")
parser.add_argument("--analysis_log_metrics",
action='store_true',
default=False,
help="Enable DDIM convergence logging and export analysis CSVs.")
parser.add_argument(
"--analysis_reference_steps",
type=int,
default=50,
help="Reference DDIM steps used to build the full-step baseline for *_vs_full50 metrics."
)
parser.add_argument("--analysis_psnr_path",
type=str,
default=None,
help="Optional CSV/JSON file with psnr_full50 values keyed by sample_id or videoid.")
return parser
if __name__ == '__main__':

View File

@@ -0,0 +1,58 @@
#!/usr/bin/env bash
set -euo pipefail
repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "${repo_root}"
run_analysis="${RUN_ANALYSIS:-1}"
case_filter="${CASE_FILTER:-}"
case_scripts=(
"unitree_z1_stackbox/case1/run_world_model_interaction.sh"
"unitree_z1_stackbox/case2/run_world_model_interaction.sh"
"unitree_z1_stackbox/case3/run_world_model_interaction.sh"
"unitree_z1_stackbox/case4/run_world_model_interaction.sh"
"unitree_z1_dual_arm_stackbox/case1/run_world_model_interaction.sh"
"unitree_z1_dual_arm_stackbox/case2/run_world_model_interaction.sh"
"unitree_z1_dual_arm_stackbox/case3/run_world_model_interaction.sh"
"unitree_z1_dual_arm_stackbox/case4/run_world_model_interaction.sh"
"unitree_z1_dual_arm_stackbox_v2/case1/run_world_model_interaction.sh"
"unitree_z1_dual_arm_stackbox_v2/case2/run_world_model_interaction.sh"
"unitree_z1_dual_arm_stackbox_v2/case3/run_world_model_interaction.sh"
"unitree_z1_dual_arm_stackbox_v2/case4/run_world_model_interaction.sh"
"unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh"
"unitree_z1_dual_arm_cleanup_pencils/case2/run_world_model_interaction.sh"
"unitree_z1_dual_arm_cleanup_pencils/case3/run_world_model_interaction.sh"
"unitree_z1_dual_arm_cleanup_pencils/case4/run_world_model_interaction.sh"
"unitree_g1_pack_camera/case1/run_world_model_interaction.sh"
"unitree_g1_pack_camera/case2/run_world_model_interaction.sh"
"unitree_g1_pack_camera/case3/run_world_model_interaction.sh"
"unitree_g1_pack_camera/case4/run_world_model_interaction.sh"
)
for case_script in "${case_scripts[@]}"; do
if [[ -n "${case_filter}" ]] && [[ "${case_script}" != *"${case_filter}"* ]]; then
continue
fi
case_dir="$(dirname "${case_script}")"
inference_dir="${repo_root}/${case_dir}/output/inference"
echo "============================================================"
echo "Running ${case_script}"
echo "============================================================"
bash "${repo_root}/${case_script}"
if [[ "${run_analysis}" == "1" ]]; then
if [[ -f "${inference_dir}/stepwise_log.csv" ]] && \
[[ -f "${inference_dir}/sample_summary.csv" ]] && \
[[ -f "${inference_dir}/round_summary.csv" ]]; then
echo "Analyzing ${case_dir}"
python3 "${repo_root}/scripts/evaluation/analyze_metrics.py" \
--input_dir "${inference_dir}"
else
echo "Skipping analysis for ${case_dir}: missing analysis CSVs."
fi
fi
done

View File

@@ -3,6 +3,8 @@ ckpt=/path/to/model/checkpoint
config=configs/inference/world_model_interaction.yaml
seed=123
res_dir="/path/to/result/directory"
analysis_reference_steps=50
# analysis_psnr_path="/path/to/psnr_full50.csv"
datasets=(
"unitree_z1_stackbox"
@@ -19,6 +21,13 @@ for i in "${!datasets[@]}"; do
dataset=${datasets[$i]}
n_iter=${n_iters[$i]}
fs=${fses[$i]}
analysis_args=(
--analysis_log_metrics
--analysis_reference_steps ${analysis_reference_steps}
)
if [ -n "${analysis_psnr_path:-}" ]; then
analysis_args+=(--analysis_psnr_path "${analysis_psnr_path}")
fi
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed ${seed} \
@@ -38,5 +47,6 @@ for i in "${!datasets[@]}"; do
--n_iter ${n_iter} \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae
--perframe_ae \
"${analysis_args[@]}"
done