保存结果一次

This commit is contained in:
qhy
2026-03-18 20:52:13 +08:00
parent 8ca159d375
commit 9d2d57d96b
15 changed files with 2312 additions and 15 deletions

View File

@@ -239,6 +239,8 @@ class InteractionAnalysisLogger:
'round_id',
'step',
'step_time_s',
'backbone_reused_blocks_count',
'backbone_reuse_hit_blocks',
'latent_delta',
'action_delta',
'state_delta',
@@ -388,6 +390,10 @@ class InteractionAnalysisLogger:
'round_id': round_id,
'step': record['step_index'],
'step_time_s': float(record['step_time_s']),
'backbone_reused_blocks_count': int(
record.get('backbone_reused_blocks_count', 0)),
'backbone_reuse_hit_blocks':
record.get('backbone_reuse_hit_blocks', ''),
'latent_delta': latent_delta,
'action_delta': action_delta,
'state_delta': state_delta,
@@ -928,7 +934,7 @@ def image_guided_synthesis_sim_mode(
return_debug_info (bool): Whether to return per-step traces for analysis logging.
**kwargs: Additional arguments passed to the DDIM sampler, including
sparse head controls such as `head_schedule`, `head_log_steps`,
and `head_skip_mode`.
and `head_skip_mode`, plus optional decoder block reuse settings.
Returns:
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W].
@@ -1051,16 +1057,27 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
head_schedule = args.head_schedule_steps if args.head_schedule_steps else None
head_log_steps = args.head_log_steps if args.head_log_steps else None
head_skip_mode = args.head_skip_mode
backbone_reuse_blocks = (args.backbone_reuse_blocks
if args.backbone_reuse_blocks else None)
backbone_reuse_schedule_steps = (
args.backbone_reuse_schedule_steps
if args.backbone_reuse_schedule_steps else None)
backbone_reuse_force_compute_steps = (
args.backbone_reuse_force_compute_steps
if args.backbone_reuse_force_compute_steps else None)
backbone_reuse_enabled = (
args.backbone_reuse_mode == "reuse_output"
and backbone_reuse_blocks is not None)
if args.analysis_log_metrics:
analysis_logger = InteractionAnalysisLogger(
output_dir=inference_dir,
psnr_lookup=load_psnr_lookup(args.analysis_psnr_path),
)
if args.analysis_profile_backbone_blocks:
if head_schedule is not None:
if head_schedule is not None or backbone_reuse_enabled:
raise ValueError(
"Backbone block profiling expects dense DDIM runs. "
"Do not pass --head_schedule_steps.")
"Do not pass sparse head or backbone reuse flags.")
backbone_profiler = BackboneBlockProfiler(output_dir=inference_dir)
case_id = get_case_id(args.prompt_dir)
@@ -1233,7 +1250,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
backbone_profiler.finish_pass()
need_policy_reference = args.analysis_log_metrics and (
args.analysis_reference_steps != args.ddim_steps
or head_schedule is not None)
or head_schedule is not None or backbone_reuse_enabled)
if need_policy_reference:
_, _, _, policy_reference_debug = image_guided_synthesis_sim_mode(
model,
@@ -1282,6 +1299,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
head_schedule=head_schedule,
head_log_steps=head_log_steps,
head_skip_mode=head_skip_mode,
backbone_reuse_blocks=backbone_reuse_blocks,
backbone_reuse_start_step=args.backbone_reuse_start_step,
backbone_reuse_schedule_steps=
backbone_reuse_schedule_steps,
backbone_reuse_force_compute_steps=
backbone_reuse_force_compute_steps,
backbone_reuse_mode=args.backbone_reuse_mode,
backbone_block_profiler=backbone_profiler)
if backbone_profiler is not None:
backbone_profiler.finish_pass()
@@ -1364,7 +1388,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
backbone_profiler.finish_pass()
need_world_reference = args.analysis_log_metrics and (
args.analysis_reference_steps != args.ddim_steps
or head_schedule is not None)
or head_schedule is not None or backbone_reuse_enabled)
if need_world_reference:
_, _, _, world_reference_debug = image_guided_synthesis_sim_mode(
model,
@@ -1413,6 +1437,13 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
head_schedule=head_schedule,
head_log_steps=head_log_steps,
head_skip_mode=head_skip_mode,
backbone_reuse_blocks=backbone_reuse_blocks,
backbone_reuse_start_step=args.backbone_reuse_start_step,
backbone_reuse_schedule_steps=
backbone_reuse_schedule_steps,
backbone_reuse_force_compute_steps=
backbone_reuse_force_compute_steps,
backbone_reuse_mode=args.backbone_reuse_mode,
backbone_block_profiler=backbone_profiler)
if backbone_profiler is not None:
backbone_profiler.finish_pass()
@@ -1659,6 +1690,35 @@ def get_parser():
default="reuse_prediction",
choices=["reuse_prediction", "freeze_state"],
help="Behavior on non-checkpoint steps: reuse cached head predictions while still running scheduler.step, or freeze action/state entirely.")
parser.add_argument(
"--backbone_reuse_blocks",
type=str,
nargs='*',
default=None,
help="Decoder block names to reuse on non-checkpoint DDIM steps, e.g. output_5 output_4 output_6.")
parser.add_argument(
"--backbone_reuse_start_step",
type=int,
default=None,
help="1-based DDIM step index after which decoder block reuse becomes eligible.")
parser.add_argument(
"--backbone_reuse_schedule_steps",
type=int,
nargs='*',
default=None,
help="1-based DDIM step indices where selected decoder blocks must be recomputed.")
parser.add_argument(
"--backbone_reuse_force_compute_steps",
type=int,
nargs='*',
default=None,
help="1-based DDIM step indices that always recompute selected decoder blocks, even if omitted from the reuse schedule.")
parser.add_argument(
"--backbone_reuse_mode",
type=str,
default="disabled",
choices=["disabled", "reuse_output"],
help="Decoder block reuse mode. 'reuse_output' reuses cached output block tensors on non-checkpoint steps.")
return parser