Files
unifolm-world-model-action/scripts/evaluation/analyze_metrics.py
2026-03-15 12:41:53 +08:00

612 lines
23 KiB
Python

import argparse
import json
import os
from pathlib import Path
import numpy as np
import pandas as pd
try:
import matplotlib.pyplot as plt
except ModuleNotFoundError:
plt = None
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Analyze DDIM convergence logs for the five research directions."
)
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="Directory containing stepwise_log.csv, sample_summary.csv, and round_summary.csv.",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory to write analysis outputs. Defaults to <input_dir>/analysis.",
)
parser.add_argument(
"--mid_step_start",
type=int,
default=15,
help="Start of the step window used for cross-step latent-delta analysis.",
)
parser.add_argument(
"--mid_step_end",
type=int,
default=35,
help="End of the step window used for cross-step latent-delta analysis.",
)
parser.add_argument(
"--full50_step_cap",
type=int,
default=50,
help="Upper bound used when summarizing step-based metrics.",
)
parser.add_argument(
"--stability_threshold",
type=float,
default=None,
help="Optional threshold for offline first_stable_step analysis. If unset, stable-step metrics are not computed.",
)
parser.add_argument(
"--stability_window",
type=int,
default=3,
help="Consecutive-step window for offline first_stable_step analysis.",
)
return parser.parse_args()
def safe_float(value) -> float:
if pd.isna(value):
return float("nan")
return float(value)
def describe_series(series: pd.Series) -> dict[str, float]:
series = pd.to_numeric(series, errors="coerce").dropna()
if series.empty:
return {
"count": 0.0,
"mean": float("nan"),
"std": float("nan"),
"min": float("nan"),
"p25": float("nan"),
"median": float("nan"),
"p75": float("nan"),
"p90": float("nan"),
"p95": float("nan"),
"max": float("nan"),
}
return {
"count": float(series.count()),
"mean": float(series.mean()),
"std": float(series.std(ddof=0)),
"min": float(series.min()),
"p25": float(series.quantile(0.25)),
"median": float(series.median()),
"p75": float(series.quantile(0.75)),
"p90": float(series.quantile(0.90)),
"p95": float(series.quantile(0.95)),
"max": float(series.max()),
}
def ensure_columns(frame: pd.DataFrame, required: list[str], frame_name: str) -> None:
missing = [column for column in required if column not in frame.columns]
if missing:
raise ValueError(f"{frame_name} is missing required columns: {missing}")
def first_consecutive_below(values: pd.Series, threshold: float,
window: int) -> float:
cleaned = pd.to_numeric(values, errors="coerce").tolist()
if window <= 0 or len(cleaned) < window:
return float("nan")
for start in range(len(cleaned) - window + 1):
chunk = cleaned[start:start + window]
if all(pd.notna(value) and value < threshold for value in chunk):
return float(start + 1)
return float("nan")
def load_tables(input_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
stepwise_path = input_dir / "stepwise_log.csv"
sample_summary_path = input_dir / "sample_summary.csv"
round_summary_path = input_dir / "round_summary.csv"
if not stepwise_path.exists():
raise FileNotFoundError(f"Missing file: {stepwise_path}")
if not sample_summary_path.exists():
raise FileNotFoundError(f"Missing file: {sample_summary_path}")
if not round_summary_path.exists():
raise FileNotFoundError(f"Missing file: {round_summary_path}")
stepwise_df = pd.read_csv(stepwise_path)
sample_summary_df = pd.read_csv(sample_summary_path)
round_summary_df = pd.read_csv(round_summary_path)
ensure_columns(
stepwise_df,
[
"sample_id",
"scene",
"pass_type",
"round_id",
"step",
"latent_delta",
"action_delta",
"state_delta",
"action_cosine_vs_full50",
"state_cosine_vs_full50",
"latent_l2_vs_full50",
],
"stepwise_log.csv",
)
ensure_columns(
sample_summary_df,
[
"sample_id",
"scene",
"pass_type",
"pass_total_time_s",
"action_first_stable_step",
"state_first_stable_step",
"latent_first_stable_step",
"action_vs_full50_90pct_step",
"action_vs_full50_95pct_step",
"oracle_budget_action",
"oracle_budget_state",
],
"sample_summary.csv",
)
ensure_columns(
round_summary_df,
[
"sample_id",
"scene",
"round_id",
"policy_pass_total_time_s",
"world_model_pass_total_time_s",
"round_total_time_s",
"latent_init_dist_to_prev_round",
"action_drift_vs_prev_round",
],
"round_summary.csv",
)
return stepwise_df, sample_summary_df, round_summary_df
def aggregate_stepwise(stepwise_df: pd.DataFrame) -> pd.DataFrame:
metric_columns = [
"step_time_s",
"latent_delta",
"action_delta",
"state_delta",
"action_cosine_vs_full50",
"state_cosine_vs_full50",
"latent_l2_vs_full50",
]
stepwise_df = stepwise_df.copy()
for column in metric_columns:
stepwise_df[column] = pd.to_numeric(stepwise_df[column], errors="coerce")
grouped = stepwise_df.groupby(["pass_type", "step"])[metric_columns]
return grouped.agg(["mean", "median", "std"]).reset_index()
def aggregate_scene_summary(sample_summary_df: pd.DataFrame) -> pd.DataFrame:
metrics = [
"pass_total_time_s",
"action_first_stable_step",
"state_first_stable_step",
"latent_first_stable_step",
"action_vs_full50_90pct_step",
"action_vs_full50_95pct_step",
"oracle_budget_action",
"oracle_budget_state",
"round_total_time_s",
"policy_pass_total_time_s",
"world_model_pass_total_time_s",
"latent_init_dist_to_prev_round",
"action_drift_vs_prev_round",
]
sample_summary_df = sample_summary_df.copy()
for column in metrics:
sample_summary_df[column] = pd.to_numeric(sample_summary_df[column],
errors="coerce")
grouped = sample_summary_df.groupby(["scene", "pass_type"])[metrics]
return grouped.agg(["mean", "median", "std"]).reset_index()
def compute_stability_summary(stepwise_df: pd.DataFrame, threshold: float,
window: int) -> pd.DataFrame:
rows = []
grouped = stepwise_df.sort_values("step").groupby(
["sample_id", "scene", "pass_type"])
for (sample_id, scene, pass_type), group in grouped:
rows.append({
"sample_id": sample_id,
"scene": scene,
"pass_type": pass_type,
"action_first_stable_step": first_consecutive_below(
group["action_delta"], threshold, window),
"state_first_stable_step": first_consecutive_below(
group["state_delta"], threshold, window),
"latent_first_stable_step": first_consecutive_below(
group["latent_delta"], threshold, window),
})
return pd.DataFrame(rows)
def compute_direction_summary(stepwise_df: pd.DataFrame,
sample_summary_df: pd.DataFrame,
round_summary_df: pd.DataFrame,
mid_step_start: int,
mid_step_end: int,
stability_threshold: float | None,
stability_window: int) -> dict:
policy_summary = sample_summary_df[
sample_summary_df["pass_type"] == "policy"].copy()
world_summary = sample_summary_df[
sample_summary_df["pass_type"] == "world_model"].copy()
latent_mid = stepwise_df[
(stepwise_df["step"] >= mid_step_start)
& (stepwise_df["step"] <= mid_step_end)].copy()
action_before_latent = policy_summary[[
"action_first_stable_step",
"latent_first_stable_step",
"action_vs_full50_95pct_step",
]].dropna()
if action_before_latent.empty:
action_first_beats_latent = float("nan")
action_95_beats_latent = float("nan")
else:
action_first_beats_latent = float(
(action_before_latent["action_first_stable_step"] <
action_before_latent["latent_first_stable_step"]).mean())
action_95_beats_latent = float(
(action_before_latent["action_vs_full50_95pct_step"] <
action_before_latent["latent_first_stable_step"]).mean())
direction_summary = {
"analysis_config": {
"stability_threshold": stability_threshold,
"stability_window": int(stability_window),
"mid_step_window": [int(mid_step_start), int(mid_step_end)],
},
"dataset_overview": {
"num_step_rows": int(len(stepwise_df)),
"num_sample_rows": int(len(sample_summary_df)),
"num_round_rows": int(len(round_summary_df)),
"num_unique_samples": int(sample_summary_df["sample_id"].nunique()),
"scenes": sorted(sample_summary_df["scene"].dropna().unique().tolist()),
"pass_types": sorted(sample_summary_df["pass_type"].dropna().unique().tolist()),
},
"direction_original_early_stop": {
"latent_first_stable_step_policy":
describe_series(policy_summary["latent_first_stable_step"]),
"latent_first_stable_step_world_model":
describe_series(world_summary["latent_first_stable_step"]),
"latent_l2_vs_full50":
describe_series(stepwise_df["latent_l2_vs_full50"]),
},
"direction_c_action_converges_first": {
"action_first_stable_step":
describe_series(policy_summary["action_first_stable_step"]),
"latent_first_stable_step":
describe_series(policy_summary["latent_first_stable_step"]),
"action_vs_full50_95pct_step":
describe_series(policy_summary["action_vs_full50_95pct_step"]),
"share_action_first_stable_before_latent":
action_first_beats_latent,
"share_action_95pct_before_latent":
action_95_beats_latent,
},
"direction_d_cross_step_similarity": {
"latent_delta_mid_steps":
describe_series(latent_mid["latent_delta"]),
"action_delta_mid_steps":
describe_series(latent_mid["action_delta"]),
"state_delta_mid_steps":
describe_series(latent_mid["state_delta"]),
},
"direction_a_budget_heterogeneity": {
"oracle_budget_action":
describe_series(policy_summary["oracle_budget_action"]),
"oracle_budget_state":
describe_series(world_summary["oracle_budget_state"]),
"oracle_budget_action_by_scene": {
scene: describe_series(group["oracle_budget_action"])
for scene, group in policy_summary.groupby("scene")
},
},
"direction_b_round_reuse": {
"latent_init_dist_to_prev_round":
describe_series(round_summary_df["latent_init_dist_to_prev_round"]),
"action_drift_vs_prev_round":
describe_series(round_summary_df["action_drift_vs_prev_round"]),
"round_total_time_s":
describe_series(round_summary_df["round_total_time_s"]),
"policy_pass_total_time_s":
describe_series(round_summary_df["policy_pass_total_time_s"]),
"world_model_pass_total_time_s":
describe_series(round_summary_df["world_model_pass_total_time_s"]),
},
}
return direction_summary
def save_json(path: Path, payload: dict) -> None:
with open(path, "w", encoding="utf-8") as file:
json.dump(payload, file, indent=2, ensure_ascii=False)
def write_markdown_report(path: Path, direction_summary: dict) -> None:
lines = [
"# DDIM Analysis Report",
"",
"## Analysis Config",
"",
]
config = direction_summary["analysis_config"]
lines.extend([
f"- Stability threshold: {config['stability_threshold']}",
f"- Stability window: {config['stability_window']}",
f"- Mid-step window: {config['mid_step_window'][0]}-{config['mid_step_window'][1]}",
"",
"## Dataset Overview",
"",
])
overview = direction_summary["dataset_overview"]
lines.extend([
f"- Unique samples: {overview['num_unique_samples']}",
f"- Step rows: {overview['num_step_rows']}",
f"- Sample rows: {overview['num_sample_rows']}",
f"- Round rows: {overview['num_round_rows']}",
f"- Scenes: {', '.join(overview['scenes'])}",
f"- Pass types: {', '.join(overview['pass_types'])}",
"",
"## Five Directions",
"",
])
def append_stats(title: str, payload: dict) -> None:
lines.append(f"### {title}")
lines.append("")
for key, value in payload.items():
if isinstance(value, dict):
if {"median", "mean", "p90"} <= set(value.keys()):
lines.append(
f"- `{key}`: mean={value['mean']:.4f}, median={value['median']:.4f}, p90={value['p90']:.4f}, p95={value['p95']:.4f}"
)
else:
lines.append(f"- `{key}`:")
for sub_key, sub_value in value.items():
if isinstance(sub_value, dict) and {"median", "mean"} <= set(
sub_value.keys()):
lines.append(
f" - `{sub_key}`: mean={sub_value['mean']:.4f}, median={sub_value['median']:.4f}, std={sub_value['std']:.4f}"
)
else:
lines.append(f" - `{sub_key}`: {sub_value}")
else:
if isinstance(value, float):
lines.append(f"- `{key}`: {value:.4f}")
else:
lines.append(f"- `{key}`: {value}")
lines.append("")
append_stats("Original Early Stop",
direction_summary["direction_original_early_stop"])
append_stats("Direction C: Action Converges First",
direction_summary["direction_c_action_converges_first"])
append_stats("Direction D: Cross-Step Similarity",
direction_summary["direction_d_cross_step_similarity"])
append_stats("Direction A: Budget Heterogeneity",
direction_summary["direction_a_budget_heterogeneity"])
append_stats("Direction B: Round Reuse",
direction_summary["direction_b_round_reuse"])
with open(path, "w", encoding="utf-8") as file:
file.write("\n".join(lines) + "\n")
def save_convergence_plot(stepwise_df: pd.DataFrame, output_path: Path) -> None:
if plt is None:
return
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
pass_types = sorted(stepwise_df["pass_type"].dropna().unique().tolist())
for pass_type in pass_types:
subset = stepwise_df[stepwise_df["pass_type"] == pass_type]
grouped = subset.groupby("step").mean(numeric_only=True).reset_index()
axes[0, 0].plot(grouped["step"],
grouped["latent_delta"],
label=pass_type)
axes[0, 1].plot(grouped["step"],
grouped["action_cosine_vs_full50"],
label=pass_type)
axes[1, 0].plot(grouped["step"],
grouped["state_cosine_vs_full50"],
label=pass_type)
axes[1, 1].plot(grouped["step"],
grouped["latent_l2_vs_full50"],
label=pass_type)
axes[0, 0].set_title("Mean Latent Delta")
axes[0, 1].set_title("Mean Action Cosine vs Full50")
axes[1, 0].set_title("Mean State Cosine vs Full50")
axes[1, 1].set_title("Mean Latent L2 vs Full50")
for axis in axes.flatten():
axis.set_xlabel("Step")
axis.grid(alpha=0.3)
axis.legend()
fig.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def save_budget_distribution_plot(sample_summary_df: pd.DataFrame,
output_path: Path) -> None:
if plt is None:
return
policy_summary = sample_summary_df[
sample_summary_df["pass_type"] == "policy"].copy()
scenes = sorted(policy_summary["scene"].dropna().unique().tolist())
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
for scene in scenes:
subset = policy_summary[policy_summary["scene"] == scene]
axes[0].hist(subset["oracle_budget_action"].dropna(),
bins=15,
alpha=0.5,
label=scene)
axes[0].set_title("Oracle Budget Action Distribution")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Count")
axes[0].legend()
axes[0].grid(alpha=0.3)
boxplot_data = [
policy_summary[policy_summary["scene"] == scene]["oracle_budget_action"].dropna()
for scene in scenes
]
if scenes and any(len(values) > 0 for values in boxplot_data):
axes[1].boxplot(boxplot_data, labels=scenes, showfliers=False)
axes[1].set_title("Oracle Budget Action by Scene")
axes[1].set_ylabel("Step")
axes[1].tick_params(axis="x", rotation=20)
axes[1].grid(alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def save_midstep_delta_plot(stepwise_df: pd.DataFrame, output_path: Path,
mid_step_start: int, mid_step_end: int) -> None:
if plt is None:
return
subset = stepwise_df[(stepwise_df["step"] >= mid_step_start)
& (stepwise_df["step"] <= mid_step_end)].copy()
steps = sorted(subset["step"].dropna().unique().tolist())
data = [subset[subset["step"] == step]["latent_delta"].dropna() for step in steps]
fig, ax = plt.subplots(figsize=(14, 5))
if steps and any(len(values) > 0 for values in data):
ax.boxplot(data, labels=steps, showfliers=False)
ax.set_title(
f"Latent Delta Distribution for Steps {mid_step_start}-{mid_step_end}")
ax.set_xlabel("Step")
ax.set_ylabel("latent_delta")
ax.grid(alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def save_round_reuse_plot(round_summary_df: pd.DataFrame, output_path: Path) -> None:
if plt is None:
return
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
subset = round_summary_df.dropna(
subset=["latent_init_dist_to_prev_round", "action_drift_vs_prev_round"])
if not subset.empty:
for scene, group in subset.groupby("scene"):
axes[0].scatter(group["latent_init_dist_to_prev_round"],
group["action_drift_vs_prev_round"],
alpha=0.6,
label=scene)
axes[0].legend()
axes[0].set_title("Round Reuse: Latent Init Dist vs Action Drift")
axes[0].set_xlabel("latent_init_dist_to_prev_round")
axes[0].set_ylabel("action_drift_vs_prev_round")
axes[0].grid(alpha=0.3)
scene_groups = []
scene_labels = []
for scene, group in round_summary_df.groupby("scene"):
values = group["round_total_time_s"].dropna()
if len(values) > 0:
scene_groups.append(values)
scene_labels.append(scene)
if scene_groups:
axes[1].boxplot(scene_groups, labels=scene_labels, showfliers=False)
axes[1].set_title("Round Total Time by Scene")
axes[1].set_ylabel("Seconds")
axes[1].tick_params(axis="x", rotation=20)
axes[1].grid(alpha=0.3)
fig.tight_layout()
fig.savefig(output_path, dpi=200, bbox_inches="tight")
plt.close(fig)
def main() -> None:
args = parse_args()
input_dir = Path(args.input_dir)
output_dir = Path(args.output_dir) if args.output_dir else input_dir / "analysis"
output_dir.mkdir(parents=True, exist_ok=True)
stepwise_df, sample_summary_df, round_summary_df = load_tables(input_dir)
stability_summary_df = None
sample_summary_for_analysis = sample_summary_df.copy()
if args.stability_threshold is not None:
stability_summary_df = compute_stability_summary(stepwise_df,
args.stability_threshold,
args.stability_window)
sample_summary_for_analysis = sample_summary_for_analysis.drop(
columns=[
"action_first_stable_step",
"state_first_stable_step",
"latent_first_stable_step",
],
errors="ignore",
).merge(stability_summary_df,
on=["sample_id", "scene", "pass_type"],
how="left")
step_aggregate_df = aggregate_stepwise(stepwise_df)
scene_summary_df = aggregate_scene_summary(sample_summary_for_analysis)
direction_summary = compute_direction_summary(
stepwise_df=stepwise_df,
sample_summary_df=sample_summary_for_analysis,
round_summary_df=round_summary_df,
mid_step_start=args.mid_step_start,
mid_step_end=args.mid_step_end,
stability_threshold=args.stability_threshold,
stability_window=args.stability_window,
)
step_aggregate_df.to_csv(output_dir / "step_aggregate.csv", index=False)
scene_summary_df.to_csv(output_dir / "scene_summary.csv", index=False)
if stability_summary_df is not None:
stability_summary_df.to_csv(output_dir / "stability_summary.csv",
index=False)
save_json(output_dir / "direction_summary.json", direction_summary)
write_markdown_report(output_dir / "analysis_report.md", direction_summary)
if plt is not None:
save_convergence_plot(stepwise_df, output_dir / "convergence_curves.png")
save_budget_distribution_plot(sample_summary_df,
output_dir / "budget_distribution.png")
save_midstep_delta_plot(stepwise_df,
output_dir / "midstep_latent_delta.png",
args.mid_step_start, args.mid_step_end)
save_round_reuse_plot(round_summary_df, output_dir / "round_reuse.png")
else:
print("matplotlib not installed; skipped PNG plots.")
print(f"Analysis written to: {output_dir}")
if __name__ == "__main__":
main()