减少了一路视频vae解码

This commit is contained in:
qhy
2026-02-10 17:13:45 +08:00
parent 91a9b0febc
commit 2a6068f9e4
3 changed files with 41 additions and 26 deletions

View File

@@ -330,7 +330,8 @@ def image_guided_synthesis_sim_mode(
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
decode_video: bool = True,
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
@@ -353,10 +354,13 @@ def image_guided_synthesis_sim_mode(
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
decode_video (bool): Whether to decode latent samples to pixel-space video.
Set to False to skip VAE decode for speed when only actions/states are needed.
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
or None when decode_video=False.
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
@@ -409,6 +413,7 @@ def image_guided_synthesis_sim_mode(
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
batch_variants = None
if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
@@ -427,9 +432,10 @@ def image_guided_synthesis_sim_mode(
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
if decode_video:
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
@@ -590,7 +596,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False)
sim_mode=False,
decode_video=not args.fast_policy_no_decode)
# Update future actions in the observation queues
for idx in range(len(pred_actions[0])):
@@ -648,11 +655,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
observation)
# Save the imagen videos for decision-making
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
if pred_videos_0 is not None:
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
@@ -661,10 +669,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fps=args.save_fps)
# Save the imagen videos for decision-making
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
sample_video_file,
fps=args.save_fps)
if pred_videos_0 is not None:
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(),
@@ -797,6 +806,11 @@ def get_parser():
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument(
"--fast_policy_no_decode",
action='store_true',
default=False,
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
parser.add_argument("--save_fps",
type=int,
default=8,