video_backbone剖析
This commit is contained in:
@@ -53,6 +53,12 @@ def build_sample_id(dataset: str, sample: pd.Series, frame_stride: int) -> str:
|
||||
return f"{dataset}-vid{sample['videoid']}-fs{frame_stride}"
|
||||
|
||||
|
||||
def get_case_id(prompt_dir: str) -> str:
|
||||
"""Resolve case id from a prompt directory like */case1/world_model_interaction_prompts."""
|
||||
normalized = os.path.normpath(prompt_dir)
|
||||
return os.path.basename(os.path.dirname(normalized))
|
||||
|
||||
|
||||
def flatten_batch_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Flatten all non-batch dimensions for batched metric computation."""
|
||||
return tensor.detach().float().reshape(tensor.shape[0], -1)
|
||||
@@ -124,6 +130,37 @@ def safe_mean(values: list[float]) -> float:
|
||||
return float(np.mean(valid_values))
|
||||
|
||||
|
||||
def flatten_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Flatten an arbitrary tensor into one 1D float vector."""
|
||||
return tensor.detach().float().reshape(-1)
|
||||
|
||||
|
||||
def tensor_l2_distance(current: torch.Tensor, reference: torch.Tensor) -> float:
|
||||
"""Compute ||current-reference|| for arbitrary tensors."""
|
||||
current_flat = flatten_tensor(current)
|
||||
reference_flat = flatten_tensor(reference)
|
||||
return float(torch.linalg.vector_norm(current_flat - reference_flat).item())
|
||||
|
||||
|
||||
def tensor_relative_l2(current: torch.Tensor, previous: torch.Tensor) -> float:
|
||||
"""Compute ||current-previous|| / (||previous|| + eps) for arbitrary tensors."""
|
||||
current_flat = flatten_tensor(current)
|
||||
previous_flat = flatten_tensor(previous)
|
||||
numerator = torch.linalg.vector_norm(current_flat - previous_flat)
|
||||
denominator = torch.linalg.vector_norm(previous_flat).clamp_min(1e-8)
|
||||
return float((numerator / denominator).item())
|
||||
|
||||
|
||||
def tensor_cosine_similarity(current: torch.Tensor,
|
||||
reference: torch.Tensor) -> float:
|
||||
"""Compute cosine similarity between arbitrary tensors."""
|
||||
current_flat = flatten_tensor(current)
|
||||
reference_flat = flatten_tensor(reference)
|
||||
return float(
|
||||
F.cosine_similarity(current_flat, reference_flat, dim=0,
|
||||
eps=1e-8).item())
|
||||
|
||||
|
||||
def make_sampling_noise_bundle(model: nn.Module,
|
||||
noise_shape: list[int]) -> dict[str, torch.Tensor]:
|
||||
"""Create aligned initial noise for latent, action, and state diffusion streams."""
|
||||
@@ -139,6 +176,15 @@ def make_sampling_noise_bundle(model: nn.Module,
|
||||
}
|
||||
|
||||
|
||||
def reset_sampling_seed(seed: int) -> None:
|
||||
"""Reset RNGs so repeated dense passes follow the same stochastic DDIM path."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed % (2**32))
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def load_psnr_lookup(psnr_path: str | None) -> dict[str, float]:
|
||||
"""Load optional PSNR values keyed by sample_id or videoid."""
|
||||
if not psnr_path:
|
||||
@@ -445,6 +491,121 @@ class InteractionAnalysisLogger:
|
||||
summary_rows.append(row)
|
||||
summary_df = pd.DataFrame(summary_rows, columns=self.SUMMARY_COLUMNS)
|
||||
summary_df.to_csv(summary_path, index=False)
|
||||
|
||||
|
||||
class BackboneBlockProfiler:
|
||||
"""Collect dense backbone block features and timings with a low-memory two-pass flow."""
|
||||
|
||||
COLUMNS = [
|
||||
'sample_id',
|
||||
'case_id',
|
||||
'scene',
|
||||
'pass_type',
|
||||
'round_id',
|
||||
'step',
|
||||
'block_name',
|
||||
'block_stage',
|
||||
'block_index',
|
||||
'shape',
|
||||
'forward_time_ms',
|
||||
'l2_delta_vs_prev',
|
||||
'rel_l2_delta_vs_prev',
|
||||
'cosine_vs_prev',
|
||||
'l2_delta_vs_full50',
|
||||
'cosine_vs_full50',
|
||||
]
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
self.rows: list[dict] = []
|
||||
self.reference_features: dict[tuple[str, str, int], dict[str,
|
||||
torch.Tensor]] = {}
|
||||
self.mode: str | None = None
|
||||
self.pass_key: tuple[str, str, int] | None = None
|
||||
self.pass_meta: dict[str, str | int] = {}
|
||||
self.current_reference: dict[str, torch.Tensor] = {}
|
||||
self.previous_features: dict[str, torch.Tensor] = {}
|
||||
|
||||
def _set_pass(self, mode: str, sample_id: str, case_id: str, scene: str,
|
||||
pass_type: str, round_id: int) -> None:
|
||||
self.mode = mode
|
||||
self.pass_key = (sample_id, pass_type, int(round_id))
|
||||
self.pass_meta = {
|
||||
'sample_id': sample_id,
|
||||
'case_id': case_id,
|
||||
'scene': scene,
|
||||
'pass_type': pass_type,
|
||||
'round_id': int(round_id),
|
||||
}
|
||||
self.current_reference = {}
|
||||
self.previous_features = {}
|
||||
|
||||
def start_reference_pass(self, sample_id: str, case_id: str, scene: str,
|
||||
pass_type: str, round_id: int) -> None:
|
||||
self._set_pass('reference', sample_id, case_id, scene, pass_type,
|
||||
round_id)
|
||||
|
||||
def start_target_pass(self, sample_id: str, case_id: str, scene: str,
|
||||
pass_type: str, round_id: int) -> None:
|
||||
self._set_pass('target', sample_id, case_id, scene, pass_type, round_id)
|
||||
|
||||
def finish_pass(self) -> None:
|
||||
if self.mode == 'reference' and self.pass_key is not None:
|
||||
self.reference_features[self.pass_key] = self.current_reference
|
||||
elif self.mode == 'target' and self.pass_key is not None:
|
||||
self.reference_features.pop(self.pass_key, None)
|
||||
self.mode = None
|
||||
self.pass_key = None
|
||||
self.pass_meta = {}
|
||||
self.current_reference = {}
|
||||
self.previous_features = {}
|
||||
|
||||
def record_block(self, step: int, block_name: str, block_stage: str,
|
||||
block_index: int | None, output: torch.Tensor,
|
||||
forward_time_ms: float) -> None:
|
||||
if self.mode is None or self.pass_key is None:
|
||||
return
|
||||
block_output = output.detach().float().cpu()
|
||||
if self.mode == 'reference':
|
||||
self.current_reference[block_name] = block_output
|
||||
return
|
||||
|
||||
previous = self.previous_features.get(block_name)
|
||||
reference = self.reference_features.get(self.pass_key, {}).get(block_name)
|
||||
row = {
|
||||
**self.pass_meta,
|
||||
'step': int(step),
|
||||
'block_name': block_name,
|
||||
'block_stage': block_stage,
|
||||
'block_index': -1 if block_index is None else int(block_index),
|
||||
'shape': str(tuple(block_output.shape)),
|
||||
'forward_time_ms': float(forward_time_ms),
|
||||
'l2_delta_vs_prev': np.nan,
|
||||
'rel_l2_delta_vs_prev': np.nan,
|
||||
'cosine_vs_prev': np.nan,
|
||||
'l2_delta_vs_full50': np.nan,
|
||||
'cosine_vs_full50': np.nan,
|
||||
}
|
||||
if previous is not None:
|
||||
row['l2_delta_vs_prev'] = tensor_l2_distance(block_output, previous)
|
||||
row['rel_l2_delta_vs_prev'] = tensor_relative_l2(
|
||||
block_output, previous)
|
||||
row['cosine_vs_prev'] = tensor_cosine_similarity(
|
||||
block_output, previous)
|
||||
if reference is not None:
|
||||
row['l2_delta_vs_full50'] = tensor_l2_distance(
|
||||
block_output, reference)
|
||||
row['cosine_vs_full50'] = tensor_cosine_similarity(
|
||||
block_output, reference)
|
||||
|
||||
self.previous_features[block_name] = block_output
|
||||
self.rows.append(row)
|
||||
|
||||
def flush(self) -> None:
|
||||
os.makedirs(self.output_dir, exist_ok=True)
|
||||
path = os.path.join(self.output_dir, 'backbone_block_log.csv')
|
||||
df = pd.DataFrame(self.rows, columns=self.COLUMNS)
|
||||
df.to_csv(path, index=False)
|
||||
|
||||
|
||||
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
||||
@@ -765,7 +926,9 @@ def image_guided_synthesis_sim_mode(
|
||||
init_noise_bundle (dict[str, torch.Tensor] | None): Optional aligned noise inputs for latent/action/state.
|
||||
decode_video (bool): Whether to decode the final latent into pixel space.
|
||||
return_debug_info (bool): Whether to return per-step traces for analysis logging.
|
||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||
**kwargs: Additional arguments passed to the DDIM sampler, including
|
||||
sparse head controls such as `head_schedule`, `head_log_steps`,
|
||||
and `head_skip_mode`.
|
||||
|
||||
Returns:
|
||||
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W].
|
||||
@@ -852,10 +1015,11 @@ def image_guided_synthesis_sim_mode(
|
||||
if decode_video:
|
||||
batch_variants = model.decode_first_stage(samples)
|
||||
|
||||
if return_debug_info:
|
||||
if return_debug_info or intermedia.get('head_sparse_logs'):
|
||||
debug_info = {
|
||||
'analysis_init': intermedia.get('analysis_init'),
|
||||
'step_records': intermedia.get('step_records', []),
|
||||
'head_sparse_logs': intermedia.get('head_sparse_logs', {}),
|
||||
'final_latent': samples.detach().cpu(),
|
||||
'final_action': actions.detach().cpu(),
|
||||
'final_state': states.detach().cpu(),
|
||||
@@ -883,11 +1047,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=log_dir)
|
||||
analysis_logger = None
|
||||
backbone_profiler = None
|
||||
head_schedule = args.head_schedule_steps if args.head_schedule_steps else None
|
||||
head_log_steps = args.head_log_steps if args.head_log_steps else None
|
||||
head_skip_mode = args.head_skip_mode
|
||||
if args.analysis_log_metrics:
|
||||
analysis_logger = InteractionAnalysisLogger(
|
||||
output_dir=inference_dir,
|
||||
psnr_lookup=load_psnr_lookup(args.analysis_psnr_path),
|
||||
)
|
||||
if args.analysis_profile_backbone_blocks:
|
||||
if head_schedule is not None:
|
||||
raise ValueError(
|
||||
"Backbone block profiling expects dense DDIM runs. "
|
||||
"Do not pass --head_schedule_steps.")
|
||||
backbone_profiler = BackboneBlockProfiler(output_dir=inference_dir)
|
||||
case_id = get_case_id(args.prompt_dir)
|
||||
|
||||
# Load prompt
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
@@ -1021,10 +1196,45 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Use world-model in policy to generate action
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
policy_noise_bundle = make_sampling_noise_bundle(
|
||||
model, noise_shape) if args.analysis_log_metrics else None
|
||||
policy_noise_bundle = (
|
||||
make_sampling_noise_bundle(model, noise_shape)
|
||||
if (args.analysis_log_metrics
|
||||
or backbone_profiler is not None) else None)
|
||||
policy_reference_debug = None
|
||||
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
|
||||
policy_sampling_seed = int(args.seed + itr * 1000 + 11)
|
||||
if backbone_profiler is not None:
|
||||
reset_sampling_seed(policy_sampling_seed)
|
||||
backbone_profiler.start_reference_pass(
|
||||
sample_id=sample_id,
|
||||
case_id=case_id,
|
||||
scene=scene,
|
||||
pass_type='policy',
|
||||
round_id=itr,
|
||||
)
|
||||
image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
ddim_steps=args.ddim_steps,
|
||||
ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False,
|
||||
init_noise_bundle=policy_noise_bundle,
|
||||
decode_video=False,
|
||||
return_debug_info=False,
|
||||
backbone_block_profiler=backbone_profiler,
|
||||
)
|
||||
backbone_profiler.finish_pass()
|
||||
need_policy_reference = args.analysis_log_metrics and (
|
||||
args.analysis_reference_steps != args.ddim_steps
|
||||
or head_schedule is not None)
|
||||
if need_policy_reference:
|
||||
_, _, _, policy_reference_debug = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
@@ -1041,8 +1251,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
sim_mode=False,
|
||||
init_noise_bundle=policy_noise_bundle,
|
||||
decode_video=False,
|
||||
return_debug_info=True)
|
||||
return_debug_info=True,
|
||||
head_log_steps=head_log_steps)
|
||||
policy_pass_start = time.time()
|
||||
if backbone_profiler is not None:
|
||||
reset_sampling_seed(policy_sampling_seed)
|
||||
backbone_profiler.start_target_pass(
|
||||
sample_id=sample_id,
|
||||
case_id=case_id,
|
||||
scene=scene,
|
||||
pass_type='policy',
|
||||
round_id=itr,
|
||||
)
|
||||
pred_videos_0, pred_actions, _, policy_debug = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
@@ -1058,7 +1278,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False,
|
||||
init_noise_bundle=policy_noise_bundle,
|
||||
return_debug_info=args.analysis_log_metrics)
|
||||
return_debug_info=args.analysis_log_metrics,
|
||||
head_schedule=head_schedule,
|
||||
head_log_steps=head_log_steps,
|
||||
head_skip_mode=head_skip_mode,
|
||||
backbone_block_profiler=backbone_profiler)
|
||||
if backbone_profiler is not None:
|
||||
backbone_profiler.finish_pass()
|
||||
policy_pass_total_time_s = time.time() - policy_pass_start
|
||||
policy_summary_row = None
|
||||
if analysis_logger is not None:
|
||||
@@ -1101,10 +1327,45 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Interaction with the world-model
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
world_noise_bundle = make_sampling_noise_bundle(
|
||||
model, noise_shape) if args.analysis_log_metrics else None
|
||||
world_noise_bundle = (
|
||||
make_sampling_noise_bundle(model, noise_shape)
|
||||
if (args.analysis_log_metrics
|
||||
or backbone_profiler is not None) else None)
|
||||
world_reference_debug = None
|
||||
if args.analysis_log_metrics and args.analysis_reference_steps != args.ddim_steps:
|
||||
world_sampling_seed = int(args.seed + itr * 1000 + 29)
|
||||
if backbone_profiler is not None:
|
||||
reset_sampling_seed(world_sampling_seed)
|
||||
backbone_profiler.start_reference_pass(
|
||||
sample_id=sample_id,
|
||||
case_id=case_id,
|
||||
scene=scene,
|
||||
pass_type='world_model',
|
||||
round_id=itr,
|
||||
)
|
||||
image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
observation,
|
||||
noise_shape,
|
||||
action_cond_step=args.exe_steps,
|
||||
ddim_steps=args.ddim_steps,
|
||||
ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.
|
||||
unconditional_guidance_scale,
|
||||
fs=model_input_fs,
|
||||
text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
init_noise_bundle=world_noise_bundle,
|
||||
decode_video=False,
|
||||
return_debug_info=False,
|
||||
backbone_block_profiler=backbone_profiler,
|
||||
)
|
||||
backbone_profiler.finish_pass()
|
||||
need_world_reference = args.analysis_log_metrics and (
|
||||
args.analysis_reference_steps != args.ddim_steps
|
||||
or head_schedule is not None)
|
||||
if need_world_reference:
|
||||
_, _, _, world_reference_debug = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
@@ -1121,8 +1382,18 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
init_noise_bundle=world_noise_bundle,
|
||||
decode_video=False,
|
||||
return_debug_info=True)
|
||||
return_debug_info=True,
|
||||
head_log_steps=head_log_steps)
|
||||
world_pass_start = time.time()
|
||||
if backbone_profiler is not None:
|
||||
reset_sampling_seed(world_sampling_seed)
|
||||
backbone_profiler.start_target_pass(
|
||||
sample_id=sample_id,
|
||||
case_id=case_id,
|
||||
scene=scene,
|
||||
pass_type='world_model',
|
||||
round_id=itr,
|
||||
)
|
||||
pred_videos_1, _, pred_states, world_debug = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
@@ -1138,7 +1409,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
init_noise_bundle=world_noise_bundle,
|
||||
return_debug_info=args.analysis_log_metrics)
|
||||
return_debug_info=args.analysis_log_metrics,
|
||||
head_schedule=head_schedule,
|
||||
head_log_steps=head_log_steps,
|
||||
head_skip_mode=head_skip_mode,
|
||||
backbone_block_profiler=backbone_profiler)
|
||||
if backbone_profiler is not None:
|
||||
backbone_profiler.finish_pass()
|
||||
world_pass_total_time_s = time.time() - world_pass_start
|
||||
world_summary_row = None
|
||||
if analysis_logger is not None:
|
||||
@@ -1224,6 +1501,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
if analysis_logger is not None:
|
||||
analysis_logger.flush()
|
||||
if backbone_profiler is not None:
|
||||
backbone_profiler.flush()
|
||||
writer.close()
|
||||
|
||||
|
||||
@@ -1357,6 +1636,29 @@ def get_parser():
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional CSV/JSON file with psnr_full50 values keyed by sample_id or videoid.")
|
||||
parser.add_argument(
|
||||
"--analysis_profile_backbone_blocks",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Run dense two-pass backbone block profiling and export backbone_block_log.csv.")
|
||||
parser.add_argument(
|
||||
"--head_schedule_steps",
|
||||
type=int,
|
||||
nargs='*',
|
||||
default=None,
|
||||
help="Zero-based DDIM loop indices where action/state heads execute. Omit for dense execution.")
|
||||
parser.add_argument(
|
||||
"--head_log_steps",
|
||||
type=int,
|
||||
nargs='*',
|
||||
default=None,
|
||||
help="Zero-based DDIM loop indices to snapshot sparse action/state/latent outputs for dense-vs-sparse comparison.")
|
||||
parser.add_argument(
|
||||
"--head_skip_mode",
|
||||
type=str,
|
||||
default="reuse_prediction",
|
||||
choices=["reuse_prediction", "freeze_state"],
|
||||
help="Behavior on non-checkpoint steps: reuse cached head predictions while still running scheduler.step, or freeze action/state entirely.")
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user