减少了一路视频vae解码
This commit is contained in:
@@ -444,7 +444,8 @@ def image_guided_synthesis_sim_mode(
|
|||||||
timestep_spacing: str = 'uniform',
|
timestep_spacing: str = 'uniform',
|
||||||
guidance_rescale: float = 0.0,
|
guidance_rescale: float = 0.0,
|
||||||
sim_mode: bool = True,
|
sim_mode: bool = True,
|
||||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
decode_video: bool = True,
|
||||||
|
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
||||||
|
|
||||||
@@ -467,10 +468,13 @@ def image_guided_synthesis_sim_mode(
|
|||||||
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
||||||
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
||||||
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
||||||
|
decode_video (bool): Whether to decode latent samples to pixel-space video.
|
||||||
|
Set to False to skip VAE decode for speed when only actions/states are needed.
|
||||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
|
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
|
||||||
|
or None when decode_video=False.
|
||||||
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
||||||
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
||||||
"""
|
"""
|
||||||
@@ -554,6 +558,7 @@ def image_guided_synthesis_sim_mode(
|
|||||||
else:
|
else:
|
||||||
autocast_ctx = nullcontext()
|
autocast_ctx = nullcontext()
|
||||||
|
|
||||||
|
batch_variants = None
|
||||||
if ddim_sampler is not None:
|
if ddim_sampler is not None:
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||||
@@ -573,6 +578,7 @@ def image_guided_synthesis_sim_mode(
|
|||||||
guidance_rescale=guidance_rescale,
|
guidance_rescale=guidance_rescale,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
if decode_video:
|
||||||
# Reconstruct from latent to pixel space
|
# Reconstruct from latent to pixel space
|
||||||
batch_images = model.decode_first_stage(samples)
|
batch_images = model.decode_first_stage(samples)
|
||||||
batch_variants = batch_images
|
batch_variants = batch_images
|
||||||
@@ -750,7 +756,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
fs=model_input_fs,
|
fs=model_input_fs,
|
||||||
timestep_spacing=args.timestep_spacing,
|
timestep_spacing=args.timestep_spacing,
|
||||||
guidance_rescale=args.guidance_rescale,
|
guidance_rescale=args.guidance_rescale,
|
||||||
sim_mode=False)
|
sim_mode=False,
|
||||||
|
decode_video=not args.fast_policy_no_decode)
|
||||||
|
|
||||||
# Update future actions in the observation queues
|
# Update future actions in the observation queues
|
||||||
for idx in range(len(pred_actions[0])):
|
for idx in range(len(pred_actions[0])):
|
||||||
@@ -808,6 +815,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
observation)
|
observation)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making
|
||||||
|
if pred_videos_0 is not None:
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard(writer,
|
log_to_tensorboard(writer,
|
||||||
pred_videos_0,
|
pred_videos_0,
|
||||||
@@ -821,6 +829,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making
|
||||||
|
if pred_videos_0 is not None:
|
||||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||||
save_results(pred_videos_0.cpu(),
|
save_results(pred_videos_0.cpu(),
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
@@ -957,6 +966,11 @@ def get_parser():
|
|||||||
action='store_true',
|
action='store_true',
|
||||||
default=False,
|
default=False,
|
||||||
help="not using the predicted states as comparison")
|
help="not using the predicted states as comparison")
|
||||||
|
parser.add_argument(
|
||||||
|
"--fast_policy_no_decode",
|
||||||
|
action='store_true',
|
||||||
|
default=False,
|
||||||
|
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
|
||||||
parser.add_argument("--save_fps",
|
parser.add_argument("--save_fps",
|
||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
__import__("pkg_resources").declare_namespace(__name__)
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
2026-02-08 18:43:46.463492: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-09 16:37:01.511249: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-08 18:43:46.466714: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-09 16:37:01.514371: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-08 18:43:46.498994: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-09 16:37:01.545068: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-08 18:43:46.499029: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-09 16:37:01.545097: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-08 18:43:46.500865: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-09 16:37:01.546937: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
2026-02-08 18:43:46.509069: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
2026-02-09 16:37:01.555024: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-08 18:43:46.509359: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-09 16:37:01.555338: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
2026-02-08 18:43:47.434136: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-09 16:37:02.212554: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
[rank: 0] Global seed set to 123
|
[rank: 0] Global seed set to 123
|
||||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
@@ -116,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
12%|█▎ | 1/8 [01:11<08:20, 71.56s/it]
|
12%|█▎ | 1/8 [01:11<08:20, 71.56s/it]
|
||||||
25%|██▌ | 2/8 [02:19<06:56, 69.36s/it]
|
25%|██▌ | 2/8 [02:19<06:56, 69.36s/it]
|
||||||
@@ -140,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||||
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||||
"psnr": 30.058508734449845
|
"psnr": 31.802224855380352
|
||||||
}
|
}
|
||||||
@@ -21,5 +21,6 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
|||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae \
|
--perframe_ae \
|
||||||
--vae_dtype bf16
|
--vae_dtype bf16 \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
Reference in New Issue
Block a user