延迟 decode,只解码 CLIP 需要的 1 帧

- world model 调用 decode_video=False,跳过 16 帧全量 decode
- 只 decode 最后 1 帧给 CLIP embedding / observation queue
- 存 raw latent,循环结束后统一 batch decode 生成最终视频
- 每轮省 15 次 VAE decode,8 轮共省 120 次
- 跳过中间迭代的 wm tensorboard/mp4 保存
psnr微弱下降
This commit is contained in:
qhy
2026-02-11 17:07:33 +08:00
parent 3101252c25
commit 508b91f5a2
3 changed files with 57 additions and 68 deletions

View File

@@ -494,6 +494,7 @@ def image_guided_synthesis_sim_mode(
cond_mask = None cond_mask = None
cond_z0 = None cond_z0 = None
batch_variants = None batch_variants = None
samples = None
if ddim_sampler is not None: if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample( samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps, S=ddim_steps,
@@ -517,7 +518,7 @@ def image_guided_synthesis_sim_mode(
batch_images = model.decode_first_stage(samples) batch_images = model.decode_first_stage(samples)
batch_variants = batch_images batch_variants = batch_images
return batch_variants, actions, states return batch_variants, actions, states, samples
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
@@ -648,7 +649,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
sample_save_dir = f'{video_save_dir}/wm/{fs}' sample_save_dir = f'{video_save_dir}/wm/{fs}'
os.makedirs(sample_save_dir, exist_ok=True) os.makedirs(sample_save_dir, exist_ok=True)
# For collecting interaction videos # For collecting interaction videos
wm_video = [] wm_latent = []
# Initialize observation queues # Initialize observation queues
cond_obs_queues = { cond_obs_queues = {
"observation.images.top": "observation.images.top":
@@ -704,7 +705,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Use world-model in policy to generate action # Use world-model in policy to generate action
print(f'>>> Step {itr}: generating actions ...') print(f'>>> Step {itr}: generating actions ...')
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode(
model, model,
sample['instruction'], sample['instruction'],
observation, observation,
@@ -746,7 +747,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Interaction with the world-model # Interaction with the world-model
print(f'>>> Step {itr}: interacting with world model ...') print(f'>>> Step {itr}: interacting with world model ...')
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
model, model,
"", "",
observation, observation,
@@ -759,12 +760,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fs=model_input_fs, fs=model_input_fs,
text_input=False, text_input=False,
timestep_spacing=args.timestep_spacing, timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale) guidance_rescale=args.guidance_rescale,
decode_video=False)
# Decode only the last frame for CLIP embedding in next iteration
last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :])
for idx in range(args.exe_steps): for idx in range(args.exe_steps):
observation = { observation = {
'observation.images.top': 'observation.images.top':
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3), last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3),
'observation.state': 'observation.state':
torch.zeros_like(pred_states[0][idx:idx + 1]) if torch.zeros_like(pred_states[0][idx:idx + 1]) if
args.zero_pred_state else pred_states[0][idx:idx + 1], args.zero_pred_state else pred_states[0][idx:idx + 1],
@@ -782,30 +787,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
pred_videos_0, pred_videos_0,
sample_tag, sample_tag,
fps=args.save_fps) fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard_async(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
# Save the imagen videos for decision-making
if pred_videos_0 is not None:
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results_async(pred_videos_0,
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results_async(pred_videos_1,
sample_video_file,
fps=args.save_fps)
print('>' * 24) print('>' * 24)
# Collect the result of world-model interactions # Store raw latent for deferred decode
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu()) wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu())
full_video = torch.cat(wm_video, dim=2) # Deferred decode: batch decode all stored latents
full_latent = torch.cat(wm_latent, dim=2).to(device)
full_video = model.decode_first_stage(full_latent).cpu()
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full" sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard_async(writer, log_to_tensorboard_async(writer,
full_video, full_video,

View File

@@ -1,10 +1,10 @@
2026-02-11 16:32:03.555597: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2026-02-11 16:58:21.710140: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-11 16:32:03.605506: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2026-02-11 16:58:21.759418: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-11 16:32:03.605550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2026-02-11 16:58:21.759461: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-11 16:32:03.606879: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2026-02-11 16:58:21.760752: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-11 16:32:03.614434: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. 2026-02-11 16:58:21.768205: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-11 16:32:04.545234: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 2026-02-11 16:58:22.691154: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123 Global seed set to 123
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ... >>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
>>> Prepared model loaded. >>> Prepared model loaded.
@@ -34,10 +34,40 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13 DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]
9%|▉ | 1/11 [00:23<03:52, 23.26s/it] 9%|▉ | 1/11 [00:23<03:52, 23.26s/it]
18%|█▊ | 2/11 [00:45<03:25, 22.85s/it] 18%|█▊ | 2/11 [00:45<03:25, 22.85s/it]
27%|██▋ | 3/11 [01:08<03:02, 22.82s/it] 27%|██▋ | 3/11 [01:08<03:02, 22.82s/it]
36%|███▋ | 4/11 [01:31<02:39, 22.83s/it]
45%|████▌ | 5/11 [01:54<02:17, 22.83s/it]
55%|█████▍ | 6/11 [02:17<01:54, 22.83s/it]
64%|██████▎ | 7/11 [02:39<01:31, 22.83s/it]
73%|███████▎ | 8/11 [03:02<01:08, 22.83s/it]
82%|████████▏ | 9/11 [03:25<00:45, 22.81s/it]
91%|█████████ | 10/11 [03:48<00:22, 22.81s/it]
100%|██████████| 11/11 [04:11<00:00, 22.79s/it]
100%|██████████| 11/11 [04:11<00:00, 22.83s/it]
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ... >>> Step 7: generating actions ...
>>> Step 7: interacting with world model ... >>> Step 7: interacting with world model ...
@@ -87,37 +117,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
18%|█▊ | 2/11 [00:47<03:31, 23.51s/it]
27%|██▋ | 3/11 [01:10<03:08, 23.56s/it]
36%|███▋ | 4/11 [01:34<02:45, 23.66s/it]
45%|████▌ | 5/11 [01:58<02:22, 23.67s/it]
55%|█████▍ | 6/11 [02:21<01:58, 23.67s/it]
64%|██████▎ | 7/11 [02:45<01:34, 23.62s/it]
73%|███████▎ | 8/11 [03:08<01:10, 23.61s/it]
82%|████████▏ | 9/11 [03:32<00:47, 23.59s/it]
91%|█████████ | 10/11 [03:56<00:23, 23.60s/it]
100%|██████████| 11/11 [04:19<00:00, 23.59s/it]
100%|██████████| 11/11 [04:19<00:00, 23.61s/it]
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
DEBUG:PIL.Image:Importing PpmImagePlugin DEBUG:PIL.Image:Importing PpmImagePlugin
>>> Step 7: generating actions ... DEBUG:PIL.Image:Importing PsdImagePlugin
>>> Step 7: interacting with world model ... DEBUG:PIL.Image:Importing QoiImagePlugin
>>>>>>>>>>>>>>>>>>>>>>>> DEBUG:PIL.Image:Importing SgiImagePlugin

View File

@@ -1,5 +1,5 @@
{ {
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4", "gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4", "pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
"psnr": 27.185465604200047 "psnr": 26.683000215343522
} }