video_backbone剖析

This commit is contained in:
qhy
2026-03-16 10:30:54 +08:00
parent 7e45eba18b
commit 8ca159d375
282 changed files with 174952 additions and 1350 deletions

View File

@@ -0,0 +1,148 @@
import argparse
from pathlib import Path
import numpy as np
import pandas as pd
MID_STEP_MIN = 15
MID_STEP_MAX = 35
TAIL_STEP_MIN = 40
TAIL_STEP_MAX = 49
def describe(series: pd.Series) -> dict[str, float]:
numeric = pd.to_numeric(series, errors='coerce').dropna()
if numeric.empty:
return {'count': 0.0, 'mean': np.nan, 'median': np.nan, 'p90': np.nan}
return {
'count': float(numeric.count()),
'mean': float(numeric.mean()),
'median': float(numeric.median()),
'p90': float(numeric.quantile(0.90)),
}
def make_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Aggregate backbone block profiling results.")
parser.add_argument("--input_csv",
type=str,
required=True,
help="Path to backbone_block_log.csv")
parser.add_argument("--output_dir",
type=str,
default=None,
help="Directory to store summaries; defaults to input parent.")
return parser
def main() -> None:
args = make_parser().parse_args()
input_path = Path(args.input_csv)
if not input_path.exists():
raise FileNotFoundError(f"Missing file: {input_path}")
output_dir = Path(args.output_dir) if args.output_dir else input_path.parent / "analysis"
output_dir.mkdir(parents=True, exist_ok=True)
df = pd.read_csv(input_path)
numeric_cols = [
'forward_time_ms',
'l2_delta_vs_prev',
'rel_l2_delta_vs_prev',
'cosine_vs_prev',
'l2_delta_vs_full50',
'cosine_vs_full50',
]
for col in numeric_cols:
df[col] = pd.to_numeric(df[col], errors='coerce')
step_aggregate = df.groupby(
['block_stage', 'block_name', 'block_index', 'step'])[numeric_cols].mean().reset_index()
step_aggregate.to_csv(output_dir / 'block_step_aggregate.csv', index=False)
block_summary = df.groupby(['block_stage', 'block_name', 'block_index']).agg(
mean_forward_time_ms=('forward_time_ms', 'mean'),
mean_rel_l2_delta_vs_prev=('rel_l2_delta_vs_prev', 'mean'),
mean_cosine_vs_prev=('cosine_vs_prev', 'mean'),
mean_l2_delta_vs_full50=('l2_delta_vs_full50', 'mean'),
mean_cosine_vs_full50=('cosine_vs_full50', 'mean'),
).reset_index()
mid_df = df[(df['step'] >= MID_STEP_MIN) & (df['step'] <= MID_STEP_MAX)]
tail_df = df[(df['step'] >= TAIL_STEP_MIN) & (df['step'] <= TAIL_STEP_MAX)]
block_summary = block_summary.merge(
mid_df.groupby(['block_stage', 'block_name', 'block_index']).agg(
mid_mean_rel_l2_delta_vs_prev=('rel_l2_delta_vs_prev', 'mean'),
mid_mean_cosine_vs_prev=('cosine_vs_prev', 'mean'),
mid_mean_forward_time_ms=('forward_time_ms', 'mean'),
).reset_index(),
on=['block_stage', 'block_name', 'block_index'],
how='left',
)
block_summary = block_summary.merge(
tail_df.groupby(['block_stage', 'block_name', 'block_index']).agg(
tail_mean_rel_l2_delta_vs_prev=('rel_l2_delta_vs_prev', 'mean'),
tail_mean_cosine_vs_prev=('cosine_vs_prev', 'mean'),
tail_mean_forward_time_ms=('forward_time_ms', 'mean'),
).reset_index(),
on=['block_stage', 'block_name', 'block_index'],
how='left',
)
block_summary = block_summary.sort_values(
['tail_mean_rel_l2_delta_vs_prev', 'mean_forward_time_ms'],
ascending=[True, False])
block_summary.to_csv(output_dir / 'block_summary.csv', index=False)
stage_summary = []
for stage_name, stage_df in [('all', df), ('mid', mid_df), ('tail', tail_df)]:
grouped = stage_df.groupby('block_stage').agg(
mean_forward_time_ms=('forward_time_ms', 'mean'),
mean_rel_l2_delta_vs_prev=('rel_l2_delta_vs_prev', 'mean'),
mean_cosine_vs_prev=('cosine_vs_prev', 'mean'),
mean_l2_delta_vs_full50=('l2_delta_vs_full50', 'mean'),
mean_cosine_vs_full50=('cosine_vs_full50', 'mean'),
).reset_index()
grouped.insert(0, 'window', stage_name)
stage_summary.append(grouped)
stage_summary_df = pd.concat(stage_summary, ignore_index=True)
stage_summary_df.to_csv(output_dir / 'stage_summary.csv', index=False)
best_cache_candidates = block_summary.sort_values(
['tail_mean_rel_l2_delta_vs_prev', 'mean_forward_time_ms'],
ascending=[True, False]).head(10)
lines = [
"# Backbone Block Profiling Report",
"",
"## Dataset Overview",
"",
f"- Rows: {len(df)}",
f"- Blocks: {df['block_name'].nunique()}",
f"- Steps: {int(df['step'].min())}-{int(df['step'].max())}",
f"- Pass types: {', '.join(sorted(df['pass_type'].dropna().unique()))}",
"",
"## Block Timing",
"",
f"- `forward_time_ms`: mean={describe(df['forward_time_ms'])['mean']:.4f}, median={describe(df['forward_time_ms'])['median']:.4f}, p90={describe(df['forward_time_ms'])['p90']:.4f}",
"",
"## Stability",
"",
f"- `rel_l2_delta_vs_prev`: mean={describe(df['rel_l2_delta_vs_prev'])['mean']:.6f}, median={describe(df['rel_l2_delta_vs_prev'])['median']:.6f}, p90={describe(df['rel_l2_delta_vs_prev'])['p90']:.6f}",
f"- `cosine_vs_prev`: mean={describe(df['cosine_vs_prev'])['mean']:.6f}, median={describe(df['cosine_vs_prev'])['median']:.6f}, p90={describe(df['cosine_vs_prev'])['p90']:.6f}",
"",
"## Top Cache Candidates",
"",
]
for _, row in best_cache_candidates.iterrows():
lines.append(
f"- `{row['block_name']}` ({row['block_stage']}): tail_rel_l2={row['tail_mean_rel_l2_delta_vs_prev']:.6f}, mean_forward_time_ms={row['mean_forward_time_ms']:.4f}"
)
(output_dir / 'backbone_profile_report.md').write_text("\n".join(lines),
encoding='utf-8')
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,380 @@
import argparse
import json
from pathlib import Path
import numpy as np
import pandas as pd
DEFAULT_SCHEMES = {
"dense": list(range(50)),
"sparse_10": list(range(0, 50, 5)),
"sparse_5": list(range(0, 50, 10)),
"tail_heavy_10": [0, 5, 10, 20, 30, 38, 43, 46, 48, 49],
"tail_only_6": [40, 43, 46, 47, 48, 49],
"sparse_8": [0, 7, 14, 21, 28, 35, 42, 49],
"sparse_4": [0, 16, 32, 49],
"tail_only_4": [40, 43, 46, 49],
"tail_heavy_6": [0, 32, 40, 44, 47, 49],
"tail_heavy_8": [0, 10, 20, 30, 38, 43, 46, 49],
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Offline sparse head execution simulation from existing DDIM logs."
)
parser.add_argument(
"--root_dir",
type=str,
default=".",
help="Repository root or directory under which case outputs are stored.",
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory to write sparse-head simulation outputs. Defaults to <root_dir>/sparse_head_simulation.",
)
parser.add_argument(
"--target_threshold",
type=float,
default=0.95,
help="Target cosine threshold used by experiment 1.2.",
)
return parser.parse_args()
def discover_stepwise_logs(root_dir: Path) -> list[Path]:
return sorted(root_dir.glob("unitree_*/case*/output/inference/stepwise_log.csv"))
def load_stepwise_table(stepwise_paths: list[Path]) -> pd.DataFrame:
frames = []
for path in stepwise_paths:
frame = pd.read_csv(path)
frame["dataset"] = path.parts[-5]
frame["case"] = path.parts[-4]
frames.append(frame)
if not frames:
raise FileNotFoundError("No stepwise_log.csv files found.")
stepwise_df = pd.concat(frames, ignore_index=True)
for column in [
"step",
"step_time_s",
"latent_delta",
"action_delta",
"state_delta",
"action_cosine_vs_full50",
"state_cosine_vs_full50",
"latent_l2_vs_full50",
]:
stepwise_df[column] = pd.to_numeric(stepwise_df[column], errors="coerce")
return stepwise_df
def simulate_schemes(stepwise_df: pd.DataFrame,
schemes: dict[str, list[int]]) -> pd.DataFrame:
rows = []
group_columns = [
"dataset",
"case",
"sample_id",
"scene",
"pass_type",
"round_id",
]
grouped = stepwise_df.groupby(group_columns)
for keys, group in grouped:
group = group.sort_values("step").reset_index(drop=True)
action_curve = dict(zip(group["step"] - 1, group["action_cosine_vs_full50"]))
state_curve = dict(zip(group["step"] - 1, group["state_cosine_vs_full50"]))
for scheme_name, checkpoints in schemes.items():
normalized_checkpoints = sorted(
checkpoint for checkpoint in checkpoints if 0 <= checkpoint <= 49)
if not normalized_checkpoints:
continue
last_checkpoint = normalized_checkpoints[-1]
rows.append({
"dataset": keys[0],
"case": keys[1],
"sample_id": keys[2],
"scene": keys[3],
"pass_type": keys[4],
"round_id": keys[5],
"scheme": scheme_name,
"head_exec_steps_zero_based": json.dumps(normalized_checkpoints),
"head_exec_count": len(normalized_checkpoints),
"head_compute_saving_ratio": 1.0 - len(normalized_checkpoints) / 50.0,
"final_checkpoint_zero_based": last_checkpoint,
"final_checkpoint_one_based": last_checkpoint + 1,
"final_action_cosine_vs_dense": action_curve[last_checkpoint],
"final_state_cosine_vs_dense": state_curve[last_checkpoint],
})
return pd.DataFrame(rows)
def summarize_scheme_results(simulation_df: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame]:
metric_columns = [
"final_action_cosine_vs_dense",
"final_state_cosine_vs_dense",
"head_exec_count",
"head_compute_saving_ratio",
]
overall = simulation_df.groupby(["scheme", "pass_type"])[metric_columns].agg(
["mean", "median", "std"]).reset_index()
per_case = simulation_df.groupby(
["dataset", "case", "scheme", "pass_type"])[metric_columns].mean().reset_index()
return overall, per_case
def compute_min_steps_needed(stepwise_df: pd.DataFrame,
threshold: float) -> pd.DataFrame:
rows = []
group_columns = [
"dataset",
"case",
"sample_id",
"scene",
"pass_type",
"round_id",
]
grouped = stepwise_df.groupby(group_columns)
for keys, group in grouped:
group = group.sort_values("step").reset_index(drop=True)
action_hits = group[group["action_cosine_vs_full50"] >= threshold]
state_hits = group[group["state_cosine_vs_full50"] >= threshold]
action_step = np.nan if action_hits.empty else float(action_hits.iloc[0]["step"] - 1)
state_step = np.nan if state_hits.empty else float(state_hits.iloc[0]["step"] - 1)
rows.append({
"dataset": keys[0],
"case": keys[1],
"sample_id": keys[2],
"scene": keys[3],
"pass_type": keys[4],
"round_id": keys[5],
"target_threshold": threshold,
"action_min_head_steps_needed": 1.0 if pd.notna(action_step) else np.nan,
"state_min_head_steps_needed": 1.0 if pd.notna(state_step) else np.nan,
"action_earliest_checkpoint_zero_based": action_step,
"state_earliest_checkpoint_zero_based": state_step,
"action_max_head_compute_saving_ratio": 0.98 if pd.notna(action_step) else np.nan,
"state_max_head_compute_saving_ratio": 0.98 if pd.notna(state_step) else np.nan,
})
return pd.DataFrame(rows)
def compare_tail_vs_uniform(simulation_df: pd.DataFrame) -> pd.DataFrame:
comparison_specs = [
("sparse_8", "tail_heavy_8", "budget_8_uniform_vs_tail"),
("sparse_4", "tail_only_4", "budget_4_uniform_vs_tail"),
("sparse_10", "tail_heavy_10", "budget_10_uniform_vs_tail"),
]
base_columns = [
"dataset",
"case",
"sample_id",
"scene",
"pass_type",
"round_id",
]
rows = []
for uniform_scheme, tail_scheme, comparison_name in comparison_specs:
uniform_df = simulation_df[simulation_df["scheme"] == uniform_scheme].copy()
tail_df = simulation_df[simulation_df["scheme"] == tail_scheme].copy()
merged = uniform_df.merge(
tail_df,
on=base_columns,
suffixes=("_uniform", "_tail"),
how="inner",
)
for _, row in merged.iterrows():
rows.append({
"comparison": comparison_name,
"dataset": row["dataset"],
"case": row["case"],
"sample_id": row["sample_id"],
"scene": row["scene"],
"pass_type": row["pass_type"],
"round_id": row["round_id"],
"uniform_scheme": uniform_scheme,
"tail_scheme": tail_scheme,
"uniform_head_exec_count": row["head_exec_count_uniform"],
"tail_head_exec_count": row["head_exec_count_tail"],
"uniform_action_cosine_vs_dense": row["final_action_cosine_vs_dense_uniform"],
"tail_action_cosine_vs_dense": row["final_action_cosine_vs_dense_tail"],
"uniform_state_cosine_vs_dense": row["final_state_cosine_vs_dense_uniform"],
"tail_state_cosine_vs_dense": row["final_state_cosine_vs_dense_tail"],
"tail_minus_uniform_action_cosine":
row["final_action_cosine_vs_dense_tail"] -
row["final_action_cosine_vs_dense_uniform"],
"tail_minus_uniform_state_cosine":
row["final_state_cosine_vs_dense_tail"] -
row["final_state_cosine_vs_dense_uniform"],
"tail_better_action": row["final_action_cosine_vs_dense_tail"] >
row["final_action_cosine_vs_dense_uniform"],
"tail_better_state": row["final_state_cosine_vs_dense_tail"] >
row["final_state_cosine_vs_dense_uniform"],
})
return pd.DataFrame(rows)
def build_summary_payload(simulation_df: pd.DataFrame,
min_steps_df: pd.DataFrame,
tail_compare_df: pd.DataFrame,
target_threshold: float) -> dict:
payload: dict[str, dict] = {
"config": {
"target_threshold": target_threshold,
"schemes": DEFAULT_SCHEMES,
},
"experiment_1_1": {},
"experiment_1_2": {},
"experiment_1_3": {},
}
for scheme, group in simulation_df.groupby("scheme"):
payload["experiment_1_1"][scheme] = {
"num_rows": int(len(group)),
"head_exec_count": float(group["head_exec_count"].iloc[0]),
"head_compute_saving_ratio": float(
group["head_compute_saving_ratio"].iloc[0]),
"final_action_cosine_vs_dense_mean": float(
group["final_action_cosine_vs_dense"].mean()),
"final_action_cosine_vs_dense_median": float(
group["final_action_cosine_vs_dense"].median()),
"final_state_cosine_vs_dense_mean": float(
group["final_state_cosine_vs_dense"].mean()),
"final_state_cosine_vs_dense_median": float(
group["final_state_cosine_vs_dense"].median()),
}
payload["experiment_1_2"] = {
"action_earliest_checkpoint_zero_based": {
"mean": float(
pd.to_numeric(
min_steps_df["action_earliest_checkpoint_zero_based"],
errors="coerce").dropna().mean()),
"median": float(
pd.to_numeric(
min_steps_df["action_earliest_checkpoint_zero_based"],
errors="coerce").dropna().median()),
},
"state_earliest_checkpoint_zero_based": {
"mean": float(
pd.to_numeric(
min_steps_df["state_earliest_checkpoint_zero_based"],
errors="coerce").dropna().mean()),
"median": float(
pd.to_numeric(
min_steps_df["state_earliest_checkpoint_zero_based"],
errors="coerce").dropna().median()),
},
"min_head_steps_needed_action_unique": sorted(
pd.to_numeric(min_steps_df["action_min_head_steps_needed"],
errors="coerce").dropna().unique().tolist()),
"min_head_steps_needed_state_unique": sorted(
pd.to_numeric(min_steps_df["state_min_head_steps_needed"],
errors="coerce").dropna().unique().tolist()),
}
for comparison, group in tail_compare_df.groupby("comparison"):
payload["experiment_1_3"][comparison] = {
"num_rows": int(len(group)),
"tail_better_action_share": float(group["tail_better_action"].mean()),
"tail_better_state_share": float(group["tail_better_state"].mean()),
"tail_minus_uniform_action_cosine_mean": float(
group["tail_minus_uniform_action_cosine"].mean()),
"tail_minus_uniform_state_cosine_mean": float(
group["tail_minus_uniform_state_cosine"].mean()),
}
return payload
def write_markdown_report(path: Path, payload: dict) -> None:
lines = [
"# Sparse Head Execution Simulation",
"",
"This report uses zero-order hold over logged stepwise action/state outputs.",
"For a sparse scheme, the final output at step 49 is approximated by the most recent checkpoint output.",
"",
"## Experiment 1.1",
"",
]
for scheme, stats in payload["experiment_1_1"].items():
lines.extend([
f"### {scheme}",
"",
f"- Head exec count: {stats['head_exec_count']:.0f}",
f"- Head compute saving ratio: {stats['head_compute_saving_ratio']:.4f}",
f"- Final action cosine vs dense: mean={stats['final_action_cosine_vs_dense_mean']:.4f}, median={stats['final_action_cosine_vs_dense_median']:.4f}",
f"- Final state cosine vs dense: mean={stats['final_state_cosine_vs_dense_mean']:.4f}, median={stats['final_state_cosine_vs_dense_median']:.4f}",
"",
])
lines.extend([
"## Experiment 1.2",
"",
f"- Target threshold: {payload['config']['target_threshold']}",
f"- Action earliest checkpoint mean: {payload['experiment_1_2']['action_earliest_checkpoint_zero_based']['mean']:.4f}",
f"- State earliest checkpoint mean: {payload['experiment_1_2']['state_earliest_checkpoint_zero_based']['mean']:.4f}",
f"- Unique min head steps needed for action: {payload['experiment_1_2']['min_head_steps_needed_action_unique']}",
f"- Unique min head steps needed for state: {payload['experiment_1_2']['min_head_steps_needed_state_unique']}",
"",
"## Experiment 1.3",
"",
])
for comparison, stats in payload["experiment_1_3"].items():
lines.extend([
f"### {comparison}",
"",
f"- Tail better action share: {stats['tail_better_action_share']:.4f}",
f"- Tail better state share: {stats['tail_better_state_share']:.4f}",
f"- Mean tail-minus-uniform action cosine: {stats['tail_minus_uniform_action_cosine_mean']:.4f}",
f"- Mean tail-minus-uniform state cosine: {stats['tail_minus_uniform_state_cosine_mean']:.4f}",
"",
])
with open(path, "w", encoding="utf-8") as file:
file.write("\n".join(lines) + "\n")
def main() -> None:
args = parse_args()
root_dir = Path(args.root_dir).resolve()
output_dir = Path(args.output_dir).resolve(
) if args.output_dir else root_dir / "sparse_head_simulation"
output_dir.mkdir(parents=True, exist_ok=True)
stepwise_paths = discover_stepwise_logs(root_dir)
stepwise_df = load_stepwise_table(stepwise_paths)
simulation_df = simulate_schemes(stepwise_df, DEFAULT_SCHEMES)
scheme_overall_df, scheme_per_case_df = summarize_scheme_results(simulation_df)
min_steps_df = compute_min_steps_needed(stepwise_df, args.target_threshold)
tail_compare_df = compare_tail_vs_uniform(simulation_df)
summary_payload = build_summary_payload(simulation_df, min_steps_df,
tail_compare_df,
args.target_threshold)
simulation_df.to_csv(output_dir / "scheme_simulation_per_round.csv",
index=False)
scheme_overall_df.to_csv(output_dir / "scheme_simulation_overall.csv",
index=False)
scheme_per_case_df.to_csv(output_dir / "scheme_simulation_per_case.csv",
index=False)
min_steps_df.to_csv(output_dir / "min_head_steps_needed.csv", index=False)
tail_compare_df.to_csv(output_dir / "tail_vs_uniform.csv", index=False)
with open(output_dir / "summary.json", "w", encoding="utf-8") as file:
json.dump(summary_payload, file, indent=2, ensure_ascii=False)
write_markdown_report(output_dir / "report.md", summary_payload)
print(f"Sparse head simulation written to: {output_dir}")
if __name__ == "__main__":
main()

View File

@@ -53,6 +53,12 @@ def build_sample_id(dataset: str, sample: pd.Series, frame_stride: int) -> str:
return f"{dataset}-vid{sample['videoid']}-fs{frame_stride}"
def get_case_id(prompt_dir: str) -> str:
"""Resolve case id from a prompt directory like */case1/world_model_interaction_prompts."""
normalized = os.path.normpath(prompt_dir)
return os.path.basename(os.path.dirname(normalized))
def flatten_batch_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Flatten all non-batch dimensions for batched metric computation."""
return tensor.detach().float().reshape(tensor.shape[0], -1)
@@ -124,6 +130,37 @@ def safe_mean(values: list[float]) -> float:
return float(np.mean(valid_values))
def flatten_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Flatten an arbitrary tensor into one 1D float vector."""
return tensor.detach().float().reshape(-1)
def tensor_l2_distance(current: torch.Tensor, reference: torch.Tensor) -> float:
"""Compute ||current-reference|| for arbitrary tensors."""
current_flat = flatten_tensor(current)
reference_flat = flatten_tensor(reference)
return float(torch.linalg.vector_norm(current_flat - reference_flat).item())
def tensor_relative_l2(current: torch.Tensor, previous: torch.Tensor) -> float:
"""Compute ||current-previous|| / (||previous|| + eps) for arbitrary tensors."""
current_flat = flatten_tensor(current)
previous_flat = flatten_tensor(previous)
numerator = torch.linalg.vector_norm(current_flat - previous_flat)
denominator = torch.linalg.vector_norm(previous_flat).clamp_min(1e-8)
return float((numerator / denominator).item())
def tensor_cosine_similarity(current: torch.Tensor,
reference: torch.Tensor) -> float:
"""Compute cosine similarity between arbitrary tensors."""
current_flat = flatten_tensor(current)
reference_flat = flatten_tensor(reference)
return float(
F.cosine_similarity(current_flat, reference_flat, dim=0,
eps=1e-8).item())
def make_sampling_noise_bundle(model: nn.Module,
noise_shape: list[int]) -> dict[str, torch.Tensor]:
"""Create aligned initial noise for latent, action, and state diffusion streams."""
@@ -139,6 +176,15 @@ def make_sampling_noise_bundle(model: nn.Module,
}
def reset_sampling_seed(seed: int) -> None:
"""Reset RNGs so repeated dense passes follow the same stochastic DDIM path."""
random.seed(seed)
np.random.seed(seed % (2**32))
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def load_psnr_lookup(psnr_path: str | None) -> dict[str, float]:
"""Load optional PSNR values keyed by sample_id or videoid."""
if not psnr_path:
@@ -445,6 +491,121 @@ class InteractionAnalysisLogger:
summary_rows.append(row)
summary_df = pd.DataFrame(summary_rows, columns=self.SUMMARY_COLUMNS)
summary_df.to_csv(summary_path, index=False)
class BackboneBlockProfiler:
"""Collect dense backbone block features and timings with a low-memory two-pass flow."""
COLUMNS = [
'sample_id',
'case_id',
'scene',
'pass_type',
'round_id',
'step',
'block_name',
'block_stage',
'block_index',
'shape',
'forward_time_ms',
'l2_delta_vs_prev',
'rel_l2_delta_vs_prev',
'cosine_vs_prev',
'l2_delta_vs_full50',
'cosine_vs_full50',
]
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.rows: list[dict] = []
self.reference_features: dict[tuple[str, str, int], dict[str,
torch.Tensor]] = {}
self.mode: str | None = None
self.pass_key: tuple[str, str, int] | None = None
self.pass_meta: dict[str, str | int] = {}
self.current_reference: dict[str, torch.Tensor] = {}
self.previous_features: dict[str, torch.Tensor] = {}
def _set_pass(self, mode: str, sample_id: str, case_id: str, scene: str,
pass_type: str, round_id: int) -> None:
self.mode = mode
self.pass_key = (sample_id, pass_type, int(round_id))
self.pass_meta = {
'sample_id': sample_id,
'case_id': case_id,
'scene': scene,
'pass_type': pass_type,
'round_id': int(round_id),
}
self.current_reference = {}
self.previous_features = {}
def start_reference_pass(self, sample_id: str, case_id: str, scene: str,
pass_type: str, round_id: int) -> None:
self._set_pass('reference', sample_id, case_id, scene, pass_type,
round_id)
def start_target_pass(self, sample_id: str, case_id: str, scene: str,
pass_type: str, round_id: int) -> None:
self._set_pass('target', sample_id, case_id, scene, pass_type, round_id)
def finish_pass(self) -> None:
if self.mode == 'reference' and self.pass_key is not None:
self.reference_features[self.pass_key] = self.current_reference
elif self.mode == 'target' and self.pass_key is not None:
self.reference_features.pop(self.pass_key, None)
self.mode = None
self.pass_key = None
self.pass_meta = {}
self.current_reference = {}
self.previous_features = {}
def record_block(self, step: int, block_name: str, block_stage: str,
block_index: int | None, output: torch.Tensor,
forward_time_ms: float) -> None:
if self.mode is None or self.pass_key is None:
return
block_output = output.detach().float().cpu()
if self.mode == 'reference':
self.current_reference[block_name] = block_output
return
previous = self.previous_features.get(block_name)
reference = self.reference_features.get(self.pass_key, {}).get(block_name)
row = {
**self.pass_meta,
'step': int(step),
'block_name': block_name,
'block_stage': block_stage,
'block_index': -1 if block_index is None else int(block_index),
'shape': str(tuple(block_output.shape)),
'forward_time_ms': float(forward_time_ms),
'l2_delta_vs_prev': np.nan,
'rel_l2_delta_vs_prev': np.nan,
'cosine_vs_prev': np.nan,
'l2_delta_vs_full50': np.nan,
'cosine_vs_full50': np.nan,
}
if previous is not None:
row['l2_delta_vs_prev'] = tensor_l2_distance(block_output, previous)
row['rel_l2_delta_vs_prev'] = tensor_relative_l2(
block_output, previous)
row['cosine_vs_prev'] = tensor_cosine_similarity(
block_output, previous)
if reference is not None:
row['l2_delta_vs_full50'] = tensor_l2_distance(
block_output, reference)
row['cosine_vs_full50'] = tensor_cosine_similarity(
block_output, reference)
self.previous_features[block_name] = block_output
self.rows.append(row)
def flush(self) -> None:
os.makedirs(self.output_dir, exist_ok=True)
path = os.path.join(self.output_dir, 'backbone_block_log.csv')
df = pd.DataFrame(self.rows, columns=self.COLUMNS)
df.to_csv(path, index=False)
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
@@ -765,7 +926,9 @@ def image_guided_synthesis_sim_mode(
init_noise_bundle (dict[str, torch.Tensor] | None): Optional aligned noise inputs for latent/action/state.
decode_video (bool): Whether to decode the final latent into pixel space.
return_debug_info (bool): Whether to return per-step traces for analysis logging.
**kwargs: Additional arguments passed to the DDIM sampler.
**kwargs: Additional arguments passed to the DDIM sampler, including
sparse head controls such as `head_schedule`, `head_log_steps`,
and `head_skip_mode`.
Returns:
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W].
@@ -852,10 +1015,11 @@ def image_guided_synthesis_sim_mode(
if decode_video:
batch_variants = model.decode_first_stage(samples)
if return_debug_info:
if return_debug_info or intermedia.get('head_sparse_logs'):
debug_info = {
'analysis_init': intermedia.get('analysis_init'),
'step_records': intermedia.get('step_records', []),
'head_sparse_logs': intermedia.get('head_sparse_logs', {}),
'final_latent': samples.detach().cpu(),
'final_action': actions.detach().cpu(),
'final_state': states.detach().cpu(),
@@ -883,11 +1047,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
analysis_logger = None
backbone_profiler = None
head_schedule = args.head_schedule_steps if args.head_schedule_steps else None
head_log_steps = args.head_log_steps if args.head_log_steps else None
head_skip_mode = args.head_skip_mode
if args.analysis_log_metrics:
analysis_logger = InteractionAnalysisLogger(
output_dir=inference_dir,
psnr_lookup=load_psnr_lookup(args.analysis_psnr_path),
)
if args.analysis_profile_backbone_blocks:
if head_schedule is not None:
raise ValueError(
"Backbone block profiling expects dense DDIM runs. "
"Do not pass --head_schedule_steps.")
backbone_profiler = BackboneBlockProfiler(output_dir=inference_dir)
case_id = get_case_id(args.prompt_dir)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
@@ -1021,10 +1196,45 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
policy_noise_bundle = make_sampling_noise_bundle(
model, noise_shape) if args.analysis_log_metrics else None
policy_noise_bundle = (
make_sampling_noise_bundle(model, noise_shape)
if (args.analysis_log_metrics
or backbone_profiler is not None) else None)
policy_reference_debug = None
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
policy_sampling_seed = int(args.seed + itr * 1000 + 11)
if backbone_profiler is not None:
reset_sampling_seed(policy_sampling_seed)
backbone_profiler.start_reference_pass(
sample_id=sample_id,
case_id=case_id,
scene=scene,
pass_type='policy',
round_id=itr,
)
image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
decode_video=False,
return_debug_info=False,
backbone_block_profiler=backbone_profiler,
)
backbone_profiler.finish_pass()
need_policy_reference = args.analysis_log_metrics and (
args.analysis_reference_steps != args.ddim_steps
or head_schedule is not None)
if need_policy_reference:
_, _, _, policy_reference_debug = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
@@ -1041,8 +1251,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
decode_video=False,
return_debug_info=True)
return_debug_info=True,
head_log_steps=head_log_steps)
policy_pass_start = time.time()
if backbone_profiler is not None:
reset_sampling_seed(policy_sampling_seed)
backbone_profiler.start_target_pass(
sample_id=sample_id,
case_id=case_id,
scene=scene,
pass_type='policy',
round_id=itr,
)
pred_videos_0, pred_actions, _, policy_debug = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
@@ -1058,7 +1278,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
guidance_rescale=args.guidance_rescale,
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
return_debug_info=args.analysis_log_metrics)
return_debug_info=args.analysis_log_metrics,
head_schedule=head_schedule,
head_log_steps=head_log_steps,
head_skip_mode=head_skip_mode,
backbone_block_profiler=backbone_profiler)
if backbone_profiler is not None:
backbone_profiler.finish_pass()
policy_pass_total_time_s = time.time() - policy_pass_start
policy_summary_row = None
if analysis_logger is not None:
@@ -1101,10 +1327,45 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
world_noise_bundle = make_sampling_noise_bundle(
model, noise_shape) if args.analysis_log_metrics else None
world_noise_bundle = (
make_sampling_noise_bundle(model, noise_shape)
if (args.analysis_log_metrics
or backbone_profiler is not None) else None)
world_reference_debug = None
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
world_sampling_seed = int(args.seed + itr * 1000 + 29)
if backbone_profiler is not None:
reset_sampling_seed(world_sampling_seed)
backbone_profiler.start_reference_pass(
sample_id=sample_id,
case_id=case_id,
scene=scene,
pass_type='world_model',
round_id=itr,
)
image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
decode_video=False,
return_debug_info=False,
backbone_block_profiler=backbone_profiler,
)
backbone_profiler.finish_pass()
need_world_reference = args.analysis_log_metrics and (
args.analysis_reference_steps != args.ddim_steps
or head_schedule is not None)
if need_world_reference:
_, _, _, world_reference_debug = image_guided_synthesis_sim_mode(
model,
"",
@@ -1121,8 +1382,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
decode_video=False,
return_debug_info=True)
return_debug_info=True,
head_log_steps=head_log_steps)
world_pass_start = time.time()
if backbone_profiler is not None:
reset_sampling_seed(world_sampling_seed)
backbone_profiler.start_target_pass(
sample_id=sample_id,
case_id=case_id,
scene=scene,
pass_type='world_model',
round_id=itr,
)
pred_videos_1, _, pred_states, world_debug = image_guided_synthesis_sim_mode(
model,
"",
@@ -1138,7 +1409,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
return_debug_info=args.analysis_log_metrics)
return_debug_info=args.analysis_log_metrics,
head_schedule=head_schedule,
head_log_steps=head_log_steps,
head_skip_mode=head_skip_mode,
backbone_block_profiler=backbone_profiler)
if backbone_profiler is not None:
backbone_profiler.finish_pass()
world_pass_total_time_s = time.time() - world_pass_start
world_summary_row = None
if analysis_logger is not None:
@@ -1224,6 +1501,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
if analysis_logger is not None:
analysis_logger.flush()
if backbone_profiler is not None:
backbone_profiler.flush()
writer.close()
@@ -1357,6 +1636,29 @@ def get_parser():
type=str,
default=None,
help="Optional CSV/JSON file with psnr_full50 values keyed by sample_id or videoid.")
parser.add_argument(
"--analysis_profile_backbone_blocks",
action='store_true',
default=False,
help="Run dense two-pass backbone block profiling and export backbone_block_log.csv.")
parser.add_argument(
"--head_schedule_steps",
type=int,
nargs='*',
default=None,
help="Zero-based DDIM loop indices where action/state heads execute. Omit for dense execution.")
parser.add_argument(
"--head_log_steps",
type=int,
nargs='*',
default=None,
help="Zero-based DDIM loop indices to snapshot sparse action/state/latent outputs for dense-vs-sparse comparison.")
parser.add_argument(
"--head_skip_mode",
type=str,
default="reuse_prediction",
choices=["reuse_prediction", "freeze_state"],
help="Behavior on non-checkpoint steps: reuse cached head predictions while still running scheduler.step, or freeze action/state entirely.")
return parser