Compare commits
14 Commits
trt-replac
...
qhy2
| Author | SHA1 | Date | |
|---|---|---|---|
| 3069666a15 | |||
| 68369cc15f | |||
| b0ebb7006e | |||
| 125b85ce68 | |||
| 0b3b0e534a | |||
| 6dca3696d8 | |||
| f192c8aca9 | |||
| 4288c9d8c9 | |||
| a2cd34dd51 | |||
| 7338cc384a | |||
| f86ab51a04 | |||
| 75c798ded0 | |||
| e588182642 | |||
| e6c55a648c |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -55,7 +55,7 @@ coverage.xml
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
|
||||
@@ -121,7 +121,6 @@ localTest/
|
||||
fig/
|
||||
figure/
|
||||
*.mp4
|
||||
*.json
|
||||
Data/ControlVAE.yml
|
||||
Data/Misc
|
||||
Data/Pretrained
|
||||
@@ -129,4 +128,5 @@ Data/utils.py
|
||||
Experiment/checkpoint
|
||||
Experiment/log
|
||||
|
||||
*.ckpt
|
||||
*.ckpt
|
||||
*.0
|
||||
135
case4_run.log
Normal file
135
case4_run.log
Normal file
@@ -0,0 +1,135 @@
|
||||
nohup: ignoring input
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
14%|█▍ | 1/7 [01:38<09:52, 98.73s/it]
|
||||
29%|██▊ | 2/7 [03:17<08:14, 98.85s/it]
|
||||
43%|████▎ | 3/7 [04:56<06:35, 98.80s/it]
|
||||
57%|█████▋ | 4/7 [06:35<04:56, 98.94s/it]
|
||||
71%|███████▏ | 5/7 [08:14<03:17, 98.93s/it]
|
||||
86%|████████▌ | 6/7 [09:53<01:38, 98.89s/it]
|
||||
100%|██████████| 7/7 [11:31<00:00, 98.81s/it]
|
||||
100%|██████████| 7/7 [11:31<00:00, 98.85s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
1
ckpts/configuration.json
Normal file
1
ckpts/configuration.json
Normal file
@@ -0,0 +1 @@
|
||||
{"framework": "pytorch", "task": "robotics", "allow_remote": true}
|
||||
@@ -222,7 +222,7 @@ data:
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
|
||||
21
env.sh
Normal file
21
env.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
# Note: This script should be sourced, not executed
|
||||
# Usage: source env.sh
|
||||
#
|
||||
# If you need render group permissions, run this first:
|
||||
# newgrp render
|
||||
# Then source this script:
|
||||
# source env.sh
|
||||
|
||||
# Initialize conda
|
||||
source /mnt/ASC1637/miniconda3/etc/profile.d/conda.sh
|
||||
|
||||
# Activate conda environment
|
||||
conda activate unifolm-wma-o
|
||||
|
||||
# Set HuggingFace cache directories
|
||||
export HF_HOME=/mnt/ASC1637/hf_home
|
||||
export HUGGINGFACE_HUB_CACHE=/mnt/ASC1637/hf_home/hub
|
||||
|
||||
echo "Environment configured successfully"
|
||||
echo "Conda environment: unifolm-wma-o"
|
||||
echo "HF_HOME: $HF_HOME"
|
||||
217
profile_unet_flops.md
Normal file
217
profile_unet_flops.md
Normal file
@@ -0,0 +1,217 @@
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_unet.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml
|
||||
|
||||
|
||||
==================================================================================================================================
|
||||
FLOPS BY ATen OPERATOR (FlopCounterMode)
|
||||
==================================================================================================================================
|
||||
ATen Op | GFLOPS | % of Total
|
||||
-------------------------------------------------------
|
||||
convolution | 6185.17 | 46.4%
|
||||
addmm | 4411.17 | 33.1%
|
||||
mm | 1798.34 | 13.5%
|
||||
bmm | 949.54 | 7.1%
|
||||
|
||||
==================================================================================================================================
|
||||
FLOPS BY MODULE (FlopCounterMode)
|
||||
==================================================================================================================================
|
||||
Module | GFLOPS | % of Total
|
||||
------------------------------------------------------------------------------------------
|
||||
Global | 13344.23 | 100.0%
|
||||
DiffusionWrapper | 13344.23 | 100.0%
|
||||
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
|
||||
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
|
||||
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
|
||||
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
|
||||
|
||||
==================================================================================================================================
|
||||
SUMMARY
|
||||
==================================================================================================================================
|
||||
Total CUDA time: 761.4 ms
|
||||
Matmul CUDA time: 404.2 ms (53.1%)
|
||||
Non-matmul CUDA time: 357.1 ms (46.9%)
|
||||
Total FLOPS (FlopCounter): 13344.23 GFLOPS
|
||||
Matmul throughput: 33.01 TFLOPS/s (54.1% of BF16 peak)
|
||||
Overall throughput: 17.53 TFLOPS/s (28.7% of BF16 peak)
|
||||
GPU peak (BF16): 61.0 TFLOPS
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
==================================================================================================================================
|
||||
FLOPS BY ATen OPERATOR (FlopCounterMode)
|
||||
==================================================================================================================================
|
||||
ATen Op | GFLOPS | % of Total
|
||||
-------------------------------------------------------
|
||||
convolution | 6185.17 | 46.4%
|
||||
addmm | 4411.17 | 33.1%
|
||||
mm | 1798.34 | 13.5%
|
||||
bmm | 949.54 | 7.1%
|
||||
|
||||
==================================================================================================================================
|
||||
FLOPS BY MODULE (FlopCounterMode)
|
||||
==================================================================================================================================
|
||||
Module | GFLOPS | % of Total
|
||||
------------------------------------------------------------------------------------------
|
||||
DiffusionWrapper | 13344.23 | 100.0%
|
||||
Global | 13344.23 | 100.0%
|
||||
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
|
||||
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
|
||||
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
|
||||
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
|
||||
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
|
||||
|
||||
==================================================================================================================================
|
||||
SUMMARY
|
||||
==================================================================================================================================
|
||||
Total CUDA time: 707.1 ms
|
||||
Matmul CUDA time: 403.1 ms (57.0%)
|
||||
Non-matmul CUDA time: 304.0 ms (43.0%)
|
||||
Total FLOPS (FlopCounter): 13344.23 GFLOPS
|
||||
Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak)
|
||||
Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak)
|
||||
GPU peak (BF16): 61.0 TFLOPS
|
||||
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
|
||||
|
||||
========================================================================
|
||||
TABLE 1: STAGE TIMING
|
||||
========================================================================
|
||||
Stage Mean(ms) Std %
|
||||
------------------------------------------------------------------------
|
||||
1_Image_Embedding 29.5 0.16 0.1%
|
||||
2_VAE_Encode 51.3 0.06 0.1%
|
||||
3_Text_Conditioning 14.7 0.18 0.0%
|
||||
4_Projectors 0.2 0.03 0.0%
|
||||
5_DDIM_Loop 33392.5 3.21 97.3%
|
||||
6_VAE_Decode 808.4 1.00 2.4%
|
||||
7_Post_Process 15.8 0.56 0.0%
|
||||
------------------------------------------------------------------------
|
||||
TOTAL 34312.4
|
||||
|
||||
================================================================================
|
||||
TABLE 2: UNET SUB-MODULE BREAKDOWN
|
||||
================================================================================
|
||||
Module Type Total(ms) Count Per-call %
|
||||
--------------------------------------------------------------------------------
|
||||
ResBlock 10256.3 1100 9.32 23.2%
|
||||
SpatialTransformer 9228.2 800 11.54 20.9%
|
||||
CrossAttention 8105.8 3300 2.46 18.3%
|
||||
ConditionalUnet1D 6409.5 100 64.10 14.5%
|
||||
TemporalTransformer 5847.0 850 6.88 13.2%
|
||||
FeedForward 4338.1 1650 2.63 9.8%
|
||||
UNet.out 73.8 50 1.48 0.2%
|
||||
--------------------------------------------------------------------------------
|
||||
TOTAL (hooked) 44258.7
|
||||
|
||||
==========================================================================================
|
||||
TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)
|
||||
==========================================================================================
|
||||
Block Total(ms) % Breakdown
|
||||
------------------------------------------------------------------------------------------
|
||||
input_blocks.1 3376.2 7.6% SpatialTransformer=1101, CrossAttention=990, ResBlock=543, TemporalTransformer=454, FeedForward=288
|
||||
input_blocks.2 3374.0 7.6% SpatialTransformer=1100, CrossAttention=991, ResBlock=540, TemporalTransformer=455, FeedForward=288
|
||||
input_blocks.4 1592.4 3.6% SpatialTransformer=394, ResBlock=374, CrossAttention=303, TemporalTransformer=272, FeedForward=249
|
||||
input_blocks.5 1642.5 3.7% ResBlock=425, SpatialTransformer=397, CrossAttention=303, TemporalTransformer=271, FeedForward=247
|
||||
input_blocks.7 1469.0 3.3% ResBlock=416, SpatialTransformer=324, FeedForward=251, CrossAttention=240, TemporalTransformer=237
|
||||
input_blocks.8 1543.7 3.5% ResBlock=491, SpatialTransformer=325, FeedForward=250, CrossAttention=240, TemporalTransformer=238
|
||||
input_blocks.10 217.5 0.5% ResBlock=218
|
||||
input_blocks.11 216.8 0.5% ResBlock=217
|
||||
middle_block 848.9 1.9% ResBlock=434, SpatialTransformer=151, CrossAttention=134, TemporalTransformer=69, FeedForward=61
|
||||
output_blocks.0 303.2 0.7% ResBlock=303
|
||||
output_blocks.1 303.1 0.7% ResBlock=303
|
||||
output_blocks.2 302.8 0.7% ResBlock=303
|
||||
output_blocks.3 1734.8 3.9% ResBlock=687, SpatialTransformer=322, FeedForward=249, CrossAttention=239, TemporalTransformer=237
|
||||
output_blocks.4 1739.8 3.9% ResBlock=688, SpatialTransformer=323, FeedForward=251, CrossAttention=239, TemporalTransformer=238
|
||||
output_blocks.5 1622.3 3.7% ResBlock=570, SpatialTransformer=324, FeedForward=251, CrossAttention=239, TemporalTransformer=238
|
||||
output_blocks.6 1881.0 4.3% ResBlock=664, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=250
|
||||
output_blocks.7 1768.0 4.0% ResBlock=554, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
|
||||
output_blocks.8 1688.7 3.8% ResBlock=474, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
|
||||
output_blocks.9 3558.6 8.0% SpatialTransformer=1096, CrossAttention=992, ResBlock=727, TemporalTransformer=454, FeedForward=290
|
||||
output_blocks.10 3492.8 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
|
||||
output_blocks.11 3493.3 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
|
||||
out 73.8 0.2% UNet.out=74
|
||||
action_unet 3212.0 7.3% ConditionalUnet1D=3212
|
||||
state_unet 3197.6 7.2% ConditionalUnet1D=3198
|
||||
other 1606.2 3.6% TemporalTransformer=960, FeedForward=337, CrossAttention=309
|
||||
------------------------------------------------------------------------------------------
|
||||
TOTAL 44258.7
|
||||
|
||||
======================================================================
|
||||
TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)
|
||||
======================================================================
|
||||
Component Total(ms) %
|
||||
----------------------------------------------------------------------
|
||||
CrossAttention 8105.8 65.1%
|
||||
FeedForward 4338.1 34.9%
|
||||
----------------------------------------------------------------------
|
||||
TOTAL (attn+ff) 12443.9
|
||||
|
||||
==================================================
|
||||
TABLE 3: MEMORY SUMMARY
|
||||
==================================================
|
||||
Initial allocated: 11.82 GB
|
||||
Peak allocated: 14.43 GB
|
||||
Delta (pipeline): 2.61 GB
|
||||
|
||||
============================================================
|
||||
TABLE 4: THROUGHPUT
|
||||
============================================================
|
||||
Total pipeline latency: 34312.4 ms
|
||||
DDIM loop latency: 33392.5 ms
|
||||
DDIM steps: 50
|
||||
CFG scale: 1.0 (1x UNet/step)
|
||||
UNet forward calls: 50
|
||||
Per DDIM step: 667.9 ms
|
||||
Per UNet forward: 667.9 ms
|
||||
VAE encode bandwidth: 0.1 GB/s (peak HBM: 864.0 GB/s)
|
||||
VAE decode bandwidth: 0.0 GB/s (peak HBM: 864.0 GB/s)
|
||||
GPU BF16 peak: 61.0 TFLOPS
|
||||
|
||||
Done.
|
||||
150
run.log
Normal file
150
run.log
Normal file
@@ -0,0 +1,150 @@
|
||||
nohup: ignoring input
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
8%|▊ | 1/12 [01:37<17:51, 97.37s/it]
|
||||
17%|█▋ | 2/12 [03:14<16:13, 97.31s/it]
|
||||
25%|██▌ | 3/12 [04:51<14:35, 97.26s/it]
|
||||
33%|███▎ | 4/12 [06:29<12:58, 97.25s/it]
|
||||
42%|████▏ | 5/12 [08:06<11:20, 97.24s/it]
|
||||
50%|█████ | 6/12 [09:43<09:43, 97.24s/it]
|
||||
58%|█████▊ | 7/12 [11:20<08:06, 97.27s/it]
|
||||
67%|██████▋ | 8/12 [12:58<06:29, 97.36s/it]
|
||||
75%|███████▌ | 9/12 [14:36<04:52, 97.49s/it]
|
||||
83%|████████▎ | 10/12 [16:13<03:15, 97.52s/it]
|
||||
92%|█████████▏| 11/12 [17:51<01:37, 97.47s/it]
|
||||
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
||||
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 8: generating actions ...
|
||||
>>> Step 8: interacting with world model ...
|
||||
975
scripts/evaluation/profile_iteration.py
Normal file
975
scripts/evaluation/profile_iteration.py
Normal file
@@ -0,0 +1,975 @@
|
||||
"""
|
||||
Profile the full iteration loop of world model interaction.
|
||||
|
||||
Three layers of profiling:
|
||||
Layer 1: Iteration-level wall-clock breakdown (CUDA events)
|
||||
Layer 2: GPU timeline trace (torch.profiler → Chrome trace)
|
||||
Layer 3: A/B comparison (standardized CSV output)
|
||||
|
||||
Usage:
|
||||
# Layer 1 only (fast, default):
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
||||
python scripts/evaluation/profile_iteration.py \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
--prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \
|
||||
--dataset unitree_z1_dual_arm_cleanup_pencils \
|
||||
--frame_stride 4 --n_iter 5
|
||||
|
||||
# Layer 1 + Layer 2 (GPU trace):
|
||||
... --trace --trace_dir ./profile_traces
|
||||
|
||||
# Layer 3 (A/B comparison): run twice, diff the CSVs
|
||||
... --csv baseline.csv
|
||||
... --csv optimized.csv
|
||||
python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from contextlib import nullcontext
|
||||
|
||||
import h5py
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torchvision
|
||||
from einops import rearrange, repeat
|
||||
from omegaconf import OmegaConf
|
||||
from PIL import Image
|
||||
from pytorch_lightning import seed_everything
|
||||
from torch import Tensor
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Constants
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
STAGE_NAMES = [
|
||||
"stack_to_device_1",
|
||||
"synth_policy",
|
||||
"update_action_queue",
|
||||
"stack_to_device_2",
|
||||
"synth_world_model",
|
||||
"update_obs_queue",
|
||||
"tensorboard_log",
|
||||
"save_results",
|
||||
"cpu_transfer",
|
||||
"itr_total",
|
||||
]
|
||||
|
||||
# Sub-stages inside image_guided_synthesis_sim_mode
|
||||
SYNTH_SUB_STAGES = [
|
||||
"ddim_sampler_init",
|
||||
"image_embedding",
|
||||
"vae_encode",
|
||||
"text_conditioning",
|
||||
"projectors",
|
||||
"cond_assembly",
|
||||
"ddim_sampling",
|
||||
"vae_decode",
|
||||
]
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# CudaTimer — GPU-precise timing via CUDA events
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
class CudaTimer:
|
||||
"""Context manager that records GPU time between enter/exit using CUDA events."""
|
||||
|
||||
def __init__(self, name, records):
|
||||
self.name = name
|
||||
self.records = records
|
||||
|
||||
def __enter__(self):
|
||||
torch.cuda.synchronize()
|
||||
self._start = torch.cuda.Event(enable_timing=True)
|
||||
self._end = torch.cuda.Event(enable_timing=True)
|
||||
self._start.record()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self._end.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = self._start.elapsed_time(self._end)
|
||||
self.records[self.name].append(elapsed_ms)
|
||||
|
||||
|
||||
class WallTimer:
|
||||
"""Context manager that records CPU wall-clock time (for pure-CPU stages)."""
|
||||
|
||||
def __init__(self, name, records):
|
||||
self.name = name
|
||||
self.records = records
|
||||
|
||||
def __enter__(self):
|
||||
torch.cuda.synchronize()
|
||||
self._t0 = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
torch.cuda.synchronize()
|
||||
elapsed_ms = (time.perf_counter() - self._t0) * 1000.0
|
||||
self.records[self.name].append(elapsed_ms)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Model loading (reused from world_model_interaction.py)
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def patch_norm_bypass_autocast():
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||
unet = model.model.diffusion_model
|
||||
compiled = 0
|
||||
for idx in hot_indices:
|
||||
block = unet.output_blocks[idx]
|
||||
for layer in block:
|
||||
if isinstance(layer, ResBlock):
|
||||
layer._forward = torch.compile(layer._forward, mode="default")
|
||||
compiled += 1
|
||||
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||
|
||||
|
||||
def load_model(args):
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
|
||||
from collections import OrderedDict
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
except Exception:
|
||||
new_sd = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
new_sd[k] = v
|
||||
for k in list(new_sd.keys()):
|
||||
if "framestride_embed" in k:
|
||||
new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k)
|
||||
model.load_state_dict(new_sd, strict=True)
|
||||
|
||||
model.eval()
|
||||
|
||||
# Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE
|
||||
model.model.to(torch.bfloat16)
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
model.embedder.to(torch.bfloat16)
|
||||
model.image_proj_model.to(torch.bfloat16)
|
||||
model.encoder_autocast_dtype = None
|
||||
model.state_projector.to(torch.bfloat16)
|
||||
model.action_projector.to(torch.bfloat16)
|
||||
model.projector_autocast_dtype = None
|
||||
if args.vae_dtype == "bf16":
|
||||
model.first_stage_model.to(torch.bfloat16)
|
||||
|
||||
# Compile hot ResBlocks
|
||||
apply_torch_compile(model)
|
||||
model = model.cuda()
|
||||
print(">>> Model loaded and ready.")
|
||||
return model, config
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Data preparation (reused from world_model_interaction.py)
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def get_init_frame_path(data_dir, sample):
|
||||
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png')
|
||||
return os.path.join(data_dir, 'images', rel)
|
||||
|
||||
|
||||
def get_transition_path(data_dir, sample):
|
||||
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5')
|
||||
return os.path.join(data_dir, 'transitions', rel)
|
||||
|
||||
|
||||
def prepare_init_input(start_idx, init_frame_path, transition_dict,
|
||||
frame_stride, wma_data, video_length=16, n_obs_steps=2):
|
||||
indices = [start_idx + frame_stride * i for i in range(video_length)]
|
||||
init_frame = Image.open(init_frame_path).convert('RGB')
|
||||
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float()
|
||||
|
||||
if start_idx < n_obs_steps - 1:
|
||||
state_indices = list(range(0, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
num_padding = n_obs_steps - 1 - start_idx
|
||||
padding = states[0:1, :].repeat(num_padding, 1)
|
||||
states = torch.cat((padding, states), dim=0)
|
||||
else:
|
||||
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
|
||||
actions = transition_dict['action'][indices, :]
|
||||
ori_state_dim = states.shape[-1]
|
||||
ori_action_dim = actions.shape[-1]
|
||||
|
||||
frames_action_state_dict = {
|
||||
'action': actions,
|
||||
'observation.state': states,
|
||||
}
|
||||
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
|
||||
frames_action_state_dict = wma_data.get_uni_vec(
|
||||
frames_action_state_dict,
|
||||
transition_dict['action_type'],
|
||||
transition_dict['state_type'],
|
||||
)
|
||||
|
||||
if wma_data.spatial_transform is not None:
|
||||
init_frame = wma_data.spatial_transform(init_frame)
|
||||
init_frame = (init_frame / 255 - 0.5) * 2
|
||||
|
||||
data = {'observation.image': init_frame}
|
||||
data.update(frames_action_state_dict)
|
||||
return data, ori_state_dim, ori_action_dim
|
||||
|
||||
|
||||
def populate_queues(queues, batch):
|
||||
for key in batch:
|
||||
if key not in queues:
|
||||
continue
|
||||
if len(queues[key]) != queues[key].maxlen:
|
||||
while len(queues[key]) != queues[key].maxlen:
|
||||
queues[key].append(batch[key])
|
||||
else:
|
||||
queues[key].append(batch[key])
|
||||
return queues
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Instrumented image_guided_synthesis_sim_mode with sub-stage timing
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def get_latent_z(model, videos):
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||
x = x.to(dtype=vae_dtype)
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
|
||||
|
||||
def save_results(video, filename, fps=8):
|
||||
video = video.detach().cpu()
|
||||
video = torch.clamp(video.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(filename, grid, fps=fps,
|
||||
video_codec='h264', options={'crf': '10'})
|
||||
|
||||
|
||||
def profiled_synthesis(model, prompts, observation, noise_shape,
|
||||
ddim_steps, ddim_eta, unconditional_guidance_scale,
|
||||
fs, text_input, timestep_spacing, guidance_rescale,
|
||||
sim_mode, decode_video, records, prefix):
|
||||
"""image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing.
|
||||
|
||||
Args:
|
||||
prefix: "policy" or "wm" — prepended to sub-stage names in records.
|
||||
"""
|
||||
b, _, t, _, _ = noise_shape
|
||||
batch_size = noise_shape[0]
|
||||
device = next(model.parameters()).device
|
||||
|
||||
# --- sub-stage: ddim_sampler_init ---
|
||||
with CudaTimer(f"{prefix}/ddim_sampler_init", records):
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device)
|
||||
|
||||
# --- sub-stage: image_embedding ---
|
||||
with CudaTimer(f"{prefix}/image_embedding", records):
|
||||
model_dtype = next(model.embedder.parameters()).dtype
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
# --- sub-stage: vae_encode ---
|
||||
with CudaTimer(f"{prefix}/vae_encode", records):
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
img_cat_cond = z[:, :, -1:, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
'b c t h w -> b c (repeat t) h w',
|
||||
repeat=noise_shape[2])
|
||||
cond = {"c_concat": [img_cat_cond]}
|
||||
|
||||
# --- sub-stage: text_conditioning ---
|
||||
with CudaTimer(f"{prefix}/text_conditioning", records):
|
||||
if not text_input:
|
||||
prompts_use = [""] * batch_size
|
||||
else:
|
||||
prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts_use)
|
||||
|
||||
# --- sub-stage: projectors ---
|
||||
with CudaTimer(f"{prefix}/projectors", records):
|
||||
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||
cond_state_emb = model.state_projector(
|
||||
observation['observation.state'].to(dtype=projector_dtype))
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
cond_action_emb = model.action_projector(
|
||||
observation['action'].to(dtype=projector_dtype))
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
if not sim_mode:
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
|
||||
# --- sub-stage: cond_assembly ---
|
||||
with CudaTimer(f"{prefix}/cond_assembly", records):
|
||||
cond["c_crossattn"] = [
|
||||
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1)
|
||||
]
|
||||
cond["c_crossattn_action"] = [
|
||||
observation['observation.images.top'][:, :, -model.n_obs_steps_acting:],
|
||||
observation['observation.state'][:, -model.n_obs_steps_acting:],
|
||||
sim_mode,
|
||||
False,
|
||||
]
|
||||
|
||||
# --- sub-stage: ddim_sampling ---
|
||||
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
||||
if autocast_dtype is not None and device.type == 'cuda':
|
||||
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
with CudaTimer(f"{prefix}/ddim_sampling", records):
|
||||
with autocast_ctx:
|
||||
samples, actions, states, _ = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=None,
|
||||
eta=ddim_eta,
|
||||
cfg_img=None,
|
||||
mask=None,
|
||||
x0=None,
|
||||
fs=fs_t,
|
||||
timestep_spacing=timestep_spacing,
|
||||
guidance_rescale=guidance_rescale,
|
||||
unconditional_conditioning_img_nonetext=None,
|
||||
)
|
||||
|
||||
# --- sub-stage: vae_decode ---
|
||||
batch_variants = None
|
||||
if decode_video:
|
||||
with CudaTimer(f"{prefix}/vae_decode", records):
|
||||
batch_variants = model.decode_first_stage(samples)
|
||||
else:
|
||||
records[f"{prefix}/vae_decode"].append(0.0)
|
||||
|
||||
return batch_variants, actions, states
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Instrumented iteration loop
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def run_profiled_iterations(model, args, config, noise_shape, device):
|
||||
"""Run the full iteration loop with per-stage timing.
|
||||
|
||||
Returns:
|
||||
all_records: list of dicts, one per itr, {stage_name: ms}
|
||||
"""
|
||||
# Load data
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
sample = df.iloc[0]
|
||||
|
||||
data_module = instantiate_from_config(config.data)
|
||||
data_module.setup()
|
||||
|
||||
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
||||
ori_fps = float(sample['fps'])
|
||||
fs = args.frame_stride
|
||||
model_input_fs = ori_fps // fs
|
||||
|
||||
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||
with h5py.File(transition_path, 'r') as h5f:
|
||||
transition_dict = {}
|
||||
for key in h5f.keys():
|
||||
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||
for key in h5f.attrs.keys():
|
||||
transition_dict[key] = h5f.attrs[key]
|
||||
|
||||
# Prepare initial observation
|
||||
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||
0, init_frame_path, transition_dict, fs,
|
||||
data_module.test_datasets[args.dataset],
|
||||
n_obs_steps=model.n_obs_steps_imagen)
|
||||
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
||||
'observation.state':
|
||||
batch['observation.state'][-1].unsqueeze(0),
|
||||
'action':
|
||||
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
||||
}
|
||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||
|
||||
cond_obs_queues = {
|
||||
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
||||
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
||||
"action": deque(maxlen=args.video_length),
|
||||
}
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||
|
||||
# Temp dir for save_results profiling
|
||||
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
|
||||
prompt_text = sample['instruction']
|
||||
all_records = []
|
||||
|
||||
print(f">>> Running {args.n_iter} profiled iterations ...")
|
||||
for itr in range(args.n_iter):
|
||||
rec = defaultdict(list)
|
||||
|
||||
# ── itr_total start ──
|
||||
torch.cuda.synchronize()
|
||||
itr_start = torch.cuda.Event(enable_timing=True)
|
||||
itr_end = torch.cuda.Event(enable_timing=True)
|
||||
itr_start.record()
|
||||
|
||||
# ① stack_to_device_1
|
||||
with CudaTimer("stack_to_device_1", rec):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||
dim=1).permute(0, 2, 1, 3, 4),
|
||||
'observation.state':
|
||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||
|
||||
# ② synth_policy
|
||||
with CudaTimer("synth_policy", rec):
|
||||
pred_videos_0, pred_actions, _ = profiled_synthesis(
|
||||
model, prompt_text, observation, noise_shape,
|
||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=model_input_fs, text_input=True,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False,
|
||||
decode_video=not args.fast_policy_no_decode,
|
||||
records=rec, prefix="policy")
|
||||
|
||||
# ③ update_action_queue
|
||||
with WallTimer("update_action_queue", rec):
|
||||
for idx in range(len(pred_actions[0])):
|
||||
obs_a = {'action': pred_actions[0][idx:idx + 1]}
|
||||
obs_a['action'][:, ori_action_dim:] = 0.0
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, obs_a)
|
||||
|
||||
# ④ stack_to_device_2
|
||||
with CudaTimer("stack_to_device_2", rec):
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||
dim=1).permute(0, 2, 1, 3, 4),
|
||||
'observation.state':
|
||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||
|
||||
# ⑤ synth_world_model
|
||||
with CudaTimer("synth_world_model", rec):
|
||||
pred_videos_1, _, pred_states = profiled_synthesis(
|
||||
model, "", observation, noise_shape,
|
||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=model_input_fs, text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=True, decode_video=True,
|
||||
records=rec, prefix="wm")
|
||||
|
||||
# ⑥ update_obs_queue
|
||||
with WallTimer("update_obs_queue", rec):
|
||||
for idx in range(args.exe_steps):
|
||||
obs_u = {
|
||||
'observation.images.top':
|
||||
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||
'observation.state':
|
||||
pred_states[0][idx:idx + 1],
|
||||
'action':
|
||||
torch.zeros_like(pred_actions[0][-1:]),
|
||||
}
|
||||
obs_u['observation.state'][:, ori_state_dim:] = 0.0
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, obs_u)
|
||||
|
||||
# ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost)
|
||||
with WallTimer("tensorboard_log", rec):
|
||||
for vid in [pred_videos_0, pred_videos_1]:
|
||||
if vid is not None and vid.dim() == 5:
|
||||
v = vid.permute(2, 0, 1, 3, 4)
|
||||
grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v]
|
||||
_ = torch.stack(grids, dim=0)
|
||||
|
||||
# ⑧ save_results
|
||||
with WallTimer("save_results", rec):
|
||||
if pred_videos_0 is not None:
|
||||
save_results(pred_videos_0.cpu(),
|
||||
os.path.join(tmp_dir, f"dm_{itr}.mp4"),
|
||||
fps=args.save_fps)
|
||||
save_results(pred_videos_1.cpu(),
|
||||
os.path.join(tmp_dir, f"wm_{itr}.mp4"),
|
||||
fps=args.save_fps)
|
||||
|
||||
# ⑨ cpu_transfer
|
||||
with CudaTimer("cpu_transfer", rec):
|
||||
_ = pred_videos_1[:, :, :args.exe_steps].cpu()
|
||||
|
||||
# ── itr_total end ──
|
||||
itr_end.record()
|
||||
torch.cuda.synchronize()
|
||||
itr_total_ms = itr_start.elapsed_time(itr_end)
|
||||
rec["itr_total"].append(itr_total_ms)
|
||||
|
||||
# Flatten: each stage has exactly one entry per itr
|
||||
itr_rec = {k: v[0] for k, v in rec.items()}
|
||||
all_records.append(itr_rec)
|
||||
|
||||
# Print live progress
|
||||
print(f" itr {itr}: {itr_total_ms:.0f} ms total | "
|
||||
f"policy={itr_rec.get('synth_policy', 0):.0f} | "
|
||||
f"wm={itr_rec.get('synth_world_model', 0):.0f} | "
|
||||
f"save={itr_rec.get('save_results', 0):.0f} | "
|
||||
f"tb={itr_rec.get('tensorboard_log', 0):.0f}")
|
||||
|
||||
return all_records
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Layer 1: Console report
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def print_iteration_report(all_records, warmup=1):
|
||||
"""Print a structured table of per-stage timing across iterations."""
|
||||
if len(all_records) <= warmup:
|
||||
records = all_records
|
||||
else:
|
||||
records = all_records[warmup:]
|
||||
print(f"\n(Skipping first {warmup} itr(s) as warmup)\n")
|
||||
|
||||
# Collect all stage keys in a stable order
|
||||
all_keys = []
|
||||
seen = set()
|
||||
for rec in records:
|
||||
for k in rec:
|
||||
if k not in seen:
|
||||
all_keys.append(k)
|
||||
seen.add(k)
|
||||
|
||||
# Separate top-level stages from sub-stages
|
||||
top_keys = [k for k in all_keys if '/' not in k]
|
||||
sub_keys = [k for k in all_keys if '/' in k]
|
||||
|
||||
def _print_table(keys, title):
|
||||
if not keys:
|
||||
return
|
||||
print("=" * 82)
|
||||
print(title)
|
||||
print("=" * 82)
|
||||
print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}")
|
||||
print("-" * 82)
|
||||
|
||||
total_mean = np.mean([rec.get("itr_total", 0) for rec in records])
|
||||
for k in keys:
|
||||
vals = [rec.get(k, 0) for rec in records]
|
||||
mean = np.mean(vals)
|
||||
std = np.std(vals)
|
||||
mn = np.min(vals)
|
||||
mx = np.max(vals)
|
||||
pct = mean / total_mean * 100 if total_mean > 0 else 0
|
||||
print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%")
|
||||
print("-" * 82)
|
||||
print()
|
||||
|
||||
_print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN")
|
||||
_print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Layer 3: CSV output for A/B comparison
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def write_csv(all_records, csv_path, warmup=1):
|
||||
"""Write per-iteration timing to CSV for later comparison."""
|
||||
records = all_records[warmup:] if len(all_records) > warmup else all_records
|
||||
|
||||
# Collect all keys
|
||||
all_keys = []
|
||||
seen = set()
|
||||
for rec in records:
|
||||
for k in rec:
|
||||
if k not in seen:
|
||||
all_keys.append(k)
|
||||
seen.add(k)
|
||||
|
||||
with open(csv_path, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys)
|
||||
writer.writeheader()
|
||||
for i, rec in enumerate(records):
|
||||
row = {'itr': i}
|
||||
row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys})
|
||||
writer.writerow(row)
|
||||
|
||||
# Also write a summary row
|
||||
summary_path = csv_path.replace('.csv', '_summary.csv')
|
||||
with open(summary_path, 'w', newline='') as f:
|
||||
writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys)
|
||||
writer.writeheader()
|
||||
for stat_name, stat_fn in [('mean', np.mean), ('std', np.std),
|
||||
('min', np.min), ('max', np.max)]:
|
||||
row = {'stat': stat_name}
|
||||
row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}"
|
||||
for k in all_keys})
|
||||
writer.writerow(row)
|
||||
|
||||
print(f">>> CSV written to: {csv_path}")
|
||||
print(f">>> Summary written to: {summary_path}")
|
||||
|
||||
|
||||
def compare_csvs(path_a, path_b):
|
||||
"""Compare two summary CSVs and print a diff table."""
|
||||
df_a = pd.read_csv(path_a, index_col='stat')
|
||||
df_b = pd.read_csv(path_b, index_col='stat')
|
||||
|
||||
# Use mean row for comparison
|
||||
mean_a = df_a.loc['mean'].astype(float)
|
||||
mean_b = df_b.loc['mean'].astype(float)
|
||||
|
||||
print("=" * 90)
|
||||
print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}")
|
||||
print("=" * 90)
|
||||
print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}")
|
||||
print("-" * 90)
|
||||
|
||||
for col in mean_a.index:
|
||||
if col not in mean_b.index:
|
||||
continue
|
||||
a_val = mean_a[col]
|
||||
b_val = mean_b[col]
|
||||
diff = b_val - a_val
|
||||
speedup = a_val / b_val if b_val > 0 else float('inf')
|
||||
marker = " <<<" if abs(diff) > 50 else ""
|
||||
print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}")
|
||||
|
||||
print("-" * 90)
|
||||
total_a = mean_a.get('itr_total', 0)
|
||||
total_b = mean_b.get('itr_total', 0)
|
||||
print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} "
|
||||
f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x")
|
||||
print()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Layer 2: GPU timeline trace wrapper
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def run_with_trace(model, args, config, noise_shape, device):
|
||||
"""Run iterations under torch.profiler to generate Chrome/TensorBoard traces."""
|
||||
trace_dir = args.trace_dir
|
||||
os.makedirs(trace_dir, exist_ok=True)
|
||||
|
||||
# We need the same data setup as run_profiled_iterations
|
||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||
df = pd.read_csv(csv_path)
|
||||
sample = df.iloc[0]
|
||||
|
||||
data_module = instantiate_from_config(config.data)
|
||||
data_module.setup()
|
||||
|
||||
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
||||
ori_fps = float(sample['fps'])
|
||||
fs = args.frame_stride
|
||||
model_input_fs = ori_fps // fs
|
||||
|
||||
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||
with h5py.File(transition_path, 'r') as h5f:
|
||||
transition_dict = {}
|
||||
for key in h5f.keys():
|
||||
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||
for key in h5f.attrs.keys():
|
||||
transition_dict[key] = h5f.attrs[key]
|
||||
|
||||
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||
0, init_frame_path, transition_dict, fs,
|
||||
data_module.test_datasets[args.dataset],
|
||||
n_obs_steps=model.n_obs_steps_imagen)
|
||||
|
||||
observation = {
|
||||
'observation.images.top':
|
||||
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
||||
'observation.state':
|
||||
batch['observation.state'][-1].unsqueeze(0),
|
||||
'action':
|
||||
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
||||
}
|
||||
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||
|
||||
cond_obs_queues = {
|
||||
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
||||
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
||||
"action": deque(maxlen=args.video_length),
|
||||
}
|
||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||
|
||||
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
||||
os.makedirs(tmp_dir, exist_ok=True)
|
||||
prompt_text = sample['instruction']
|
||||
|
||||
# Total iterations: warmup + active
|
||||
n_warmup = 1
|
||||
n_active = min(args.n_iter, 2) # trace 2 active iterations max
|
||||
n_total = n_warmup + n_active
|
||||
|
||||
print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations")
|
||||
print(f">>> Trace output: {trace_dir}")
|
||||
|
||||
with torch.no_grad(), torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=0, warmup=n_warmup, active=n_active, repeat=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
) as prof:
|
||||
for itr_idx in range(n_total):
|
||||
phase = "warmup" if itr_idx < n_warmup else "active"
|
||||
print(f" trace itr {itr_idx} ({phase})...")
|
||||
|
||||
# ── One full iteration (same logic as run_inference) ──
|
||||
obs_loc = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||
dim=1).permute(0, 2, 1, 3, 4),
|
||||
'observation.state':
|
||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
obs_loc = {k: v.to(device) for k, v in obs_loc.items()}
|
||||
|
||||
# Policy pass
|
||||
dummy_rec = defaultdict(list)
|
||||
pv0, pa, _ = profiled_synthesis(
|
||||
model, prompt_text, obs_loc, noise_shape,
|
||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=model_input_fs, text_input=True,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False,
|
||||
decode_video=not args.fast_policy_no_decode,
|
||||
records=dummy_rec, prefix="policy")
|
||||
|
||||
for idx in range(len(pa[0])):
|
||||
oa = {'action': pa[0][idx:idx + 1]}
|
||||
oa['action'][:, ori_action_dim:] = 0.0
|
||||
populate_queues(cond_obs_queues, oa)
|
||||
|
||||
# Re-stack for world model
|
||||
obs_loc2 = {
|
||||
'observation.images.top':
|
||||
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||
dim=1).permute(0, 2, 1, 3, 4),
|
||||
'observation.state':
|
||||
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||
'action':
|
||||
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||
}
|
||||
obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()}
|
||||
|
||||
# World model pass
|
||||
pv1, _, ps = profiled_synthesis(
|
||||
model, "", obs_loc2, noise_shape,
|
||||
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||
fs=model_input_fs, text_input=False,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=True, decode_video=True,
|
||||
records=dummy_rec, prefix="wm")
|
||||
|
||||
# Update obs queue
|
||||
for idx in range(args.exe_steps):
|
||||
ou = {
|
||||
'observation.images.top':
|
||||
pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||
'observation.state': ps[0][idx:idx + 1],
|
||||
'action': torch.zeros_like(pa[0][-1:]),
|
||||
}
|
||||
ou['observation.state'][:, ori_state_dim:] = 0.0
|
||||
populate_queues(cond_obs_queues, ou)
|
||||
|
||||
# Save results (captures CPU stall in trace)
|
||||
if pv0 is not None:
|
||||
save_results(pv0.cpu(),
|
||||
os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"),
|
||||
fps=args.save_fps)
|
||||
save_results(pv1.cpu(),
|
||||
os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"),
|
||||
fps=args.save_fps)
|
||||
|
||||
prof.step()
|
||||
|
||||
print(f">>> Trace saved to {trace_dir}")
|
||||
print(" View with: tensorboard --logdir", trace_dir)
|
||||
print(" Or open the .json file in chrome://tracing")
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Argument parser
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def get_parser():
|
||||
p = argparse.ArgumentParser(description="Profile full iteration loop")
|
||||
|
||||
# Compare mode (no model needed)
|
||||
p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"),
|
||||
help="Compare two summary CSVs and exit")
|
||||
|
||||
# Model / data
|
||||
p.add_argument("--ckpt_path", type=str, default=None)
|
||||
p.add_argument("--config", type=str, default=None)
|
||||
p.add_argument("--prompt_dir", type=str, default=None)
|
||||
p.add_argument("--dataset", type=str, default=None)
|
||||
p.add_argument("--savedir", type=str, default="profile_output")
|
||||
|
||||
# Inference params (match world_model_interaction.py)
|
||||
p.add_argument("--ddim_steps", type=int, default=50)
|
||||
p.add_argument("--ddim_eta", type=float, default=1.0)
|
||||
p.add_argument("--bs", type=int, default=1)
|
||||
p.add_argument("--height", type=int, default=320)
|
||||
p.add_argument("--width", type=int, default=512)
|
||||
p.add_argument("--frame_stride", type=int, default=4)
|
||||
p.add_argument("--unconditional_guidance_scale", type=float, default=1.0)
|
||||
p.add_argument("--video_length", type=int, default=16)
|
||||
p.add_argument("--timestep_spacing", type=str, default="uniform_trailing")
|
||||
p.add_argument("--guidance_rescale", type=float, default=0.7)
|
||||
p.add_argument("--exe_steps", type=int, default=16)
|
||||
p.add_argument("--n_iter", type=int, default=5)
|
||||
p.add_argument("--save_fps", type=int, default=8)
|
||||
p.add_argument("--seed", type=int, default=123)
|
||||
p.add_argument("--perframe_ae", action='store_true', default=False)
|
||||
p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16")
|
||||
p.add_argument("--fast_policy_no_decode", action='store_true', default=False)
|
||||
|
||||
# Profiling control
|
||||
p.add_argument("--warmup", type=int, default=1,
|
||||
help="Number of warmup iterations to skip in statistics")
|
||||
p.add_argument("--csv", type=str, default=None,
|
||||
help="Write per-iteration timing to this CSV file")
|
||||
p.add_argument("--trace", action='store_true', default=False,
|
||||
help="Enable Layer 2: GPU timeline trace")
|
||||
p.add_argument("--trace_dir", type=str, default="./profile_traces",
|
||||
help="Directory for trace output")
|
||||
|
||||
return p
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Main
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
def main():
|
||||
patch_norm_bypass_autocast()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# ── Compare mode: no model needed ──
|
||||
if args.compare:
|
||||
compare_csvs(args.compare[0], args.compare[1])
|
||||
return
|
||||
|
||||
# ── Validate required args ──
|
||||
for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']:
|
||||
if getattr(args, required) is None:
|
||||
parser.error(f"--{required} is required for profiling mode")
|
||||
|
||||
seed_everything(args.seed)
|
||||
os.makedirs(args.savedir, exist_ok=True)
|
||||
|
||||
# ── Load model ──
|
||||
print("=" * 60)
|
||||
print("PROFILE ITERATION — Loading model...")
|
||||
print("=" * 60)
|
||||
model, config = load_model(args)
|
||||
device = next(model.parameters()).device
|
||||
|
||||
h, w = args.height // 8, args.width // 8
|
||||
channels = model.model.diffusion_model.out_channels
|
||||
noise_shape = [args.bs, channels, args.video_length, h, w]
|
||||
print(f">>> Noise shape: {noise_shape}")
|
||||
print(f">>> DDIM steps: {args.ddim_steps}")
|
||||
print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}")
|
||||
|
||||
# ── Layer 2: GPU trace (optional) ──
|
||||
if args.trace:
|
||||
with torch.no_grad():
|
||||
run_with_trace(model, args, config, noise_shape, device)
|
||||
print()
|
||||
|
||||
# ── Layer 1: Iteration-level breakdown ──
|
||||
print("=" * 60)
|
||||
print("LAYER 1: ITERATION-LEVEL PROFILING")
|
||||
print("=" * 60)
|
||||
with torch.no_grad():
|
||||
all_records = run_profiled_iterations(
|
||||
model, args, config, noise_shape, device)
|
||||
|
||||
# Print report
|
||||
print_iteration_report(all_records, warmup=args.warmup)
|
||||
|
||||
# ── Layer 3: CSV output for A/B comparison ──
|
||||
if args.csv:
|
||||
write_csv(all_records, args.csv, warmup=args.warmup)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
733
scripts/evaluation/profile_pipeline.py
Normal file
733
scripts/evaluation/profile_pipeline.py
Normal file
@@ -0,0 +1,733 @@
|
||||
"""
|
||||
Profile the full inference pipeline of the world model, covering all 7 stages:
|
||||
1. Image Embedding
|
||||
2. VAE Encode
|
||||
3. Text Conditioning
|
||||
4. State/Action Projectors
|
||||
5. DDIM Loop
|
||||
6. VAE Decode
|
||||
7. Post-process
|
||||
|
||||
Reports stage-level timing, UNet sub-module breakdown, memory summary,
|
||||
and throughput analysis.
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep
|
||||
Usage:
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common
|
||||
from contextlib import nullcontext, contextmanager
|
||||
from collections import defaultdict
|
||||
from omegaconf import OmegaConf
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.modules.attention import (
|
||||
SpatialTransformer, TemporalTransformer,
|
||||
BasicTransformerBlock, CrossAttention, FeedForward,
|
||||
)
|
||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
|
||||
# --- W7900D theoretical peak ---
|
||||
PEAK_BF16_TFLOPS = 61.0
|
||||
MEM_BW_GBS = 864.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility: patch norms to bypass autocast fp32 promotion
|
||||
# ---------------------------------------------------------------------------
|
||||
def patch_norm_bypass_autocast():
|
||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
||||
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Utility: torch.compile hot ResBlocks
|
||||
# ---------------------------------------------------------------------------
|
||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||
unet = model.model.diffusion_model
|
||||
compiled = 0
|
||||
for idx in hot_indices:
|
||||
block = unet.output_blocks[idx]
|
||||
for layer in block:
|
||||
if isinstance(layer, ResBlock):
|
||||
layer._forward = torch.compile(layer._forward, mode="default")
|
||||
compiled += 1
|
||||
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model loading
|
||||
# ---------------------------------------------------------------------------
|
||||
def load_model(args):
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
model.eval()
|
||||
model.model.to(torch.bfloat16)
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
apply_torch_compile(model)
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CudaTimer — precise GPU timing via CUDA events
|
||||
# ---------------------------------------------------------------------------
|
||||
class CudaTimer:
|
||||
"""Context manager for GPU-precise stage timing using CUDA events."""
|
||||
|
||||
def __init__(self, name, records):
|
||||
self.name = name
|
||||
self.records = records
|
||||
self.start = torch.cuda.Event(enable_timing=True)
|
||||
self.end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
def __enter__(self):
|
||||
torch.cuda.synchronize()
|
||||
self.start.record()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.end.record()
|
||||
torch.cuda.synchronize()
|
||||
elapsed = self.start.elapsed_time(self.end)
|
||||
self.records[self.name].append(elapsed)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HookProfiler — sub-module level timing inside UNet via hooks
|
||||
# ---------------------------------------------------------------------------
|
||||
class HookProfiler:
|
||||
"""Register forward hooks on UNet sub-modules to collect per-call timing."""
|
||||
|
||||
# Coarse-grained targets (original)
|
||||
COARSE_CLASSES = (
|
||||
SpatialTransformer,
|
||||
TemporalTransformer,
|
||||
ResBlock,
|
||||
ConditionalUnet1D,
|
||||
)
|
||||
|
||||
# Fine-grained targets for deep DDIM analysis
|
||||
FINE_CLASSES = (
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
)
|
||||
|
||||
def __init__(self, unet, deep=False):
|
||||
self.unet = unet
|
||||
self.deep = deep
|
||||
self.handles = []
|
||||
# per-instance data: {instance_id: [(start_event, end_event), ...]}
|
||||
self._events = defaultdict(list)
|
||||
# tag mapping: {instance_id: (class_name, module_name)}
|
||||
self._tags = {}
|
||||
# block location: {instance_id: block_location_str}
|
||||
self._block_loc = {}
|
||||
|
||||
@staticmethod
|
||||
def _get_block_location(name):
|
||||
"""Derive UNet block location from module name, e.g. 'input_blocks.3.1'."""
|
||||
parts = name.split('.')
|
||||
if len(parts) >= 2 and parts[0] == 'input_blocks':
|
||||
return f"input_blocks.{parts[1]}"
|
||||
elif len(parts) >= 1 and parts[0] == 'middle_block':
|
||||
return "middle_block"
|
||||
elif len(parts) >= 2 and parts[0] == 'output_blocks':
|
||||
return f"output_blocks.{parts[1]}"
|
||||
elif 'action_unet' in name:
|
||||
return "action_unet"
|
||||
elif 'state_unet' in name:
|
||||
return "state_unet"
|
||||
elif name == 'out' or name.startswith('out.'):
|
||||
return "out"
|
||||
return "other"
|
||||
|
||||
def register(self):
|
||||
"""Attach pre/post forward hooks to target sub-modules + unet.out."""
|
||||
target_classes = self.COARSE_CLASSES
|
||||
if self.deep:
|
||||
target_classes = target_classes + self.FINE_CLASSES
|
||||
|
||||
for name, mod in self.unet.named_modules():
|
||||
if isinstance(mod, target_classes):
|
||||
tag = type(mod).__name__
|
||||
inst_id = id(mod)
|
||||
self._tags[inst_id] = (tag, name)
|
||||
self._block_loc[inst_id] = self._get_block_location(name)
|
||||
self.handles.append(
|
||||
mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
|
||||
self.handles.append(
|
||||
mod.register_forward_hook(self._make_post_hook(inst_id)))
|
||||
|
||||
# Also hook unet.out (nn.Sequential)
|
||||
out_mod = self.unet.out
|
||||
inst_id = id(out_mod)
|
||||
self._tags[inst_id] = ("UNet.out", "out")
|
||||
self._block_loc[inst_id] = "out"
|
||||
self.handles.append(
|
||||
out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
|
||||
self.handles.append(
|
||||
out_mod.register_forward_hook(self._make_post_hook(inst_id)))
|
||||
|
||||
def _make_pre_hook(self, inst_id):
|
||||
events = self._events
|
||||
|
||||
def hook(module, input):
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
start.record()
|
||||
events[inst_id].append([start, None])
|
||||
return hook
|
||||
|
||||
def _make_post_hook(self, inst_id):
|
||||
events = self._events
|
||||
|
||||
def hook(module, input, output):
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
end.record()
|
||||
events[inst_id][-1][1] = end
|
||||
return hook
|
||||
|
||||
def reset(self):
|
||||
"""Clear collected events for a fresh run."""
|
||||
self._events.clear()
|
||||
|
||||
def synchronize_and_collect(self):
|
||||
"""Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block)."""
|
||||
torch.cuda.synchronize()
|
||||
by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []})
|
||||
by_instance = {}
|
||||
# by_block: {block_loc: {tag: {"total_ms", "count"}}}
|
||||
by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0}))
|
||||
|
||||
for inst_id, pairs in self._events.items():
|
||||
tag, mod_name = self._tags[inst_id]
|
||||
block_loc = self._block_loc.get(inst_id, "other")
|
||||
inst_times = []
|
||||
for start_evt, end_evt in pairs:
|
||||
if end_evt is not None:
|
||||
ms = start_evt.elapsed_time(end_evt)
|
||||
inst_times.append(ms)
|
||||
by_type[tag]["total_ms"] += ms
|
||||
by_type[tag]["count"] += 1
|
||||
by_type[tag]["calls"].append(ms)
|
||||
by_block[block_loc][tag]["total_ms"] += ms
|
||||
by_block[block_loc][tag]["count"] += 1
|
||||
by_instance[(tag, mod_name)] = inst_times
|
||||
|
||||
return dict(by_type), by_instance, dict(by_block)
|
||||
|
||||
def remove(self):
|
||||
"""Remove all hooks."""
|
||||
for h in self.handles:
|
||||
h.remove()
|
||||
self.handles.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Build dummy inputs matching the pipeline's expected shapes
|
||||
# ---------------------------------------------------------------------------
|
||||
def build_dummy_inputs(model, noise_shape):
|
||||
"""Create synthetic observation dict and prompts for profiling."""
|
||||
device = next(model.parameters()).device
|
||||
B, C, T, H, W = noise_shape
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline)
|
||||
O = 2
|
||||
obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype)
|
||||
obs_state = torch.randn(B, O, 16, device=device, dtype=dtype)
|
||||
action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||
|
||||
observation = {
|
||||
'observation.images.top': obs_images,
|
||||
'observation.state': obs_state,
|
||||
'action': action,
|
||||
}
|
||||
prompts = ["a robot arm performing a task"] * B
|
||||
return observation, prompts
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Run one full pipeline pass with per-stage timing
|
||||
# ---------------------------------------------------------------------------
|
||||
def run_pipeline(model, observation, prompts, noise_shape, ddim_steps,
|
||||
cfg_scale, hook_profiler):
|
||||
"""Execute the full 7-stage pipeline, returning per-stage timing dict."""
|
||||
records = defaultdict(list)
|
||||
device = next(model.parameters()).device
|
||||
B, C, T, H, W = noise_shape
|
||||
dtype = torch.bfloat16
|
||||
|
||||
fs = torch.tensor([1] * B, dtype=torch.long, device=device)
|
||||
|
||||
# --- Stage 1: Image Embedding ---
|
||||
with CudaTimer("1_Image_Embedding", records):
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype)
|
||||
with torch.autocast('cuda', dtype=torch.bfloat16):
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
# --- Stage 2: VAE Encode ---
|
||||
with CudaTimer("2_VAE_Encode", records):
|
||||
videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W]
|
||||
b_v, c_v, t_v, h_v, w_v = videos.shape
|
||||
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||
x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype)
|
||||
z = model.encode_first_stage(x_vae)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v)
|
||||
img_cat_cond = z[:, :, -1:, :, :]
|
||||
img_cat_cond = repeat(img_cat_cond,
|
||||
'b c t h w -> b c (repeat t) h w', repeat=T)
|
||||
cond = {"c_concat": [img_cat_cond]}
|
||||
|
||||
vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size()
|
||||
vae_enc_output_bytes = z.nelement() * z.element_size()
|
||||
|
||||
# --- Stage 3: Text Conditioning ---
|
||||
with CudaTimer("3_Text_Conditioning", records):
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
|
||||
# --- Stage 4: State/Action Projectors ---
|
||||
with CudaTimer("4_Projectors", records):
|
||||
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||
with torch.autocast('cuda', dtype=torch.bfloat16):
|
||||
cond_state_emb = model.state_projector(
|
||||
observation['observation.state'].to(dtype=projector_dtype))
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
|
||||
cond_action_emb = model.action_projector(
|
||||
observation['action'].to(dtype=projector_dtype))
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
|
||||
# Assemble cross-attention conditioning
|
||||
cond["c_crossattn"] = [
|
||||
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
|
||||
dim=1)
|
||||
]
|
||||
n_obs_acting = getattr(model, 'n_obs_steps_acting', 2)
|
||||
cond["c_crossattn_action"] = [
|
||||
observation['observation.images.top'][:, :, -n_obs_acting:],
|
||||
observation['observation.state'][:, -n_obs_acting:],
|
||||
True, # sim_mode
|
||||
False,
|
||||
]
|
||||
|
||||
# CFG: build unconditional conditioning if needed
|
||||
uc = None
|
||||
if cfg_scale != 1.0:
|
||||
uc_crossattn = torch.zeros_like(cond["c_crossattn"][0])
|
||||
uc = {
|
||||
"c_concat": cond["c_concat"],
|
||||
"c_crossattn": [uc_crossattn],
|
||||
"c_crossattn_action": cond["c_crossattn_action"],
|
||||
}
|
||||
|
||||
# --- Stage 5: DDIM Loop ---
|
||||
ddim_sampler = DDIMSampler(model)
|
||||
hook_profiler.reset()
|
||||
|
||||
with CudaTimer("5_DDIM_Loop", records):
|
||||
with torch.autocast('cuda', dtype=torch.bfloat16):
|
||||
samples, actions, states, _ = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=B,
|
||||
shape=noise_shape[1:],
|
||||
verbose=False,
|
||||
unconditional_guidance_scale=cfg_scale,
|
||||
unconditional_conditioning=uc,
|
||||
eta=1.0,
|
||||
cfg_img=None,
|
||||
mask=None,
|
||||
x0=None,
|
||||
fs=fs,
|
||||
timestep_spacing='uniform',
|
||||
guidance_rescale=0.0,
|
||||
unconditional_conditioning_img_nonetext=None,
|
||||
)
|
||||
|
||||
hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect()
|
||||
|
||||
# --- Stage 6: VAE Decode ---
|
||||
with CudaTimer("6_VAE_Decode", records):
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
|
||||
vae_dec_input_bytes = samples.nelement() * samples.element_size()
|
||||
vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size()
|
||||
|
||||
# --- Stage 7: Post-process ---
|
||||
with CudaTimer("7_Post_Process", records):
|
||||
batch_images_cpu = batch_images.cpu()
|
||||
actions_cpu = actions.cpu()
|
||||
states_cpu = states.cpu()
|
||||
# Simulate video save overhead: clamp + uint8 conversion
|
||||
_ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8)
|
||||
|
||||
# Flatten single-element lists
|
||||
stage_times = {k: v[0] for k, v in records.items()}
|
||||
|
||||
bandwidth_info = {
|
||||
"vae_enc_input_bytes": vae_enc_input_bytes,
|
||||
"vae_enc_output_bytes": vae_enc_output_bytes,
|
||||
"vae_dec_input_bytes": vae_dec_input_bytes,
|
||||
"vae_dec_output_bytes": vae_dec_output_bytes,
|
||||
}
|
||||
|
||||
return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reporting
|
||||
# ---------------------------------------------------------------------------
|
||||
def print_stage_timing(all_runs_stages):
|
||||
"""Table 1: Stage Timing — name | mean(ms) | std | percent."""
|
||||
import numpy as np
|
||||
stage_names = list(all_runs_stages[0].keys())
|
||||
means = {}
|
||||
stds = {}
|
||||
for name in stage_names:
|
||||
vals = [run[name] for run in all_runs_stages]
|
||||
means[name] = np.mean(vals)
|
||||
stds[name] = np.std(vals)
|
||||
total = sum(means.values())
|
||||
|
||||
print()
|
||||
print("=" * 72)
|
||||
print("TABLE 1: STAGE TIMING")
|
||||
print("=" * 72)
|
||||
print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}")
|
||||
print("-" * 72)
|
||||
for name in stage_names:
|
||||
pct = means[name] / total * 100 if total > 0 else 0
|
||||
print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%")
|
||||
print("-" * 72)
|
||||
print(f"{'TOTAL':<25} {total:>10.1f}")
|
||||
print()
|
||||
|
||||
|
||||
def print_unet_breakdown(all_runs_hooks):
|
||||
"""Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent."""
|
||||
import numpy as np
|
||||
# Aggregate across runs
|
||||
agg = defaultdict(lambda: {"totals": [], "counts": []})
|
||||
for hook_by_type in all_runs_hooks:
|
||||
for tag, data in hook_by_type.items():
|
||||
agg[tag]["totals"].append(data["total_ms"])
|
||||
agg[tag]["counts"].append(data["count"])
|
||||
|
||||
print("=" * 80)
|
||||
print("TABLE 2: UNET SUB-MODULE BREAKDOWN")
|
||||
print("=" * 80)
|
||||
print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}")
|
||||
print("-" * 80)
|
||||
|
||||
grand_total = 0
|
||||
rows = []
|
||||
for tag, d in agg.items():
|
||||
mean_total = np.mean(d["totals"])
|
||||
mean_count = np.mean(d["counts"])
|
||||
per_call = mean_total / mean_count if mean_count > 0 else 0
|
||||
grand_total += mean_total
|
||||
rows.append((tag, mean_total, mean_count, per_call))
|
||||
|
||||
rows.sort(key=lambda r: r[1], reverse=True)
|
||||
for tag, mean_total, mean_count, per_call in rows:
|
||||
pct = mean_total / grand_total * 100 if grand_total > 0 else 0
|
||||
print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%")
|
||||
print("-" * 80)
|
||||
print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}")
|
||||
print()
|
||||
|
||||
|
||||
def print_block_timing(all_runs_blocks):
|
||||
"""Table 2b: Per-UNet-block timing — which blocks are hottest."""
|
||||
import numpy as np
|
||||
# Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}}
|
||||
agg = defaultdict(lambda: defaultdict(list))
|
||||
for by_block in all_runs_blocks:
|
||||
for block_loc, tag_dict in by_block.items():
|
||||
for tag, data in tag_dict.items():
|
||||
agg[block_loc][tag].append(data["total_ms"])
|
||||
|
||||
# Compute per-block totals
|
||||
block_totals = {}
|
||||
for block_loc, tag_dict in agg.items():
|
||||
block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values())
|
||||
|
||||
grand_total = sum(block_totals.values())
|
||||
|
||||
# Sort blocks in logical order
|
||||
def block_sort_key(name):
|
||||
if name.startswith("input_blocks."):
|
||||
return (0, int(name.split('.')[1]))
|
||||
elif name == "middle_block":
|
||||
return (1, 0)
|
||||
elif name.startswith("output_blocks."):
|
||||
return (2, int(name.split('.')[1]))
|
||||
elif name == "out":
|
||||
return (3, 0)
|
||||
elif name == "action_unet":
|
||||
return (4, 0)
|
||||
elif name == "state_unet":
|
||||
return (5, 0)
|
||||
return (9, 0)
|
||||
|
||||
sorted_blocks = sorted(block_totals.keys(), key=block_sort_key)
|
||||
|
||||
print("=" * 90)
|
||||
print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)")
|
||||
print("=" * 90)
|
||||
print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown")
|
||||
print("-" * 90)
|
||||
|
||||
for block_loc in sorted_blocks:
|
||||
total = block_totals[block_loc]
|
||||
pct = total / grand_total * 100 if grand_total > 0 else 0
|
||||
# Build breakdown string
|
||||
parts = []
|
||||
for tag, vals in sorted(agg[block_loc].items(),
|
||||
key=lambda x: np.mean(x[1]), reverse=True):
|
||||
parts.append(f"{tag}={np.mean(vals):.0f}")
|
||||
breakdown = ", ".join(parts)
|
||||
print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}")
|
||||
|
||||
print("-" * 90)
|
||||
print(f"{'TOTAL':<22} {grand_total:>10.1f}")
|
||||
print()
|
||||
|
||||
|
||||
def print_attn_ff_breakdown(all_runs_hooks):
|
||||
"""Table 2c: CrossAttention vs FeedForward breakdown (--deep mode)."""
|
||||
import numpy as np
|
||||
agg = defaultdict(list)
|
||||
for hook_by_type in all_runs_hooks:
|
||||
for tag, data in hook_by_type.items():
|
||||
if tag in ("CrossAttention", "FeedForward"):
|
||||
agg[tag].append(data["total_ms"])
|
||||
|
||||
if not agg:
|
||||
return
|
||||
|
||||
print("=" * 70)
|
||||
print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)")
|
||||
print("=" * 70)
|
||||
print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}")
|
||||
print("-" * 70)
|
||||
|
||||
grand = 0
|
||||
rows = []
|
||||
for tag in ("CrossAttention", "FeedForward"):
|
||||
if tag in agg:
|
||||
mean_t = np.mean(agg[tag])
|
||||
grand += mean_t
|
||||
rows.append((tag, mean_t))
|
||||
|
||||
for tag, mean_t in rows:
|
||||
pct = mean_t / grand * 100 if grand > 0 else 0
|
||||
print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%")
|
||||
print("-" * 70)
|
||||
print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}")
|
||||
print()
|
||||
|
||||
|
||||
def print_unet_detailed(all_runs_instances):
|
||||
"""Print per-instance UNet sub-module detail (--detailed mode)."""
|
||||
import numpy as np
|
||||
# Use last run's data
|
||||
by_instance = all_runs_instances[-1]
|
||||
print("=" * 100)
|
||||
print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)")
|
||||
print("=" * 100)
|
||||
print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}")
|
||||
print("-" * 100)
|
||||
|
||||
rows = []
|
||||
for (tag, mod_name), times in by_instance.items():
|
||||
if len(times) == 0:
|
||||
continue
|
||||
total = sum(times)
|
||||
mean = np.mean(times)
|
||||
rows.append((tag, mod_name, len(times), total, mean))
|
||||
rows.sort(key=lambda r: r[3], reverse=True)
|
||||
|
||||
for tag, mod_name, count, total, mean in rows:
|
||||
short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name
|
||||
print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}")
|
||||
print()
|
||||
|
||||
|
||||
def print_memory_summary(mem_before, mem_peak):
|
||||
"""Table 3: Memory Summary."""
|
||||
delta = mem_peak - mem_before
|
||||
print("=" * 50)
|
||||
print("TABLE 3: MEMORY SUMMARY")
|
||||
print("=" * 50)
|
||||
print(f" Initial allocated: {mem_before / 1e9:.2f} GB")
|
||||
print(f" Peak allocated: {mem_peak / 1e9:.2f} GB")
|
||||
print(f" Delta (pipeline): {delta / 1e9:.2f} GB")
|
||||
print()
|
||||
|
||||
|
||||
def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale):
|
||||
"""Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth."""
|
||||
import numpy as np
|
||||
n_runs = len(all_runs_stages)
|
||||
|
||||
# Total latency
|
||||
totals = []
|
||||
for run in all_runs_stages:
|
||||
totals.append(sum(run.values()))
|
||||
mean_total = np.mean(totals)
|
||||
|
||||
# DDIM loop time
|
||||
ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages]
|
||||
mean_ddim = np.mean(ddim_times)
|
||||
|
||||
unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2
|
||||
per_step = mean_ddim / ddim_steps
|
||||
per_unet = mean_ddim / unet_calls
|
||||
|
||||
# VAE bandwidth
|
||||
mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages])
|
||||
mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages])
|
||||
|
||||
bw = all_bw[-1] # use last run's byte counts
|
||||
enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"]
|
||||
dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"]
|
||||
enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0
|
||||
dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0
|
||||
|
||||
print("=" * 60)
|
||||
print("TABLE 4: THROUGHPUT")
|
||||
print("=" * 60)
|
||||
print(f" Total pipeline latency: {mean_total:.1f} ms")
|
||||
print(f" DDIM loop latency: {mean_ddim:.1f} ms")
|
||||
print(f" DDIM steps: {ddim_steps}")
|
||||
print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})")
|
||||
print(f" UNet forward calls: {unet_calls}")
|
||||
print(f" Per DDIM step: {per_step:.1f} ms")
|
||||
print(f" Per UNet forward: {per_unet:.1f} ms")
|
||||
print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
|
||||
print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
|
||||
print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS")
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
def main():
|
||||
patch_norm_bypass_autocast()
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Profile the full inference pipeline")
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
parser.add_argument("--ddim_steps", type=int, default=50)
|
||||
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
||||
parser.add_argument("--n_runs", type=int, default=3)
|
||||
parser.add_argument("--warmup", type=int, default=1)
|
||||
parser.add_argument("--detailed", action="store_true",
|
||||
help="Print per-instance UNet sub-module detail")
|
||||
parser.add_argument("--deep", action="store_true",
|
||||
help="Enable deep DDIM analysis: per-block, attn vs ff")
|
||||
args = parser.parse_args()
|
||||
|
||||
noise_shape = [1, 4, 16, 40, 64]
|
||||
|
||||
# --- Load model ---
|
||||
print("Loading model...")
|
||||
model = load_model(args)
|
||||
observation, prompts = build_dummy_inputs(model, noise_shape)
|
||||
|
||||
# --- Setup hook profiler ---
|
||||
unet = model.model.diffusion_model
|
||||
hook_profiler = HookProfiler(unet, deep=args.deep)
|
||||
hook_profiler.register()
|
||||
print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules")
|
||||
|
||||
# --- Warmup ---
|
||||
print(f"Warmup: {args.warmup} run(s)...")
|
||||
with torch.no_grad():
|
||||
for i in range(args.warmup):
|
||||
run_pipeline(model, observation, prompts, noise_shape,
|
||||
args.ddim_steps, args.cfg_scale, hook_profiler)
|
||||
print(f" warmup {i+1}/{args.warmup} done")
|
||||
|
||||
# --- Measurement runs ---
|
||||
print(f"Measuring: {args.n_runs} run(s)...")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
mem_before = torch.cuda.memory_allocated()
|
||||
|
||||
all_stages = []
|
||||
all_hooks = []
|
||||
all_instances = []
|
||||
all_blocks = []
|
||||
all_bw = []
|
||||
|
||||
with torch.no_grad():
|
||||
for i in range(args.n_runs):
|
||||
stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline(
|
||||
model, observation, prompts, noise_shape,
|
||||
args.ddim_steps, args.cfg_scale, hook_profiler)
|
||||
all_stages.append(stage_times)
|
||||
all_hooks.append(hook_by_type)
|
||||
all_instances.append(hook_by_instance)
|
||||
all_blocks.append(hook_by_block)
|
||||
all_bw.append(bw)
|
||||
total = sum(stage_times.values())
|
||||
print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total")
|
||||
|
||||
mem_peak = torch.cuda.max_memory_allocated()
|
||||
|
||||
# --- Reports ---
|
||||
print_stage_timing(all_stages)
|
||||
print_unet_breakdown(all_hooks)
|
||||
print_block_timing(all_blocks)
|
||||
if args.deep:
|
||||
print_attn_ff_breakdown(all_hooks)
|
||||
if args.detailed:
|
||||
print_unet_detailed(all_instances)
|
||||
print_memory_summary(mem_before, mem_peak)
|
||||
print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale)
|
||||
|
||||
hook_profiler.remove()
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
287
scripts/evaluation/profile_unet.py
Normal file
287
scripts/evaluation/profile_unet.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""
|
||||
Profile one DDIM sampling iteration to capture all matmul/attention ops,
|
||||
their matrix sizes, wall time, and compute utilization.
|
||||
|
||||
Uses torch.profiler for CUDA timing and FlopCounterMode for accurate
|
||||
FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS).
|
||||
|
||||
Usage:
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
||||
python scripts/evaluation/profile_unet.py \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import OrderedDict, defaultdict
|
||||
from omegaconf import OmegaConf
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def patch_norm_bypass_autocast():
|
||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
||||
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
# --- W7900D theoretical peak (TFLOPS) ---
|
||||
PEAK_BF16_TFLOPS = 61.0
|
||||
PEAK_FP32_TFLOPS = 30.5
|
||||
|
||||
|
||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||
unet = model.model.diffusion_model
|
||||
compiled = 0
|
||||
for idx in hot_indices:
|
||||
block = unet.output_blocks[idx]
|
||||
for layer in block:
|
||||
if isinstance(layer, ResBlock):
|
||||
layer._forward = torch.compile(layer._forward, mode="default")
|
||||
compiled += 1
|
||||
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||
|
||||
|
||||
def load_model(args):
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||
if "state_dict" in state_dict:
|
||||
state_dict = state_dict["state_dict"]
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
model.eval()
|
||||
model.model.to(torch.bfloat16)
|
||||
apply_torch_compile(model)
|
||||
model = model.cuda()
|
||||
return model
|
||||
|
||||
|
||||
def build_call_kwargs(model, noise_shape):
|
||||
"""Build dummy inputs matching the hybrid conditioning forward signature."""
|
||||
device = next(model.parameters()).device
|
||||
B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64]
|
||||
dtype = torch.bfloat16
|
||||
|
||||
x_action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||
x_state = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||
timesteps = torch.tensor([500], device=device, dtype=torch.long)
|
||||
context = torch.randn(B, 351, 1024, device=device, dtype=dtype)
|
||||
|
||||
obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype)
|
||||
obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype)
|
||||
context_action = [obs_images, obs_state, True, False]
|
||||
fps = torch.tensor([1], device=device, dtype=torch.long)
|
||||
|
||||
x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype)
|
||||
c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)]
|
||||
|
||||
return dict(
|
||||
x=x_raw, x_action=x_action, x_state=x_state, t=timesteps,
|
||||
c_concat=c_concat, c_crossattn=[context],
|
||||
c_crossattn_action=context_action, s=fps,
|
||||
)
|
||||
|
||||
|
||||
def profile_one_step(model, noise_shape):
|
||||
"""Run one UNet forward pass under torch.profiler for CUDA timing."""
|
||||
diff_wrapper = model.model
|
||||
call_kwargs = build_call_kwargs(model, noise_shape)
|
||||
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
||||
# Warmup
|
||||
for _ in range(2):
|
||||
_ = diff_wrapper(**call_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[torch.profiler.ProfilerActivity.CUDA],
|
||||
record_shapes=True,
|
||||
with_flops=True,
|
||||
) as prof:
|
||||
_ = diff_wrapper(**call_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return prof
|
||||
|
||||
|
||||
def count_flops(model, noise_shape):
|
||||
"""Run one UNet forward pass under FlopCounterMode for accurate FLOPS."""
|
||||
diff_wrapper = model.model
|
||||
call_kwargs = build_call_kwargs(model, noise_shape)
|
||||
|
||||
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
||||
flop_counter = FlopCounterMode(display=False)
|
||||
with flop_counter:
|
||||
_ = diff_wrapper(**call_kwargs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return flop_counter
|
||||
|
||||
|
||||
def print_report(prof, flop_counter):
|
||||
"""Parse profiler results and print a structured report with accurate FLOPS."""
|
||||
events = prof.key_averages()
|
||||
|
||||
# --- Extract per-operator FLOPS from FlopCounterMode ---
|
||||
# flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting
|
||||
flop_by_op = {}
|
||||
flop_by_module = {}
|
||||
if hasattr(flop_counter, 'flop_counts'):
|
||||
# Per-op: only from top-level "Global" entry (no parent/child duplication)
|
||||
global_ops = flop_counter.flop_counts.get("Global", {})
|
||||
for op_name, flop_count in global_ops.items():
|
||||
key = str(op_name).split('.')[-1]
|
||||
flop_by_op[key] = flop_by_op.get(key, 0) + flop_count
|
||||
|
||||
# Per-module: collect all, skip "Global" and top-level wrapper duplicates
|
||||
for module_name, op_dict in flop_counter.flop_counts.items():
|
||||
module_total = sum(op_dict.values())
|
||||
if module_total > 0:
|
||||
flop_by_module[module_name] = module_total
|
||||
|
||||
total_counted_flops = flop_counter.get_total_flops()
|
||||
|
||||
# Collect matmul-like ops
|
||||
matmul_ops = []
|
||||
other_ops = []
|
||||
|
||||
for evt in events:
|
||||
if evt.device_time_total <= 0:
|
||||
continue
|
||||
name = evt.key
|
||||
is_matmul = any(k in name.lower() for k in
|
||||
['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear'])
|
||||
entry = {
|
||||
'name': name,
|
||||
'input_shapes': str(evt.input_shapes) if evt.input_shapes else '',
|
||||
'cuda_time_ms': evt.device_time_total / 1000.0,
|
||||
'count': evt.count,
|
||||
'flops': evt.flops if evt.flops else 0,
|
||||
}
|
||||
if is_matmul:
|
||||
matmul_ops.append(entry)
|
||||
else:
|
||||
other_ops.append(entry)
|
||||
|
||||
# Sort by CUDA time
|
||||
matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
||||
other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
||||
|
||||
total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops)
|
||||
total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops)
|
||||
# --- Print matmul ops ---
|
||||
print("=" * 130)
|
||||
print("MATMUL / LINEAR OPS (sorted by CUDA time)")
|
||||
print("=" * 130)
|
||||
print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
||||
print("-" * 130)
|
||||
|
||||
for op in matmul_ops:
|
||||
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
||||
print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
||||
|
||||
# --- Print top non-matmul ops ---
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("TOP NON-MATMUL OPS (sorted by CUDA time)")
|
||||
print("=" * 130)
|
||||
print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
||||
print("-" * 130)
|
||||
|
||||
for op in other_ops[:20]:
|
||||
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
||||
print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
||||
|
||||
# --- FlopCounterMode per-operator breakdown ---
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("FLOPS BY ATen OPERATOR (FlopCounterMode)")
|
||||
print("=" * 130)
|
||||
print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}")
|
||||
print("-" * 55)
|
||||
|
||||
sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True)
|
||||
for op_name, flops in sorted_flop_ops:
|
||||
gflops = flops / 1e9
|
||||
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
||||
print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%")
|
||||
|
||||
# --- FlopCounterMode per-module breakdown ---
|
||||
if flop_by_module:
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("FLOPS BY MODULE (FlopCounterMode)")
|
||||
print("=" * 130)
|
||||
print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}")
|
||||
print("-" * 90)
|
||||
|
||||
sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True)
|
||||
for mod_name, flops in sorted_modules[:30]:
|
||||
gflops = flops / 1e9
|
||||
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
||||
name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name
|
||||
print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%")
|
||||
|
||||
# --- Summary ---
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("SUMMARY")
|
||||
print("=" * 130)
|
||||
print(f" Total CUDA time: {total_cuda_ms:.1f} ms")
|
||||
print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)")
|
||||
print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)")
|
||||
print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS")
|
||||
if total_matmul_ms > 0 and total_counted_flops > 0:
|
||||
avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12
|
||||
avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100
|
||||
overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12
|
||||
overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100
|
||||
print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)")
|
||||
print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)")
|
||||
print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
patch_norm_bypass_autocast()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading model...")
|
||||
model = load_model(args)
|
||||
|
||||
noise_shape = [1, 4, 16, 40, 64]
|
||||
|
||||
print(f"Profiling UNet forward pass with shape {noise_shape}...")
|
||||
prof = profile_one_step(model, noise_shape)
|
||||
|
||||
print("Counting FLOPS with FlopCounterMode...")
|
||||
flop_counter = count_flops(model, noise_shape)
|
||||
|
||||
print_report(prof, flop_counter)
|
||||
@@ -1,4 +1,7 @@
|
||||
import argparse, os, glob
|
||||
from contextlib import nullcontext
|
||||
import atexit
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import pandas as pd
|
||||
import random
|
||||
import torch
|
||||
@@ -10,13 +13,15 @@ import einops
|
||||
import warnings
|
||||
import imageio
|
||||
|
||||
from typing import Optional, List, Any
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
from collections import OrderedDict
|
||||
from torch import nn
|
||||
from eval_utils import populate_queues, log_to_tensorboard
|
||||
from eval_utils import populate_queues
|
||||
from collections import deque
|
||||
from torch import Tensor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@@ -24,6 +29,105 @@ from PIL import Image
|
||||
|
||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# ========== Async I/O utilities ==========
|
||||
_io_executor: Optional[ThreadPoolExecutor] = None
|
||||
_io_futures: List[Any] = []
|
||||
|
||||
|
||||
def _get_io_executor() -> ThreadPoolExecutor:
|
||||
global _io_executor
|
||||
if _io_executor is None:
|
||||
_io_executor = ThreadPoolExecutor(max_workers=2)
|
||||
return _io_executor
|
||||
|
||||
|
||||
def _flush_io():
|
||||
"""Wait for all pending async I/O to finish."""
|
||||
global _io_futures
|
||||
for fut in _io_futures:
|
||||
try:
|
||||
fut.result()
|
||||
except Exception as e:
|
||||
print(f">>> [async I/O] error: {e}")
|
||||
_io_futures.clear()
|
||||
|
||||
|
||||
atexit.register(_flush_io)
|
||||
|
||||
|
||||
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
||||
"""Synchronous save on CPU tensor (runs in background thread)."""
|
||||
video = torch.clamp(video_cpu.float(), -1., 1.)
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||
torchvision.io.write_video(filename,
|
||||
grid,
|
||||
fps=fps,
|
||||
video_codec='h264',
|
||||
options={'crf': '10'})
|
||||
|
||||
|
||||
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
||||
"""Submit video saving to background thread pool."""
|
||||
video_cpu = video.detach().cpu()
|
||||
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def _log_to_tb_sync(video_cpu: Tensor, writer: SummaryWriter, tag: str, fps: int) -> None:
|
||||
"""Synchronous tensorboard logging on CPU tensor (runs in background thread)."""
|
||||
video = video_cpu.float()
|
||||
n = video.shape[0]
|
||||
video = video.permute(2, 0, 1, 3, 4)
|
||||
frame_grids = [
|
||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||
for framesheet in video
|
||||
]
|
||||
grid = torch.stack(frame_grids, dim=0)
|
||||
grid = (grid + 1.0) / 2.0
|
||||
grid = grid.unsqueeze(dim=0)
|
||||
writer.add_video(tag, grid, fps=fps)
|
||||
|
||||
|
||||
def log_to_tensorboard_async(writer: SummaryWriter, video: Tensor, tag: str, fps: int = 10) -> None:
|
||||
"""Submit tensorboard logging to background thread pool."""
|
||||
video_cpu = video.detach().cpu()
|
||||
fut = _get_io_executor().submit(_log_to_tb_sync, video_cpu, writer, tag, fps)
|
||||
_io_futures.append(fut)
|
||||
|
||||
|
||||
def patch_norm_bypass_autocast():
|
||||
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
|
||||
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
|
||||
|
||||
def _group_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(
|
||||
x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
def _layer_norm_forward(self, x):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.layer_norm(
|
||||
x, self.normalized_shape,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||
|
||||
|
||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
@@ -38,6 +142,92 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||
return next(iter(module.parameters())).device
|
||||
|
||||
|
||||
def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module:
|
||||
"""Apply precision settings to model components based on command-line arguments.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to apply precision settings to.
|
||||
args (argparse.Namespace): Parsed command-line arguments containing precision settings.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with precision settings applied.
|
||||
"""
|
||||
print(f">>> Applying precision settings:")
|
||||
print(f" - Diffusion dtype: {args.diffusion_dtype}")
|
||||
print(f" - Projector mode: {args.projector_mode}")
|
||||
print(f" - Encoder mode: {args.encoder_mode}")
|
||||
print(f" - VAE dtype: {args.vae_dtype}")
|
||||
|
||||
# 1. Set Diffusion backbone precision
|
||||
if args.diffusion_dtype == "bf16":
|
||||
# Convert diffusion model weights to bf16
|
||||
model.model.to(torch.bfloat16)
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model weights converted to bfloat16")
|
||||
else:
|
||||
model.diffusion_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Diffusion model using fp32")
|
||||
|
||||
# 2. Set Projector precision
|
||||
if args.projector_mode == "bf16_full":
|
||||
model.state_projector.to(torch.bfloat16)
|
||||
model.action_projector.to(torch.bfloat16)
|
||||
model.projector_autocast_dtype = None
|
||||
print(" ✓ Projectors converted to bfloat16")
|
||||
elif args.projector_mode == "autocast":
|
||||
model.projector_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Projectors will use autocast (weights fp32, compute bf16)")
|
||||
else:
|
||||
model.projector_autocast_dtype = None
|
||||
# fp32 mode: do nothing, keep original precision
|
||||
|
||||
# 3. Set Encoder precision
|
||||
if args.encoder_mode == "bf16_full":
|
||||
model.embedder.to(torch.bfloat16)
|
||||
model.image_proj_model.to(torch.bfloat16)
|
||||
model.encoder_autocast_dtype = None
|
||||
print(" ✓ Encoders converted to bfloat16")
|
||||
elif args.encoder_mode == "autocast":
|
||||
model.encoder_autocast_dtype = torch.bfloat16
|
||||
print(" ✓ Encoders will use autocast (weights fp32, compute bf16)")
|
||||
else:
|
||||
model.encoder_autocast_dtype = None
|
||||
# fp32 mode: do nothing, keep original precision
|
||||
|
||||
# 4. Set VAE precision
|
||||
if args.vae_dtype == "bf16":
|
||||
model.first_stage_model.to(torch.bfloat16)
|
||||
print(" ✓ VAE converted to bfloat16")
|
||||
else:
|
||||
print(" ✓ VAE kept in fp32 for best quality")
|
||||
|
||||
# 5. Safety net: ensure no fp32 parameters remain when all components are bf16
|
||||
if args.diffusion_dtype == "bf16":
|
||||
fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32]
|
||||
if fp32_params:
|
||||
print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16")
|
||||
for name, param in fp32_params:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
print(" ✓ All parameters converted to bfloat16")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||
unet = model.model.diffusion_model
|
||||
compiled = 0
|
||||
for idx in hot_indices:
|
||||
block = unet.output_blocks[idx]
|
||||
for layer in block:
|
||||
if isinstance(layer, ResBlock):
|
||||
layer._forward = torch.compile(layer._forward, mode="default")
|
||||
compiled += 1
|
||||
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||
return model
|
||||
|
||||
|
||||
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
||||
"""Save a list of frames to a video file.
|
||||
|
||||
@@ -73,17 +263,18 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||
return file_list
|
||||
|
||||
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
|
||||
def load_model_checkpoint(model: nn.Module, ckpt: str, device: str = "cpu") -> nn.Module:
|
||||
"""Load model weights from checkpoint file.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Model instance.
|
||||
ckpt (str): Path to the checkpoint file.
|
||||
device (str): Target device for loaded tensors.
|
||||
|
||||
Returns:
|
||||
nn.Module: Model with loaded weights.
|
||||
"""
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
state_dict = torch.load(ckpt, map_location=device)
|
||||
if "state_dict" in list(state_dict.keys()):
|
||||
state_dict = state_dict["state_dict"]
|
||||
try:
|
||||
@@ -262,6 +453,11 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
||||
"""
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
|
||||
# Auto-detect VAE dtype and convert input
|
||||
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||
x = x.to(dtype=vae_dtype)
|
||||
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
@@ -327,7 +523,8 @@ def image_guided_synthesis_sim_mode(
|
||||
timestep_spacing: str = 'uniform',
|
||||
guidance_rescale: float = 0.0,
|
||||
sim_mode: bool = True,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
decode_video: bool = True,
|
||||
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
|
||||
|
||||
@@ -350,10 +547,13 @@ def image_guided_synthesis_sim_mode(
|
||||
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
|
||||
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
|
||||
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
|
||||
decode_video (bool): Whether to decode latent samples to pixel-space video.
|
||||
Set to False to skip VAE decode for speed when only actions/states are needed.
|
||||
**kwargs: Additional arguments passed to the DDIM sampler.
|
||||
|
||||
Returns:
|
||||
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
|
||||
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
|
||||
or None when decode_video=False.
|
||||
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
|
||||
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
|
||||
"""
|
||||
@@ -363,10 +563,22 @@ def image_guided_synthesis_sim_mode(
|
||||
|
||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||
|
||||
# Auto-detect model dtype and convert inputs accordingly
|
||||
model_dtype = next(model.embedder.parameters()).dtype
|
||||
|
||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
||||
|
||||
# Encoder autocast: weights stay fp32, compute in bf16
|
||||
enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None)
|
||||
if enc_ac_dtype is not None and model.device.type == 'cuda':
|
||||
enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype)
|
||||
else:
|
||||
enc_ctx = nullcontext()
|
||||
|
||||
with enc_ctx:
|
||||
cond_img_emb = model.embedder(cond_img)
|
||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||
|
||||
if model.model.conditioning_key == 'hybrid':
|
||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||
@@ -380,11 +592,22 @@ def image_guided_synthesis_sim_mode(
|
||||
prompts = [""] * batch_size
|
||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||
|
||||
cond_state_emb = model.state_projector(observation['observation.state'])
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
# Auto-detect projector dtype and convert inputs
|
||||
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||
|
||||
cond_action_emb = model.action_projector(observation['action'])
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
# Projector autocast: weights stay fp32, compute in bf16
|
||||
proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None)
|
||||
if proj_ac_dtype is not None and model.device.type == 'cuda':
|
||||
proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype)
|
||||
else:
|
||||
proj_ctx = nullcontext()
|
||||
|
||||
with proj_ctx:
|
||||
cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype))
|
||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||
|
||||
cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype))
|
||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||
|
||||
if not sim_mode:
|
||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||
@@ -406,8 +629,18 @@ def image_guided_synthesis_sim_mode(
|
||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||
cond_mask = None
|
||||
cond_z0 = None
|
||||
|
||||
# Setup autocast context for diffusion sampling
|
||||
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
||||
if autocast_dtype is not None and model.device.type == 'cuda':
|
||||
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
||||
else:
|
||||
autocast_ctx = nullcontext()
|
||||
|
||||
batch_variants = None
|
||||
if ddim_sampler is not None:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
with autocast_ctx:
|
||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||
S=ddim_steps,
|
||||
conditioning=cond,
|
||||
batch_size=batch_size,
|
||||
@@ -424,9 +657,10 @@ def image_guided_synthesis_sim_mode(
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
if decode_video:
|
||||
# Reconstruct from latent to pixel space
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
|
||||
@@ -455,22 +689,63 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
# Load config
|
||||
config = OmegaConf.load(args.config)
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path)
|
||||
model.eval()
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Build unnomalizer
|
||||
prepared_path = args.ckpt_path + ".prepared.pt"
|
||||
if os.path.exists(prepared_path):
|
||||
# ---- Fast path: load the fully-prepared model ----
|
||||
print(f">>> Loading prepared model from {prepared_path} ...")
|
||||
model = torch.load(prepared_path,
|
||||
map_location=f"cuda:{gpu_no}",
|
||||
weights_only=False)
|
||||
model.eval()
|
||||
|
||||
# Restore autocast attributes (weights already cast, just need contexts)
|
||||
model.diffusion_autocast_dtype = torch.bfloat16 if args.diffusion_dtype == "bf16" else torch.bfloat16
|
||||
model.projector_autocast_dtype = torch.bfloat16 if args.projector_mode == "autocast" else None
|
||||
model.encoder_autocast_dtype = torch.bfloat16 if args.encoder_mode == "autocast" else None
|
||||
|
||||
# Compile hot ResBlocks for operator fusion
|
||||
apply_torch_compile(model)
|
||||
|
||||
print(f">>> Prepared model loaded.")
|
||||
else:
|
||||
# ---- Normal path: construct + checkpoint + casting ----
|
||||
config['model']['params']['wma_config']['params'][
|
||||
'use_checkpoint'] = False
|
||||
model = instantiate_from_config(config.model)
|
||||
model.perframe_ae = args.perframe_ae
|
||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||
model = load_model_checkpoint(model, args.ckpt_path,
|
||||
device=f"cuda:{gpu_no}")
|
||||
model.eval()
|
||||
print(f'>>> Load pre-trained model ...')
|
||||
|
||||
# Apply precision settings before moving to GPU
|
||||
model = apply_precision_settings(model, args)
|
||||
|
||||
# Export precision-converted checkpoint if requested
|
||||
if args.export_precision_ckpt:
|
||||
export_path = args.export_precision_ckpt
|
||||
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
|
||||
torch.save({"state_dict": model.state_dict()}, export_path)
|
||||
print(f">>> Precision-converted checkpoint saved to: {export_path}")
|
||||
return
|
||||
|
||||
model = model.cuda(gpu_no)
|
||||
|
||||
# Save prepared model for fast loading next time (before torch.compile)
|
||||
print(f">>> Saving prepared model to {prepared_path} ...")
|
||||
torch.save(model, prepared_path)
|
||||
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
||||
|
||||
# Compile hot ResBlocks for operator fusion (after save, compiled objects can't be pickled)
|
||||
apply_torch_compile(model)
|
||||
|
||||
# Build normalizer (always needed, independent of model loading path)
|
||||
logging.info("***** Configing Data *****")
|
||||
data = instantiate_from_config(config.data)
|
||||
data.setup()
|
||||
print(">>> Dataset is successfully loaded ...")
|
||||
|
||||
model = model.cuda(gpu_no)
|
||||
device = get_device_from_parameters(model)
|
||||
|
||||
# Run over data
|
||||
@@ -587,7 +862,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
fs=model_input_fs,
|
||||
timestep_spacing=args.timestep_spacing,
|
||||
guidance_rescale=args.guidance_rescale,
|
||||
sim_mode=False)
|
||||
sim_mode=False,
|
||||
decode_video=not args.fast_policy_no_decode)
|
||||
|
||||
# Update future actions in the observation queues
|
||||
for idx in range(len(pred_actions[0])):
|
||||
@@ -645,28 +921,30 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
observation)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
if pred_videos_0 is not None:
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_0,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||
log_to_tensorboard(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
pred_videos_1,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
|
||||
# Save the imagen videos for decision-making
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_0.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
if pred_videos_0 is not None:
|
||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||
save_results_async(pred_videos_0,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
# Save videos environment changes via world-model interaction
|
||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||
save_results(pred_videos_1.cpu(),
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
save_results_async(pred_videos_1,
|
||||
sample_video_file,
|
||||
fps=args.save_fps)
|
||||
|
||||
print('>' * 24)
|
||||
# Collect the result of world-model interactions
|
||||
@@ -674,12 +952,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
|
||||
full_video = torch.cat(wm_video, dim=2)
|
||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||
log_to_tensorboard(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
log_to_tensorboard_async(writer,
|
||||
full_video,
|
||||
sample_tag,
|
||||
fps=args.save_fps)
|
||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
||||
|
||||
# Wait for all async I/O to complete
|
||||
_flush_io()
|
||||
|
||||
|
||||
def get_parser():
|
||||
@@ -794,14 +1075,49 @@ def get_parser():
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="not using the predicted states as comparison")
|
||||
parser.add_argument(
|
||||
"--fast_policy_no_decode",
|
||||
action='store_true',
|
||||
default=False,
|
||||
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
|
||||
parser.add_argument("--save_fps",
|
||||
type=int,
|
||||
default=8,
|
||||
help="fps for the saving video")
|
||||
parser.add_argument(
|
||||
"--diffusion_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="bf16",
|
||||
help="Diffusion backbone precision (fp32/bf16)")
|
||||
parser.add_argument(
|
||||
"--projector_mode",
|
||||
type=str,
|
||||
choices=["fp32", "autocast", "bf16_full"],
|
||||
default="bf16_full",
|
||||
help="Projector precision mode (fp32/autocast/bf16_full)")
|
||||
parser.add_argument(
|
||||
"--encoder_mode",
|
||||
type=str,
|
||||
choices=["fp32", "autocast", "bf16_full"],
|
||||
default="bf16_full",
|
||||
help="Encoder precision mode (fp32/autocast/bf16_full)")
|
||||
parser.add_argument(
|
||||
"--vae_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="fp32",
|
||||
help="VAE precision (fp32/bf16, most affects image quality)")
|
||||
parser.add_argument(
|
||||
"--export_precision_ckpt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Export precision-converted checkpoint to this path, then exit.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
patch_norm_bypass_autocast()
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
seed = args.seed
|
||||
|
||||
@@ -99,6 +99,8 @@ class AutoencoderKL(pl.LightningModule):
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if getattr(self, '_channels_last', False):
|
||||
x = x.to(memory_format=torch.channels_last)
|
||||
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
@@ -106,6 +108,8 @@ class AutoencoderKL(pl.LightningModule):
|
||||
return posterior
|
||||
|
||||
def decode(self, z, **kwargs):
|
||||
if getattr(self, '_channels_last', False):
|
||||
z = z.to(memory_format=torch.channels_last)
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
@@ -1074,10 +1074,10 @@ class LatentDiffusion(DDPM):
|
||||
encoder_posterior = self.first_stage_model.encode(x)
|
||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||
else: ## Consume less GPU memory but slower
|
||||
bs = getattr(self, 'vae_encode_bs', 1)
|
||||
results = []
|
||||
for index in range(x.shape[0]):
|
||||
frame_batch = self.first_stage_model.encode(x[index:index +
|
||||
1, :, :, :])
|
||||
for i in range(0, x.shape[0], bs):
|
||||
frame_batch = self.first_stage_model.encode(x[i:i + bs])
|
||||
frame_result = self.get_first_stage_encoding(
|
||||
frame_batch).detach()
|
||||
results.append(frame_result)
|
||||
@@ -1105,14 +1105,18 @@ class LatentDiffusion(DDPM):
|
||||
else:
|
||||
reshape_back = False
|
||||
|
||||
# Align input dtype with VAE weights (e.g. fp32 samples → bf16 VAE)
|
||||
vae_dtype = next(self.first_stage_model.parameters()).dtype
|
||||
z = z.to(dtype=vae_dtype)
|
||||
|
||||
z = 1. / self.scale_factor * z
|
||||
if not self.perframe_ae:
|
||||
z = 1. / self.scale_factor * z
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
else:
|
||||
bs = getattr(self, 'vae_decode_bs', 1)
|
||||
results = []
|
||||
for index in range(z.shape[0]):
|
||||
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
|
||||
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
|
||||
for i in range(0, z.shape[0], bs):
|
||||
frame_result = self.first_stage_model.decode(z[i:i + bs], **kwargs)
|
||||
results.append(frame_result)
|
||||
results = torch.cat(results, dim=0)
|
||||
|
||||
@@ -1799,7 +1803,9 @@ class LatentDiffusion(DDPM):
|
||||
|
||||
"""
|
||||
if ddim:
|
||||
ddim_sampler = DDIMSampler(self)
|
||||
if not hasattr(self, '_ddim_sampler') or self._ddim_sampler is None:
|
||||
self._ddim_sampler = DDIMSampler(self)
|
||||
ddim_sampler = self._ddim_sampler
|
||||
shape = (self.channels, self.temporal_length, *self.image_size)
|
||||
samples, actions, states, intermediates = ddim_sampler.sample(
|
||||
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
||||
@@ -2457,7 +2463,6 @@ class DiffusionWrapper(pl.LightningModule):
|
||||
Returns:
|
||||
Output from the inner diffusion model (tensor or tuple, depending on the model).
|
||||
"""
|
||||
|
||||
if self.conditioning_key is None:
|
||||
out = self.diffusion_model(x, t)
|
||||
elif self.conditioning_key == 'concat':
|
||||
|
||||
@@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module):
|
||||
self.last_frame_only = last_frame_only
|
||||
self.horizon = horizon
|
||||
|
||||
# Context precomputation cache
|
||||
self._global_cond_cache_enabled = False
|
||||
self._global_cond_cache = {}
|
||||
|
||||
def forward(self,
|
||||
sample: torch.Tensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
@@ -530,14 +534,20 @@ class ConditionalUnet1D(nn.Module):
|
||||
B, T, D = sample.shape
|
||||
if self.use_linear_act_proj:
|
||||
sample = self.proj_in_action(sample.unsqueeze(-1))
|
||||
global_cond = self.obs_encoder(cond)
|
||||
global_cond = rearrange(global_cond,
|
||||
'(b t) d -> b 1 (t d)',
|
||||
b=B,
|
||||
t=self.n_obs_steps)
|
||||
global_cond = repeat(global_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=T)
|
||||
_gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr())
|
||||
if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache:
|
||||
global_cond = self._global_cond_cache[_gc_key]
|
||||
else:
|
||||
global_cond = self.obs_encoder(cond)
|
||||
global_cond = rearrange(global_cond,
|
||||
'(b t) d -> b 1 (t d)',
|
||||
b=B,
|
||||
t=self.n_obs_steps)
|
||||
global_cond = repeat(global_cond,
|
||||
'b c d -> b (repeat c) d',
|
||||
repeat=T)
|
||||
if self._global_cond_cache_enabled:
|
||||
self._global_cond_cache[_gc_key] = global_cond
|
||||
else:
|
||||
sample = einops.rearrange(sample, 'b h t -> b t h')
|
||||
sample = self.proj_in_horizon(sample)
|
||||
|
||||
@@ -8,12 +8,14 @@ class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
# Dummy buffer so .to(dtype) propagates to this module
|
||||
self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = x.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
return emb.to(self._dtype_buf.dtype)
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(python3:*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,8 @@ from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim
|
||||
from unifolm_wma.utils.common import noise_like
|
||||
from unifolm_wma.utils.common import extract_into_tensor
|
||||
from tqdm import tqdm
|
||||
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
|
||||
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
@@ -16,6 +18,7 @@ class DDIMSampler(object):
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.counter = 0
|
||||
self._schedule_key = None # (ddim_num_steps, ddim_discretize, ddim_eta)
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
@@ -28,6 +31,11 @@ class DDIMSampler(object):
|
||||
ddim_discretize="uniform",
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
key = (ddim_num_steps, ddim_discretize, ddim_eta)
|
||||
if self._schedule_key == key:
|
||||
return
|
||||
self._schedule_key = key
|
||||
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
@@ -36,7 +44,7 @@ class DDIMSampler(object):
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model
|
||||
.device)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
@@ -67,11 +75,12 @@ class DDIMSampler(object):
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
# Ensure tensors are on correct device for efficient indexing
|
||||
self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
|
||||
self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
|
||||
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
@@ -208,9 +217,9 @@ class DDIMSampler(object):
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
img = img.to(dtype=torch.float16)
|
||||
action = action.to(dtype=torch.float16)
|
||||
state = state.to(dtype=torch.float16)
|
||||
img = img.to(dtype=torch.bfloat16)
|
||||
action = action.to(dtype=torch.bfloat16)
|
||||
state = state.to(dtype=torch.bfloat16)
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
@@ -241,63 +250,70 @@ class DDIMSampler(object):
|
||||
|
||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||
enable_cross_attn_kv_cache(self.model)
|
||||
enable_ctx_cache(self.model)
|
||||
try:
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts.fill_(step)
|
||||
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
if clean_cond:
|
||||
img_orig = x0
|
||||
else:
|
||||
img_orig = self.model.q_sample(x0, ts)
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
action,
|
||||
state,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
fs=fs,
|
||||
guidance_rescale=guidance_rescale,
|
||||
**kwargs)
|
||||
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
img, pred_x0, model_output_action, model_output_state = outs
|
||||
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
model_output_state,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
action = dp_ddim_scheduler_action.step(
|
||||
model_output_action,
|
||||
step,
|
||||
action,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
state = dp_ddim_scheduler_state.step(
|
||||
model_output_state,
|
||||
step,
|
||||
state,
|
||||
generator=None,
|
||||
).prev_sample
|
||||
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
if callback: callback(i)
|
||||
if img_callback: img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
intermediates['x_inter_action'].append(action)
|
||||
intermediates['x_inter_state'].append(state)
|
||||
finally:
|
||||
disable_cross_attn_kv_cache(self.model)
|
||||
disable_ctx_cache(self.model)
|
||||
|
||||
return img, action, state, intermediates
|
||||
|
||||
@@ -325,10 +341,6 @@ class DDIMSampler(object):
|
||||
guidance_rescale=0.0,
|
||||
**kwargs):
|
||||
b, *_, device = *x.shape, x.device
|
||||
if x.dim() == 5:
|
||||
is_video = True
|
||||
else:
|
||||
is_video = False
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output, model_output_action, model_output_state = self.model.apply_model(
|
||||
@@ -377,17 +389,11 @@ class DDIMSampler(object):
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if is_video:
|
||||
size = (b, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (b, 1, 1, 1)
|
||||
|
||||
a_t = torch.full(size, alphas[index], device=device)
|
||||
a_prev = torch.full(size, alphas_prev[index], device=device)
|
||||
sigma_t = torch.full(size, sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full(size,
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
||||
a_t = alphas[index].to(x.dtype)
|
||||
a_prev = alphas_prev[index].to(x.dtype)
|
||||
sigma_t = sigmas[index].to(x.dtype)
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
@@ -395,12 +401,8 @@ class DDIMSampler(object):
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if self.model.use_dynamic_rescale:
|
||||
scale_t = torch.full(size,
|
||||
self.ddim_scale_arr[index],
|
||||
device=device)
|
||||
prev_scale_t = torch.full(size,
|
||||
self.ddim_scale_arr_prev[index],
|
||||
device=device)
|
||||
scale_t = self.ddim_scale_arr[index]
|
||||
prev_scale_t = self.ddim_scale_arr_prev[index]
|
||||
rescale = (prev_scale_t / scale_t)
|
||||
pred_x0 *= rescale
|
||||
|
||||
|
||||
@@ -86,9 +86,8 @@ class CrossAttention(nn.Module):
|
||||
self.relative_position_v = RelativePosition(
|
||||
num_units=dim_head, max_relative_position=temporal_length)
|
||||
else:
|
||||
## only used for spatial attention, while NOT for temporal attention
|
||||
if XFORMERS_IS_AVAILBLE and temporal_length is None:
|
||||
self.forward = self.efficient_forward
|
||||
## bmm fused-scale attention for all non-relative-position cases
|
||||
self.forward = self.bmm_forward
|
||||
|
||||
self.video_length = video_length
|
||||
self.image_cross_attention = image_cross_attention
|
||||
@@ -98,6 +97,9 @@ class CrossAttention(nn.Module):
|
||||
self.text_context_len = text_context_len
|
||||
self.agent_state_context_len = agent_state_context_len
|
||||
self.agent_action_context_len = agent_action_context_len
|
||||
self._kv_cache = {}
|
||||
self._kv_cache_enabled = False
|
||||
|
||||
self.cross_attention_scale_learnable = cross_attention_scale_learnable
|
||||
if self.image_cross_attention:
|
||||
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
@@ -125,7 +127,7 @@ class CrossAttention(nn.Module):
|
||||
context = default(context, x)
|
||||
|
||||
if self.image_cross_attention and not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
|
||||
# assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:,
|
||||
self.agent_state_context_len:self.
|
||||
@@ -173,7 +175,8 @@ class CrossAttention(nn.Module):
|
||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
||||
if self.relative_position:
|
||||
@@ -190,7 +193,8 @@ class CrossAttention(nn.Module):
|
||||
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_ip) * self.scale
|
||||
del k_ip
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
@@ -201,7 +205,8 @@ class CrossAttention(nn.Module):
|
||||
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_as) * self.scale
|
||||
del k_as
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
@@ -212,7 +217,8 @@ class CrossAttention(nn.Module):
|
||||
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
||||
k_aa) * self.scale
|
||||
del k_aa
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
@@ -230,6 +236,141 @@ class CrossAttention(nn.Module):
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def bmm_forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k_ip, v_ip, out_ip = None, None, None
|
||||
k_as, v_as, out_as = None, None, None
|
||||
k_aa, v_aa, out_aa = None, None, None
|
||||
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
|
||||
use_cache = self._kv_cache_enabled and not spatial_self_attn
|
||||
cache_hit = use_cache and len(self._kv_cache) > 0
|
||||
|
||||
if cache_hit:
|
||||
# Reuse cached K/V (already in (b*h, n, d) shape)
|
||||
k = self._kv_cache['k']
|
||||
v = self._kv_cache['v']
|
||||
if 'k_ip' in self._kv_cache:
|
||||
k_ip = self._kv_cache['k_ip']
|
||||
v_ip = self._kv_cache['v_ip']
|
||||
k_as = self._kv_cache['k_as']
|
||||
v_as = self._kv_cache['v_as']
|
||||
k_aa = self._kv_cache['k_aa']
|
||||
v_aa = self._kv_cache['v_aa']
|
||||
q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
|
||||
elif self.image_cross_attention and not spatial_self_attn:
|
||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||
context_agent_action = context[:,
|
||||
self.agent_state_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len, :]
|
||||
context_ins = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len:self.
|
||||
agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len, :]
|
||||
context_image = context[:, self.agent_state_context_len +
|
||||
self.agent_action_context_len +
|
||||
self.text_context_len:, :]
|
||||
|
||||
k = self.to_k(context_ins)
|
||||
v = self.to_v(context_ins)
|
||||
k_ip = self.to_k_ip(context_image)
|
||||
v_ip = self.to_v_ip(context_image)
|
||||
k_as = self.to_k_as(context_agent_state)
|
||||
v_as = self.to_v_as(context_agent_state)
|
||||
k_aa = self.to_k_aa(context_agent_action)
|
||||
v_aa = self.to_v_aa(context_agent_action)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_ip, v_ip))
|
||||
k_as, v_as = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_as, v_as))
|
||||
k_aa, v_aa = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(k_aa, v_aa))
|
||||
|
||||
if use_cache:
|
||||
self._kv_cache = {
|
||||
'k': k, 'v': v,
|
||||
'k_ip': k_ip, 'v_ip': v_ip,
|
||||
'k_as': k_as, 'v_as': v_as,
|
||||
'k_aa': k_aa, 'v_aa': v_aa,
|
||||
}
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
context = context[:, :self.text_context_len, :]
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
if use_cache:
|
||||
self._kv_cache = {'k': k, 'v': v}
|
||||
|
||||
# baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul
|
||||
sim = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
|
||||
if exists(mask):
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
|
||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = torch.bmm(sim, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if k_ip is not None and k_as is not None and k_aa is not None:
|
||||
## image cross-attention (k_ip/v_ip already in (b*h, n, d) shape)
|
||||
sim_ip = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_ip = sim_ip.softmax(dim=-1)
|
||||
out_ip = torch.bmm(sim_ip, v_ip)
|
||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## agent state cross-attention (k_as/v_as already in (b*h, n, d) shape)
|
||||
sim_as = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k_as.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_as = sim_as.softmax(dim=-1)
|
||||
out_as = torch.bmm(sim_as, v_as)
|
||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
## agent action cross-attention (k_aa/v_aa already in (b*h, n, d) shape)
|
||||
sim_aa = torch.baddbmm(
|
||||
torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device),
|
||||
q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
sim_aa = sim_aa.softmax(dim=-1)
|
||||
out_aa = torch.bmm(sim_aa, v_aa)
|
||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||
|
||||
if out_ip is not None and out_as is not None and out_aa is not None:
|
||||
if self.cross_attention_scale_learnable:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
||||
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
||||
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
||||
else:
|
||||
out = out + \
|
||||
self.image_cross_attention_scale * out_ip + \
|
||||
self.agent_state_cross_attention_scale * out_as + \
|
||||
self.agent_action_cross_attention_scale * out_aa
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def efficient_forward(self, x, context=None, mask=None):
|
||||
spatial_self_attn = (context is None)
|
||||
k, v, out = None, None, None
|
||||
@@ -275,7 +416,8 @@ class CrossAttention(nn.Module):
|
||||
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
||||
q.shape[1],
|
||||
k_aa.shape[1],
|
||||
block_size=16).to(k_aa.device)
|
||||
block_size=16,
|
||||
device=k_aa.device)
|
||||
else:
|
||||
if not spatial_self_attn:
|
||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||
@@ -386,17 +528,43 @@ class CrossAttention(nn.Module):
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
|
||||
def _get_attn_mask_aa(self, b, l1, l2, block_size=16, device=None):
|
||||
cache_key = (b, l1, l2, block_size)
|
||||
if hasattr(self, '_attn_mask_aa_cache_key') and self._attn_mask_aa_cache_key == cache_key:
|
||||
cached = self._attn_mask_aa_cache
|
||||
if device is not None and cached.device != torch.device(device):
|
||||
cached = cached.to(device)
|
||||
self._attn_mask_aa_cache = cached
|
||||
return cached
|
||||
|
||||
target_device = device if device is not None else 'cpu'
|
||||
num_token = l2 // block_size
|
||||
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2)
|
||||
start_positions = ((torch.arange(b, device=target_device) % block_size) + 1) * num_token
|
||||
col_indices = torch.arange(l2, device=target_device)
|
||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||
attn_mask = torch.zeros_like(mask, dtype=torch.float)
|
||||
attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device)
|
||||
attn_mask[mask] = float('-inf')
|
||||
|
||||
self._attn_mask_aa_cache_key = cache_key
|
||||
self._attn_mask_aa_cache = attn_mask
|
||||
return attn_mask
|
||||
|
||||
|
||||
def enable_cross_attn_kv_cache(module):
|
||||
for m in module.modules():
|
||||
if isinstance(m, CrossAttention):
|
||||
m._kv_cache_enabled = True
|
||||
m._kv_cache = {}
|
||||
|
||||
|
||||
def disable_cross_attn_kv_cache(module):
|
||||
for m in module.modules():
|
||||
if isinstance(m, CrossAttention):
|
||||
m._kv_cache_enabled = False
|
||||
m._kv_cache = {}
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
@@ -11,7 +11,7 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@@ -422,7 +422,7 @@ class WMAModel(nn.Module):
|
||||
self.temporal_attention = temporal_attention
|
||||
time_embed_dim = model_channels * 4
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.dtype = torch.float16 if use_fp16 else torch.bfloat16
|
||||
temporal_self_att_only = True
|
||||
self.addition_attention = addition_attention
|
||||
self.temporal_length = temporal_length
|
||||
@@ -685,6 +685,12 @@ class WMAModel(nn.Module):
|
||||
self.action_token_projector = instantiate_from_config(
|
||||
stem_process_config)
|
||||
|
||||
# Context precomputation cache
|
||||
self._ctx_cache_enabled = False
|
||||
self._ctx_cache = {}
|
||||
# fs_embed cache
|
||||
self._fs_embed_cache = None
|
||||
|
||||
def forward(self,
|
||||
x: Tensor,
|
||||
x_action: Tensor,
|
||||
@@ -720,58 +726,64 @@ class WMAModel(nn.Module):
|
||||
repeat_only=False).type(x.dtype)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
_ctx_key = context.data_ptr()
|
||||
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
|
||||
context = self._ctx_cache[_ctx_key]
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
bt, l_context, _ = context.shape
|
||||
if self.base_model_gen_only:
|
||||
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
|
||||
else:
|
||||
if l_context == self.n_obs_steps + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
|
||||
77, :]
|
||||
context_img = context[:, self.n_obs_steps + 77:, :]
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context = torch.cat(
|
||||
[context_agent_state, context_text, context_img], dim=1)
|
||||
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
|
||||
context_agent_state = context[:, :self.n_obs_steps]
|
||||
context_agent_action = context[:, self.
|
||||
n_obs_steps:self.n_obs_steps +
|
||||
16, :]
|
||||
context_agent_action = rearrange(
|
||||
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
|
||||
context_agent_action = self.action_token_projector(
|
||||
context_agent_action)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'(b o) l d -> b o l d',
|
||||
o=t)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b o (t l) d -> b o t l d',
|
||||
t=t)
|
||||
context_agent_action = context_agent_action.permute(
|
||||
0, 2, 1, 3, 4)
|
||||
context_agent_action = rearrange(context_agent_action,
|
||||
'b t o l d -> (b t) (o l) d')
|
||||
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
context_text = context[:, self.n_obs_steps +
|
||||
16:self.n_obs_steps + 16 + 77, :]
|
||||
context_text = context_text.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
|
||||
context_img = rearrange(context_img,
|
||||
'b (t l) c -> (b t) l c',
|
||||
t=t)
|
||||
context_agent_state = context_agent_state.repeat_interleave(
|
||||
repeats=t, dim=0)
|
||||
context = torch.cat([
|
||||
context_agent_state, context_agent_action, context_text,
|
||||
context_img
|
||||
],
|
||||
dim=1)
|
||||
if self._ctx_cache_enabled:
|
||||
self._ctx_cache[_ctx_key] = context
|
||||
|
||||
emb = emb.repeat_interleave(repeats=t, dim=0)
|
||||
|
||||
@@ -779,16 +791,20 @@ class WMAModel(nn.Module):
|
||||
|
||||
# Combine emb
|
||||
if self.fs_condition:
|
||||
if fs is None:
|
||||
fs = torch.tensor([self.default_fs] * b,
|
||||
dtype=torch.long,
|
||||
device=x.device)
|
||||
fs_emb = timestep_embedding(fs,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
|
||||
fs_embed = self.fps_embedding(fs_emb)
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
if self._ctx_cache_enabled and self._fs_embed_cache is not None:
|
||||
fs_embed = self._fs_embed_cache
|
||||
else:
|
||||
if fs is None:
|
||||
fs = torch.tensor([self.default_fs] * b,
|
||||
dtype=torch.long,
|
||||
device=x.device)
|
||||
fs_emb = timestep_embedding(fs,
|
||||
self.model_channels,
|
||||
repeat_only=False).type(x.dtype)
|
||||
fs_embed = self.fps_embedding(fs_emb)
|
||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||
if self._ctx_cache_enabled:
|
||||
self._fs_embed_cache = fs_embed
|
||||
emb = emb + fs_embed
|
||||
|
||||
h = x.type(self.dtype)
|
||||
@@ -846,3 +862,32 @@ class WMAModel(nn.Module):
|
||||
s_y = torch.zeros_like(x_state)
|
||||
|
||||
return y, a_y, s_y
|
||||
|
||||
|
||||
def enable_ctx_cache(model):
|
||||
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = True
|
||||
m._ctx_cache = {}
|
||||
m._fs_embed_cache = None
|
||||
# conditional_unet1d cache
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
if isinstance(m, ConditionalUnet1D):
|
||||
m._global_cond_cache_enabled = True
|
||||
m._global_cond_cache = {}
|
||||
|
||||
|
||||
def disable_ctx_cache(model):
|
||||
"""Disable and clear context precomputation cache."""
|
||||
for m in model.modules():
|
||||
if isinstance(m, WMAModel):
|
||||
m._ctx_cache_enabled = False
|
||||
m._ctx_cache = {}
|
||||
m._fs_embed_cache = None
|
||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||
for m in model.modules():
|
||||
if isinstance(m, ConditionalUnet1D):
|
||||
m._global_cond_cache_enabled = False
|
||||
m._global_cond_cache = {}
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
#
|
||||
# thanks!
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
@@ -78,7 +80,11 @@ def nonlinearity(type='silu'):
|
||||
class GroupNormSpecific(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
return F.group_norm(x, self.num_groups,
|
||||
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||
self.eps)
|
||||
|
||||
|
||||
def normalization(channels, num_groups=32):
|
||||
|
||||
144
unitree_g1_pack_camera/case1/output.log
Normal file
144
unitree_g1_pack_camera/case1/output.log
Normal file
@@ -0,0 +1,144 @@
|
||||
2026-02-08 05:20:49.828675: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 05:20:49.831563: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 05:20:49.861366: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 05:20:49.861402: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 05:20:49.862974: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 05:20:49.870402: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 05:20:49.870647: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 05:20:50.486843: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [01:38<16:25, 98.56s/it]
|
||||
18%|█▊ | 2/11 [03:16<14:44, 98.31s/it]
|
||||
27%|██▋ | 3/11 [04:55<13:06, 98.33s/it]
|
||||
36%|███▋ | 4/11 [06:36<11:37, 99.66s/it]
|
||||
45%|████▌ | 5/11 [08:31<10:29, 104.96s/it]
|
||||
55%|█████▍ | 6/11 [10:10<08:35, 103.07s/it]
|
||||
64%|██████▎ | 7/11 [11:48<06:46, 101.50s/it]
|
||||
73%|███████▎ | 8/11 [13:27<05:01, 100.52s/it]
|
||||
82%|████████▏ | 9/11 [15:05<03:19, 99.79s/it]
|
||||
91%|█████████ | 10/11 [16:43<01:39, 99.30s/it]
|
||||
100%|██████████| 11/11 [18:21<00:00, 98.97s/it]
|
||||
100%|██████████| 11/11 [18:21<00:00, 100.16s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
5
unitree_g1_pack_camera/case1/psnr_result.json
Normal file
5
unitree_g1_pack_camera/case1/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4",
|
||||
"pred_video": "unitree_g1_pack_camera/case1/output/inference/unitree_g1_pack_camera_case1_amd.mp4",
|
||||
"psnr": 16.415668383379177
|
||||
}
|
||||
158
unitree_g1_pack_camera/case2/output.log
Normal file
158
unitree_g1_pack_camera/case2/output.log
Normal file
@@ -0,0 +1,158 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 18:28:48.960238: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 18:28:48.963331: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 18:28:48.995688: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 18:28:48.995732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 18:28:48.997547: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 18:28:49.005673: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 18:28:49.005948: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 18:28:50.009660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
>>> Applying precision settings:
|
||||
- Diffusion dtype: bf16
|
||||
- Projector mode: bf16_full
|
||||
- Encoder mode: bf16_full
|
||||
- VAE dtype: fp32
|
||||
✓ Diffusion model weights converted to bfloat16
|
||||
✓ Projectors converted to bfloat16
|
||||
✓ Encoders converted to bfloat16
|
||||
✓ VAE kept in fp32 for best quality
|
||||
⚠ Found 849 fp32 params, converting to bf16
|
||||
✓ All parameters converted to bfloat16
|
||||
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [01:14<12:29, 74.95s/it]
|
||||
18%|█▊ | 2/11 [02:23<10:40, 71.18s/it]
|
||||
27%|██▋ | 3/11 [03:32<09:20, 70.05s/it]
|
||||
36%|███▋ | 4/11 [04:40<08:06, 69.51s/it]
|
||||
45%|████▌ | 5/11 [05:49<06:55, 69.19s/it]
|
||||
55%|█████▍ | 6/11 [06:57<05:44, 68.95s/it]
|
||||
64%|██████▎ | 7/11 [08:06<04:35, 68.79s/it]
|
||||
73%|███████▎ | 8/11 [09:14<03:26, 68.70s/it]
|
||||
82%|████████▏ | 9/11 [10:23<02:17, 68.65s/it]
|
||||
91%|█████████ | 10/11 [11:31<01:08, 68.58s/it]
|
||||
100%|██████████| 11/11 [12:40<00:00, 68.51s/it]
|
||||
100%|██████████| 11/11 [12:40<00:00, 69.11s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
5
unitree_g1_pack_camera/case2/psnr_result.json
Normal file
5
unitree_g1_pack_camera/case2/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_g1_pack_camera/case2/unitree_g1_pack_camera_case2.mp4",
|
||||
"pred_video": "unitree_g1_pack_camera/case2/output/inference/unitree_g1_pack_camera_case2_amd.mp4",
|
||||
"psnr": 19.515250190529375
|
||||
}
|
||||
144
unitree_g1_pack_camera/case3/output.log
Normal file
144
unitree_g1_pack_camera/case3/output.log
Normal file
@@ -0,0 +1,144 @@
|
||||
2026-02-08 05:08:32.803904: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 05:08:32.807010: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 05:08:32.837936: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 05:08:32.837978: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 05:08:32.839785: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 05:08:32.847835: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 05:08:32.848223: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 05:08:34.120114: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [01:39<16:34, 99.46s/it]
|
||||
18%|█▊ | 2/11 [03:18<14:55, 99.48s/it]
|
||||
27%|██▋ | 3/11 [04:58<13:16, 99.60s/it]
|
||||
36%|███▋ | 4/11 [06:38<11:37, 99.69s/it]
|
||||
45%|████▌ | 5/11 [08:18<09:58, 99.68s/it]
|
||||
55%|█████▍ | 6/11 [09:57<08:18, 99.66s/it]
|
||||
64%|██████▎ | 7/11 [11:37<06:38, 99.62s/it]
|
||||
73%|███████▎ | 8/11 [13:16<04:58, 99.55s/it]
|
||||
82%|████████▏ | 9/11 [14:56<03:19, 99.50s/it]
|
||||
91%|█████████ | 10/11 [16:35<01:39, 99.43s/it]
|
||||
100%|██████████| 11/11 [18:14<00:00, 99.36s/it]
|
||||
100%|██████████| 11/11 [18:14<00:00, 99.51s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
5
unitree_g1_pack_camera/case3/psnr_result.json
Normal file
5
unitree_g1_pack_camera/case3/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_g1_pack_camera/case3/unitree_g1_pack_camera_case3.mp4",
|
||||
"pred_video": "unitree_g1_pack_camera/case3/output/inference/unitree_g1_pack_camera_case3_amd.mp4",
|
||||
"psnr": 19.429578160315536
|
||||
}
|
||||
144
unitree_g1_pack_camera/case4/output.log
Normal file
144
unitree_g1_pack_camera/case4/output.log
Normal file
@@ -0,0 +1,144 @@
|
||||
2026-02-08 05:29:19.728303: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 05:29:19.731620: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 05:29:19.761276: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 05:29:19.761301: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 05:29:19.762880: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 05:29:19.770578: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 05:29:19.771072: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 05:29:21.043661: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [01:37<16:18, 97.81s/it]
|
||||
18%|█▊ | 2/11 [03:15<14:38, 97.56s/it]
|
||||
27%|██▋ | 3/11 [04:52<12:59, 97.48s/it]
|
||||
36%|███▋ | 4/11 [06:29<11:21, 97.38s/it]
|
||||
45%|████▌ | 5/11 [08:06<09:43, 97.28s/it]
|
||||
55%|█████▍ | 6/11 [09:44<08:06, 97.35s/it]
|
||||
64%|██████▎ | 7/11 [11:21<06:29, 97.36s/it]
|
||||
73%|███████▎ | 8/11 [12:59<04:52, 97.38s/it]
|
||||
82%|████████▏ | 9/11 [14:36<03:14, 97.39s/it]
|
||||
91%|█████████ | 10/11 [16:14<01:37, 97.42s/it]
|
||||
100%|██████████| 11/11 [17:51<00:00, 97.42s/it]
|
||||
100%|██████████| 11/11 [17:51<00:00, 97.41s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
5
unitree_g1_pack_camera/case4/psnr_result.json
Normal file
5
unitree_g1_pack_camera/case4/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_g1_pack_camera/case4/unitree_g1_pack_camera_case4.mp4",
|
||||
"pred_video": "unitree_g1_pack_camera/case4/output/inference/unitree_g1_pack_camera_case4_amd.mp4",
|
||||
"psnr": 17.80386833747375
|
||||
}
|
||||
145
unitree_z1_dual_arm_cleanup_pencils/case1/output.log
Normal file
145
unitree_z1_dual_arm_cleanup_pencils/case1/output.log
Normal file
@@ -0,0 +1,145 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-09 18:39:50.119842: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-09 18:39:50.123128: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-09 18:39:50.156652: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-09 18:39:50.156708: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-09 18:39:50.158926: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-09 18:39:50.167779: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-09 18:39:50.168073: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-09 18:39:50.915144: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
>>> Applying precision settings:
|
||||
- Diffusion dtype: bf16
|
||||
- Projector mode: bf16_full
|
||||
- Encoder mode: bf16_full
|
||||
- VAE dtype: bf16
|
||||
✓ Diffusion model weights converted to bfloat16
|
||||
✓ Projectors converted to bfloat16
|
||||
✓ Encoders converted to bfloat16
|
||||
✓ VAE converted to bfloat16
|
||||
⚠ Found 601 fp32 params, converting to bf16
|
||||
✓ All parameters converted to bfloat16
|
||||
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
12%|█▎ | 1/8 [01:08<07:58, 68.38s/it]
|
||||
25%|██▌ | 2/8 [02:13<06:38, 66.48s/it]
|
||||
38%|███▊ | 3/8 [03:18<05:29, 65.83s/it]
|
||||
50%|█████ | 4/8 [04:23<04:22, 65.52s/it]
|
||||
62%|██████▎ | 5/8 [05:28<03:15, 65.33s/it]
|
||||
75%|███████▌ | 6/8 [06:33<02:10, 65.23s/it]
|
||||
88%|████████▊ | 7/8 [07:38<01:05, 65.12s/it]
|
||||
100%|██████████| 8/8 [08:43<00:00, 65.08s/it]
|
||||
100%|██████████| 8/8 [08:43<00:00, 65.44s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||
"psnr": 19.586376345676264
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||
"psnr": 31.802224855380352
|
||||
}
|
||||
5
unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh
Normal file
5
unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh
Normal file
@@ -0,0 +1,5 @@
|
||||
#\!/bin/bash
|
||||
res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
|
||||
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||
|
||||
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/profile_iteration.py --seed 123 --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --savedir "${res_dir}/profile_output" --prompt_dir "${res_dir}/world_model_interaction_prompts" --dataset ${dataset} --bs 1 --height 320 --width 512 --unconditional_guidance_scale 1.0 --ddim_steps 50 --ddim_eta 1.0 --video_length 16 --frame_stride 4 --exe_steps 16 --n_iter 5 --warmup 1 --timestep_spacing uniform_trailing --guidance_rescale 0.7 --perframe_ae --vae_dtype bf16 --fast_policy_no_decode --csv "${res_dir}/profile_output/baseline.csv" 2>&1 | tee "${res_dir}/profile_output/profile.log"
|
||||
@@ -2,9 +2,9 @@ res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
|
||||
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
--savedir "${res_dir}/output" \
|
||||
--bs 1 --height 320 --width 512 \
|
||||
@@ -20,5 +20,10 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||
--n_iter 8 \
|
||||
--timestep_spacing 'uniform_trailing' \
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae
|
||||
--perframe_ae \
|
||||
--diffusion_dtype fp32 \
|
||||
--projector_mode fp32 \
|
||||
--encoder_mode fp32 \
|
||||
--vae_dtype fp32 \
|
||||
--fast_policy_no_decode
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
137
unitree_z1_dual_arm_cleanup_pencils/case2/output.log
Normal file
137
unitree_z1_dual_arm_cleanup_pencils/case2/output.log
Normal file
@@ -0,0 +1,137 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 06:59:34.465946: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 06:59:34.469367: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 06:59:34.500805: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 06:59:34.500837: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 06:59:34.502917: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 06:59:34.511434: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 06:59:34.511678: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 06:59:35.478194: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
12%|█▎ | 1/8 [01:37<11:23, 97.57s/it]
|
||||
25%|██▌ | 2/8 [03:14<09:44, 97.48s/it]
|
||||
38%|███▊ | 3/8 [04:52<08:07, 97.47s/it]
|
||||
50%|█████ | 4/8 [06:29<06:29, 97.49s/it]
|
||||
62%|██████▎ | 5/8 [08:07<04:52, 97.42s/it]
|
||||
75%|███████▌ | 6/8 [09:44<03:14, 97.32s/it]
|
||||
88%|████████▊ | 7/8 [11:21<01:37, 97.34s/it]
|
||||
100%|██████████| 8/8 [12:59<00:00, 97.36s/it]
|
||||
100%|██████████| 8/8 [12:59<00:00, 97.40s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case2/unitree_z1_dual_arm_cleanup_pencils_case2.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case2/output/inference/unitree_z1_dual_arm_cleanup_pencils_case2_amd.mp4",
|
||||
"psnr": 20.484298972158296
|
||||
}
|
||||
137
unitree_z1_dual_arm_cleanup_pencils/case3/output.log
Normal file
137
unitree_z1_dual_arm_cleanup_pencils/case3/output.log
Normal file
@@ -0,0 +1,137 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:18:52.629976: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:18:52.633025: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:18:52.663985: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:18:52.664018: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:18:52.665837: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:18:52.673889: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:18:52.674218: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:18:53.298338: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
12%|█▎ | 1/8 [01:40<11:43, 100.54s/it]
|
||||
25%|██▌ | 2/8 [03:20<10:02, 100.36s/it]
|
||||
38%|███▊ | 3/8 [05:01<08:21, 100.32s/it]
|
||||
50%|█████ | 4/8 [06:41<06:41, 100.36s/it]
|
||||
62%|██████▎ | 5/8 [08:21<05:00, 100.30s/it]
|
||||
75%|███████▌ | 6/8 [10:01<03:20, 100.28s/it]
|
||||
88%|████████▊ | 7/8 [11:42<01:40, 100.34s/it]
|
||||
100%|██████████| 8/8 [13:22<00:00, 100.36s/it]
|
||||
100%|██████████| 8/8 [13:22<00:00, 100.34s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case3/unitree_z1_dual_arm_cleanup_pencils_case3.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case3/output/inference/unitree_z1_dual_arm_cleanup_pencils_case3_amd.mp4",
|
||||
"psnr": 21.20205061239349
|
||||
}
|
||||
137
unitree_z1_dual_arm_cleanup_pencils/case4/output.log
Normal file
137
unitree_z1_dual_arm_cleanup_pencils/case4/output.log
Normal file
@@ -0,0 +1,137 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:22:15.333099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:22:15.336215: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:22:15.366489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:22:15.366522: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:22:15.368294: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:22:15.376202: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:22:15.376444: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:22:15.995383: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
12%|█▎ | 1/8 [01:37<11:23, 97.68s/it]
|
||||
25%|██▌ | 2/8 [03:15<09:47, 97.83s/it]
|
||||
38%|███▊ | 3/8 [04:53<08:09, 97.91s/it]
|
||||
50%|█████ | 4/8 [06:31<06:32, 98.03s/it]
|
||||
62%|██████▎ | 5/8 [08:10<04:54, 98.11s/it]
|
||||
75%|███████▌ | 6/8 [09:48<03:16, 98.18s/it]
|
||||
88%|████████▊ | 7/8 [11:26<01:38, 98.24s/it]
|
||||
100%|██████████| 8/8 [13:04<00:00, 98.16s/it]
|
||||
100%|██████████| 8/8 [13:04<00:00, 98.09s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case4/unitree_z1_dual_arm_cleanup_pencils_case4.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case4/output/inference/unitree_z1_dual_arm_cleanup_pencils_case4_amd.mp4",
|
||||
"psnr": 21.130122583788612
|
||||
}
|
||||
134
unitree_z1_dual_arm_stackbox/case1/output.log
Normal file
134
unitree_z1_dual_arm_stackbox/case1/output.log
Normal file
@@ -0,0 +1,134 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:24:40.357099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:24:40.360365: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:24:40.391744: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:24:40.391772: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:24:40.393608: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:24:40.401837: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:24:40.402077: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:24:41.022382: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
14%|█▍ | 1/7 [01:41<10:09, 101.63s/it]
|
||||
29%|██▊ | 2/7 [03:20<08:18, 99.78s/it]
|
||||
43%|████▎ | 3/7 [04:58<06:36, 99.24s/it]
|
||||
57%|█████▋ | 4/7 [06:37<04:57, 99.05s/it]
|
||||
71%|███████▏ | 5/7 [08:16<03:17, 98.90s/it]
|
||||
86%|████████▌ | 6/7 [09:54<01:38, 98.80s/it]
|
||||
100%|██████████| 7/7 [11:33<00:00, 98.70s/it]
|
||||
100%|██████████| 7/7 [11:33<00:00, 99.03s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
5
unitree_z1_dual_arm_stackbox/case1/psnr_result.json
Normal file
5
unitree_z1_dual_arm_stackbox/case1/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox/case1/unitree_z1_dual_arm_stackbox_case1.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox/case1/output/inference/unitree_z1_dual_arm_stackbox_case1_amd.mp4",
|
||||
"psnr": 21.258130518117493
|
||||
}
|
||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case1"
|
||||
dataset="unitree_z1_dual_arm_stackbox"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
|
||||
134
unitree_z1_dual_arm_stackbox/case2/output.log
Normal file
134
unitree_z1_dual_arm_stackbox/case2/output.log
Normal file
@@ -0,0 +1,134 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:25:18.653033: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:25:18.656060: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:25:18.687077: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:25:18.687119: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:25:18.688915: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:25:18.697008: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:25:18.697255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:25:19.338303: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
14%|█▍ | 1/7 [01:39<09:56, 99.35s/it]
|
||||
29%|██▊ | 2/7 [03:18<08:17, 99.50s/it]
|
||||
43%|████▎ | 3/7 [04:58<06:38, 99.54s/it]
|
||||
57%|█████▋ | 4/7 [06:38<04:58, 99.52s/it]
|
||||
71%|███████▏ | 5/7 [08:17<03:19, 99.55s/it]
|
||||
86%|████████▌ | 6/7 [09:57<01:39, 99.53s/it]
|
||||
100%|██████████| 7/7 [11:36<00:00, 99.50s/it]
|
||||
100%|██████████| 7/7 [11:36<00:00, 99.51s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
5
unitree_z1_dual_arm_stackbox/case2/psnr_result.json
Normal file
5
unitree_z1_dual_arm_stackbox/case2/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox/case2/unitree_z1_dual_arm_stackbox_case2.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox/case2/output/inference/unitree_z1_dual_arm_stackbox_case2_amd.mp4",
|
||||
"psnr": 23.878153424077645
|
||||
}
|
||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case2"
|
||||
dataset="unitree_z1_dual_arm_stackbox"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
|
||||
134
unitree_z1_dual_arm_stackbox/case3/output.log
Normal file
134
unitree_z1_dual_arm_stackbox/case3/output.log
Normal file
@@ -0,0 +1,134 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:35:33.682231: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:35:33.685275: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:35:33.716682: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:35:33.716728: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:35:33.718523: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:35:33.726756: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:35:33.727105: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:35:34.356722: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
14%|█▍ | 1/7 [01:41<10:06, 101.02s/it]
|
||||
29%|██▊ | 2/7 [03:23<08:29, 101.84s/it]
|
||||
43%|████▎ | 3/7 [05:04<06:45, 101.43s/it]
|
||||
57%|█████▋ | 4/7 [06:45<05:04, 101.42s/it]
|
||||
71%|███████▏ | 5/7 [08:27<03:22, 101.40s/it]
|
||||
86%|████████▌ | 6/7 [10:08<01:41, 101.39s/it]
|
||||
100%|██████████| 7/7 [11:49<00:00, 101.33s/it]
|
||||
100%|██████████| 7/7 [11:49<00:00, 101.39s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
5
unitree_z1_dual_arm_stackbox/case3/psnr_result.json
Normal file
5
unitree_z1_dual_arm_stackbox/case3/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox/case3/unitree_z1_dual_arm_stackbox_case3.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox/case3/output/inference/unitree_z1_dual_arm_stackbox_case3_amd.mp4",
|
||||
"psnr": 25.400458754751128
|
||||
}
|
||||
134
unitree_z1_dual_arm_stackbox/case4/output.log
Normal file
134
unitree_z1_dual_arm_stackbox/case4/output.log
Normal file
@@ -0,0 +1,134 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
14%|█▍ | 1/7 [01:38<09:52, 98.73s/it]
|
||||
29%|██▊ | 2/7 [03:17<08:14, 98.85s/it]
|
||||
43%|████▎ | 3/7 [04:56<06:35, 98.80s/it]
|
||||
57%|█████▋ | 4/7 [06:35<04:56, 98.94s/it]
|
||||
71%|███████▏ | 5/7 [08:14<03:17, 98.93s/it]
|
||||
86%|████████▌ | 6/7 [09:53<01:38, 98.89s/it]
|
||||
100%|██████████| 7/7 [11:31<00:00, 98.81s/it]
|
||||
100%|██████████| 7/7 [11:31<00:00, 98.85s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
5
unitree_z1_dual_arm_stackbox/case4/psnr_result.json
Normal file
5
unitree_z1_dual_arm_stackbox/case4/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox/case4/unitree_z1_dual_arm_stackbox_case4.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox/case4/output/inference/unitree_z1_dual_arm_stackbox_case4_amd.mp4",
|
||||
"psnr": 24.098958457373858
|
||||
}
|
||||
146
unitree_z1_dual_arm_stackbox_v2/case1/output.log
Normal file
146
unitree_z1_dual_arm_stackbox_v2/case1/output.log
Normal file
@@ -0,0 +1,146 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:51:23.961486: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:51:24.200063: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:51:24.522299: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:51:24.522350: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:51:24.528237: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:51:24.579400: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:51:24.579644: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:51:25.781311: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [01:38<16:20, 98.04s/it]
|
||||
18%|█▊ | 2/11 [03:15<14:40, 97.81s/it]
|
||||
27%|██▋ | 3/11 [04:53<13:01, 97.72s/it]
|
||||
36%|███▋ | 4/11 [06:31<11:24, 97.71s/it]
|
||||
45%|████▌ | 5/11 [08:08<09:46, 97.71s/it]
|
||||
55%|█████▍ | 6/11 [09:46<08:08, 97.65s/it]
|
||||
64%|██████▎ | 7/11 [11:23<06:30, 97.65s/it]
|
||||
73%|███████▎ | 8/11 [13:02<04:54, 98.09s/it]
|
||||
82%|████████▏ | 9/11 [14:40<03:15, 97.83s/it]
|
||||
91%|█████████ | 10/11 [16:17<01:37, 97.73s/it]
|
||||
100%|██████████| 11/11 [17:55<00:00, 97.64s/it]
|
||||
100%|██████████| 11/11 [17:55<00:00, 97.74s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
5
unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json
Normal file
5
unitree_z1_dual_arm_stackbox_v2/case1/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/unitree_z1_dual_arm_stackbox_v2_case1_amd.mp4",
|
||||
"psnr": 18.126776535969576
|
||||
}
|
||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case1"
|
||||
dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
|
||||
146
unitree_z1_dual_arm_stackbox_v2/case2/output.log
Normal file
146
unitree_z1_dual_arm_stackbox_v2/case2/output.log
Normal file
@@ -0,0 +1,146 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:56:31.144789: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:56:31.148256: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:56:31.178870: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:56:31.178898: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:56:31.180683: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:56:31.188800: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:56:31.189142: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:56:31.810098: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [01:40<16:41, 100.16s/it]
|
||||
18%|█▊ | 2/11 [03:20<15:04, 100.47s/it]
|
||||
27%|██▋ | 3/11 [05:01<13:24, 100.62s/it]
|
||||
36%|███▋ | 4/11 [06:42<11:44, 100.69s/it]
|
||||
45%|████▌ | 5/11 [08:22<10:02, 100.48s/it]
|
||||
55%|█████▍ | 6/11 [10:02<08:21, 100.33s/it]
|
||||
64%|██████▎ | 7/11 [11:42<06:40, 100.23s/it]
|
||||
73%|███████▎ | 8/11 [13:22<05:00, 100.23s/it]
|
||||
82%|████████▏ | 9/11 [15:03<03:20, 100.23s/it]
|
||||
91%|█████████ | 10/11 [16:43<01:40, 100.33s/it]
|
||||
100%|██████████| 11/11 [18:24<00:00, 100.41s/it]
|
||||
100%|██████████| 11/11 [18:24<00:00, 100.39s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
5
unitree_z1_dual_arm_stackbox_v2/case2/psnr_result.json
Normal file
5
unitree_z1_dual_arm_stackbox_v2/case2/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case2/unitree_z1_dual_arm_stackbox_v2_case2.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case2/output/inference/unitree_z1_dual_arm_stackbox_v2_case2_amd.mp4",
|
||||
"psnr": 19.38130614773096
|
||||
}
|
||||
146
unitree_z1_dual_arm_stackbox_v2/case3/output.log
Normal file
146
unitree_z1_dual_arm_stackbox_v2/case3/output.log
Normal file
@@ -0,0 +1,146 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 07:56:04.467082: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 07:56:04.470145: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:56:04.502248: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 07:56:04.502277: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 07:56:04.504088: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 07:56:04.512557: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 07:56:04.512830: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 07:56:05.259641: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [01:38<16:20, 98.03s/it]
|
||||
18%|█▊ | 2/11 [03:16<14:43, 98.19s/it]
|
||||
27%|██▋ | 3/11 [04:55<13:08, 98.54s/it]
|
||||
36%|███▋ | 4/11 [06:33<11:29, 98.52s/it]
|
||||
45%|████▌ | 5/11 [08:11<09:50, 98.38s/it]
|
||||
55%|█████▍ | 6/11 [09:49<08:10, 98.11s/it]
|
||||
64%|██████▎ | 7/11 [11:27<06:31, 97.97s/it]
|
||||
73%|███████▎ | 8/11 [13:04<04:53, 97.83s/it]
|
||||
82%|████████▏ | 9/11 [14:42<03:15, 97.72s/it]
|
||||
91%|█████████ | 10/11 [16:19<01:37, 97.71s/it]
|
||||
100%|██████████| 11/11 [17:57<00:00, 97.74s/it]
|
||||
100%|██████████| 11/11 [17:57<00:00, 97.97s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
5
unitree_z1_dual_arm_stackbox_v2/case3/psnr_result.json
Normal file
5
unitree_z1_dual_arm_stackbox_v2/case3/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case3/unitree_z1_dual_arm_stackbox_v2_case3.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case3/output/inference/unitree_z1_dual_arm_stackbox_v2_case3_amd.mp4",
|
||||
"psnr": 18.74462122425683
|
||||
}
|
||||
146
unitree_z1_dual_arm_stackbox_v2/case4/output.log
Normal file
146
unitree_z1_dual_arm_stackbox_v2/case4/output.log
Normal file
@@ -0,0 +1,146 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 08:04:16.104516: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 08:04:16.109112: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:04:16.138703: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 08:04:16.138737: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 08:04:16.140302: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 08:04:16.147672: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:04:16.147903: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 08:04:17.363218: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
9%|▉ | 1/11 [01:39<16:32, 99.26s/it]
|
||||
18%|█▊ | 2/11 [03:17<14:49, 98.81s/it]
|
||||
27%|██▋ | 3/11 [04:56<13:10, 98.76s/it]
|
||||
36%|███▋ | 4/11 [06:35<11:31, 98.80s/it]
|
||||
45%|████▌ | 5/11 [08:14<09:53, 98.85s/it]
|
||||
55%|█████▍ | 6/11 [09:53<08:14, 98.87s/it]
|
||||
64%|██████▎ | 7/11 [11:31<06:34, 98.68s/it]
|
||||
73%|███████▎ | 8/11 [13:09<04:55, 98.49s/it]
|
||||
82%|████████▏ | 9/11 [14:47<03:16, 98.38s/it]
|
||||
91%|█████████ | 10/11 [16:25<01:38, 98.29s/it]
|
||||
100%|██████████| 11/11 [18:03<00:00, 98.26s/it]
|
||||
100%|██████████| 11/11 [18:03<00:00, 98.54s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
5
unitree_z1_dual_arm_stackbox_v2/case4/psnr_result.json
Normal file
5
unitree_z1_dual_arm_stackbox_v2/case4/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case4/unitree_z1_dual_arm_stackbox_v2_case4.mp4",
|
||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case4/output/inference/unitree_z1_dual_arm_stackbox_v2_case4_amd.mp4",
|
||||
"psnr": 19.526448380726254
|
||||
}
|
||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case4"
|
||||
dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
|
||||
149
unitree_z1_stackbox/case1/output.log
Normal file
149
unitree_z1_stackbox/case1/output.log
Normal file
@@ -0,0 +1,149 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 08:12:47.424053: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 08:12:47.427280: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:12:47.458253: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 08:12:47.458288: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 08:12:47.462758: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 08:12:47.518283: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:12:47.518566: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 08:12:48.593011: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
8%|▊ | 1/12 [01:38<18:08, 98.94s/it]
|
||||
17%|█▋ | 2/12 [03:18<16:30, 99.01s/it]
|
||||
25%|██▌ | 3/12 [04:57<14:51, 99.07s/it]
|
||||
33%|███▎ | 4/12 [06:36<13:12, 99.04s/it]
|
||||
42%|████▏ | 5/12 [08:15<11:33, 99.00s/it]
|
||||
50%|█████ | 6/12 [09:54<09:54, 99.10s/it]
|
||||
58%|█████▊ | 7/12 [11:33<08:14, 99.00s/it]
|
||||
67%|██████▋ | 8/12 [13:13<06:38, 99.58s/it]
|
||||
75%|███████▌ | 9/12 [14:54<04:59, 99.88s/it]
|
||||
83%|████████▎ | 10/12 [16:33<03:19, 99.58s/it]
|
||||
92%|█████████▏| 11/12 [18:12<01:39, 99.39s/it]
|
||||
100%|██████████| 12/12 [19:51<00:00, 99.25s/it]
|
||||
100%|██████████| 12/12 [19:51<00:00, 99.28s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 8: generating actions ...
|
||||
>>> Step 8: interacting with world model ...
|
||||
5
unitree_z1_stackbox/case1/psnr_result.json
Normal file
5
unitree_z1_stackbox/case1/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_stackbox/case1/unitree_z1_stackbox_case1.mp4",
|
||||
"pred_video": "unitree_z1_stackbox/case1/output/inference/unitree_z1_stackbox_case1_amd.mp4",
|
||||
"psnr": 19.81391789862606
|
||||
}
|
||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case1"
|
||||
dataset="unitree_z1_stackbox"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=5 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
|
||||
149
unitree_z1_stackbox/case2/output.log
Normal file
149
unitree_z1_stackbox/case2/output.log
Normal file
@@ -0,0 +1,149 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
8%|▊ | 1/12 [01:37<17:51, 97.37s/it]
|
||||
17%|█▋ | 2/12 [03:14<16:13, 97.31s/it]
|
||||
25%|██▌ | 3/12 [04:51<14:35, 97.26s/it]
|
||||
33%|███▎ | 4/12 [06:29<12:58, 97.25s/it]
|
||||
42%|████▏ | 5/12 [08:06<11:20, 97.24s/it]
|
||||
50%|█████ | 6/12 [09:43<09:43, 97.24s/it]
|
||||
58%|█████▊ | 7/12 [11:20<08:06, 97.27s/it]
|
||||
67%|██████▋ | 8/12 [12:58<06:29, 97.36s/it]
|
||||
75%|███████▌ | 9/12 [14:36<04:52, 97.49s/it]
|
||||
83%|████████▎ | 10/12 [16:13<03:15, 97.52s/it]
|
||||
92%|█████████▏| 11/12 [17:51<01:37, 97.47s/it]
|
||||
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
||||
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 8: generating actions ...
|
||||
>>> Step 8: interacting with world model ...
|
||||
5
unitree_z1_stackbox/case2/psnr_result.json
Normal file
5
unitree_z1_stackbox/case2/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_stackbox/case2/unitree_z1_stackbox_case2.mp4",
|
||||
"pred_video": "unitree_z1_stackbox/case2/output/inference/unitree_z1_stackbox_case2_amd.mp4",
|
||||
"psnr": 21.083821459054743
|
||||
}
|
||||
149
unitree_z1_stackbox/case3/output.log
Normal file
149
unitree_z1_stackbox/case3/output.log
Normal file
@@ -0,0 +1,149 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 08:16:22.299521: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 08:16:22.302545: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:16:22.335354: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 08:16:22.335389: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 08:16:22.337179: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 08:16:22.345296: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:16:22.345548: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 08:16:23.008743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
[rank: 0] Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
8%|▊ | 1/12 [01:39<18:16, 99.64s/it]
|
||||
17%|█▋ | 2/12 [03:19<16:35, 99.56s/it]
|
||||
25%|██▌ | 3/12 [04:58<14:55, 99.53s/it]
|
||||
33%|███▎ | 4/12 [06:38<13:16, 99.53s/it]
|
||||
42%|████▏ | 5/12 [08:17<11:36, 99.54s/it]
|
||||
50%|█████ | 6/12 [09:57<09:57, 99.57s/it]
|
||||
58%|█████▊ | 7/12 [11:37<08:18, 99.66s/it]
|
||||
67%|██████▋ | 8/12 [13:17<06:39, 99.83s/it]
|
||||
75%|███████▌ | 9/12 [14:57<04:59, 99.93s/it]
|
||||
83%|████████▎ | 10/12 [16:37<03:19, 99.97s/it]
|
||||
92%|█████████▏| 11/12 [18:17<01:39, 99.85s/it]
|
||||
100%|██████████| 12/12 [19:56<00:00, 99.71s/it]
|
||||
100%|██████████| 12/12 [19:56<00:00, 99.71s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 8: generating actions ...
|
||||
>>> Step 8: interacting with world model ...
|
||||
5
unitree_z1_stackbox/case3/psnr_result.json
Normal file
5
unitree_z1_stackbox/case3/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_stackbox/case3/unitree_z1_stackbox_case3.mp4",
|
||||
"pred_video": "unitree_z1_stackbox/case3/output/inference/unitree_z1_stackbox_case3_amd.mp4",
|
||||
"psnr": 21.322784880212172
|
||||
}
|
||||
149
unitree_z1_stackbox/case4/output.log
Normal file
149
unitree_z1_stackbox/case4/output.log
Normal file
@@ -0,0 +1,149 @@
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||
__import__("pkg_resources").declare_namespace(__name__)
|
||||
2026-02-08 08:25:54.657305: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||
2026-02-08 08:25:54.660628: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:25:54.691237: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||
2026-02-08 08:25:54.691275: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||
2026-02-08 08:25:54.693046: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||
2026-02-08 08:25:54.701142: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||
2026-02-08 08:25:54.701413: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||
2026-02-08 08:25:55.801367: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||
Global seed set to 123
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||
INFO:root:Loaded ViT-H-14 model config.
|
||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||
state_dict = torch.load(ckpt, map_location="cpu")
|
||||
>>> model checkpoint loaded.
|
||||
>>> Load pre-trained model ...
|
||||
INFO:root:***** Configing Data *****
|
||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_stackbox: data stats loaded.
|
||||
>>> unitree_z1_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||
>>> unitree_g1_pack_camera: data stats loaded.
|
||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||
>>> Dataset is successfully loaded ...
|
||||
>>> Generate 16 frames under each generation ...
|
||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||
|
||||
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||
proj = linear(q, w, b)
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
>>> Step 0: generating actions ...
|
||||
>>> Step 0: interacting with world model ...
|
||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||
|
||||
8%|▊ | 1/12 [01:37<17:51, 97.38s/it]
|
||||
17%|█▋ | 2/12 [03:14<16:12, 97.24s/it]
|
||||
25%|██▌ | 3/12 [04:51<14:35, 97.28s/it]
|
||||
33%|███▎ | 4/12 [06:29<12:59, 97.40s/it]
|
||||
42%|████▏ | 5/12 [08:06<11:21, 97.30s/it]
|
||||
50%|█████ | 6/12 [09:43<09:43, 97.17s/it]
|
||||
58%|█████▊ | 7/12 [11:20<08:05, 97.07s/it]
|
||||
67%|██████▋ | 8/12 [12:57<06:28, 97.02s/it]
|
||||
75%|███████▌ | 9/12 [14:34<04:50, 96.98s/it]
|
||||
83%|████████▎ | 10/12 [16:11<03:14, 97.00s/it]
|
||||
92%|█████████▏| 11/12 [17:48<01:37, 97.06s/it]
|
||||
100%|██████████| 12/12 [19:25<00:00, 97.13s/it]
|
||||
100%|██████████| 12/12 [19:25<00:00, 97.14s/it]
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 1: generating actions ...
|
||||
>>> Step 1: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 2: generating actions ...
|
||||
>>> Step 2: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 3: generating actions ...
|
||||
>>> Step 3: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 4: generating actions ...
|
||||
>>> Step 4: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 5: generating actions ...
|
||||
>>> Step 5: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 6: generating actions ...
|
||||
>>> Step 6: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 7: generating actions ...
|
||||
>>> Step 7: interacting with world model ...
|
||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||
>>> Step 8: generating actions ...
|
||||
>>> Step 8: interacting with world model ...
|
||||
5
unitree_z1_stackbox/case4/psnr_result.json
Normal file
5
unitree_z1_stackbox/case4/psnr_result.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gt_video": "unitree_z1_stackbox/case4/unitree_z1_stackbox_case4.mp4",
|
||||
"pred_video": "unitree_z1_stackbox/case4/output/inference/unitree_z1_stackbox_case4_amd.mp4",
|
||||
"psnr": 25.32928948331741
|
||||
}
|
||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case4"
|
||||
dataset="unitree_z1_stackbox"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
|
||||
Reference in New Issue
Block a user