早停特征验证,早停不通
This commit is contained in:
611
scripts/evaluation/analyze_metrics.py
Normal file
611
scripts/evaluation/analyze_metrics.py
Normal file
@@ -0,0 +1,611 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ModuleNotFoundError:
|
||||
plt = None
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Analyze DDIM convergence logs for the five research directions."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing stepwise_log.csv, sample_summary.csv, and round_summary.csv.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to write analysis outputs. Defaults to <input_dir>/analysis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mid_step_start",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Start of the step window used for cross-step latent-delta analysis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mid_step_end",
|
||||
type=int,
|
||||
default=35,
|
||||
help="End of the step window used for cross-step latent-delta analysis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full50_step_cap",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Upper bound used when summarizing step-based metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stability_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Optional threshold for offline first_stable_step analysis. If unset, stable-step metrics are not computed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stability_window",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Consecutive-step window for offline first_stable_step analysis.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def safe_float(value) -> float:
|
||||
if pd.isna(value):
|
||||
return float("nan")
|
||||
return float(value)
|
||||
|
||||
|
||||
def describe_series(series: pd.Series) -> dict[str, float]:
|
||||
series = pd.to_numeric(series, errors="coerce").dropna()
|
||||
if series.empty:
|
||||
return {
|
||||
"count": 0.0,
|
||||
"mean": float("nan"),
|
||||
"std": float("nan"),
|
||||
"min": float("nan"),
|
||||
"p25": float("nan"),
|
||||
"median": float("nan"),
|
||||
"p75": float("nan"),
|
||||
"p90": float("nan"),
|
||||
"p95": float("nan"),
|
||||
"max": float("nan"),
|
||||
}
|
||||
return {
|
||||
"count": float(series.count()),
|
||||
"mean": float(series.mean()),
|
||||
"std": float(series.std(ddof=0)),
|
||||
"min": float(series.min()),
|
||||
"p25": float(series.quantile(0.25)),
|
||||
"median": float(series.median()),
|
||||
"p75": float(series.quantile(0.75)),
|
||||
"p90": float(series.quantile(0.90)),
|
||||
"p95": float(series.quantile(0.95)),
|
||||
"max": float(series.max()),
|
||||
}
|
||||
|
||||
|
||||
def ensure_columns(frame: pd.DataFrame, required: list[str], frame_name: str) -> None:
|
||||
missing = [column for column in required if column not in frame.columns]
|
||||
if missing:
|
||||
raise ValueError(f"{frame_name} is missing required columns: {missing}")
|
||||
|
||||
|
||||
def first_consecutive_below(values: pd.Series, threshold: float,
|
||||
window: int) -> float:
|
||||
cleaned = pd.to_numeric(values, errors="coerce").tolist()
|
||||
if window <= 0 or len(cleaned) < window:
|
||||
return float("nan")
|
||||
for start in range(len(cleaned) - window + 1):
|
||||
chunk = cleaned[start:start + window]
|
||||
if all(pd.notna(value) and value < threshold for value in chunk):
|
||||
return float(start + 1)
|
||||
return float("nan")
|
||||
|
||||
|
||||
def load_tables(input_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
stepwise_path = input_dir / "stepwise_log.csv"
|
||||
sample_summary_path = input_dir / "sample_summary.csv"
|
||||
round_summary_path = input_dir / "round_summary.csv"
|
||||
|
||||
if not stepwise_path.exists():
|
||||
raise FileNotFoundError(f"Missing file: {stepwise_path}")
|
||||
if not sample_summary_path.exists():
|
||||
raise FileNotFoundError(f"Missing file: {sample_summary_path}")
|
||||
if not round_summary_path.exists():
|
||||
raise FileNotFoundError(f"Missing file: {round_summary_path}")
|
||||
|
||||
stepwise_df = pd.read_csv(stepwise_path)
|
||||
sample_summary_df = pd.read_csv(sample_summary_path)
|
||||
round_summary_df = pd.read_csv(round_summary_path)
|
||||
|
||||
ensure_columns(
|
||||
stepwise_df,
|
||||
[
|
||||
"sample_id",
|
||||
"scene",
|
||||
"pass_type",
|
||||
"round_id",
|
||||
"step",
|
||||
"latent_delta",
|
||||
"action_delta",
|
||||
"state_delta",
|
||||
"action_cosine_vs_full50",
|
||||
"state_cosine_vs_full50",
|
||||
"latent_l2_vs_full50",
|
||||
],
|
||||
"stepwise_log.csv",
|
||||
)
|
||||
ensure_columns(
|
||||
sample_summary_df,
|
||||
[
|
||||
"sample_id",
|
||||
"scene",
|
||||
"pass_type",
|
||||
"pass_total_time_s",
|
||||
"action_first_stable_step",
|
||||
"state_first_stable_step",
|
||||
"latent_first_stable_step",
|
||||
"action_vs_full50_90pct_step",
|
||||
"action_vs_full50_95pct_step",
|
||||
"oracle_budget_action",
|
||||
"oracle_budget_state",
|
||||
],
|
||||
"sample_summary.csv",
|
||||
)
|
||||
ensure_columns(
|
||||
round_summary_df,
|
||||
[
|
||||
"sample_id",
|
||||
"scene",
|
||||
"round_id",
|
||||
"policy_pass_total_time_s",
|
||||
"world_model_pass_total_time_s",
|
||||
"round_total_time_s",
|
||||
"latent_init_dist_to_prev_round",
|
||||
"action_drift_vs_prev_round",
|
||||
],
|
||||
"round_summary.csv",
|
||||
)
|
||||
return stepwise_df, sample_summary_df, round_summary_df
|
||||
|
||||
|
||||
def aggregate_stepwise(stepwise_df: pd.DataFrame) -> pd.DataFrame:
|
||||
metric_columns = [
|
||||
"step_time_s",
|
||||
"latent_delta",
|
||||
"action_delta",
|
||||
"state_delta",
|
||||
"action_cosine_vs_full50",
|
||||
"state_cosine_vs_full50",
|
||||
"latent_l2_vs_full50",
|
||||
]
|
||||
stepwise_df = stepwise_df.copy()
|
||||
for column in metric_columns:
|
||||
stepwise_df[column] = pd.to_numeric(stepwise_df[column], errors="coerce")
|
||||
grouped = stepwise_df.groupby(["pass_type", "step"])[metric_columns]
|
||||
return grouped.agg(["mean", "median", "std"]).reset_index()
|
||||
|
||||
|
||||
def aggregate_scene_summary(sample_summary_df: pd.DataFrame) -> pd.DataFrame:
|
||||
metrics = [
|
||||
"pass_total_time_s",
|
||||
"action_first_stable_step",
|
||||
"state_first_stable_step",
|
||||
"latent_first_stable_step",
|
||||
"action_vs_full50_90pct_step",
|
||||
"action_vs_full50_95pct_step",
|
||||
"oracle_budget_action",
|
||||
"oracle_budget_state",
|
||||
"round_total_time_s",
|
||||
"policy_pass_total_time_s",
|
||||
"world_model_pass_total_time_s",
|
||||
"latent_init_dist_to_prev_round",
|
||||
"action_drift_vs_prev_round",
|
||||
]
|
||||
sample_summary_df = sample_summary_df.copy()
|
||||
for column in metrics:
|
||||
sample_summary_df[column] = pd.to_numeric(sample_summary_df[column],
|
||||
errors="coerce")
|
||||
grouped = sample_summary_df.groupby(["scene", "pass_type"])[metrics]
|
||||
return grouped.agg(["mean", "median", "std"]).reset_index()
|
||||
|
||||
|
||||
def compute_stability_summary(stepwise_df: pd.DataFrame, threshold: float,
|
||||
window: int) -> pd.DataFrame:
|
||||
rows = []
|
||||
grouped = stepwise_df.sort_values("step").groupby(
|
||||
["sample_id", "scene", "pass_type"])
|
||||
for (sample_id, scene, pass_type), group in grouped:
|
||||
rows.append({
|
||||
"sample_id": sample_id,
|
||||
"scene": scene,
|
||||
"pass_type": pass_type,
|
||||
"action_first_stable_step": first_consecutive_below(
|
||||
group["action_delta"], threshold, window),
|
||||
"state_first_stable_step": first_consecutive_below(
|
||||
group["state_delta"], threshold, window),
|
||||
"latent_first_stable_step": first_consecutive_below(
|
||||
group["latent_delta"], threshold, window),
|
||||
})
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def compute_direction_summary(stepwise_df: pd.DataFrame,
|
||||
sample_summary_df: pd.DataFrame,
|
||||
round_summary_df: pd.DataFrame,
|
||||
mid_step_start: int,
|
||||
mid_step_end: int,
|
||||
stability_threshold: float | None,
|
||||
stability_window: int) -> dict:
|
||||
policy_summary = sample_summary_df[
|
||||
sample_summary_df["pass_type"] == "policy"].copy()
|
||||
world_summary = sample_summary_df[
|
||||
sample_summary_df["pass_type"] == "world_model"].copy()
|
||||
|
||||
latent_mid = stepwise_df[
|
||||
(stepwise_df["step"] >= mid_step_start)
|
||||
& (stepwise_df["step"] <= mid_step_end)].copy()
|
||||
|
||||
action_before_latent = policy_summary[[
|
||||
"action_first_stable_step",
|
||||
"latent_first_stable_step",
|
||||
"action_vs_full50_95pct_step",
|
||||
]].dropna()
|
||||
if action_before_latent.empty:
|
||||
action_first_beats_latent = float("nan")
|
||||
action_95_beats_latent = float("nan")
|
||||
else:
|
||||
action_first_beats_latent = float(
|
||||
(action_before_latent["action_first_stable_step"] <
|
||||
action_before_latent["latent_first_stable_step"]).mean())
|
||||
action_95_beats_latent = float(
|
||||
(action_before_latent["action_vs_full50_95pct_step"] <
|
||||
action_before_latent["latent_first_stable_step"]).mean())
|
||||
|
||||
direction_summary = {
|
||||
"analysis_config": {
|
||||
"stability_threshold": stability_threshold,
|
||||
"stability_window": int(stability_window),
|
||||
"mid_step_window": [int(mid_step_start), int(mid_step_end)],
|
||||
},
|
||||
"dataset_overview": {
|
||||
"num_step_rows": int(len(stepwise_df)),
|
||||
"num_sample_rows": int(len(sample_summary_df)),
|
||||
"num_round_rows": int(len(round_summary_df)),
|
||||
"num_unique_samples": int(sample_summary_df["sample_id"].nunique()),
|
||||
"scenes": sorted(sample_summary_df["scene"].dropna().unique().tolist()),
|
||||
"pass_types": sorted(sample_summary_df["pass_type"].dropna().unique().tolist()),
|
||||
},
|
||||
"direction_original_early_stop": {
|
||||
"latent_first_stable_step_policy":
|
||||
describe_series(policy_summary["latent_first_stable_step"]),
|
||||
"latent_first_stable_step_world_model":
|
||||
describe_series(world_summary["latent_first_stable_step"]),
|
||||
"latent_l2_vs_full50":
|
||||
describe_series(stepwise_df["latent_l2_vs_full50"]),
|
||||
},
|
||||
"direction_c_action_converges_first": {
|
||||
"action_first_stable_step":
|
||||
describe_series(policy_summary["action_first_stable_step"]),
|
||||
"latent_first_stable_step":
|
||||
describe_series(policy_summary["latent_first_stable_step"]),
|
||||
"action_vs_full50_95pct_step":
|
||||
describe_series(policy_summary["action_vs_full50_95pct_step"]),
|
||||
"share_action_first_stable_before_latent":
|
||||
action_first_beats_latent,
|
||||
"share_action_95pct_before_latent":
|
||||
action_95_beats_latent,
|
||||
},
|
||||
"direction_d_cross_step_similarity": {
|
||||
"latent_delta_mid_steps":
|
||||
describe_series(latent_mid["latent_delta"]),
|
||||
"action_delta_mid_steps":
|
||||
describe_series(latent_mid["action_delta"]),
|
||||
"state_delta_mid_steps":
|
||||
describe_series(latent_mid["state_delta"]),
|
||||
},
|
||||
"direction_a_budget_heterogeneity": {
|
||||
"oracle_budget_action":
|
||||
describe_series(policy_summary["oracle_budget_action"]),
|
||||
"oracle_budget_state":
|
||||
describe_series(world_summary["oracle_budget_state"]),
|
||||
"oracle_budget_action_by_scene": {
|
||||
scene: describe_series(group["oracle_budget_action"])
|
||||
for scene, group in policy_summary.groupby("scene")
|
||||
},
|
||||
},
|
||||
"direction_b_round_reuse": {
|
||||
"latent_init_dist_to_prev_round":
|
||||
describe_series(round_summary_df["latent_init_dist_to_prev_round"]),
|
||||
"action_drift_vs_prev_round":
|
||||
describe_series(round_summary_df["action_drift_vs_prev_round"]),
|
||||
"round_total_time_s":
|
||||
describe_series(round_summary_df["round_total_time_s"]),
|
||||
"policy_pass_total_time_s":
|
||||
describe_series(round_summary_df["policy_pass_total_time_s"]),
|
||||
"world_model_pass_total_time_s":
|
||||
describe_series(round_summary_df["world_model_pass_total_time_s"]),
|
||||
},
|
||||
}
|
||||
return direction_summary
|
||||
|
||||
|
||||
def save_json(path: Path, payload: dict) -> None:
|
||||
with open(path, "w", encoding="utf-8") as file:
|
||||
json.dump(payload, file, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_markdown_report(path: Path, direction_summary: dict) -> None:
|
||||
lines = [
|
||||
"# DDIM Analysis Report",
|
||||
"",
|
||||
"## Analysis Config",
|
||||
"",
|
||||
]
|
||||
config = direction_summary["analysis_config"]
|
||||
lines.extend([
|
||||
f"- Stability threshold: {config['stability_threshold']}",
|
||||
f"- Stability window: {config['stability_window']}",
|
||||
f"- Mid-step window: {config['mid_step_window'][0]}-{config['mid_step_window'][1]}",
|
||||
"",
|
||||
"## Dataset Overview",
|
||||
"",
|
||||
])
|
||||
overview = direction_summary["dataset_overview"]
|
||||
lines.extend([
|
||||
f"- Unique samples: {overview['num_unique_samples']}",
|
||||
f"- Step rows: {overview['num_step_rows']}",
|
||||
f"- Sample rows: {overview['num_sample_rows']}",
|
||||
f"- Round rows: {overview['num_round_rows']}",
|
||||
f"- Scenes: {', '.join(overview['scenes'])}",
|
||||
f"- Pass types: {', '.join(overview['pass_types'])}",
|
||||
"",
|
||||
"## Five Directions",
|
||||
"",
|
||||
])
|
||||
|
||||
def append_stats(title: str, payload: dict) -> None:
|
||||
lines.append(f"### {title}")
|
||||
lines.append("")
|
||||
for key, value in payload.items():
|
||||
if isinstance(value, dict):
|
||||
if {"median", "mean", "p90"} <= set(value.keys()):
|
||||
lines.append(
|
||||
f"- `{key}`: mean={value['mean']:.4f}, median={value['median']:.4f}, p90={value['p90']:.4f}, p95={value['p95']:.4f}"
|
||||
)
|
||||
else:
|
||||
lines.append(f"- `{key}`:")
|
||||
for sub_key, sub_value in value.items():
|
||||
if isinstance(sub_value, dict) and {"median", "mean"} <= set(
|
||||
sub_value.keys()):
|
||||
lines.append(
|
||||
f" - `{sub_key}`: mean={sub_value['mean']:.4f}, median={sub_value['median']:.4f}, std={sub_value['std']:.4f}"
|
||||
)
|
||||
else:
|
||||
lines.append(f" - `{sub_key}`: {sub_value}")
|
||||
else:
|
||||
if isinstance(value, float):
|
||||
lines.append(f"- `{key}`: {value:.4f}")
|
||||
else:
|
||||
lines.append(f"- `{key}`: {value}")
|
||||
lines.append("")
|
||||
|
||||
append_stats("Original Early Stop",
|
||||
direction_summary["direction_original_early_stop"])
|
||||
append_stats("Direction C: Action Converges First",
|
||||
direction_summary["direction_c_action_converges_first"])
|
||||
append_stats("Direction D: Cross-Step Similarity",
|
||||
direction_summary["direction_d_cross_step_similarity"])
|
||||
append_stats("Direction A: Budget Heterogeneity",
|
||||
direction_summary["direction_a_budget_heterogeneity"])
|
||||
append_stats("Direction B: Round Reuse",
|
||||
direction_summary["direction_b_round_reuse"])
|
||||
|
||||
with open(path, "w", encoding="utf-8") as file:
|
||||
file.write("\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def save_convergence_plot(stepwise_df: pd.DataFrame, output_path: Path) -> None:
|
||||
if plt is None:
|
||||
return
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||||
pass_types = sorted(stepwise_df["pass_type"].dropna().unique().tolist())
|
||||
for pass_type in pass_types:
|
||||
subset = stepwise_df[stepwise_df["pass_type"] == pass_type]
|
||||
grouped = subset.groupby("step").mean(numeric_only=True).reset_index()
|
||||
axes[0, 0].plot(grouped["step"],
|
||||
grouped["latent_delta"],
|
||||
label=pass_type)
|
||||
axes[0, 1].plot(grouped["step"],
|
||||
grouped["action_cosine_vs_full50"],
|
||||
label=pass_type)
|
||||
axes[1, 0].plot(grouped["step"],
|
||||
grouped["state_cosine_vs_full50"],
|
||||
label=pass_type)
|
||||
axes[1, 1].plot(grouped["step"],
|
||||
grouped["latent_l2_vs_full50"],
|
||||
label=pass_type)
|
||||
|
||||
axes[0, 0].set_title("Mean Latent Delta")
|
||||
axes[0, 1].set_title("Mean Action Cosine vs Full50")
|
||||
axes[1, 0].set_title("Mean State Cosine vs Full50")
|
||||
axes[1, 1].set_title("Mean Latent L2 vs Full50")
|
||||
for axis in axes.flatten():
|
||||
axis.set_xlabel("Step")
|
||||
axis.grid(alpha=0.3)
|
||||
axis.legend()
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_budget_distribution_plot(sample_summary_df: pd.DataFrame,
|
||||
output_path: Path) -> None:
|
||||
if plt is None:
|
||||
return
|
||||
policy_summary = sample_summary_df[
|
||||
sample_summary_df["pass_type"] == "policy"].copy()
|
||||
scenes = sorted(policy_summary["scene"].dropna().unique().tolist())
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
for scene in scenes:
|
||||
subset = policy_summary[policy_summary["scene"] == scene]
|
||||
axes[0].hist(subset["oracle_budget_action"].dropna(),
|
||||
bins=15,
|
||||
alpha=0.5,
|
||||
label=scene)
|
||||
axes[0].set_title("Oracle Budget Action Distribution")
|
||||
axes[0].set_xlabel("Step")
|
||||
axes[0].set_ylabel("Count")
|
||||
axes[0].legend()
|
||||
axes[0].grid(alpha=0.3)
|
||||
|
||||
boxplot_data = [
|
||||
policy_summary[policy_summary["scene"] == scene]["oracle_budget_action"].dropna()
|
||||
for scene in scenes
|
||||
]
|
||||
if scenes and any(len(values) > 0 for values in boxplot_data):
|
||||
axes[1].boxplot(boxplot_data, labels=scenes, showfliers=False)
|
||||
axes[1].set_title("Oracle Budget Action by Scene")
|
||||
axes[1].set_ylabel("Step")
|
||||
axes[1].tick_params(axis="x", rotation=20)
|
||||
axes[1].grid(alpha=0.3)
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_midstep_delta_plot(stepwise_df: pd.DataFrame, output_path: Path,
|
||||
mid_step_start: int, mid_step_end: int) -> None:
|
||||
if plt is None:
|
||||
return
|
||||
subset = stepwise_df[(stepwise_df["step"] >= mid_step_start)
|
||||
& (stepwise_df["step"] <= mid_step_end)].copy()
|
||||
steps = sorted(subset["step"].dropna().unique().tolist())
|
||||
data = [subset[subset["step"] == step]["latent_delta"].dropna() for step in steps]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(14, 5))
|
||||
if steps and any(len(values) > 0 for values in data):
|
||||
ax.boxplot(data, labels=steps, showfliers=False)
|
||||
ax.set_title(
|
||||
f"Latent Delta Distribution for Steps {mid_step_start}-{mid_step_end}")
|
||||
ax.set_xlabel("Step")
|
||||
ax.set_ylabel("latent_delta")
|
||||
ax.grid(alpha=0.3)
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_round_reuse_plot(round_summary_df: pd.DataFrame, output_path: Path) -> None:
|
||||
if plt is None:
|
||||
return
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
subset = round_summary_df.dropna(
|
||||
subset=["latent_init_dist_to_prev_round", "action_drift_vs_prev_round"])
|
||||
if not subset.empty:
|
||||
for scene, group in subset.groupby("scene"):
|
||||
axes[0].scatter(group["latent_init_dist_to_prev_round"],
|
||||
group["action_drift_vs_prev_round"],
|
||||
alpha=0.6,
|
||||
label=scene)
|
||||
axes[0].legend()
|
||||
axes[0].set_title("Round Reuse: Latent Init Dist vs Action Drift")
|
||||
axes[0].set_xlabel("latent_init_dist_to_prev_round")
|
||||
axes[0].set_ylabel("action_drift_vs_prev_round")
|
||||
axes[0].grid(alpha=0.3)
|
||||
|
||||
scene_groups = []
|
||||
scene_labels = []
|
||||
for scene, group in round_summary_df.groupby("scene"):
|
||||
values = group["round_total_time_s"].dropna()
|
||||
if len(values) > 0:
|
||||
scene_groups.append(values)
|
||||
scene_labels.append(scene)
|
||||
if scene_groups:
|
||||
axes[1].boxplot(scene_groups, labels=scene_labels, showfliers=False)
|
||||
axes[1].set_title("Round Total Time by Scene")
|
||||
axes[1].set_ylabel("Seconds")
|
||||
axes[1].tick_params(axis="x", rotation=20)
|
||||
axes[1].grid(alpha=0.3)
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
input_dir = Path(args.input_dir)
|
||||
output_dir = Path(args.output_dir) if args.output_dir else input_dir / "analysis"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stepwise_df, sample_summary_df, round_summary_df = load_tables(input_dir)
|
||||
|
||||
stability_summary_df = None
|
||||
sample_summary_for_analysis = sample_summary_df.copy()
|
||||
if args.stability_threshold is not None:
|
||||
stability_summary_df = compute_stability_summary(stepwise_df,
|
||||
args.stability_threshold,
|
||||
args.stability_window)
|
||||
sample_summary_for_analysis = sample_summary_for_analysis.drop(
|
||||
columns=[
|
||||
"action_first_stable_step",
|
||||
"state_first_stable_step",
|
||||
"latent_first_stable_step",
|
||||
],
|
||||
errors="ignore",
|
||||
).merge(stability_summary_df,
|
||||
on=["sample_id", "scene", "pass_type"],
|
||||
how="left")
|
||||
|
||||
step_aggregate_df = aggregate_stepwise(stepwise_df)
|
||||
scene_summary_df = aggregate_scene_summary(sample_summary_for_analysis)
|
||||
direction_summary = compute_direction_summary(
|
||||
stepwise_df=stepwise_df,
|
||||
sample_summary_df=sample_summary_for_analysis,
|
||||
round_summary_df=round_summary_df,
|
||||
mid_step_start=args.mid_step_start,
|
||||
mid_step_end=args.mid_step_end,
|
||||
stability_threshold=args.stability_threshold,
|
||||
stability_window=args.stability_window,
|
||||
)
|
||||
|
||||
step_aggregate_df.to_csv(output_dir / "step_aggregate.csv", index=False)
|
||||
scene_summary_df.to_csv(output_dir / "scene_summary.csv", index=False)
|
||||
if stability_summary_df is not None:
|
||||
stability_summary_df.to_csv(output_dir / "stability_summary.csv",
|
||||
index=False)
|
||||
save_json(output_dir / "direction_summary.json", direction_summary)
|
||||
write_markdown_report(output_dir / "analysis_report.md", direction_summary)
|
||||
|
||||
if plt is not None:
|
||||
save_convergence_plot(stepwise_df, output_dir / "convergence_curves.png")
|
||||
save_budget_distribution_plot(sample_summary_df,
|
||||
output_dir / "budget_distribution.png")
|
||||
save_midstep_delta_plot(stepwise_df,
|
||||
output_dir / "midstep_latent_delta.png",
|
||||
args.mid_step_start, args.mid_step_end)
|
||||
save_round_reuse_plot(round_summary_df, output_dir / "round_reuse.png")
|
||||
else:
|
||||
print("matplotlib not installed; skipped PNG plots.")
|
||||
|
||||
print(f"Analysis written to: {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,32 +1,35 @@
|
||||
import argparse, os, glob
|
||||
import pandas as pd
|
||||
import random
|
||||
import torch
|
||||
import torchvision
|
||||
import h5py
|
||||
import argparse, os, glob
|
||||
import json
|
||||
import pandas as pd
|
||||
import random
|
||||
import torch
|
||||
import torchvision
|
||||
import h5py
|
||||
import numpy as np
|
||||
import logging
|
||||
import einops
|
||||
import warnings
|
||||
import imageio
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
from torch import nn
|
||||
from eval_utils import populate_queues, log_to_tensorboard
|
||||
from collections import deque
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from PIL import Image
|
||||
import warnings
|
||||
import imageio
|
||||
import time
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from eval_utils import populate_queues, log_to_tensorboard
|
||||
from collections import deque
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from PIL import Image
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
"""Get a module's device by checking one of its parameters.
|
||||
|
||||
Args:
|
||||
@@ -35,7 +38,413 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
Returns:
|
||||
torch.device: The device of the model's parameters.
|
||||
"""
|
||||
return next(iter(module.parameters())).device
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
|
||||
def get_scene_name(sample: pd.Series, fallback: str) -> str:
|
||||
"""Resolve the scene label used in analysis logs."""
|
||||
if 'data_dir' in sample and pd.notna(sample['data_dir']):
|
||||
return str(sample['data_dir'])
|
||||
return fallback
|
||||
|
||||
|
||||
def build_sample_id(dataset: str, sample: pd.Series, frame_stride: int) -> str:
|
||||
"""Build a stable sample id while keeping the required CSV schema flat."""
|
||||
return f"{dataset}-vid{sample['videoid']}-fs{frame_stride}"
|
||||
|
||||
|
||||
def flatten_batch_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Flatten all non-batch dimensions for batched metric computation."""
|
||||
return tensor.detach().float().reshape(tensor.shape[0], -1)
|
||||
|
||||
|
||||
def batch_relative_l2(current: torch.Tensor,
|
||||
previous: torch.Tensor) -> list[float]:
|
||||
"""Compute ||current-previous|| / ||previous|| for each item in the batch."""
|
||||
current_flat = flatten_batch_tensor(current)
|
||||
previous_flat = flatten_batch_tensor(previous)
|
||||
numerator = torch.linalg.vector_norm(current_flat - previous_flat, dim=1)
|
||||
denominator = torch.linalg.vector_norm(previous_flat, dim=1).clamp_min(1e-8)
|
||||
return (numerator / denominator).cpu().tolist()
|
||||
|
||||
|
||||
def batch_l2_distance(current: torch.Tensor,
|
||||
reference: torch.Tensor) -> list[float]:
|
||||
"""Compute L2 distance against a reference tensor for each batch item."""
|
||||
current_flat = flatten_batch_tensor(current)
|
||||
reference_flat = flatten_batch_tensor(reference)
|
||||
return torch.linalg.vector_norm(current_flat - reference_flat,
|
||||
dim=1).cpu().tolist()
|
||||
|
||||
|
||||
def batch_cosine_similarity(current: torch.Tensor,
|
||||
reference: torch.Tensor) -> list[float]:
|
||||
"""Compute cosine similarity against a reference tensor for each batch item."""
|
||||
current_flat = flatten_batch_tensor(current)
|
||||
reference_flat = flatten_batch_tensor(reference)
|
||||
return F.cosine_similarity(current_flat,
|
||||
reference_flat,
|
||||
dim=1,
|
||||
eps=1e-8).cpu().tolist()
|
||||
|
||||
|
||||
def first_consecutive_below(values: list[float], threshold: float,
|
||||
window: int) -> float:
|
||||
"""Return the first 1-based index where `window` consecutive values are below threshold."""
|
||||
if window <= 0 or len(values) < window:
|
||||
return np.nan
|
||||
for start in range(len(values) - window + 1):
|
||||
window_values = values[start:start + window]
|
||||
if all(pd.notna(value) and value < threshold for value in window_values):
|
||||
return float(start + 1)
|
||||
return np.nan
|
||||
|
||||
|
||||
def first_at_least(values: list[float], threshold: float) -> float:
|
||||
"""Return the first 1-based index where the series reaches the threshold."""
|
||||
for index, value in enumerate(values, start=1):
|
||||
if pd.notna(value) and value >= threshold:
|
||||
return float(index)
|
||||
return np.nan
|
||||
|
||||
|
||||
def first_at_most(values: list[float], threshold: float) -> float:
|
||||
"""Return the first 1-based index where the series drops below the threshold."""
|
||||
for index, value in enumerate(values, start=1):
|
||||
if pd.notna(value) and value <= threshold:
|
||||
return float(index)
|
||||
return np.nan
|
||||
|
||||
|
||||
def safe_mean(values: list[float]) -> float:
|
||||
"""Average numeric values while ignoring NaNs."""
|
||||
valid_values = [value for value in values if pd.notna(value)]
|
||||
if not valid_values:
|
||||
return np.nan
|
||||
return float(np.mean(valid_values))
|
||||
|
||||
|
||||
def make_sampling_noise_bundle(model: nn.Module,
|
||||
noise_shape: list[int]) -> dict[str, torch.Tensor]:
|
||||
"""Create aligned initial noise for latent, action, and state diffusion streams."""
|
||||
batch_size = noise_shape[0]
|
||||
horizon = noise_shape[2]
|
||||
device = model.device
|
||||
return {
|
||||
'img': torch.randn(noise_shape, device=device),
|
||||
'action': torch.randn((batch_size, horizon, model.agent_action_dim),
|
||||
device=device),
|
||||
'state': torch.randn((batch_size, horizon, model.agent_state_dim),
|
||||
device=device),
|
||||
}
|
||||
|
||||
|
||||
def load_psnr_lookup(psnr_path: str | None) -> dict[str, float]:
|
||||
"""Load optional PSNR values keyed by sample_id or videoid."""
|
||||
if not psnr_path:
|
||||
return {}
|
||||
if not os.path.exists(psnr_path):
|
||||
logging.warning("PSNR file not found: %s", psnr_path)
|
||||
return {}
|
||||
|
||||
suffix = os.path.splitext(psnr_path)[1].lower()
|
||||
lookup: dict[str, float] = {}
|
||||
if suffix == '.csv':
|
||||
df = pd.read_csv(psnr_path)
|
||||
key_column = 'sample_id' if 'sample_id' in df.columns else 'videoid'
|
||||
value_column = 'psnr_full50' if 'psnr_full50' in df.columns else 'psnr'
|
||||
for _, row in df.iterrows():
|
||||
if pd.notna(row[key_column]) and pd.notna(row[value_column]):
|
||||
lookup[str(row[key_column])] = float(row[value_column])
|
||||
return lookup
|
||||
|
||||
if suffix == '.json':
|
||||
with open(psnr_path, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
if isinstance(data, dict):
|
||||
if 'sample_id' in data and ('psnr_full50' in data or 'psnr' in data):
|
||||
lookup[str(data['sample_id'])] = float(
|
||||
data.get('psnr_full50', data['psnr']))
|
||||
else:
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (int, float)):
|
||||
lookup[str(key)] = float(value)
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
if 'sample_id' in item and ('psnr_full50' in item or 'psnr' in
|
||||
item):
|
||||
lookup[str(item['sample_id'])] = float(
|
||||
item.get('psnr_full50', item['psnr']))
|
||||
return lookup
|
||||
|
||||
logging.warning("Unsupported PSNR file format: %s", psnr_path)
|
||||
return {}
|
||||
|
||||
|
||||
class InteractionAnalysisLogger:
|
||||
"""Collect stepwise metrics and aggregated per-sample summaries."""
|
||||
|
||||
STEP_COLUMNS = [
|
||||
'sample_id',
|
||||
'scene',
|
||||
'pass_type',
|
||||
'round_id',
|
||||
'step',
|
||||
'step_time_s',
|
||||
'latent_delta',
|
||||
'action_delta',
|
||||
'state_delta',
|
||||
'action_cosine_vs_full50',
|
||||
'state_cosine_vs_full50',
|
||||
'latent_l2_vs_full50',
|
||||
]
|
||||
SUMMARY_COLUMNS = [
|
||||
'sample_id',
|
||||
'scene',
|
||||
'pass_type',
|
||||
'pass_total_time_s',
|
||||
'action_first_stable_step',
|
||||
'state_first_stable_step',
|
||||
'latent_first_stable_step',
|
||||
'action_vs_full50_90pct_step',
|
||||
'action_vs_full50_95pct_step',
|
||||
'oracle_budget_action',
|
||||
'oracle_budget_state',
|
||||
'oracle_budget_latent',
|
||||
'latent_init_dist_to_prev_round',
|
||||
'action_drift_vs_prev_round',
|
||||
'round_total_time_s',
|
||||
'policy_pass_total_time_s',
|
||||
'world_model_pass_total_time_s',
|
||||
'psnr_full50',
|
||||
]
|
||||
ROUND_COLUMNS = [
|
||||
'sample_id',
|
||||
'scene',
|
||||
'round_id',
|
||||
'policy_pass_total_time_s',
|
||||
'world_model_pass_total_time_s',
|
||||
'round_total_time_s',
|
||||
'latent_init_dist_to_prev_round',
|
||||
'action_drift_vs_prev_round',
|
||||
'psnr_full50',
|
||||
]
|
||||
|
||||
def __init__(self, output_dir: str, psnr_lookup: dict[str, float]):
|
||||
self.output_dir = output_dir
|
||||
self.psnr_lookup = psnr_lookup
|
||||
self.step_rows: list[dict] = []
|
||||
self.summary_buckets: dict[tuple[str, str, str], dict] = {}
|
||||
self.round_rows: list[dict] = []
|
||||
self.round_buckets: dict[tuple[str, str], dict] = {}
|
||||
self.prev_policy_action: dict[str, torch.Tensor] = {}
|
||||
self.prev_world_latent: dict[str, torch.Tensor] = {}
|
||||
|
||||
def resolve_psnr(self, sample_id: str, videoid: int) -> float:
|
||||
"""Resolve a PSNR value by full sample id first, then by raw video id."""
|
||||
candidates = [sample_id, sample_id.rsplit('-fs', 1)[0], str(videoid)]
|
||||
for candidate in candidates:
|
||||
if candidate in self.psnr_lookup:
|
||||
return float(self.psnr_lookup[candidate])
|
||||
return np.nan
|
||||
|
||||
def append_summary_row(self, row: dict) -> None:
|
||||
"""Store per-round summaries and aggregate them later by sample and pass type."""
|
||||
key = (row['sample_id'], row['scene'], row['pass_type'])
|
||||
metric_columns = self.SUMMARY_COLUMNS[3:]
|
||||
if key not in self.summary_buckets:
|
||||
self.summary_buckets[key] = {
|
||||
'sample_id': row['sample_id'],
|
||||
'scene': row['scene'],
|
||||
'pass_type': row['pass_type'],
|
||||
**{column: [] for column in metric_columns},
|
||||
}
|
||||
for column in metric_columns:
|
||||
self.summary_buckets[key][column].append(row.get(column, np.nan))
|
||||
|
||||
def append_round_row(self, row: dict) -> None:
|
||||
"""Store per-round metrics and aggregate them later by sample."""
|
||||
self.round_rows.append(row)
|
||||
key = (row['sample_id'], row['scene'])
|
||||
metric_columns = self.ROUND_COLUMNS[3:]
|
||||
if key not in self.round_buckets:
|
||||
self.round_buckets[key] = {
|
||||
'sample_id': row['sample_id'],
|
||||
'scene': row['scene'],
|
||||
**{column: [] for column in metric_columns},
|
||||
}
|
||||
for column in metric_columns:
|
||||
self.round_buckets[key][column].append(row.get(column, np.nan))
|
||||
|
||||
def collect_trace_series(self, debug_info: dict, reference_action: torch.Tensor,
|
||||
reference_state: torch.Tensor,
|
||||
reference_latent: torch.Tensor) -> tuple[list[float], list[float], list[float]]:
|
||||
"""Extract cosine/L2 curves for either the target pass or the full-50 reference pass."""
|
||||
action_cosines = []
|
||||
state_cosines = []
|
||||
latent_l2s = []
|
||||
for record in debug_info['step_records']:
|
||||
action_cosines.append(
|
||||
batch_cosine_similarity(record['action'], reference_action)[0])
|
||||
state_cosines.append(
|
||||
batch_cosine_similarity(record['state'], reference_state)[0])
|
||||
latent_l2s.append(
|
||||
batch_l2_distance(record['pred_x0'], reference_latent)[0])
|
||||
return action_cosines, state_cosines, latent_l2s
|
||||
|
||||
def log_pass(self, sample_id: str, videoid: int, scene: str, pass_type: str,
|
||||
round_id: int, pass_total_time_s: float, target_debug: dict,
|
||||
reference_debug: dict) -> dict | None:
|
||||
"""Log one pass worth of stepwise and aggregated metrics."""
|
||||
if not target_debug or not target_debug.get('step_records'):
|
||||
return None
|
||||
if not reference_debug or not reference_debug.get('step_records'):
|
||||
reference_debug = target_debug
|
||||
|
||||
reference_final_action = reference_debug['step_records'][-1]['action']
|
||||
reference_final_state = reference_debug['step_records'][-1]['state']
|
||||
reference_final_latent = reference_debug['step_records'][-1]['pred_x0']
|
||||
|
||||
prev_img = target_debug['analysis_init']['img']
|
||||
prev_action = target_debug['analysis_init']['action']
|
||||
prev_state = target_debug['analysis_init']['state']
|
||||
action_deltas: list[float] = []
|
||||
state_deltas: list[float] = []
|
||||
latent_deltas: list[float] = []
|
||||
action_cosines: list[float] = []
|
||||
state_cosines: list[float] = []
|
||||
latent_l2s: list[float] = []
|
||||
|
||||
for record in target_debug['step_records']:
|
||||
latent_delta = batch_relative_l2(record['img'], prev_img)[0]
|
||||
action_delta = batch_relative_l2(record['action'], prev_action)[0]
|
||||
state_delta = batch_relative_l2(record['state'], prev_state)[0]
|
||||
action_cosine = batch_cosine_similarity(record['action'],
|
||||
reference_final_action)[0]
|
||||
state_cosine = batch_cosine_similarity(record['state'],
|
||||
reference_final_state)[0]
|
||||
latent_l2 = batch_l2_distance(record['pred_x0'],
|
||||
reference_final_latent)[0]
|
||||
|
||||
action_deltas.append(action_delta)
|
||||
state_deltas.append(state_delta)
|
||||
latent_deltas.append(latent_delta)
|
||||
action_cosines.append(action_cosine)
|
||||
state_cosines.append(state_cosine)
|
||||
latent_l2s.append(latent_l2)
|
||||
|
||||
self.step_rows.append({
|
||||
'sample_id': sample_id,
|
||||
'scene': scene,
|
||||
'pass_type': pass_type,
|
||||
'round_id': round_id,
|
||||
'step': record['step_index'],
|
||||
'step_time_s': float(record['step_time_s']),
|
||||
'latent_delta': latent_delta,
|
||||
'action_delta': action_delta,
|
||||
'state_delta': state_delta,
|
||||
'action_cosine_vs_full50': action_cosine,
|
||||
'state_cosine_vs_full50': state_cosine,
|
||||
'latent_l2_vs_full50': latent_l2,
|
||||
})
|
||||
|
||||
prev_img = record['img']
|
||||
prev_action = record['action']
|
||||
prev_state = record['state']
|
||||
|
||||
oracle_action_cosines, oracle_state_cosines, oracle_latent_l2s = self.collect_trace_series(
|
||||
reference_debug, reference_final_action, reference_final_state,
|
||||
reference_final_latent)
|
||||
|
||||
latent_init_dist_to_prev_round = np.nan
|
||||
action_drift_vs_prev_round = np.nan
|
||||
if pass_type == 'policy':
|
||||
previous_action = self.prev_policy_action.get(sample_id)
|
||||
if previous_action is not None:
|
||||
action_drift_vs_prev_round = 1.0 - batch_cosine_similarity(
|
||||
reference_final_action, previous_action)[0]
|
||||
self.prev_policy_action[sample_id] = reference_final_action.clone()
|
||||
elif pass_type == 'world_model':
|
||||
previous_latent = self.prev_world_latent.get(sample_id)
|
||||
if previous_latent is not None:
|
||||
latent_init_dist_to_prev_round = batch_l2_distance(
|
||||
reference_final_latent, previous_latent)[0]
|
||||
self.prev_world_latent[sample_id] = reference_final_latent.clone()
|
||||
|
||||
summary_row = {
|
||||
'sample_id': sample_id,
|
||||
'scene': scene,
|
||||
'pass_type': pass_type,
|
||||
'pass_total_time_s': float(pass_total_time_s),
|
||||
'action_first_stable_step': np.nan,
|
||||
'state_first_stable_step': np.nan,
|
||||
'latent_first_stable_step': np.nan,
|
||||
'action_vs_full50_90pct_step': first_at_least(action_cosines, 0.90),
|
||||
'action_vs_full50_95pct_step': first_at_least(action_cosines, 0.95),
|
||||
'oracle_budget_action': first_at_least(oracle_action_cosines, 0.95),
|
||||
'oracle_budget_state': first_at_least(oracle_state_cosines, 0.95),
|
||||
'oracle_budget_latent': np.nan,
|
||||
'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round,
|
||||
'action_drift_vs_prev_round': action_drift_vs_prev_round,
|
||||
'round_total_time_s': np.nan,
|
||||
'policy_pass_total_time_s': np.nan,
|
||||
'world_model_pass_total_time_s': np.nan,
|
||||
'psnr_full50': self.resolve_psnr(sample_id, videoid),
|
||||
}
|
||||
self.append_summary_row(summary_row)
|
||||
return summary_row
|
||||
|
||||
def log_round(self, sample_id: str, videoid: int, scene: str, round_id: int,
|
||||
policy_pass_total_time_s: float,
|
||||
world_model_pass_total_time_s: float, round_total_time_s: float,
|
||||
latent_init_dist_to_prev_round: float,
|
||||
action_drift_vs_prev_round: float) -> None:
|
||||
"""Log one interaction round consisting of one policy pass and one world-model pass."""
|
||||
self.append_round_row({
|
||||
'sample_id': sample_id,
|
||||
'scene': scene,
|
||||
'round_id': round_id,
|
||||
'policy_pass_total_time_s': float(policy_pass_total_time_s),
|
||||
'world_model_pass_total_time_s': float(world_model_pass_total_time_s),
|
||||
'round_total_time_s': float(round_total_time_s),
|
||||
'latent_init_dist_to_prev_round': latent_init_dist_to_prev_round,
|
||||
'action_drift_vs_prev_round': action_drift_vs_prev_round,
|
||||
'psnr_full50': self.resolve_psnr(sample_id, videoid),
|
||||
})
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Write analysis CSVs to disk."""
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
stepwise_path = os.path.join(self.output_dir, 'stepwise_log.csv')
|
||||
summary_path = os.path.join(self.output_dir, 'sample_summary.csv')
|
||||
round_path = os.path.join(self.output_dir, 'round_summary.csv')
|
||||
|
||||
stepwise_df = pd.DataFrame(self.step_rows, columns=self.STEP_COLUMNS)
|
||||
stepwise_df.to_csv(stepwise_path, index=False)
|
||||
|
||||
round_df = pd.DataFrame(self.round_rows, columns=self.ROUND_COLUMNS)
|
||||
round_df.to_csv(round_path, index=False)
|
||||
|
||||
summary_rows = []
|
||||
metric_columns = self.SUMMARY_COLUMNS[3:]
|
||||
for bucket in self.summary_buckets.values():
|
||||
round_bucket = self.round_buckets.get((bucket['sample_id'],
|
||||
bucket['scene']))
|
||||
row = {
|
||||
'sample_id': bucket['sample_id'],
|
||||
'scene': bucket['scene'],
|
||||
'pass_type': bucket['pass_type'],
|
||||
}
|
||||
for column in metric_columns:
|
||||
if round_bucket is not None and column in round_bucket:
|
||||
row[column] = safe_mean(round_bucket[column])
|
||||
else:
|
||||
row[column] = safe_mean(bucket[column])
|
||||
summary_rows.append(row)
|
||||
summary_df = pd.DataFrame(summary_rows, columns=self.SUMMARY_COLUMNS)
|
||||
summary_df.to_csv(summary_path, index=False)
|
||||
|
||||
|
||||
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
||||
@@ -312,22 +721,25 @@ def preprocess_observation(
|
||||
return return_observations
|
||||
|
||||
|
||||
def image_guided_synthesis_sim_mode(
|
||||
model: torch.nn.Module,
|
||||
prompts: list[str],
|
||||
observation: dict,
|
||||
noise_shape: tuple[int, int, int, int, int],
|
||||
def image_guided_synthesis_sim_mode(
|
||||
model: torch.nn.Module,
|
||||
prompts: list[str],
|
||||
observation: dict,
|
||||
noise_shape: tuple[int, int, int, int, int],
|
||||
action_cond_step: int = 16,
|
||||
n_samples: int = 1,
|
||||
ddim_steps: int = 50,
|
||||
ddim_eta: float = 1.0,
|
||||
unconditional_guidance_scale: float = 1.0,
|
||||
fs: int | None = None,
|
||||
text_input: bool = True,
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
sim_mode: bool = True,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
text_input: bool = True,
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
sim_mode: bool = True,
|
||||
init_noise_bundle: dict[str, torch.Tensor] | None = None,
|
||||
decode_video: bool = True,
|
||||
return_debug_info: bool = False,
|
||||
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor, dict | None]:
|
||||
"""
|
||||
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
||||
|
||||
@@ -350,18 +762,24 @@ def image_guided_synthesis_sim_mode(
|
||||
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
||||
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
||||
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||
|
||||
Returns:
|
||||
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
|
||||
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
||||
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
||||
init_noise_bundle (dict[str, torch.Tensor] | None): Optional aligned noise inputs for latent/action/state.
|
||||
decode_video (bool): Whether to decode the final latent into pixel space.
|
||||
return_debug_info (bool): Whether to return per-step traces for analysis logging.
|
||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||
|
||||
Returns:
|
||||
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W].
|
||||
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
||||
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
||||
debug_info (dict | None): Optional per-step trace used for convergence analysis.
|
||||
"""
|
||||
b, _, t, _, _ = noise_shape
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
batch_size = noise_shape[0]
|
||||
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
b, _, t, _, _ = noise_shape
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
batch_size = noise_shape[0]
|
||||
batch_variants = None
|
||||
debug_info = None
|
||||
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
@@ -406,11 +824,11 @@ def image_guided_synthesis_sim_mode(
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
if ddim_sampler is not None:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
if ddim_sampler is not None:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
@@ -418,20 +836,35 @@ def image_guided_synthesis_sim_mode(
|
||||
eta=ddim_eta,
|
||||
cfg_img=None,
|
||||
mask=cond_mask,
|
||||
x0=cond_z0,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
|
||||
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
x0=cond_z0,
|
||||
fs=fs,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
x_T=None if init_noise_bundle is None else init_noise_bundle['img'],
|
||||
action_T=None if init_noise_bundle is None else
|
||||
init_noise_bundle['action'],
|
||||
state_T=None if init_noise_bundle is None else
|
||||
init_noise_bundle['state'],
|
||||
record_step_outputs=return_debug_info,
|
||||
**kwargs)
|
||||
|
||||
batch_variants = None
|
||||
if decode_video:
|
||||
batch_variants = model.decode_first_stage(samples)
|
||||
|
||||
if return_debug_info:
|
||||
debug_info = {
|
||||
'analysis_init': intermedia.get('analysis_init'),
|
||||
'step_records': intermedia.get('step_records', []),
|
||||
'final_latent': samples.detach().cpu(),
|
||||
'final_action': actions.detach().cpu(),
|
||||
'final_state': states.detach().cpu(),
|
||||
}
|
||||
|
||||
return batch_variants, actions, states, debug_info
|
||||
|
||||
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
"""
|
||||
Run inference pipeline on prompts and image inputs.
|
||||
|
||||
@@ -443,11 +876,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# Create inference and tensorboard dirs
|
||||
os.makedirs(args.savedir + '/inference', exist_ok=True)
|
||||
log_dir = args.savedir + f"/tensorboard"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=log_dir)
|
||||
# Create inference and tensorboard dirs
|
||||
inference_dir = args.savedir + '/inference'
|
||||
os.makedirs(inference_dir, exist_ok=True)
|
||||
log_dir = args.savedir + f"/tensorboard"
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=log_dir)
|
||||
analysis_logger = None
|
||||
if args.analysis_log_metrics:
|
||||
analysis_logger = InteractionAnalysisLogger(
|
||||
output_dir=inference_dir,
|
||||
psnr_lookup=load_psnr_lookup(args.analysis_psnr_path),
|
||||
)
|
||||
|
||||
# Load prompt
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
@@ -474,10 +914,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
# Run over data
|
||||
assert (args.height % 16 == 0) and (
|
||||
args.width % 16
|
||||
== 0), "Error: image size [h,w] should be multiples of 16!"
|
||||
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
||||
assert (args.height % 16 == 0) and (
|
||||
args.width % 16
|
||||
== 0), "Error: image size [h,w] should be multiples of 16!"
|
||||
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
|
||||
if args.analysis_log_metrics:
|
||||
assert args.ddim_steps > 0, "analysis_log_metrics requires positive --ddim_steps."
|
||||
assert args.analysis_reference_steps > 0, (
|
||||
"analysis_log_metrics requires positive --analysis_reference_steps.")
|
||||
|
||||
# Get latent noise shape
|
||||
h, w = args.height // 8, args.width // 8
|
||||
@@ -508,12 +952,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
for key in h5f.attrs.keys():
|
||||
transition_dict[key] = h5f.attrs[key]
|
||||
|
||||
# If many, test various frequence control and world-model generation
|
||||
for fs in args.frame_stride:
|
||||
|
||||
# For saving imagens in policy
|
||||
sample_save_dir = f'{video_save_dir}/dm/{fs}'
|
||||
os.makedirs(sample_save_dir, exist_ok=True)
|
||||
# If many, test various frequence control and world-model generation
|
||||
for fs in args.frame_stride:
|
||||
sample_id = build_sample_id(args.dataset, sample, fs)
|
||||
scene = get_scene_name(sample, args.dataset)
|
||||
|
||||
# For saving imagens in policy
|
||||
sample_save_dir = f'{video_save_dir}/dm/{fs}'
|
||||
os.makedirs(sample_save_dir, exist_ok=True)
|
||||
# For saving environmental changes in world-model
|
||||
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
||||
os.makedirs(sample_save_dir, exist_ok=True)
|
||||
@@ -552,11 +998,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
# Update observation queues
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||
|
||||
# Multi-round interaction with the world-model
|
||||
for itr in tqdm(range(args.n_iter)):
|
||||
|
||||
# Get observation
|
||||
observation = {
|
||||
# Multi-round interaction with the world-model
|
||||
for itr in tqdm(range(args.n_iter)):
|
||||
round_start_time = time.time()
|
||||
|
||||
# Get observation
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(
|
||||
cond_obs_queues['observation.images.top']),
|
||||
@@ -566,33 +1013,72 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
dim=1),
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=True)
|
||||
for key in observation
|
||||
}
|
||||
|
||||
# Use world-model in policy to generate action
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
}
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=True)
|
||||
for key in observation
|
||||
}
|
||||
|
||||
# Use world-model in policy to generate action
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
policy_noise_bundle = make_sampling_noise_bundle(
|
||||
model, noise_shape) if args.analysis_log_metrics else None
|
||||
policy_reference_debug = None
|
||||
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
|
||||
_, _, _, policy_reference_debug = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
ddim_steps=args.analysis_reference_steps,
|
||||
ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False,
|
||||
init_noise_bundle=policy_noise_bundle,
|
||||
decode_video=False,
|
||||
return_debug_info=True)
|
||||
policy_pass_start = time.time()
|
||||
pred_videos_0, pred_actions, _, policy_debug = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
ddim_steps=args.ddim_steps,
|
||||
ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False)
|
||||
|
||||
# Update future actions in the observation queues
|
||||
for idx in range(len(pred_actions[0])):
|
||||
observation = {'action': pred_actions[0][idx:idx + 1]}
|
||||
observation['action'][:, ori_action_dim:] = 0.0
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False,
|
||||
init_noise_bundle=policy_noise_bundle,
|
||||
return_debug_info=args.analysis_log_metrics)
|
||||
policy_pass_total_time_s = time.time() - policy_pass_start
|
||||
policy_summary_row = None
|
||||
if analysis_logger is not None:
|
||||
if policy_reference_debug is None:
|
||||
policy_reference_debug = policy_debug
|
||||
policy_summary_row = analysis_logger.log_pass(
|
||||
sample_id=sample_id,
|
||||
videoid=int(sample['videoid']),
|
||||
scene=scene,
|
||||
pass_type='policy',
|
||||
round_id=itr,
|
||||
pass_total_time_s=policy_pass_total_time_s,
|
||||
target_debug=policy_debug,
|
||||
reference_debug=policy_reference_debug,
|
||||
)
|
||||
|
||||
# Update future actions in the observation queues
|
||||
for idx in range(len(pred_actions[0])):
|
||||
observation = {'action': pred_actions[0][idx:idx + 1]}
|
||||
observation['action'][:, ori_action_dim:] = 0.0
|
||||
cond_obs_queues = populate_queues(cond_obs_queues,
|
||||
observation)
|
||||
|
||||
@@ -611,29 +1097,84 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
observation = {
|
||||
key: observation[key].to(device, non_blocking=True)
|
||||
for key in observation
|
||||
}
|
||||
|
||||
# Interaction with the world-model
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
}
|
||||
|
||||
# Interaction with the world-model
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
world_noise_bundle = make_sampling_noise_bundle(
|
||||
model, noise_shape) if args.analysis_log_metrics else None
|
||||
world_reference_debug = None
|
||||
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
|
||||
_, _, _, world_reference_debug = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
ddim_steps=args.analysis_reference_steps,
|
||||
ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
init_noise_bundle=world_noise_bundle,
|
||||
decode_video=False,
|
||||
return_debug_info=True)
|
||||
world_pass_start = time.time()
|
||||
pred_videos_1, _, pred_states, world_debug = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
ddim_steps=args.ddim_steps,
|
||||
ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale)
|
||||
|
||||
for idx in range(args.exe_steps):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
init_noise_bundle=world_noise_bundle,
|
||||
return_debug_info=args.analysis_log_metrics)
|
||||
world_pass_total_time_s = time.time() - world_pass_start
|
||||
world_summary_row = None
|
||||
if analysis_logger is not None:
|
||||
if world_reference_debug is None:
|
||||
world_reference_debug = world_debug
|
||||
world_summary_row = analysis_logger.log_pass(
|
||||
sample_id=sample_id,
|
||||
videoid=int(sample['videoid']),
|
||||
scene=scene,
|
||||
pass_type='world_model',
|
||||
round_id=itr,
|
||||
pass_total_time_s=world_pass_total_time_s,
|
||||
target_debug=world_debug,
|
||||
reference_debug=world_reference_debug,
|
||||
)
|
||||
analysis_logger.log_round(
|
||||
sample_id=sample_id,
|
||||
videoid=int(sample['videoid']),
|
||||
scene=scene,
|
||||
round_id=itr,
|
||||
policy_pass_total_time_s=policy_pass_total_time_s,
|
||||
world_model_pass_total_time_s=
|
||||
world_pass_total_time_s,
|
||||
round_total_time_s=time.time() - round_start_time,
|
||||
latent_init_dist_to_prev_round=np.nan
|
||||
if world_summary_row is None else
|
||||
world_summary_row['latent_init_dist_to_prev_round'],
|
||||
action_drift_vs_prev_round=np.nan
|
||||
if policy_summary_row is None else
|
||||
policy_summary_row['action_drift_vs_prev_round'],
|
||||
)
|
||||
|
||||
for idx in range(args.exe_steps):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||
'observation.state':
|
||||
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
||||
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
||||
@@ -678,8 +1219,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
|
||||
if analysis_logger is not None:
|
||||
analysis_logger.flush()
|
||||
writer.close()
|
||||
|
||||
|
||||
def get_parser():
|
||||
@@ -794,11 +1339,25 @@ def get_parser():
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="not using the predicted states as comparison")
|
||||
parser.add_argument("--save_fps",
|
||||
type=int,
|
||||
default=8,
|
||||
help="fps for the saving video")
|
||||
return parser
|
||||
parser.add_argument("--save_fps",
|
||||
type=int,
|
||||
default=8,
|
||||
help="fps for the saving video")
|
||||
parser.add_argument("--analysis_log_metrics",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Enable DDIM convergence logging and export analysis CSVs.")
|
||||
parser.add_argument(
|
||||
"--analysis_reference_steps",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Reference DDIM steps used to build the full-step baseline for *_vs_full50 metrics."
|
||||
)
|
||||
parser.add_argument("--analysis_psnr_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional CSV/JSON file with psnr_full50 values keyed by sample_id or videoid.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
58
scripts/run_all_world_model_cases.sh
Normal file
58
scripts/run_all_world_model_cases.sh
Normal file
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
cd "${repo_root}"
|
||||
|
||||
run_analysis="${RUN_ANALYSIS:-1}"
|
||||
case_filter="${CASE_FILTER:-}"
|
||||
|
||||
case_scripts=(
|
||||
"unitree_z1_stackbox/case1/run_world_model_interaction.sh"
|
||||
"unitree_z1_stackbox/case2/run_world_model_interaction.sh"
|
||||
"unitree_z1_stackbox/case3/run_world_model_interaction.sh"
|
||||
"unitree_z1_stackbox/case4/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_stackbox/case1/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_stackbox/case2/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_stackbox/case3/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_stackbox/case4/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_stackbox_v2/case1/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_stackbox_v2/case2/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_stackbox_v2/case3/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_stackbox_v2/case4/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_cleanup_pencils/case1/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_cleanup_pencils/case2/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_cleanup_pencils/case3/run_world_model_interaction.sh"
|
||||
"unitree_z1_dual_arm_cleanup_pencils/case4/run_world_model_interaction.sh"
|
||||
"unitree_g1_pack_camera/case1/run_world_model_interaction.sh"
|
||||
"unitree_g1_pack_camera/case2/run_world_model_interaction.sh"
|
||||
"unitree_g1_pack_camera/case3/run_world_model_interaction.sh"
|
||||
"unitree_g1_pack_camera/case4/run_world_model_interaction.sh"
|
||||
)
|
||||
|
||||
for case_script in "${case_scripts[@]}"; do
|
||||
if [[ -n "${case_filter}" ]] && [[ "${case_script}" != *"${case_filter}"* ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
case_dir="$(dirname "${case_script}")"
|
||||
inference_dir="${repo_root}/${case_dir}/output/inference"
|
||||
|
||||
echo "============================================================"
|
||||
echo "Running ${case_script}"
|
||||
echo "============================================================"
|
||||
bash "${repo_root}/${case_script}"
|
||||
|
||||
if [[ "${run_analysis}" == "1" ]]; then
|
||||
if [[ -f "${inference_dir}/stepwise_log.csv" ]] && \
|
||||
[[ -f "${inference_dir}/sample_summary.csv" ]] && \
|
||||
[[ -f "${inference_dir}/round_summary.csv" ]]; then
|
||||
echo "Analyzing ${case_dir}"
|
||||
python3 "${repo_root}/scripts/evaluation/analyze_metrics.py" \
|
||||
--input_dir "${inference_dir}"
|
||||
else
|
||||
echo "Skipping analysis for ${case_dir}: missing analysis CSVs."
|
||||
fi
|
||||
fi
|
||||
done
|
||||
@@ -3,6 +3,8 @@ ckpt=/path/to/model/checkpoint
|
||||
config=configs/inference/world_model_interaction.yaml
|
||||
seed=123
|
||||
res_dir="/path/to/result/directory"
|
||||
analysis_reference_steps=50
|
||||
# analysis_psnr_path="/path/to/psnr_full50.csv"
|
||||
|
||||
datasets=(
|
||||
"unitree_z1_stackbox"
|
||||
@@ -19,6 +21,13 @@ for i in "${!datasets[@]}"; do
|
||||
dataset=${datasets[$i]}
|
||||
n_iter=${n_iters[$i]}
|
||||
fs=${fses[$i]}
|
||||
analysis_args=(
|
||||
--analysis_log_metrics
|
||||
--analysis_reference_steps ${analysis_reference_steps}
|
||||
)
|
||||
if [ -n "${analysis_psnr_path:-}" ]; then
|
||||
analysis_args+=(--analysis_psnr_path "${analysis_psnr_path}")
|
||||
fi
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed ${seed} \
|
||||
@@ -38,5 +47,6 @@ for i in "${!datasets[@]}"; do
|
||||
--n_iter ${n_iter} \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
--perframe_ae \
|
||||
"${analysis_args[@]}"
|
||||
done
|
||||
|
||||
Reference in New Issue
Block a user