video_backbone剖析

This commit is contained in:
qhy
2026-03-16 10:30:54 +08:00
parent 7e45eba18b
commit 8ca159d375
282 changed files with 174952 additions and 1350 deletions

View File

@@ -53,6 +53,12 @@ def build_sample_id(dataset: str, sample: pd.Series, frame_stride: int) -> str:
return f"{dataset}-vid{sample['videoid']}-fs{frame_stride}"
def get_case_id(prompt_dir: str) -> str:
"""Resolve case id from a prompt directory like */case1/world_model_interaction_prompts."""
normalized = os.path.normpath(prompt_dir)
return os.path.basename(os.path.dirname(normalized))
def flatten_batch_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Flatten all non-batch dimensions for batched metric computation."""
return tensor.detach().float().reshape(tensor.shape[0], -1)
@@ -124,6 +130,37 @@ def safe_mean(values: list[float]) -> float:
return float(np.mean(valid_values))
def flatten_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Flatten an arbitrary tensor into one 1D float vector."""
return tensor.detach().float().reshape(-1)
def tensor_l2_distance(current: torch.Tensor, reference: torch.Tensor) -> float:
"""Compute ||current-reference|| for arbitrary tensors."""
current_flat = flatten_tensor(current)
reference_flat = flatten_tensor(reference)
return float(torch.linalg.vector_norm(current_flat - reference_flat).item())
def tensor_relative_l2(current: torch.Tensor, previous: torch.Tensor) -> float:
"""Compute ||current-previous|| / (||previous|| + eps) for arbitrary tensors."""
current_flat = flatten_tensor(current)
previous_flat = flatten_tensor(previous)
numerator = torch.linalg.vector_norm(current_flat - previous_flat)
denominator = torch.linalg.vector_norm(previous_flat).clamp_min(1e-8)
return float((numerator / denominator).item())
def tensor_cosine_similarity(current: torch.Tensor,
reference: torch.Tensor) -> float:
"""Compute cosine similarity between arbitrary tensors."""
current_flat = flatten_tensor(current)
reference_flat = flatten_tensor(reference)
return float(
F.cosine_similarity(current_flat, reference_flat, dim=0,
eps=1e-8).item())
def make_sampling_noise_bundle(model: nn.Module,
noise_shape: list[int]) -> dict[str, torch.Tensor]:
"""Create aligned initial noise for latent, action, and state diffusion streams."""
@@ -139,6 +176,15 @@ def make_sampling_noise_bundle(model: nn.Module,
}
def reset_sampling_seed(seed: int) -> None:
"""Reset RNGs so repeated dense passes follow the same stochastic DDIM path."""
random.seed(seed)
np.random.seed(seed % (2**32))
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def load_psnr_lookup(psnr_path: str | None) -> dict[str, float]:
"""Load optional PSNR values keyed by sample_id or videoid."""
if not psnr_path:
@@ -445,6 +491,121 @@ class InteractionAnalysisLogger:
summary_rows.append(row)
summary_df = pd.DataFrame(summary_rows, columns=self.SUMMARY_COLUMNS)
summary_df.to_csv(summary_path, index=False)
class BackboneBlockProfiler:
"""Collect dense backbone block features and timings with a low-memory two-pass flow."""
COLUMNS = [
'sample_id',
'case_id',
'scene',
'pass_type',
'round_id',
'step',
'block_name',
'block_stage',
'block_index',
'shape',
'forward_time_ms',
'l2_delta_vs_prev',
'rel_l2_delta_vs_prev',
'cosine_vs_prev',
'l2_delta_vs_full50',
'cosine_vs_full50',
]
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.rows: list[dict] = []
self.reference_features: dict[tuple[str, str, int], dict[str,
torch.Tensor]] = {}
self.mode: str | None = None
self.pass_key: tuple[str, str, int] | None = None
self.pass_meta: dict[str, str | int] = {}
self.current_reference: dict[str, torch.Tensor] = {}
self.previous_features: dict[str, torch.Tensor] = {}
def _set_pass(self, mode: str, sample_id: str, case_id: str, scene: str,
pass_type: str, round_id: int) -> None:
self.mode = mode
self.pass_key = (sample_id, pass_type, int(round_id))
self.pass_meta = {
'sample_id': sample_id,
'case_id': case_id,
'scene': scene,
'pass_type': pass_type,
'round_id': int(round_id),
}
self.current_reference = {}
self.previous_features = {}
def start_reference_pass(self, sample_id: str, case_id: str, scene: str,
pass_type: str, round_id: int) -> None:
self._set_pass('reference', sample_id, case_id, scene, pass_type,
round_id)
def start_target_pass(self, sample_id: str, case_id: str, scene: str,
pass_type: str, round_id: int) -> None:
self._set_pass('target', sample_id, case_id, scene, pass_type, round_id)
def finish_pass(self) -> None:
if self.mode == 'reference' and self.pass_key is not None:
self.reference_features[self.pass_key] = self.current_reference
elif self.mode == 'target' and self.pass_key is not None:
self.reference_features.pop(self.pass_key, None)
self.mode = None
self.pass_key = None
self.pass_meta = {}
self.current_reference = {}
self.previous_features = {}
def record_block(self, step: int, block_name: str, block_stage: str,
block_index: int | None, output: torch.Tensor,
forward_time_ms: float) -> None:
if self.mode is None or self.pass_key is None:
return
block_output = output.detach().float().cpu()
if self.mode == 'reference':
self.current_reference[block_name] = block_output
return
previous = self.previous_features.get(block_name)
reference = self.reference_features.get(self.pass_key, {}).get(block_name)
row = {
**self.pass_meta,
'step': int(step),
'block_name': block_name,
'block_stage': block_stage,
'block_index': -1 if block_index is None else int(block_index),
'shape': str(tuple(block_output.shape)),
'forward_time_ms': float(forward_time_ms),
'l2_delta_vs_prev': np.nan,
'rel_l2_delta_vs_prev': np.nan,
'cosine_vs_prev': np.nan,
'l2_delta_vs_full50': np.nan,
'cosine_vs_full50': np.nan,
}
if previous is not None:
row['l2_delta_vs_prev'] = tensor_l2_distance(block_output, previous)
row['rel_l2_delta_vs_prev'] = tensor_relative_l2(
block_output, previous)
row['cosine_vs_prev'] = tensor_cosine_similarity(
block_output, previous)
if reference is not None:
row['l2_delta_vs_full50'] = tensor_l2_distance(
block_output, reference)
row['cosine_vs_full50'] = tensor_cosine_similarity(
block_output, reference)
self.previous_features[block_name] = block_output
self.rows.append(row)
def flush(self) -> None:
os.makedirs(self.output_dir, exist_ok=True)
path = os.path.join(self.output_dir, 'backbone_block_log.csv')
df = pd.DataFrame(self.rows, columns=self.COLUMNS)
df.to_csv(path, index=False)
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
@@ -765,7 +926,9 @@ def image_guided_synthesis_sim_mode(
init_noise_bundle (dict[str, torch.Tensor] | None): Optional aligned noise inputs for latent/action/state.
decode_video (bool): Whether to decode the final latent into pixel space.
return_debug_info (bool): Whether to return per-step traces for analysis logging.
**kwargs: Additional arguments passed to the DDIM sampler.
**kwargs: Additional arguments passed to the DDIM sampler, including
sparse head controls such as `head_schedule`, `head_log_steps`,
and `head_skip_mode`.
Returns:
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W].
@@ -852,10 +1015,11 @@ def image_guided_synthesis_sim_mode(
if decode_video:
batch_variants = model.decode_first_stage(samples)
if return_debug_info:
if return_debug_info or intermedia.get('head_sparse_logs'):
debug_info = {
'analysis_init': intermedia.get('analysis_init'),
'step_records': intermedia.get('step_records', []),
'head_sparse_logs': intermedia.get('head_sparse_logs', {}),
'final_latent': samples.detach().cpu(),
'final_action': actions.detach().cpu(),
'final_state': states.detach().cpu(),
@@ -883,11 +1047,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
analysis_logger = None
backbone_profiler = None
head_schedule = args.head_schedule_steps if args.head_schedule_steps else None
head_log_steps = args.head_log_steps if args.head_log_steps else None
head_skip_mode = args.head_skip_mode
if args.analysis_log_metrics:
analysis_logger = InteractionAnalysisLogger(
output_dir=inference_dir,
psnr_lookup=load_psnr_lookup(args.analysis_psnr_path),
)
if args.analysis_profile_backbone_blocks:
if head_schedule is not None:
raise ValueError(
"Backbone block profiling expects dense DDIM runs. "
"Do not pass --head_schedule_steps.")
backbone_profiler = BackboneBlockProfiler(output_dir=inference_dir)
case_id = get_case_id(args.prompt_dir)
# Load prompt
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
@@ -1021,10 +1196,45 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...')
policy_noise_bundle = make_sampling_noise_bundle(
model, noise_shape) if args.analysis_log_metrics else None
policy_noise_bundle = (
make_sampling_noise_bundle(model, noise_shape)
if (args.analysis_log_metrics
or backbone_profiler is not None) else None)
policy_reference_debug = None
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
policy_sampling_seed = int(args.seed + itr * 1000 + 11)
if backbone_profiler is not None:
reset_sampling_seed(policy_sampling_seed)
backbone_profiler.start_reference_pass(
sample_id=sample_id,
case_id=case_id,
scene=scene,
pass_type='policy',
round_id=itr,
)
image_guided_synthesis_sim_mode(
model,
sample['instruction'],
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
decode_video=False,
return_debug_info=False,
backbone_block_profiler=backbone_profiler,
)
backbone_profiler.finish_pass()
need_policy_reference = args.analysis_log_metrics and (
args.analysis_reference_steps != args.ddim_steps
or head_schedule is not None)
if need_policy_reference:
_, _, _, policy_reference_debug = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
@@ -1041,8 +1251,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
decode_video=False,
return_debug_info=True)
return_debug_info=True,
head_log_steps=head_log_steps)
policy_pass_start = time.time()
if backbone_profiler is not None:
reset_sampling_seed(policy_sampling_seed)
backbone_profiler.start_target_pass(
sample_id=sample_id,
case_id=case_id,
scene=scene,
pass_type='policy',
round_id=itr,
)
pred_videos_0, pred_actions, _, policy_debug = image_guided_synthesis_sim_mode(
model,
sample['instruction'],
@@ -1058,7 +1278,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
guidance_rescale=args.guidance_rescale,
sim_mode=False,
init_noise_bundle=policy_noise_bundle,
return_debug_info=args.analysis_log_metrics)
return_debug_info=args.analysis_log_metrics,
head_schedule=head_schedule,
head_log_steps=head_log_steps,
head_skip_mode=head_skip_mode,
backbone_block_profiler=backbone_profiler)
if backbone_profiler is not None:
backbone_profiler.finish_pass()
policy_pass_total_time_s = time.time() - policy_pass_start
policy_summary_row = None
if analysis_logger is not None:
@@ -1101,10 +1327,45 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...')
world_noise_bundle = make_sampling_noise_bundle(
model, noise_shape) if args.analysis_log_metrics else None
world_noise_bundle = (
make_sampling_noise_bundle(model, noise_shape)
if (args.analysis_log_metrics
or backbone_profiler is not None) else None)
world_reference_debug = None
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
world_sampling_seed = int(args.seed + itr * 1000 + 29)
if backbone_profiler is not None:
reset_sampling_seed(world_sampling_seed)
backbone_profiler.start_reference_pass(
sample_id=sample_id,
case_id=case_id,
scene=scene,
pass_type='world_model',
round_id=itr,
)
image_guided_synthesis_sim_mode(
model,
"",
observation,
noise_shape,
action_cond_step=args.exe_steps,
ddim_steps=args.ddim_steps,
ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.
unconditional_guidance_scale,
fs=model_input_fs,
text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
decode_video=False,
return_debug_info=False,
backbone_block_profiler=backbone_profiler,
)
backbone_profiler.finish_pass()
need_world_reference = args.analysis_log_metrics and (
args.analysis_reference_steps != args.ddim_steps
or head_schedule is not None)
if need_world_reference:
_, _, _, world_reference_debug = image_guided_synthesis_sim_mode(
model,
"",
@@ -1121,8 +1382,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
decode_video=False,
return_debug_info=True)
return_debug_info=True,
head_log_steps=head_log_steps)
world_pass_start = time.time()
if backbone_profiler is not None:
reset_sampling_seed(world_sampling_seed)
backbone_profiler.start_target_pass(
sample_id=sample_id,
case_id=case_id,
scene=scene,
pass_type='world_model',
round_id=itr,
)
pred_videos_1, _, pred_states, world_debug = image_guided_synthesis_sim_mode(
model,
"",
@@ -1138,7 +1409,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
init_noise_bundle=world_noise_bundle,
return_debug_info=args.analysis_log_metrics)
return_debug_info=args.analysis_log_metrics,
head_schedule=head_schedule,
head_log_steps=head_log_steps,
head_skip_mode=head_skip_mode,
backbone_block_profiler=backbone_profiler)
if backbone_profiler is not None:
backbone_profiler.finish_pass()
world_pass_total_time_s = time.time() - world_pass_start
world_summary_row = None
if analysis_logger is not None:
@@ -1224,6 +1501,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
if analysis_logger is not None:
analysis_logger.flush()
if backbone_profiler is not None:
backbone_profiler.flush()
writer.close()
@@ -1357,6 +1636,29 @@ def get_parser():
type=str,
default=None,
help="Optional CSV/JSON file with psnr_full50 values keyed by sample_id or videoid.")
parser.add_argument(
"--analysis_profile_backbone_blocks",
action='store_true',
default=False,
help="Run dense two-pass backbone block profiling and export backbone_block_log.csv.")
parser.add_argument(
"--head_schedule_steps",
type=int,
nargs='*',
default=None,
help="Zero-based DDIM loop indices where action/state heads execute. Omit for dense execution.")
parser.add_argument(
"--head_log_steps",
type=int,
nargs='*',
default=None,
help="Zero-based DDIM loop indices to snapshot sparse action/state/latent outputs for dense-vs-sparse comparison.")
parser.add_argument(
"--head_skip_mode",
type=str,
default="reuse_prediction",
choices=["reuse_prediction", "freeze_state"],
help="Behavior on non-checkpoint steps: reuse cached head predictions while still running scheduler.step, or freeze action/state entirely.")
return parser