import argparse import json import os from pathlib import Path import numpy as np import pandas as pd try: import matplotlib.pyplot as plt except ModuleNotFoundError: plt = None def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Analyze DDIM convergence logs for the five research directions." ) parser.add_argument( "--input_dir", type=str, required=True, help="Directory containing stepwise_log.csv, sample_summary.csv, and round_summary.csv.", ) parser.add_argument( "--output_dir", type=str, default=None, help="Directory to write analysis outputs. Defaults to /analysis.", ) parser.add_argument( "--mid_step_start", type=int, default=15, help="Start of the step window used for cross-step latent-delta analysis.", ) parser.add_argument( "--mid_step_end", type=int, default=35, help="End of the step window used for cross-step latent-delta analysis.", ) parser.add_argument( "--full50_step_cap", type=int, default=50, help="Upper bound used when summarizing step-based metrics.", ) parser.add_argument( "--stability_threshold", type=float, default=None, help="Optional threshold for offline first_stable_step analysis. If unset, stable-step metrics are not computed.", ) parser.add_argument( "--stability_window", type=int, default=3, help="Consecutive-step window for offline first_stable_step analysis.", ) return parser.parse_args() def safe_float(value) -> float: if pd.isna(value): return float("nan") return float(value) def describe_series(series: pd.Series) -> dict[str, float]: series = pd.to_numeric(series, errors="coerce").dropna() if series.empty: return { "count": 0.0, "mean": float("nan"), "std": float("nan"), "min": float("nan"), "p25": float("nan"), "median": float("nan"), "p75": float("nan"), "p90": float("nan"), "p95": float("nan"), "max": float("nan"), } return { "count": float(series.count()), "mean": float(series.mean()), "std": float(series.std(ddof=0)), "min": float(series.min()), "p25": float(series.quantile(0.25)), "median": float(series.median()), "p75": float(series.quantile(0.75)), "p90": float(series.quantile(0.90)), "p95": float(series.quantile(0.95)), "max": float(series.max()), } def ensure_columns(frame: pd.DataFrame, required: list[str], frame_name: str) -> None: missing = [column for column in required if column not in frame.columns] if missing: raise ValueError(f"{frame_name} is missing required columns: {missing}") def first_consecutive_below(values: pd.Series, threshold: float, window: int) -> float: cleaned = pd.to_numeric(values, errors="coerce").tolist() if window <= 0 or len(cleaned) < window: return float("nan") for start in range(len(cleaned) - window + 1): chunk = cleaned[start:start + window] if all(pd.notna(value) and value < threshold for value in chunk): return float(start + 1) return float("nan") def load_tables(input_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: stepwise_path = input_dir / "stepwise_log.csv" sample_summary_path = input_dir / "sample_summary.csv" round_summary_path = input_dir / "round_summary.csv" if not stepwise_path.exists(): raise FileNotFoundError(f"Missing file: {stepwise_path}") if not sample_summary_path.exists(): raise FileNotFoundError(f"Missing file: {sample_summary_path}") if not round_summary_path.exists(): raise FileNotFoundError(f"Missing file: {round_summary_path}") stepwise_df = pd.read_csv(stepwise_path) sample_summary_df = pd.read_csv(sample_summary_path) round_summary_df = pd.read_csv(round_summary_path) ensure_columns( stepwise_df, [ "sample_id", "scene", "pass_type", "round_id", "step", "latent_delta", "action_delta", "state_delta", "action_cosine_vs_full50", "state_cosine_vs_full50", "latent_l2_vs_full50", ], "stepwise_log.csv", ) ensure_columns( sample_summary_df, [ "sample_id", "scene", "pass_type", "pass_total_time_s", "action_first_stable_step", "state_first_stable_step", "latent_first_stable_step", "action_vs_full50_90pct_step", "action_vs_full50_95pct_step", "oracle_budget_action", "oracle_budget_state", ], "sample_summary.csv", ) ensure_columns( round_summary_df, [ "sample_id", "scene", "round_id", "policy_pass_total_time_s", "world_model_pass_total_time_s", "round_total_time_s", "latent_init_dist_to_prev_round", "action_drift_vs_prev_round", ], "round_summary.csv", ) return stepwise_df, sample_summary_df, round_summary_df def aggregate_stepwise(stepwise_df: pd.DataFrame) -> pd.DataFrame: metric_columns = [ "step_time_s", "latent_delta", "action_delta", "state_delta", "action_cosine_vs_full50", "state_cosine_vs_full50", "latent_l2_vs_full50", ] stepwise_df = stepwise_df.copy() for column in metric_columns: stepwise_df[column] = pd.to_numeric(stepwise_df[column], errors="coerce") grouped = stepwise_df.groupby(["pass_type", "step"])[metric_columns] return grouped.agg(["mean", "median", "std"]).reset_index() def aggregate_scene_summary(sample_summary_df: pd.DataFrame) -> pd.DataFrame: metrics = [ "pass_total_time_s", "action_first_stable_step", "state_first_stable_step", "latent_first_stable_step", "action_vs_full50_90pct_step", "action_vs_full50_95pct_step", "oracle_budget_action", "oracle_budget_state", "round_total_time_s", "policy_pass_total_time_s", "world_model_pass_total_time_s", "latent_init_dist_to_prev_round", "action_drift_vs_prev_round", ] sample_summary_df = sample_summary_df.copy() for column in metrics: sample_summary_df[column] = pd.to_numeric(sample_summary_df[column], errors="coerce") grouped = sample_summary_df.groupby(["scene", "pass_type"])[metrics] return grouped.agg(["mean", "median", "std"]).reset_index() def compute_stability_summary(stepwise_df: pd.DataFrame, threshold: float, window: int) -> pd.DataFrame: rows = [] grouped = stepwise_df.sort_values("step").groupby( ["sample_id", "scene", "pass_type"]) for (sample_id, scene, pass_type), group in grouped: rows.append({ "sample_id": sample_id, "scene": scene, "pass_type": pass_type, "action_first_stable_step": first_consecutive_below( group["action_delta"], threshold, window), "state_first_stable_step": first_consecutive_below( group["state_delta"], threshold, window), "latent_first_stable_step": first_consecutive_below( group["latent_delta"], threshold, window), }) return pd.DataFrame(rows) def compute_direction_summary(stepwise_df: pd.DataFrame, sample_summary_df: pd.DataFrame, round_summary_df: pd.DataFrame, mid_step_start: int, mid_step_end: int, stability_threshold: float | None, stability_window: int) -> dict: policy_summary = sample_summary_df[ sample_summary_df["pass_type"] == "policy"].copy() world_summary = sample_summary_df[ sample_summary_df["pass_type"] == "world_model"].copy() latent_mid = stepwise_df[ (stepwise_df["step"] >= mid_step_start) & (stepwise_df["step"] <= mid_step_end)].copy() action_before_latent = policy_summary[[ "action_first_stable_step", "latent_first_stable_step", "action_vs_full50_95pct_step", ]].dropna() if action_before_latent.empty: action_first_beats_latent = float("nan") action_95_beats_latent = float("nan") else: action_first_beats_latent = float( (action_before_latent["action_first_stable_step"] < action_before_latent["latent_first_stable_step"]).mean()) action_95_beats_latent = float( (action_before_latent["action_vs_full50_95pct_step"] < action_before_latent["latent_first_stable_step"]).mean()) direction_summary = { "analysis_config": { "stability_threshold": stability_threshold, "stability_window": int(stability_window), "mid_step_window": [int(mid_step_start), int(mid_step_end)], }, "dataset_overview": { "num_step_rows": int(len(stepwise_df)), "num_sample_rows": int(len(sample_summary_df)), "num_round_rows": int(len(round_summary_df)), "num_unique_samples": int(sample_summary_df["sample_id"].nunique()), "scenes": sorted(sample_summary_df["scene"].dropna().unique().tolist()), "pass_types": sorted(sample_summary_df["pass_type"].dropna().unique().tolist()), }, "direction_original_early_stop": { "latent_first_stable_step_policy": describe_series(policy_summary["latent_first_stable_step"]), "latent_first_stable_step_world_model": describe_series(world_summary["latent_first_stable_step"]), "latent_l2_vs_full50": describe_series(stepwise_df["latent_l2_vs_full50"]), }, "direction_c_action_converges_first": { "action_first_stable_step": describe_series(policy_summary["action_first_stable_step"]), "latent_first_stable_step": describe_series(policy_summary["latent_first_stable_step"]), "action_vs_full50_95pct_step": describe_series(policy_summary["action_vs_full50_95pct_step"]), "share_action_first_stable_before_latent": action_first_beats_latent, "share_action_95pct_before_latent": action_95_beats_latent, }, "direction_d_cross_step_similarity": { "latent_delta_mid_steps": describe_series(latent_mid["latent_delta"]), "action_delta_mid_steps": describe_series(latent_mid["action_delta"]), "state_delta_mid_steps": describe_series(latent_mid["state_delta"]), }, "direction_a_budget_heterogeneity": { "oracle_budget_action": describe_series(policy_summary["oracle_budget_action"]), "oracle_budget_state": describe_series(world_summary["oracle_budget_state"]), "oracle_budget_action_by_scene": { scene: describe_series(group["oracle_budget_action"]) for scene, group in policy_summary.groupby("scene") }, }, "direction_b_round_reuse": { "latent_init_dist_to_prev_round": describe_series(round_summary_df["latent_init_dist_to_prev_round"]), "action_drift_vs_prev_round": describe_series(round_summary_df["action_drift_vs_prev_round"]), "round_total_time_s": describe_series(round_summary_df["round_total_time_s"]), "policy_pass_total_time_s": describe_series(round_summary_df["policy_pass_total_time_s"]), "world_model_pass_total_time_s": describe_series(round_summary_df["world_model_pass_total_time_s"]), }, } return direction_summary def save_json(path: Path, payload: dict) -> None: with open(path, "w", encoding="utf-8") as file: json.dump(payload, file, indent=2, ensure_ascii=False) def write_markdown_report(path: Path, direction_summary: dict) -> None: lines = [ "# DDIM Analysis Report", "", "## Analysis Config", "", ] config = direction_summary["analysis_config"] lines.extend([ f"- Stability threshold: {config['stability_threshold']}", f"- Stability window: {config['stability_window']}", f"- Mid-step window: {config['mid_step_window'][0]}-{config['mid_step_window'][1]}", "", "## Dataset Overview", "", ]) overview = direction_summary["dataset_overview"] lines.extend([ f"- Unique samples: {overview['num_unique_samples']}", f"- Step rows: {overview['num_step_rows']}", f"- Sample rows: {overview['num_sample_rows']}", f"- Round rows: {overview['num_round_rows']}", f"- Scenes: {', '.join(overview['scenes'])}", f"- Pass types: {', '.join(overview['pass_types'])}", "", "## Five Directions", "", ]) def append_stats(title: str, payload: dict) -> None: lines.append(f"### {title}") lines.append("") for key, value in payload.items(): if isinstance(value, dict): if {"median", "mean", "p90"} <= set(value.keys()): lines.append( f"- `{key}`: mean={value['mean']:.4f}, median={value['median']:.4f}, p90={value['p90']:.4f}, p95={value['p95']:.4f}" ) else: lines.append(f"- `{key}`:") for sub_key, sub_value in value.items(): if isinstance(sub_value, dict) and {"median", "mean"} <= set( sub_value.keys()): lines.append( f" - `{sub_key}`: mean={sub_value['mean']:.4f}, median={sub_value['median']:.4f}, std={sub_value['std']:.4f}" ) else: lines.append(f" - `{sub_key}`: {sub_value}") else: if isinstance(value, float): lines.append(f"- `{key}`: {value:.4f}") else: lines.append(f"- `{key}`: {value}") lines.append("") append_stats("Original Early Stop", direction_summary["direction_original_early_stop"]) append_stats("Direction C: Action Converges First", direction_summary["direction_c_action_converges_first"]) append_stats("Direction D: Cross-Step Similarity", direction_summary["direction_d_cross_step_similarity"]) append_stats("Direction A: Budget Heterogeneity", direction_summary["direction_a_budget_heterogeneity"]) append_stats("Direction B: Round Reuse", direction_summary["direction_b_round_reuse"]) with open(path, "w", encoding="utf-8") as file: file.write("\n".join(lines) + "\n") def save_convergence_plot(stepwise_df: pd.DataFrame, output_path: Path) -> None: if plt is None: return fig, axes = plt.subplots(2, 2, figsize=(14, 10)) pass_types = sorted(stepwise_df["pass_type"].dropna().unique().tolist()) for pass_type in pass_types: subset = stepwise_df[stepwise_df["pass_type"] == pass_type] grouped = subset.groupby("step").mean(numeric_only=True).reset_index() axes[0, 0].plot(grouped["step"], grouped["latent_delta"], label=pass_type) axes[0, 1].plot(grouped["step"], grouped["action_cosine_vs_full50"], label=pass_type) axes[1, 0].plot(grouped["step"], grouped["state_cosine_vs_full50"], label=pass_type) axes[1, 1].plot(grouped["step"], grouped["latent_l2_vs_full50"], label=pass_type) axes[0, 0].set_title("Mean Latent Delta") axes[0, 1].set_title("Mean Action Cosine vs Full50") axes[1, 0].set_title("Mean State Cosine vs Full50") axes[1, 1].set_title("Mean Latent L2 vs Full50") for axis in axes.flatten(): axis.set_xlabel("Step") axis.grid(alpha=0.3) axis.legend() fig.tight_layout() fig.savefig(output_path, dpi=200, bbox_inches="tight") plt.close(fig) def save_budget_distribution_plot(sample_summary_df: pd.DataFrame, output_path: Path) -> None: if plt is None: return policy_summary = sample_summary_df[ sample_summary_df["pass_type"] == "policy"].copy() scenes = sorted(policy_summary["scene"].dropna().unique().tolist()) fig, axes = plt.subplots(1, 2, figsize=(14, 5)) for scene in scenes: subset = policy_summary[policy_summary["scene"] == scene] axes[0].hist(subset["oracle_budget_action"].dropna(), bins=15, alpha=0.5, label=scene) axes[0].set_title("Oracle Budget Action Distribution") axes[0].set_xlabel("Step") axes[0].set_ylabel("Count") axes[0].legend() axes[0].grid(alpha=0.3) boxplot_data = [ policy_summary[policy_summary["scene"] == scene]["oracle_budget_action"].dropna() for scene in scenes ] if scenes and any(len(values) > 0 for values in boxplot_data): axes[1].boxplot(boxplot_data, labels=scenes, showfliers=False) axes[1].set_title("Oracle Budget Action by Scene") axes[1].set_ylabel("Step") axes[1].tick_params(axis="x", rotation=20) axes[1].grid(alpha=0.3) fig.tight_layout() fig.savefig(output_path, dpi=200, bbox_inches="tight") plt.close(fig) def save_midstep_delta_plot(stepwise_df: pd.DataFrame, output_path: Path, mid_step_start: int, mid_step_end: int) -> None: if plt is None: return subset = stepwise_df[(stepwise_df["step"] >= mid_step_start) & (stepwise_df["step"] <= mid_step_end)].copy() steps = sorted(subset["step"].dropna().unique().tolist()) data = [subset[subset["step"] == step]["latent_delta"].dropna() for step in steps] fig, ax = plt.subplots(figsize=(14, 5)) if steps and any(len(values) > 0 for values in data): ax.boxplot(data, labels=steps, showfliers=False) ax.set_title( f"Latent Delta Distribution for Steps {mid_step_start}-{mid_step_end}") ax.set_xlabel("Step") ax.set_ylabel("latent_delta") ax.grid(alpha=0.3) fig.tight_layout() fig.savefig(output_path, dpi=200, bbox_inches="tight") plt.close(fig) def save_round_reuse_plot(round_summary_df: pd.DataFrame, output_path: Path) -> None: if plt is None: return fig, axes = plt.subplots(1, 2, figsize=(14, 5)) subset = round_summary_df.dropna( subset=["latent_init_dist_to_prev_round", "action_drift_vs_prev_round"]) if not subset.empty: for scene, group in subset.groupby("scene"): axes[0].scatter(group["latent_init_dist_to_prev_round"], group["action_drift_vs_prev_round"], alpha=0.6, label=scene) axes[0].legend() axes[0].set_title("Round Reuse: Latent Init Dist vs Action Drift") axes[0].set_xlabel("latent_init_dist_to_prev_round") axes[0].set_ylabel("action_drift_vs_prev_round") axes[0].grid(alpha=0.3) scene_groups = [] scene_labels = [] for scene, group in round_summary_df.groupby("scene"): values = group["round_total_time_s"].dropna() if len(values) > 0: scene_groups.append(values) scene_labels.append(scene) if scene_groups: axes[1].boxplot(scene_groups, labels=scene_labels, showfliers=False) axes[1].set_title("Round Total Time by Scene") axes[1].set_ylabel("Seconds") axes[1].tick_params(axis="x", rotation=20) axes[1].grid(alpha=0.3) fig.tight_layout() fig.savefig(output_path, dpi=200, bbox_inches="tight") plt.close(fig) def main() -> None: args = parse_args() input_dir = Path(args.input_dir) output_dir = Path(args.output_dir) if args.output_dir else input_dir / "analysis" output_dir.mkdir(parents=True, exist_ok=True) stepwise_df, sample_summary_df, round_summary_df = load_tables(input_dir) stability_summary_df = None sample_summary_for_analysis = sample_summary_df.copy() if args.stability_threshold is not None: stability_summary_df = compute_stability_summary(stepwise_df, args.stability_threshold, args.stability_window) sample_summary_for_analysis = sample_summary_for_analysis.drop( columns=[ "action_first_stable_step", "state_first_stable_step", "latent_first_stable_step", ], errors="ignore", ).merge(stability_summary_df, on=["sample_id", "scene", "pass_type"], how="left") step_aggregate_df = aggregate_stepwise(stepwise_df) scene_summary_df = aggregate_scene_summary(sample_summary_for_analysis) direction_summary = compute_direction_summary( stepwise_df=stepwise_df, sample_summary_df=sample_summary_for_analysis, round_summary_df=round_summary_df, mid_step_start=args.mid_step_start, mid_step_end=args.mid_step_end, stability_threshold=args.stability_threshold, stability_window=args.stability_window, ) step_aggregate_df.to_csv(output_dir / "step_aggregate.csv", index=False) scene_summary_df.to_csv(output_dir / "scene_summary.csv", index=False) if stability_summary_df is not None: stability_summary_df.to_csv(output_dir / "stability_summary.csv", index=False) save_json(output_dir / "direction_summary.json", direction_summary) write_markdown_report(output_dir / "analysis_report.md", direction_summary) if plt is not None: save_convergence_plot(stepwise_df, output_dir / "convergence_curves.png") save_budget_distribution_plot(sample_summary_df, output_dir / "budget_distribution.png") save_midstep_delta_plot(stepwise_df, output_dir / "midstep_latent_delta.png", args.mid_step_start, args.mid_step_end) save_round_reuse_plot(round_summary_df, output_dir / "round_reuse.png") else: print("matplotlib not installed; skipped PNG plots.") print(f"Analysis written to: {output_dir}") if __name__ == "__main__": main()