早停特征验证,早停不通
This commit is contained in:
611
scripts/evaluation/analyze_metrics.py
Normal file
611
scripts/evaluation/analyze_metrics.py
Normal file
@@ -0,0 +1,611 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ModuleNotFoundError:
|
||||
plt = None
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Analyze DDIM convergence logs for the five research directions."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory containing stepwise_log.csv, sample_summary.csv, and round_summary.csv.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to write analysis outputs. Defaults to <input_dir>/analysis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mid_step_start",
|
||||
type=int,
|
||||
default=15,
|
||||
help="Start of the step window used for cross-step latent-delta analysis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mid_step_end",
|
||||
type=int,
|
||||
default=35,
|
||||
help="End of the step window used for cross-step latent-delta analysis.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full50_step_cap",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Upper bound used when summarizing step-based metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stability_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Optional threshold for offline first_stable_step analysis. If unset, stable-step metrics are not computed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stability_window",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Consecutive-step window for offline first_stable_step analysis.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def safe_float(value) -> float:
|
||||
if pd.isna(value):
|
||||
return float("nan")
|
||||
return float(value)
|
||||
|
||||
|
||||
def describe_series(series: pd.Series) -> dict[str, float]:
|
||||
series = pd.to_numeric(series, errors="coerce").dropna()
|
||||
if series.empty:
|
||||
return {
|
||||
"count": 0.0,
|
||||
"mean": float("nan"),
|
||||
"std": float("nan"),
|
||||
"min": float("nan"),
|
||||
"p25": float("nan"),
|
||||
"median": float("nan"),
|
||||
"p75": float("nan"),
|
||||
"p90": float("nan"),
|
||||
"p95": float("nan"),
|
||||
"max": float("nan"),
|
||||
}
|
||||
return {
|
||||
"count": float(series.count()),
|
||||
"mean": float(series.mean()),
|
||||
"std": float(series.std(ddof=0)),
|
||||
"min": float(series.min()),
|
||||
"p25": float(series.quantile(0.25)),
|
||||
"median": float(series.median()),
|
||||
"p75": float(series.quantile(0.75)),
|
||||
"p90": float(series.quantile(0.90)),
|
||||
"p95": float(series.quantile(0.95)),
|
||||
"max": float(series.max()),
|
||||
}
|
||||
|
||||
|
||||
def ensure_columns(frame: pd.DataFrame, required: list[str], frame_name: str) -> None:
|
||||
missing = [column for column in required if column not in frame.columns]
|
||||
if missing:
|
||||
raise ValueError(f"{frame_name} is missing required columns: {missing}")
|
||||
|
||||
|
||||
def first_consecutive_below(values: pd.Series, threshold: float,
|
||||
window: int) -> float:
|
||||
cleaned = pd.to_numeric(values, errors="coerce").tolist()
|
||||
if window <= 0 or len(cleaned) < window:
|
||||
return float("nan")
|
||||
for start in range(len(cleaned) - window + 1):
|
||||
chunk = cleaned[start:start + window]
|
||||
if all(pd.notna(value) and value < threshold for value in chunk):
|
||||
return float(start + 1)
|
||||
return float("nan")
|
||||
|
||||
|
||||
def load_tables(input_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
|
||||
stepwise_path = input_dir / "stepwise_log.csv"
|
||||
sample_summary_path = input_dir / "sample_summary.csv"
|
||||
round_summary_path = input_dir / "round_summary.csv"
|
||||
|
||||
if not stepwise_path.exists():
|
||||
raise FileNotFoundError(f"Missing file: {stepwise_path}")
|
||||
if not sample_summary_path.exists():
|
||||
raise FileNotFoundError(f"Missing file: {sample_summary_path}")
|
||||
if not round_summary_path.exists():
|
||||
raise FileNotFoundError(f"Missing file: {round_summary_path}")
|
||||
|
||||
stepwise_df = pd.read_csv(stepwise_path)
|
||||
sample_summary_df = pd.read_csv(sample_summary_path)
|
||||
round_summary_df = pd.read_csv(round_summary_path)
|
||||
|
||||
ensure_columns(
|
||||
stepwise_df,
|
||||
[
|
||||
"sample_id",
|
||||
"scene",
|
||||
"pass_type",
|
||||
"round_id",
|
||||
"step",
|
||||
"latent_delta",
|
||||
"action_delta",
|
||||
"state_delta",
|
||||
"action_cosine_vs_full50",
|
||||
"state_cosine_vs_full50",
|
||||
"latent_l2_vs_full50",
|
||||
],
|
||||
"stepwise_log.csv",
|
||||
)
|
||||
ensure_columns(
|
||||
sample_summary_df,
|
||||
[
|
||||
"sample_id",
|
||||
"scene",
|
||||
"pass_type",
|
||||
"pass_total_time_s",
|
||||
"action_first_stable_step",
|
||||
"state_first_stable_step",
|
||||
"latent_first_stable_step",
|
||||
"action_vs_full50_90pct_step",
|
||||
"action_vs_full50_95pct_step",
|
||||
"oracle_budget_action",
|
||||
"oracle_budget_state",
|
||||
],
|
||||
"sample_summary.csv",
|
||||
)
|
||||
ensure_columns(
|
||||
round_summary_df,
|
||||
[
|
||||
"sample_id",
|
||||
"scene",
|
||||
"round_id",
|
||||
"policy_pass_total_time_s",
|
||||
"world_model_pass_total_time_s",
|
||||
"round_total_time_s",
|
||||
"latent_init_dist_to_prev_round",
|
||||
"action_drift_vs_prev_round",
|
||||
],
|
||||
"round_summary.csv",
|
||||
)
|
||||
return stepwise_df, sample_summary_df, round_summary_df
|
||||
|
||||
|
||||
def aggregate_stepwise(stepwise_df: pd.DataFrame) -> pd.DataFrame:
|
||||
metric_columns = [
|
||||
"step_time_s",
|
||||
"latent_delta",
|
||||
"action_delta",
|
||||
"state_delta",
|
||||
"action_cosine_vs_full50",
|
||||
"state_cosine_vs_full50",
|
||||
"latent_l2_vs_full50",
|
||||
]
|
||||
stepwise_df = stepwise_df.copy()
|
||||
for column in metric_columns:
|
||||
stepwise_df[column] = pd.to_numeric(stepwise_df[column], errors="coerce")
|
||||
grouped = stepwise_df.groupby(["pass_type", "step"])[metric_columns]
|
||||
return grouped.agg(["mean", "median", "std"]).reset_index()
|
||||
|
||||
|
||||
def aggregate_scene_summary(sample_summary_df: pd.DataFrame) -> pd.DataFrame:
|
||||
metrics = [
|
||||
"pass_total_time_s",
|
||||
"action_first_stable_step",
|
||||
"state_first_stable_step",
|
||||
"latent_first_stable_step",
|
||||
"action_vs_full50_90pct_step",
|
||||
"action_vs_full50_95pct_step",
|
||||
"oracle_budget_action",
|
||||
"oracle_budget_state",
|
||||
"round_total_time_s",
|
||||
"policy_pass_total_time_s",
|
||||
"world_model_pass_total_time_s",
|
||||
"latent_init_dist_to_prev_round",
|
||||
"action_drift_vs_prev_round",
|
||||
]
|
||||
sample_summary_df = sample_summary_df.copy()
|
||||
for column in metrics:
|
||||
sample_summary_df[column] = pd.to_numeric(sample_summary_df[column],
|
||||
errors="coerce")
|
||||
grouped = sample_summary_df.groupby(["scene", "pass_type"])[metrics]
|
||||
return grouped.agg(["mean", "median", "std"]).reset_index()
|
||||
|
||||
|
||||
def compute_stability_summary(stepwise_df: pd.DataFrame, threshold: float,
|
||||
window: int) -> pd.DataFrame:
|
||||
rows = []
|
||||
grouped = stepwise_df.sort_values("step").groupby(
|
||||
["sample_id", "scene", "pass_type"])
|
||||
for (sample_id, scene, pass_type), group in grouped:
|
||||
rows.append({
|
||||
"sample_id": sample_id,
|
||||
"scene": scene,
|
||||
"pass_type": pass_type,
|
||||
"action_first_stable_step": first_consecutive_below(
|
||||
group["action_delta"], threshold, window),
|
||||
"state_first_stable_step": first_consecutive_below(
|
||||
group["state_delta"], threshold, window),
|
||||
"latent_first_stable_step": first_consecutive_below(
|
||||
group["latent_delta"], threshold, window),
|
||||
})
|
||||
return pd.DataFrame(rows)
|
||||
|
||||
|
||||
def compute_direction_summary(stepwise_df: pd.DataFrame,
|
||||
sample_summary_df: pd.DataFrame,
|
||||
round_summary_df: pd.DataFrame,
|
||||
mid_step_start: int,
|
||||
mid_step_end: int,
|
||||
stability_threshold: float | None,
|
||||
stability_window: int) -> dict:
|
||||
policy_summary = sample_summary_df[
|
||||
sample_summary_df["pass_type"] == "policy"].copy()
|
||||
world_summary = sample_summary_df[
|
||||
sample_summary_df["pass_type"] == "world_model"].copy()
|
||||
|
||||
latent_mid = stepwise_df[
|
||||
(stepwise_df["step"] >= mid_step_start)
|
||||
& (stepwise_df["step"] <= mid_step_end)].copy()
|
||||
|
||||
action_before_latent = policy_summary[[
|
||||
"action_first_stable_step",
|
||||
"latent_first_stable_step",
|
||||
"action_vs_full50_95pct_step",
|
||||
]].dropna()
|
||||
if action_before_latent.empty:
|
||||
action_first_beats_latent = float("nan")
|
||||
action_95_beats_latent = float("nan")
|
||||
else:
|
||||
action_first_beats_latent = float(
|
||||
(action_before_latent["action_first_stable_step"] <
|
||||
action_before_latent["latent_first_stable_step"]).mean())
|
||||
action_95_beats_latent = float(
|
||||
(action_before_latent["action_vs_full50_95pct_step"] <
|
||||
action_before_latent["latent_first_stable_step"]).mean())
|
||||
|
||||
direction_summary = {
|
||||
"analysis_config": {
|
||||
"stability_threshold": stability_threshold,
|
||||
"stability_window": int(stability_window),
|
||||
"mid_step_window": [int(mid_step_start), int(mid_step_end)],
|
||||
},
|
||||
"dataset_overview": {
|
||||
"num_step_rows": int(len(stepwise_df)),
|
||||
"num_sample_rows": int(len(sample_summary_df)),
|
||||
"num_round_rows": int(len(round_summary_df)),
|
||||
"num_unique_samples": int(sample_summary_df["sample_id"].nunique()),
|
||||
"scenes": sorted(sample_summary_df["scene"].dropna().unique().tolist()),
|
||||
"pass_types": sorted(sample_summary_df["pass_type"].dropna().unique().tolist()),
|
||||
},
|
||||
"direction_original_early_stop": {
|
||||
"latent_first_stable_step_policy":
|
||||
describe_series(policy_summary["latent_first_stable_step"]),
|
||||
"latent_first_stable_step_world_model":
|
||||
describe_series(world_summary["latent_first_stable_step"]),
|
||||
"latent_l2_vs_full50":
|
||||
describe_series(stepwise_df["latent_l2_vs_full50"]),
|
||||
},
|
||||
"direction_c_action_converges_first": {
|
||||
"action_first_stable_step":
|
||||
describe_series(policy_summary["action_first_stable_step"]),
|
||||
"latent_first_stable_step":
|
||||
describe_series(policy_summary["latent_first_stable_step"]),
|
||||
"action_vs_full50_95pct_step":
|
||||
describe_series(policy_summary["action_vs_full50_95pct_step"]),
|
||||
"share_action_first_stable_before_latent":
|
||||
action_first_beats_latent,
|
||||
"share_action_95pct_before_latent":
|
||||
action_95_beats_latent,
|
||||
},
|
||||
"direction_d_cross_step_similarity": {
|
||||
"latent_delta_mid_steps":
|
||||
describe_series(latent_mid["latent_delta"]),
|
||||
"action_delta_mid_steps":
|
||||
describe_series(latent_mid["action_delta"]),
|
||||
"state_delta_mid_steps":
|
||||
describe_series(latent_mid["state_delta"]),
|
||||
},
|
||||
"direction_a_budget_heterogeneity": {
|
||||
"oracle_budget_action":
|
||||
describe_series(policy_summary["oracle_budget_action"]),
|
||||
"oracle_budget_state":
|
||||
describe_series(world_summary["oracle_budget_state"]),
|
||||
"oracle_budget_action_by_scene": {
|
||||
scene: describe_series(group["oracle_budget_action"])
|
||||
for scene, group in policy_summary.groupby("scene")
|
||||
},
|
||||
},
|
||||
"direction_b_round_reuse": {
|
||||
"latent_init_dist_to_prev_round":
|
||||
describe_series(round_summary_df["latent_init_dist_to_prev_round"]),
|
||||
"action_drift_vs_prev_round":
|
||||
describe_series(round_summary_df["action_drift_vs_prev_round"]),
|
||||
"round_total_time_s":
|
||||
describe_series(round_summary_df["round_total_time_s"]),
|
||||
"policy_pass_total_time_s":
|
||||
describe_series(round_summary_df["policy_pass_total_time_s"]),
|
||||
"world_model_pass_total_time_s":
|
||||
describe_series(round_summary_df["world_model_pass_total_time_s"]),
|
||||
},
|
||||
}
|
||||
return direction_summary
|
||||
|
||||
|
||||
def save_json(path: Path, payload: dict) -> None:
|
||||
with open(path, "w", encoding="utf-8") as file:
|
||||
json.dump(payload, file, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_markdown_report(path: Path, direction_summary: dict) -> None:
|
||||
lines = [
|
||||
"# DDIM Analysis Report",
|
||||
"",
|
||||
"## Analysis Config",
|
||||
"",
|
||||
]
|
||||
config = direction_summary["analysis_config"]
|
||||
lines.extend([
|
||||
f"- Stability threshold: {config['stability_threshold']}",
|
||||
f"- Stability window: {config['stability_window']}",
|
||||
f"- Mid-step window: {config['mid_step_window'][0]}-{config['mid_step_window'][1]}",
|
||||
"",
|
||||
"## Dataset Overview",
|
||||
"",
|
||||
])
|
||||
overview = direction_summary["dataset_overview"]
|
||||
lines.extend([
|
||||
f"- Unique samples: {overview['num_unique_samples']}",
|
||||
f"- Step rows: {overview['num_step_rows']}",
|
||||
f"- Sample rows: {overview['num_sample_rows']}",
|
||||
f"- Round rows: {overview['num_round_rows']}",
|
||||
f"- Scenes: {', '.join(overview['scenes'])}",
|
||||
f"- Pass types: {', '.join(overview['pass_types'])}",
|
||||
"",
|
||||
"## Five Directions",
|
||||
"",
|
||||
])
|
||||
|
||||
def append_stats(title: str, payload: dict) -> None:
|
||||
lines.append(f"### {title}")
|
||||
lines.append("")
|
||||
for key, value in payload.items():
|
||||
if isinstance(value, dict):
|
||||
if {"median", "mean", "p90"} <= set(value.keys()):
|
||||
lines.append(
|
||||
f"- `{key}`: mean={value['mean']:.4f}, median={value['median']:.4f}, p90={value['p90']:.4f}, p95={value['p95']:.4f}"
|
||||
)
|
||||
else:
|
||||
lines.append(f"- `{key}`:")
|
||||
for sub_key, sub_value in value.items():
|
||||
if isinstance(sub_value, dict) and {"median", "mean"} <= set(
|
||||
sub_value.keys()):
|
||||
lines.append(
|
||||
f" - `{sub_key}`: mean={sub_value['mean']:.4f}, median={sub_value['median']:.4f}, std={sub_value['std']:.4f}"
|
||||
)
|
||||
else:
|
||||
lines.append(f" - `{sub_key}`: {sub_value}")
|
||||
else:
|
||||
if isinstance(value, float):
|
||||
lines.append(f"- `{key}`: {value:.4f}")
|
||||
else:
|
||||
lines.append(f"- `{key}`: {value}")
|
||||
lines.append("")
|
||||
|
||||
append_stats("Original Early Stop",
|
||||
direction_summary["direction_original_early_stop"])
|
||||
append_stats("Direction C: Action Converges First",
|
||||
direction_summary["direction_c_action_converges_first"])
|
||||
append_stats("Direction D: Cross-Step Similarity",
|
||||
direction_summary["direction_d_cross_step_similarity"])
|
||||
append_stats("Direction A: Budget Heterogeneity",
|
||||
direction_summary["direction_a_budget_heterogeneity"])
|
||||
append_stats("Direction B: Round Reuse",
|
||||
direction_summary["direction_b_round_reuse"])
|
||||
|
||||
with open(path, "w", encoding="utf-8") as file:
|
||||
file.write("\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def save_convergence_plot(stepwise_df: pd.DataFrame, output_path: Path) -> None:
|
||||
if plt is None:
|
||||
return
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
|
||||
pass_types = sorted(stepwise_df["pass_type"].dropna().unique().tolist())
|
||||
for pass_type in pass_types:
|
||||
subset = stepwise_df[stepwise_df["pass_type"] == pass_type]
|
||||
grouped = subset.groupby("step").mean(numeric_only=True).reset_index()
|
||||
axes[0, 0].plot(grouped["step"],
|
||||
grouped["latent_delta"],
|
||||
label=pass_type)
|
||||
axes[0, 1].plot(grouped["step"],
|
||||
grouped["action_cosine_vs_full50"],
|
||||
label=pass_type)
|
||||
axes[1, 0].plot(grouped["step"],
|
||||
grouped["state_cosine_vs_full50"],
|
||||
label=pass_type)
|
||||
axes[1, 1].plot(grouped["step"],
|
||||
grouped["latent_l2_vs_full50"],
|
||||
label=pass_type)
|
||||
|
||||
axes[0, 0].set_title("Mean Latent Delta")
|
||||
axes[0, 1].set_title("Mean Action Cosine vs Full50")
|
||||
axes[1, 0].set_title("Mean State Cosine vs Full50")
|
||||
axes[1, 1].set_title("Mean Latent L2 vs Full50")
|
||||
for axis in axes.flatten():
|
||||
axis.set_xlabel("Step")
|
||||
axis.grid(alpha=0.3)
|
||||
axis.legend()
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_budget_distribution_plot(sample_summary_df: pd.DataFrame,
|
||||
output_path: Path) -> None:
|
||||
if plt is None:
|
||||
return
|
||||
policy_summary = sample_summary_df[
|
||||
sample_summary_df["pass_type"] == "policy"].copy()
|
||||
scenes = sorted(policy_summary["scene"].dropna().unique().tolist())
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
for scene in scenes:
|
||||
subset = policy_summary[policy_summary["scene"] == scene]
|
||||
axes[0].hist(subset["oracle_budget_action"].dropna(),
|
||||
bins=15,
|
||||
alpha=0.5,
|
||||
label=scene)
|
||||
axes[0].set_title("Oracle Budget Action Distribution")
|
||||
axes[0].set_xlabel("Step")
|
||||
axes[0].set_ylabel("Count")
|
||||
axes[0].legend()
|
||||
axes[0].grid(alpha=0.3)
|
||||
|
||||
boxplot_data = [
|
||||
policy_summary[policy_summary["scene"] == scene]["oracle_budget_action"].dropna()
|
||||
for scene in scenes
|
||||
]
|
||||
if scenes and any(len(values) > 0 for values in boxplot_data):
|
||||
axes[1].boxplot(boxplot_data, labels=scenes, showfliers=False)
|
||||
axes[1].set_title("Oracle Budget Action by Scene")
|
||||
axes[1].set_ylabel("Step")
|
||||
axes[1].tick_params(axis="x", rotation=20)
|
||||
axes[1].grid(alpha=0.3)
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_midstep_delta_plot(stepwise_df: pd.DataFrame, output_path: Path,
|
||||
mid_step_start: int, mid_step_end: int) -> None:
|
||||
if plt is None:
|
||||
return
|
||||
subset = stepwise_df[(stepwise_df["step"] >= mid_step_start)
|
||||
& (stepwise_df["step"] <= mid_step_end)].copy()
|
||||
steps = sorted(subset["step"].dropna().unique().tolist())
|
||||
data = [subset[subset["step"] == step]["latent_delta"].dropna() for step in steps]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(14, 5))
|
||||
if steps and any(len(values) > 0 for values in data):
|
||||
ax.boxplot(data, labels=steps, showfliers=False)
|
||||
ax.set_title(
|
||||
f"Latent Delta Distribution for Steps {mid_step_start}-{mid_step_end}")
|
||||
ax.set_xlabel("Step")
|
||||
ax.set_ylabel("latent_delta")
|
||||
ax.grid(alpha=0.3)
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def save_round_reuse_plot(round_summary_df: pd.DataFrame, output_path: Path) -> None:
|
||||
if plt is None:
|
||||
return
|
||||
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
||||
|
||||
subset = round_summary_df.dropna(
|
||||
subset=["latent_init_dist_to_prev_round", "action_drift_vs_prev_round"])
|
||||
if not subset.empty:
|
||||
for scene, group in subset.groupby("scene"):
|
||||
axes[0].scatter(group["latent_init_dist_to_prev_round"],
|
||||
group["action_drift_vs_prev_round"],
|
||||
alpha=0.6,
|
||||
label=scene)
|
||||
axes[0].legend()
|
||||
axes[0].set_title("Round Reuse: Latent Init Dist vs Action Drift")
|
||||
axes[0].set_xlabel("latent_init_dist_to_prev_round")
|
||||
axes[0].set_ylabel("action_drift_vs_prev_round")
|
||||
axes[0].grid(alpha=0.3)
|
||||
|
||||
scene_groups = []
|
||||
scene_labels = []
|
||||
for scene, group in round_summary_df.groupby("scene"):
|
||||
values = group["round_total_time_s"].dropna()
|
||||
if len(values) > 0:
|
||||
scene_groups.append(values)
|
||||
scene_labels.append(scene)
|
||||
if scene_groups:
|
||||
axes[1].boxplot(scene_groups, labels=scene_labels, showfliers=False)
|
||||
axes[1].set_title("Round Total Time by Scene")
|
||||
axes[1].set_ylabel("Seconds")
|
||||
axes[1].tick_params(axis="x", rotation=20)
|
||||
axes[1].grid(alpha=0.3)
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(output_path, dpi=200, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
input_dir = Path(args.input_dir)
|
||||
output_dir = Path(args.output_dir) if args.output_dir else input_dir / "analysis"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
stepwise_df, sample_summary_df, round_summary_df = load_tables(input_dir)
|
||||
|
||||
stability_summary_df = None
|
||||
sample_summary_for_analysis = sample_summary_df.copy()
|
||||
if args.stability_threshold is not None:
|
||||
stability_summary_df = compute_stability_summary(stepwise_df,
|
||||
args.stability_threshold,
|
||||
args.stability_window)
|
||||
sample_summary_for_analysis = sample_summary_for_analysis.drop(
|
||||
columns=[
|
||||
"action_first_stable_step",
|
||||
"state_first_stable_step",
|
||||
"latent_first_stable_step",
|
||||
],
|
||||
errors="ignore",
|
||||
).merge(stability_summary_df,
|
||||
on=["sample_id", "scene", "pass_type"],
|
||||
how="left")
|
||||
|
||||
step_aggregate_df = aggregate_stepwise(stepwise_df)
|
||||
scene_summary_df = aggregate_scene_summary(sample_summary_for_analysis)
|
||||
direction_summary = compute_direction_summary(
|
||||
stepwise_df=stepwise_df,
|
||||
sample_summary_df=sample_summary_for_analysis,
|
||||
round_summary_df=round_summary_df,
|
||||
mid_step_start=args.mid_step_start,
|
||||
mid_step_end=args.mid_step_end,
|
||||
stability_threshold=args.stability_threshold,
|
||||
stability_window=args.stability_window,
|
||||
)
|
||||
|
||||
step_aggregate_df.to_csv(output_dir / "step_aggregate.csv", index=False)
|
||||
scene_summary_df.to_csv(output_dir / "scene_summary.csv", index=False)
|
||||
if stability_summary_df is not None:
|
||||
stability_summary_df.to_csv(output_dir / "stability_summary.csv",
|
||||
index=False)
|
||||
save_json(output_dir / "direction_summary.json", direction_summary)
|
||||
write_markdown_report(output_dir / "analysis_report.md", direction_summary)
|
||||
|
||||
if plt is not None:
|
||||
save_convergence_plot(stepwise_df, output_dir / "convergence_curves.png")
|
||||
save_budget_distribution_plot(sample_summary_df,
|
||||
output_dir / "budget_distribution.png")
|
||||
save_midstep_delta_plot(stepwise_df,
|
||||
output_dir / "midstep_latent_delta.png",
|
||||
args.mid_step_start, args.mid_step_end)
|
||||
save_round_reuse_plot(round_summary_df, output_dir / "round_reuse.png")
|
||||
else:
|
||||
print("matplotlib not installed; skipped PNG plots.")
|
||||
|
||||
print(f"Analysis written to: {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user