减少了一路视频vae解码

This commit is contained in:
2026-02-09 16:48:16 +00:00
parent a2cd34dd51
commit 4288c9d8c9
4 changed files with 44 additions and 29 deletions

View File

@@ -444,7 +444,8 @@ def image_guided_synthesis_sim_mode(
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
decode_video: bool = True,
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
@@ -467,10 +468,13 @@ def image_guided_synthesis_sim_mode(
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
decode_video (bool): Whether to decode latent samples to pixel-space video.
Set to False to skip VAE decode for speed when only actions/states are needed.
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
or None when decode_video=False.
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
@@ -554,6 +558,7 @@ def image_guided_synthesis_sim_mode(
else:
autocast_ctx = nullcontext()
batch_variants = None
if ddim_sampler is not None:
with autocast_ctx:
samples, actions, states, intermedia = ddim_sampler.sample(
@@ -573,9 +578,10 @@ def image_guided_synthesis_sim_mode(
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
if decode_video:
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
@@ -750,7 +756,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False)
sim_mode=False,
decode_video=not args.fast_policy_no_decode)
# Update future actions in the observation queues
for idx in range(len(pred_actions[0])):
@@ -808,11 +815,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
observation)
# Save the imagen videos for decision-making
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
if pred_videos_0 is not None:
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
@@ -821,10 +829,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fps=args.save_fps)
# Save the imagen videos for decision-making
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
sample_video_file,
fps=args.save_fps)
if pred_videos_0 is not None:
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(),
@@ -957,6 +966,11 @@ def get_parser():
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument(
"--fast_policy_no_decode",
action='store_true',
default=False,
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
parser.add_argument("--save_fps",
type=int,
default=8,