延迟 decode,只解码 CLIP 需要的 1 帧
- world model 调用 decode_video=False,跳过 16 帧全量 decode - 只 decode 最后 1 帧给 CLIP embedding / observation queue - 存 raw latent,循环结束后统一 batch decode 生成最终视频 - 每轮省 15 次 VAE decode,8 轮共省 120 次 - 跳过中间迭代的 wm tensorboard/mp4 保存 psnr微弱下降
This commit is contained in:
@@ -494,6 +494,7 @@ def image_guided_synthesis_sim_mode(
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
batch_variants = None
|
||||
samples = None
|
||||
if ddim_sampler is not None:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
@@ -517,7 +518,7 @@ def image_guided_synthesis_sim_mode(
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
return batch_variants, actions, states, samples
|
||||
|
||||
|
||||
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
@@ -648,7 +649,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
sample_save_dir = f'{video_save_dir}/wm/{fs}'
|
||||
os.makedirs(sample_save_dir, exist_ok=True)
|
||||
# For collecting interaction videos
|
||||
wm_video = []
|
||||
wm_latent = []
|
||||
# Initialize observation queues
|
||||
cond_obs_queues = {
|
||||
"observation.images.top":
|
||||
@@ -704,7 +705,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Use world-model in policy to generate action
|
||||
print(f'>>> Step {itr}: generating actions ...')
|
||||
pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
|
||||
pred_videos_0, pred_actions, _, _ = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
sample['instruction'],
|
||||
observation,
|
||||
@@ -746,7 +747,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Interaction with the world-model
|
||||
print(f'>>> Step {itr}: interacting with world model ...')
|
||||
pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
|
||||
pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
|
||||
model,
|
||||
"",
|
||||
observation,
|
||||
@@ -759,12 +760,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
fs=model_input_fs,
|
||||
text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale)
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
decode_video=False)
|
||||
|
||||
# Decode only the last frame for CLIP embedding in next iteration
|
||||
last_frame_pixel = model.decode_first_stage(wm_samples[:, :, -1:, :, :])
|
||||
|
||||
for idx in range(args.exe_steps):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||
last_frame_pixel[0, :, 0:1].permute(1, 0, 2, 3),
|
||||
'observation.state':
|
||||
torch.zeros_like(pred_states[0][idx:idx + 1]) if
|
||||
args.zero_pred_state else pred_states[0][idx:idx + 1],
|
||||
@@ -782,30 +787,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
if pred_videos_0 is not None:
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results_async(pred_videos_0,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||
save_results_async(pred_videos_1,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
|
||||
print('>' * 24)
|
||||
# Collect the result of world-model interactions
|
||||
wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
|
||||
# Store raw latent for deferred decode
|
||||
wm_latent.append(wm_samples[:, :, :args.exe_steps].cpu())
|
||||
|
||||
full_video = torch.cat(wm_video, dim=2)
|
||||
# Deferred decode: batch decode all stored latents
|
||||
full_latent = torch.cat(wm_latent, dim=2).to(device)
|
||||
full_video = model.decode_first_stage(full_latent).cpu()
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||
log_to_tensorboard_async(writer,
|
||||
full_video,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
2026-02-11 16:32:03.555597: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-11 16:32:03.605506: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-11 16:32:03.605550: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-11 16:32:03.606879: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-11 16:32:03.614434: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
2026-02-11 16:58:21.710140: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-11 16:58:21.759418: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-11 16:58:21.759461: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-11 16:58:21.760752: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-11 16:58:21.768205: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-11 16:32:04.545234: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
2026-02-11 16:58:22.691154: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
||||
>>> Prepared model loaded.
|
||||
@@ -34,10 +34,40 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]
|
||||
9%|▉ | 1/11 [00:23<03:52, 23.26s/it]
|
||||
18%|█▊ | 2/11 [00:45<03:25, 22.85s/it]
|
||||
27%|██▋ | 3/11 [01:08<03:02, 22.82s/it]
|
||||
36%|███▋ | 4/11 [01:31<02:39, 22.83s/it]
|
||||
45%|████▌ | 5/11 [01:54<02:17, 22.83s/it]
|
||||
55%|█████▍ | 6/11 [02:17<01:54, 22.83s/it]
|
||||
64%|██████▎ | 7/11 [02:39<01:31, 22.83s/it]
|
||||
73%|███████▎ | 8/11 [03:02<01:08, 22.83s/it]
|
||||
82%|████████▏ | 9/11 [03:25<00:45, 22.81s/it]
|
||||
91%|█████████ | 10/11 [03:48<00:22, 22.81s/it]
|
||||
100%|██████████| 11/11 [04:11<00:00, 22.79s/it]
|
||||
100%|██████████| 11/11 [04:11<00:00, 22.83s/it]
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
@@ -87,37 +117,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
18%|█▊ | 2/11 [00:47<03:31, 23.51s/it]
|
||||
27%|██▋ | 3/11 [01:10<03:08, 23.56s/it]
|
||||
36%|███▋ | 4/11 [01:34<02:45, 23.66s/it]
|
||||
45%|████▌ | 5/11 [01:58<02:22, 23.67s/it]
|
||||
55%|█████▍ | 6/11 [02:21<01:58, 23.67s/it]
|
||||
64%|██████▎ | 7/11 [02:45<01:34, 23.62s/it]
|
||||
73%|███████▎ | 8/11 [03:08<01:10, 23.61s/it]
|
||||
82%|████████▏ | 9/11 [03:32<00:47, 23.59s/it]
|
||||
91%|█████████ | 10/11 [03:56<00:23, 23.60s/it]
|
||||
100%|██████████| 11/11 [04:19<00:00, 23.59s/it]
|
||||
100%|██████████| 11/11 [04:19<00:00, 23.61s/it]
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
|
||||
"psnr": 27.185465604200047
|
||||
"psnr": 26.683000215343522
|
||||
}
|
||||
Reference in New Issue
Block a user