14 Commits

Author SHA1 Message Date
qhy
3069666a15 脚本修改 2026-02-10 14:49:26 +08:00
qhy
68369cc15f 合并后测试 2026-02-10 14:45:14 +08:00
b0ebb7006e 添加三层迭代级性能分析工具 profile_iteration.py
Layer1: CUDA Events 精确测量每个itr内10个阶段耗时
Layer2: torch.profiler GPU timeline trace
Layer3: CSV输出支持A/B对比

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-10 05:42:11 +00:00
125b85ce68 实现fs_embed 缓存,收益不明显,精度不降低 2026-02-09 18:49:44 +00:00
0b3b0e534a 复用 DDIMSampler + make_schedule微弱提升 2026-02-09 18:26:39 +00:00
6dca3696d8 实现了Context 预计算和缓存功能,提升了采样效率。 psnr不下降 2026-02-09 17:42:47 +00:00
f192c8aca9 添加CrossAttention kv缓存,减少重复计算,提升性能,psnr=31.8022 dB 2026-02-09 17:04:23 +00:00
4288c9d8c9 减少了一路视频vae解码 2026-02-09 16:48:16 +00:00
a2cd34dd51 1. einsum('b i d, b j d -> b i j') → torch.bmm(q, k.transpose(-1,-2)) — 直接映射 rocBLAS batched GEMM
2. baddbmm 把 scale 融合进 GEMM,少一次 kernel launch
3. 第二个 einsum 同理换torch.bm
每一轮加速1到两秒
2026-02-08 18:54:48 +00:00
7338cc384a ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配
attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
2026-02-08 17:02:05 +00:00
f86ab51a04 全链路 bf16 混合精度修正与 UNet FLOPS profiling
- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销
  - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配
  - alphas_cumprod 提升到 float64 保证数值精度
  - SinusoidalPosEmb 输出 dtype跟随模型精度
  - 新增 profile_unet.py 脚本及FLOPS 分析结果
  - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL
  - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
2026-02-08 16:01:30 +00:00
75c798ded0 DDIM loop 内小张量分配优化,attention mask 缓存到 GPU 2026-02-08 14:20:48 +00:00
e588182642 修复混合精度vae相关的配置错误,确保在推理阶段正确使用了混合精度模型,并且导出了正确精度的检查点文件。 2026-02-08 12:35:59 +00:00
e6c55a648c 所有case的baseline,amd版本的ground truth都上传了 2026-02-08 09:42:14 +00:00
70 changed files with 6297 additions and 236 deletions

6
.gitignore vendored
View File

@@ -55,7 +55,7 @@ coverage.xml
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
@@ -121,7 +121,6 @@ localTest/
fig/
figure/
*.mp4
*.json
Data/ControlVAE.yml
Data/Misc
Data/Pretrained
@@ -129,4 +128,5 @@ Data/utils.py
Experiment/checkpoint
Experiment/log
*.ckpt
*.ckpt
*.0

135
case4_run.log Normal file
View File

@@ -0,0 +1,135 @@
nohup: ignoring input
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:38<09:52, 98.73s/it]
29%|██▊ | 2/7 [03:17<08:14, 98.85s/it]
43%|████▎ | 3/7 [04:56<06:35, 98.80s/it]
57%|█████▋ | 4/7 [06:35<04:56, 98.94s/it]
71%|███████▏ | 5/7 [08:14<03:17, 98.93s/it]
86%|████████▌ | 6/7 [09:53<01:38, 98.89s/it]
100%|██████████| 7/7 [11:31<00:00, 98.81s/it]
100%|██████████| 7/7 [11:31<00:00, 98.85s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

1
ckpts/configuration.json Normal file
View File

@@ -0,0 +1 @@
{"framework": "pytorch", "task": "robotics", "allow_remote": true}

View File

@@ -222,7 +222,7 @@ data:
test:
target: unifolm_wma.data.wma_data.WMAData
params:
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2
load_raw_resolution: True

21
env.sh Normal file
View File

@@ -0,0 +1,21 @@
# Note: This script should be sourced, not executed
# Usage: source env.sh
#
# If you need render group permissions, run this first:
# newgrp render
# Then source this script:
# source env.sh
# Initialize conda
source /mnt/ASC1637/miniconda3/etc/profile.d/conda.sh
# Activate conda environment
conda activate unifolm-wma-o
# Set HuggingFace cache directories
export HF_HOME=/mnt/ASC1637/hf_home
export HUGGINGFACE_HUB_CACHE=/mnt/ASC1637/hf_home/hub
echo "Environment configured successfully"
echo "Conda environment: unifolm-wma-o"
echo "HF_HOME: $HF_HOME"

217
profile_unet_flops.md Normal file
View File

@@ -0,0 +1,217 @@
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_unet.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml
==================================================================================================================================
FLOPS BY ATen OPERATOR (FlopCounterMode)
==================================================================================================================================
ATen Op | GFLOPS | % of Total
-------------------------------------------------------
convolution | 6185.17 | 46.4%
addmm | 4411.17 | 33.1%
mm | 1798.34 | 13.5%
bmm | 949.54 | 7.1%
==================================================================================================================================
FLOPS BY MODULE (FlopCounterMode)
==================================================================================================================================
Module | GFLOPS | % of Total
------------------------------------------------------------------------------------------
Global | 13344.23 | 100.0%
DiffusionWrapper | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
==================================================================================================================================
SUMMARY
==================================================================================================================================
Total CUDA time: 761.4 ms
Matmul CUDA time: 404.2 ms (53.1%)
Non-matmul CUDA time: 357.1 ms (46.9%)
Total FLOPS (FlopCounter): 13344.23 GFLOPS
Matmul throughput: 33.01 TFLOPS/s (54.1% of BF16 peak)
Overall throughput: 17.53 TFLOPS/s (28.7% of BF16 peak)
GPU peak (BF16): 61.0 TFLOPS
==================================================================================================================================
FLOPS BY ATen OPERATOR (FlopCounterMode)
==================================================================================================================================
ATen Op | GFLOPS | % of Total
-------------------------------------------------------
convolution | 6185.17 | 46.4%
addmm | 4411.17 | 33.1%
mm | 1798.34 | 13.5%
bmm | 949.54 | 7.1%
==================================================================================================================================
FLOPS BY MODULE (FlopCounterMode)
==================================================================================================================================
Module | GFLOPS | % of Total
------------------------------------------------------------------------------------------
DiffusionWrapper | 13344.23 | 100.0%
Global | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
==================================================================================================================================
SUMMARY
==================================================================================================================================
Total CUDA time: 707.1 ms
Matmul CUDA time: 403.1 ms (57.0%)
Non-matmul CUDA time: 304.0 ms (43.0%)
Total FLOPS (FlopCounter): 13344.23 GFLOPS
Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak)
Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak)
GPU peak (BF16): 61.0 TFLOPS
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
========================================================================
TABLE 1: STAGE TIMING
========================================================================
Stage Mean(ms) Std %
------------------------------------------------------------------------
1_Image_Embedding 29.5 0.16 0.1%
2_VAE_Encode 51.3 0.06 0.1%
3_Text_Conditioning 14.7 0.18 0.0%
4_Projectors 0.2 0.03 0.0%
5_DDIM_Loop 33392.5 3.21 97.3%
6_VAE_Decode 808.4 1.00 2.4%
7_Post_Process 15.8 0.56 0.0%
------------------------------------------------------------------------
TOTAL 34312.4
================================================================================
TABLE 2: UNET SUB-MODULE BREAKDOWN
================================================================================
Module Type Total(ms) Count Per-call %
--------------------------------------------------------------------------------
ResBlock 10256.3 1100 9.32 23.2%
SpatialTransformer 9228.2 800 11.54 20.9%
CrossAttention 8105.8 3300 2.46 18.3%
ConditionalUnet1D 6409.5 100 64.10 14.5%
TemporalTransformer 5847.0 850 6.88 13.2%
FeedForward 4338.1 1650 2.63 9.8%
UNet.out 73.8 50 1.48 0.2%
--------------------------------------------------------------------------------
TOTAL (hooked) 44258.7
==========================================================================================
TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)
==========================================================================================
Block Total(ms) % Breakdown
------------------------------------------------------------------------------------------
input_blocks.1 3376.2 7.6% SpatialTransformer=1101, CrossAttention=990, ResBlock=543, TemporalTransformer=454, FeedForward=288
input_blocks.2 3374.0 7.6% SpatialTransformer=1100, CrossAttention=991, ResBlock=540, TemporalTransformer=455, FeedForward=288
input_blocks.4 1592.4 3.6% SpatialTransformer=394, ResBlock=374, CrossAttention=303, TemporalTransformer=272, FeedForward=249
input_blocks.5 1642.5 3.7% ResBlock=425, SpatialTransformer=397, CrossAttention=303, TemporalTransformer=271, FeedForward=247
input_blocks.7 1469.0 3.3% ResBlock=416, SpatialTransformer=324, FeedForward=251, CrossAttention=240, TemporalTransformer=237
input_blocks.8 1543.7 3.5% ResBlock=491, SpatialTransformer=325, FeedForward=250, CrossAttention=240, TemporalTransformer=238
input_blocks.10 217.5 0.5% ResBlock=218
input_blocks.11 216.8 0.5% ResBlock=217
middle_block 848.9 1.9% ResBlock=434, SpatialTransformer=151, CrossAttention=134, TemporalTransformer=69, FeedForward=61
output_blocks.0 303.2 0.7% ResBlock=303
output_blocks.1 303.1 0.7% ResBlock=303
output_blocks.2 302.8 0.7% ResBlock=303
output_blocks.3 1734.8 3.9% ResBlock=687, SpatialTransformer=322, FeedForward=249, CrossAttention=239, TemporalTransformer=237
output_blocks.4 1739.8 3.9% ResBlock=688, SpatialTransformer=323, FeedForward=251, CrossAttention=239, TemporalTransformer=238
output_blocks.5 1622.3 3.7% ResBlock=570, SpatialTransformer=324, FeedForward=251, CrossAttention=239, TemporalTransformer=238
output_blocks.6 1881.0 4.3% ResBlock=664, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=250
output_blocks.7 1768.0 4.0% ResBlock=554, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
output_blocks.8 1688.7 3.8% ResBlock=474, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
output_blocks.9 3558.6 8.0% SpatialTransformer=1096, CrossAttention=992, ResBlock=727, TemporalTransformer=454, FeedForward=290
output_blocks.10 3492.8 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
output_blocks.11 3493.3 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
out 73.8 0.2% UNet.out=74
action_unet 3212.0 7.3% ConditionalUnet1D=3212
state_unet 3197.6 7.2% ConditionalUnet1D=3198
other 1606.2 3.6% TemporalTransformer=960, FeedForward=337, CrossAttention=309
------------------------------------------------------------------------------------------
TOTAL 44258.7
======================================================================
TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)
======================================================================
Component Total(ms) %
----------------------------------------------------------------------
CrossAttention 8105.8 65.1%
FeedForward 4338.1 34.9%
----------------------------------------------------------------------
TOTAL (attn+ff) 12443.9
==================================================
TABLE 3: MEMORY SUMMARY
==================================================
Initial allocated: 11.82 GB
Peak allocated: 14.43 GB
Delta (pipeline): 2.61 GB
============================================================
TABLE 4: THROUGHPUT
============================================================
Total pipeline latency: 34312.4 ms
DDIM loop latency: 33392.5 ms
DDIM steps: 50
CFG scale: 1.0 (1x UNet/step)
UNet forward calls: 50
Per DDIM step: 667.9 ms
Per UNet forward: 667.9 ms
VAE encode bandwidth: 0.1 GB/s (peak HBM: 864.0 GB/s)
VAE decode bandwidth: 0.0 GB/s (peak HBM: 864.0 GB/s)
GPU BF16 peak: 61.0 TFLOPS
Done.

150
run.log Normal file
View File

@@ -0,0 +1,150 @@
nohup: ignoring input
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:37<17:51, 97.37s/it]
17%|█▋ | 2/12 [03:14<16:13, 97.31s/it]
25%|██▌ | 3/12 [04:51<14:35, 97.26s/it]
33%|███▎ | 4/12 [06:29<12:58, 97.25s/it]
42%|████▏ | 5/12 [08:06<11:20, 97.24s/it]
50%|█████ | 6/12 [09:43<09:43, 97.24s/it]
58%|█████▊ | 7/12 [11:20<08:06, 97.27s/it]
67%|██████▋ | 8/12 [12:58<06:29, 97.36s/it]
75%|███████▌ | 9/12 [14:36<04:52, 97.49s/it]
83%|████████▎ | 10/12 [16:13<03:15, 97.52s/it]
92%|█████████▏| 11/12 [17:51<01:37, 97.47s/it]
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -0,0 +1,975 @@
"""
Profile the full iteration loop of world model interaction.
Three layers of profiling:
Layer 1: Iteration-level wall-clock breakdown (CUDA events)
Layer 2: GPU timeline trace (torch.profiler → Chrome trace)
Layer 3: A/B comparison (standardized CSV output)
Usage:
# Layer 1 only (fast, default):
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
python scripts/evaluation/profile_iteration.py \
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
--config configs/inference/world_model_interaction.yaml \
--prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \
--dataset unitree_z1_dual_arm_cleanup_pencils \
--frame_stride 4 --n_iter 5
# Layer 1 + Layer 2 (GPU trace):
... --trace --trace_dir ./profile_traces
# Layer 3 (A/B comparison): run twice, diff the CSVs
... --csv baseline.csv
... --csv optimized.csv
python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv
"""
import argparse
import csv
import os
import sys
import time
from collections import defaultdict, deque
from contextlib import nullcontext
import h5py
import numpy as np
import pandas as pd
import torch
import torchvision
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from torch import Tensor
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
# ──────────────────────────────────────────────────────────────────────
# Constants
# ──────────────────────────────────────────────────────────────────────
STAGE_NAMES = [
"stack_to_device_1",
"synth_policy",
"update_action_queue",
"stack_to_device_2",
"synth_world_model",
"update_obs_queue",
"tensorboard_log",
"save_results",
"cpu_transfer",
"itr_total",
]
# Sub-stages inside image_guided_synthesis_sim_mode
SYNTH_SUB_STAGES = [
"ddim_sampler_init",
"image_embedding",
"vae_encode",
"text_conditioning",
"projectors",
"cond_assembly",
"ddim_sampling",
"vae_decode",
]
# ──────────────────────────────────────────────────────────────────────
# CudaTimer — GPU-precise timing via CUDA events
# ──────────────────────────────────────────────────────────────────────
class CudaTimer:
"""Context manager that records GPU time between enter/exit using CUDA events."""
def __init__(self, name, records):
self.name = name
self.records = records
def __enter__(self):
torch.cuda.synchronize()
self._start = torch.cuda.Event(enable_timing=True)
self._end = torch.cuda.Event(enable_timing=True)
self._start.record()
return self
def __exit__(self, *args):
self._end.record()
torch.cuda.synchronize()
elapsed_ms = self._start.elapsed_time(self._end)
self.records[self.name].append(elapsed_ms)
class WallTimer:
"""Context manager that records CPU wall-clock time (for pure-CPU stages)."""
def __init__(self, name, records):
self.name = name
self.records = records
def __enter__(self):
torch.cuda.synchronize()
self._t0 = time.perf_counter()
return self
def __exit__(self, *args):
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - self._t0) * 1000.0
self.records[self.name].append(elapsed_ms)
# ──────────────────────────────────────────────────────────────────────
# Model loading (reused from world_model_interaction.py)
# ──────────────────────────────────────────────────────────────────────
def patch_norm_bypass_autocast():
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
from collections import OrderedDict
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=True)
except Exception:
new_sd = OrderedDict()
for k, v in state_dict.items():
new_sd[k] = v
for k in list(new_sd.keys()):
if "framestride_embed" in k:
new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k)
model.load_state_dict(new_sd, strict=True)
model.eval()
# Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
model.embedder.to(torch.bfloat16)
model.image_proj_model.to(torch.bfloat16)
model.encoder_autocast_dtype = None
model.state_projector.to(torch.bfloat16)
model.action_projector.to(torch.bfloat16)
model.projector_autocast_dtype = None
if args.vae_dtype == "bf16":
model.first_stage_model.to(torch.bfloat16)
# Compile hot ResBlocks
apply_torch_compile(model)
model = model.cuda()
print(">>> Model loaded and ready.")
return model, config
# ──────────────────────────────────────────────────────────────────────
# Data preparation (reused from world_model_interaction.py)
# ──────────────────────────────────────────────────────────────────────
def get_init_frame_path(data_dir, sample):
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png')
return os.path.join(data_dir, 'images', rel)
def get_transition_path(data_dir, sample):
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5')
return os.path.join(data_dir, 'transitions', rel)
def prepare_init_input(start_idx, init_frame_path, transition_dict,
frame_stride, wma_data, video_length=16, n_obs_steps=2):
indices = [start_idx + frame_stride * i for i in range(video_length)]
init_frame = Image.open(init_frame_path).convert('RGB')
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float()
if start_idx < n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = n_obs_steps - 1 - start_idx
padding = states[0:1, :].repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
actions = transition_dict['action'][indices, :]
ori_state_dim = states.shape[-1]
ori_action_dim = actions.shape[-1]
frames_action_state_dict = {
'action': actions,
'observation.state': states,
}
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
frames_action_state_dict = wma_data.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
if wma_data.spatial_transform is not None:
init_frame = wma_data.spatial_transform(init_frame)
init_frame = (init_frame / 255 - 0.5) * 2
data = {'observation.image': init_frame}
data.update(frames_action_state_dict)
return data, ori_state_dim, ori_action_dim
def populate_queues(queues, batch):
for key in batch:
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen:
while len(queues[key]) != queues[key].maxlen:
queues[key].append(batch[key])
else:
queues[key].append(batch[key])
return queues
# ──────────────────────────────────────────────────────────────────────
# Instrumented image_guided_synthesis_sim_mode with sub-stage timing
# ──────────────────────────────────────────────────────────────────────
def get_latent_z(model, videos):
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_dtype = next(model.first_stage_model.parameters()).dtype
x = x.to(dtype=vae_dtype)
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def save_results(video, filename, fps=8):
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename, grid, fps=fps,
video_codec='h264', options={'crf': '10'})
def profiled_synthesis(model, prompts, observation, noise_shape,
ddim_steps, ddim_eta, unconditional_guidance_scale,
fs, text_input, timestep_spacing, guidance_rescale,
sim_mode, decode_video, records, prefix):
"""image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing.
Args:
prefix: "policy" or "wm" — prepended to sub-stage names in records.
"""
b, _, t, _, _ = noise_shape
batch_size = noise_shape[0]
device = next(model.parameters()).device
# --- sub-stage: ddim_sampler_init ---
with CudaTimer(f"{prefix}/ddim_sampler_init", records):
ddim_sampler = DDIMSampler(model)
fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device)
# --- sub-stage: image_embedding ---
with CudaTimer(f"{prefix}/image_embedding", records):
model_dtype = next(model.embedder.parameters()).dtype
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
# --- sub-stage: vae_encode ---
with CudaTimer(f"{prefix}/vae_encode", records):
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
# --- sub-stage: text_conditioning ---
with CudaTimer(f"{prefix}/text_conditioning", records):
if not text_input:
prompts_use = [""] * batch_size
else:
prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts_use)
# --- sub-stage: projectors ---
with CudaTimer(f"{prefix}/projectors", records):
projector_dtype = next(model.state_projector.parameters()).dtype
cond_state_emb = model.state_projector(
observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(
observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
# --- sub-stage: cond_assembly ---
with CudaTimer(f"{prefix}/cond_assembly", records):
cond["c_crossattn"] = [
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :, -model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:],
sim_mode,
False,
]
# --- sub-stage: ddim_sampling ---
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
if autocast_dtype is not None and device.type == 'cuda':
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
else:
autocast_ctx = nullcontext()
with CudaTimer(f"{prefix}/ddim_sampling", records):
with autocast_ctx:
samples, actions, states, _ = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=None,
eta=ddim_eta,
cfg_img=None,
mask=None,
x0=None,
fs=fs_t,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
unconditional_conditioning_img_nonetext=None,
)
# --- sub-stage: vae_decode ---
batch_variants = None
if decode_video:
with CudaTimer(f"{prefix}/vae_decode", records):
batch_variants = model.decode_first_stage(samples)
else:
records[f"{prefix}/vae_decode"].append(0.0)
return batch_variants, actions, states
# ──────────────────────────────────────────────────────────────────────
# Instrumented iteration loop
# ──────────────────────────────────────────────────────────────────────
def run_profiled_iterations(model, args, config, noise_shape, device):
"""Run the full iteration loop with per-stage timing.
Returns:
all_records: list of dicts, one per itr, {stage_name: ms}
"""
# Load data
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
sample = df.iloc[0]
data_module = instantiate_from_config(config.data)
data_module.setup()
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
fs = args.frame_stride
model_input_fs = ori_fps // fs
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# Prepare initial observation
batch, ori_state_dim, ori_action_dim = prepare_init_input(
0, init_frame_path, transition_dict, fs,
data_module.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
cond_obs_queues = {
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Temp dir for save_results profiling
tmp_dir = os.path.join(args.savedir, "profile_tmp")
os.makedirs(tmp_dir, exist_ok=True)
prompt_text = sample['instruction']
all_records = []
print(f">>> Running {args.n_iter} profiled iterations ...")
for itr in range(args.n_iter):
rec = defaultdict(list)
# ── itr_total start ──
torch.cuda.synchronize()
itr_start = torch.cuda.Event(enable_timing=True)
itr_end = torch.cuda.Event(enable_timing=True)
itr_start.record()
# ① stack_to_device_1
with CudaTimer("stack_to_device_1", rec):
observation = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
# ② synth_policy
with CudaTimer("synth_policy", rec):
pred_videos_0, pred_actions, _ = profiled_synthesis(
model, prompt_text, observation, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=True,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
decode_video=not args.fast_policy_no_decode,
records=rec, prefix="policy")
# ③ update_action_queue
with WallTimer("update_action_queue", rec):
for idx in range(len(pred_actions[0])):
obs_a = {'action': pred_actions[0][idx:idx + 1]}
obs_a['action'][:, ori_action_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues, obs_a)
# ④ stack_to_device_2
with CudaTimer("stack_to_device_2", rec):
observation = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
# ⑤ synth_world_model
with CudaTimer("synth_world_model", rec):
pred_videos_1, _, pred_states = profiled_synthesis(
model, "", observation, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=True, decode_video=True,
records=rec, prefix="wm")
# ⑥ update_obs_queue
with WallTimer("update_obs_queue", rec):
for idx in range(args.exe_steps):
obs_u = {
'observation.images.top':
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state':
pred_states[0][idx:idx + 1],
'action':
torch.zeros_like(pred_actions[0][-1:]),
}
obs_u['observation.state'][:, ori_state_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues, obs_u)
# ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost)
with WallTimer("tensorboard_log", rec):
for vid in [pred_videos_0, pred_videos_1]:
if vid is not None and vid.dim() == 5:
v = vid.permute(2, 0, 1, 3, 4)
grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v]
_ = torch.stack(grids, dim=0)
# ⑧ save_results
with WallTimer("save_results", rec):
if pred_videos_0 is not None:
save_results(pred_videos_0.cpu(),
os.path.join(tmp_dir, f"dm_{itr}.mp4"),
fps=args.save_fps)
save_results(pred_videos_1.cpu(),
os.path.join(tmp_dir, f"wm_{itr}.mp4"),
fps=args.save_fps)
# ⑨ cpu_transfer
with CudaTimer("cpu_transfer", rec):
_ = pred_videos_1[:, :, :args.exe_steps].cpu()
# ── itr_total end ──
itr_end.record()
torch.cuda.synchronize()
itr_total_ms = itr_start.elapsed_time(itr_end)
rec["itr_total"].append(itr_total_ms)
# Flatten: each stage has exactly one entry per itr
itr_rec = {k: v[0] for k, v in rec.items()}
all_records.append(itr_rec)
# Print live progress
print(f" itr {itr}: {itr_total_ms:.0f} ms total | "
f"policy={itr_rec.get('synth_policy', 0):.0f} | "
f"wm={itr_rec.get('synth_world_model', 0):.0f} | "
f"save={itr_rec.get('save_results', 0):.0f} | "
f"tb={itr_rec.get('tensorboard_log', 0):.0f}")
return all_records
# ──────────────────────────────────────────────────────────────────────
# Layer 1: Console report
# ──────────────────────────────────────────────────────────────────────
def print_iteration_report(all_records, warmup=1):
"""Print a structured table of per-stage timing across iterations."""
if len(all_records) <= warmup:
records = all_records
else:
records = all_records[warmup:]
print(f"\n(Skipping first {warmup} itr(s) as warmup)\n")
# Collect all stage keys in a stable order
all_keys = []
seen = set()
for rec in records:
for k in rec:
if k not in seen:
all_keys.append(k)
seen.add(k)
# Separate top-level stages from sub-stages
top_keys = [k for k in all_keys if '/' not in k]
sub_keys = [k for k in all_keys if '/' in k]
def _print_table(keys, title):
if not keys:
return
print("=" * 82)
print(title)
print("=" * 82)
print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}")
print("-" * 82)
total_mean = np.mean([rec.get("itr_total", 0) for rec in records])
for k in keys:
vals = [rec.get(k, 0) for rec in records]
mean = np.mean(vals)
std = np.std(vals)
mn = np.min(vals)
mx = np.max(vals)
pct = mean / total_mean * 100 if total_mean > 0 else 0
print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%")
print("-" * 82)
print()
_print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN")
_print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN")
# ──────────────────────────────────────────────────────────────────────
# Layer 3: CSV output for A/B comparison
# ──────────────────────────────────────────────────────────────────────
def write_csv(all_records, csv_path, warmup=1):
"""Write per-iteration timing to CSV for later comparison."""
records = all_records[warmup:] if len(all_records) > warmup else all_records
# Collect all keys
all_keys = []
seen = set()
for rec in records:
for k in rec:
if k not in seen:
all_keys.append(k)
seen.add(k)
with open(csv_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys)
writer.writeheader()
for i, rec in enumerate(records):
row = {'itr': i}
row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys})
writer.writerow(row)
# Also write a summary row
summary_path = csv_path.replace('.csv', '_summary.csv')
with open(summary_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys)
writer.writeheader()
for stat_name, stat_fn in [('mean', np.mean), ('std', np.std),
('min', np.min), ('max', np.max)]:
row = {'stat': stat_name}
row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}"
for k in all_keys})
writer.writerow(row)
print(f">>> CSV written to: {csv_path}")
print(f">>> Summary written to: {summary_path}")
def compare_csvs(path_a, path_b):
"""Compare two summary CSVs and print a diff table."""
df_a = pd.read_csv(path_a, index_col='stat')
df_b = pd.read_csv(path_b, index_col='stat')
# Use mean row for comparison
mean_a = df_a.loc['mean'].astype(float)
mean_b = df_b.loc['mean'].astype(float)
print("=" * 90)
print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}")
print("=" * 90)
print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}")
print("-" * 90)
for col in mean_a.index:
if col not in mean_b.index:
continue
a_val = mean_a[col]
b_val = mean_b[col]
diff = b_val - a_val
speedup = a_val / b_val if b_val > 0 else float('inf')
marker = " <<<" if abs(diff) > 50 else ""
print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}")
print("-" * 90)
total_a = mean_a.get('itr_total', 0)
total_b = mean_b.get('itr_total', 0)
print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} "
f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x")
print()
# ──────────────────────────────────────────────────────────────────────
# Layer 2: GPU timeline trace wrapper
# ──────────────────────────────────────────────────────────────────────
def run_with_trace(model, args, config, noise_shape, device):
"""Run iterations under torch.profiler to generate Chrome/TensorBoard traces."""
trace_dir = args.trace_dir
os.makedirs(trace_dir, exist_ok=True)
# We need the same data setup as run_profiled_iterations
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
sample = df.iloc[0]
data_module = instantiate_from_config(config.data)
data_module.setup()
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
fs = args.frame_stride
model_input_fs = ori_fps // fs
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
batch, ori_state_dim, ori_action_dim = prepare_init_input(
0, init_frame_path, transition_dict, fs,
data_module.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
cond_obs_queues = {
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
cond_obs_queues = populate_queues(cond_obs_queues, observation)
tmp_dir = os.path.join(args.savedir, "profile_tmp")
os.makedirs(tmp_dir, exist_ok=True)
prompt_text = sample['instruction']
# Total iterations: warmup + active
n_warmup = 1
n_active = min(args.n_iter, 2) # trace 2 active iterations max
n_total = n_warmup + n_active
print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations")
print(f">>> Trace output: {trace_dir}")
with torch.no_grad(), torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=0, warmup=n_warmup, active=n_active, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
record_shapes=True,
with_stack=True,
) as prof:
for itr_idx in range(n_total):
phase = "warmup" if itr_idx < n_warmup else "active"
print(f" trace itr {itr_idx} ({phase})...")
# ── One full iteration (same logic as run_inference) ──
obs_loc = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
obs_loc = {k: v.to(device) for k, v in obs_loc.items()}
# Policy pass
dummy_rec = defaultdict(list)
pv0, pa, _ = profiled_synthesis(
model, prompt_text, obs_loc, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=True,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
decode_video=not args.fast_policy_no_decode,
records=dummy_rec, prefix="policy")
for idx in range(len(pa[0])):
oa = {'action': pa[0][idx:idx + 1]}
oa['action'][:, ori_action_dim:] = 0.0
populate_queues(cond_obs_queues, oa)
# Re-stack for world model
obs_loc2 = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()}
# World model pass
pv1, _, ps = profiled_synthesis(
model, "", obs_loc2, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=True, decode_video=True,
records=dummy_rec, prefix="wm")
# Update obs queue
for idx in range(args.exe_steps):
ou = {
'observation.images.top':
pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state': ps[0][idx:idx + 1],
'action': torch.zeros_like(pa[0][-1:]),
}
ou['observation.state'][:, ori_state_dim:] = 0.0
populate_queues(cond_obs_queues, ou)
# Save results (captures CPU stall in trace)
if pv0 is not None:
save_results(pv0.cpu(),
os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"),
fps=args.save_fps)
save_results(pv1.cpu(),
os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"),
fps=args.save_fps)
prof.step()
print(f">>> Trace saved to {trace_dir}")
print(" View with: tensorboard --logdir", trace_dir)
print(" Or open the .json file in chrome://tracing")
# ──────────────────────────────────────────────────────────────────────
# Argument parser
# ──────────────────────────────────────────────────────────────────────
def get_parser():
p = argparse.ArgumentParser(description="Profile full iteration loop")
# Compare mode (no model needed)
p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"),
help="Compare two summary CSVs and exit")
# Model / data
p.add_argument("--ckpt_path", type=str, default=None)
p.add_argument("--config", type=str, default=None)
p.add_argument("--prompt_dir", type=str, default=None)
p.add_argument("--dataset", type=str, default=None)
p.add_argument("--savedir", type=str, default="profile_output")
# Inference params (match world_model_interaction.py)
p.add_argument("--ddim_steps", type=int, default=50)
p.add_argument("--ddim_eta", type=float, default=1.0)
p.add_argument("--bs", type=int, default=1)
p.add_argument("--height", type=int, default=320)
p.add_argument("--width", type=int, default=512)
p.add_argument("--frame_stride", type=int, default=4)
p.add_argument("--unconditional_guidance_scale", type=float, default=1.0)
p.add_argument("--video_length", type=int, default=16)
p.add_argument("--timestep_spacing", type=str, default="uniform_trailing")
p.add_argument("--guidance_rescale", type=float, default=0.7)
p.add_argument("--exe_steps", type=int, default=16)
p.add_argument("--n_iter", type=int, default=5)
p.add_argument("--save_fps", type=int, default=8)
p.add_argument("--seed", type=int, default=123)
p.add_argument("--perframe_ae", action='store_true', default=False)
p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16")
p.add_argument("--fast_policy_no_decode", action='store_true', default=False)
# Profiling control
p.add_argument("--warmup", type=int, default=1,
help="Number of warmup iterations to skip in statistics")
p.add_argument("--csv", type=str, default=None,
help="Write per-iteration timing to this CSV file")
p.add_argument("--trace", action='store_true', default=False,
help="Enable Layer 2: GPU timeline trace")
p.add_argument("--trace_dir", type=str, default="./profile_traces",
help="Directory for trace output")
return p
# ──────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────
def main():
patch_norm_bypass_autocast()
parser = get_parser()
args = parser.parse_args()
# ── Compare mode: no model needed ──
if args.compare:
compare_csvs(args.compare[0], args.compare[1])
return
# ── Validate required args ──
for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']:
if getattr(args, required) is None:
parser.error(f"--{required} is required for profiling mode")
seed_everything(args.seed)
os.makedirs(args.savedir, exist_ok=True)
# ── Load model ──
print("=" * 60)
print("PROFILE ITERATION — Loading model...")
print("=" * 60)
model, config = load_model(args)
device = next(model.parameters()).device
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
noise_shape = [args.bs, channels, args.video_length, h, w]
print(f">>> Noise shape: {noise_shape}")
print(f">>> DDIM steps: {args.ddim_steps}")
print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}")
# ── Layer 2: GPU trace (optional) ──
if args.trace:
with torch.no_grad():
run_with_trace(model, args, config, noise_shape, device)
print()
# ── Layer 1: Iteration-level breakdown ──
print("=" * 60)
print("LAYER 1: ITERATION-LEVEL PROFILING")
print("=" * 60)
with torch.no_grad():
all_records = run_profiled_iterations(
model, args, config, noise_shape, device)
# Print report
print_iteration_report(all_records, warmup=args.warmup)
# ── Layer 3: CSV output for A/B comparison ──
if args.csv:
write_csv(all_records, args.csv, warmup=args.warmup)
print("Done.")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,733 @@
"""
Profile the full inference pipeline of the world model, covering all 7 stages:
1. Image Embedding
2. VAE Encode
3. Text Conditioning
4. State/Action Projectors
5. DDIM Loop
6. VAE Decode
7. Post-process
Reports stage-level timing, UNet sub-module breakdown, memory summary,
and throughput analysis.
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep
Usage:
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3
"""
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common
from contextlib import nullcontext, contextmanager
from collections import defaultdict
from omegaconf import OmegaConf
from einops import rearrange, repeat
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.modules.attention import (
SpatialTransformer, TemporalTransformer,
BasicTransformerBlock, CrossAttention, FeedForward,
)
from unifolm_wma.modules.networks.wma_model import ResBlock
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
# --- W7900D theoretical peak ---
PEAK_BF16_TFLOPS = 61.0
MEM_BW_GBS = 864.0
# ---------------------------------------------------------------------------
# Utility: patch norms to bypass autocast fp32 promotion
# ---------------------------------------------------------------------------
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
# ---------------------------------------------------------------------------
# Utility: torch.compile hot ResBlocks
# ---------------------------------------------------------------------------
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=True)
model.eval()
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
apply_torch_compile(model)
model = model.cuda()
return model
# ---------------------------------------------------------------------------
# CudaTimer — precise GPU timing via CUDA events
# ---------------------------------------------------------------------------
class CudaTimer:
"""Context manager for GPU-precise stage timing using CUDA events."""
def __init__(self, name, records):
self.name = name
self.records = records
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
def __enter__(self):
torch.cuda.synchronize()
self.start.record()
return self
def __exit__(self, *args):
self.end.record()
torch.cuda.synchronize()
elapsed = self.start.elapsed_time(self.end)
self.records[self.name].append(elapsed)
# ---------------------------------------------------------------------------
# HookProfiler — sub-module level timing inside UNet via hooks
# ---------------------------------------------------------------------------
class HookProfiler:
"""Register forward hooks on UNet sub-modules to collect per-call timing."""
# Coarse-grained targets (original)
COARSE_CLASSES = (
SpatialTransformer,
TemporalTransformer,
ResBlock,
ConditionalUnet1D,
)
# Fine-grained targets for deep DDIM analysis
FINE_CLASSES = (
CrossAttention,
FeedForward,
)
def __init__(self, unet, deep=False):
self.unet = unet
self.deep = deep
self.handles = []
# per-instance data: {instance_id: [(start_event, end_event), ...]}
self._events = defaultdict(list)
# tag mapping: {instance_id: (class_name, module_name)}
self._tags = {}
# block location: {instance_id: block_location_str}
self._block_loc = {}
@staticmethod
def _get_block_location(name):
"""Derive UNet block location from module name, e.g. 'input_blocks.3.1'."""
parts = name.split('.')
if len(parts) >= 2 and parts[0] == 'input_blocks':
return f"input_blocks.{parts[1]}"
elif len(parts) >= 1 and parts[0] == 'middle_block':
return "middle_block"
elif len(parts) >= 2 and parts[0] == 'output_blocks':
return f"output_blocks.{parts[1]}"
elif 'action_unet' in name:
return "action_unet"
elif 'state_unet' in name:
return "state_unet"
elif name == 'out' or name.startswith('out.'):
return "out"
return "other"
def register(self):
"""Attach pre/post forward hooks to target sub-modules + unet.out."""
target_classes = self.COARSE_CLASSES
if self.deep:
target_classes = target_classes + self.FINE_CLASSES
for name, mod in self.unet.named_modules():
if isinstance(mod, target_classes):
tag = type(mod).__name__
inst_id = id(mod)
self._tags[inst_id] = (tag, name)
self._block_loc[inst_id] = self._get_block_location(name)
self.handles.append(
mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
self.handles.append(
mod.register_forward_hook(self._make_post_hook(inst_id)))
# Also hook unet.out (nn.Sequential)
out_mod = self.unet.out
inst_id = id(out_mod)
self._tags[inst_id] = ("UNet.out", "out")
self._block_loc[inst_id] = "out"
self.handles.append(
out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
self.handles.append(
out_mod.register_forward_hook(self._make_post_hook(inst_id)))
def _make_pre_hook(self, inst_id):
events = self._events
def hook(module, input):
start = torch.cuda.Event(enable_timing=True)
start.record()
events[inst_id].append([start, None])
return hook
def _make_post_hook(self, inst_id):
events = self._events
def hook(module, input, output):
end = torch.cuda.Event(enable_timing=True)
end.record()
events[inst_id][-1][1] = end
return hook
def reset(self):
"""Clear collected events for a fresh run."""
self._events.clear()
def synchronize_and_collect(self):
"""Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block)."""
torch.cuda.synchronize()
by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []})
by_instance = {}
# by_block: {block_loc: {tag: {"total_ms", "count"}}}
by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0}))
for inst_id, pairs in self._events.items():
tag, mod_name = self._tags[inst_id]
block_loc = self._block_loc.get(inst_id, "other")
inst_times = []
for start_evt, end_evt in pairs:
if end_evt is not None:
ms = start_evt.elapsed_time(end_evt)
inst_times.append(ms)
by_type[tag]["total_ms"] += ms
by_type[tag]["count"] += 1
by_type[tag]["calls"].append(ms)
by_block[block_loc][tag]["total_ms"] += ms
by_block[block_loc][tag]["count"] += 1
by_instance[(tag, mod_name)] = inst_times
return dict(by_type), by_instance, dict(by_block)
def remove(self):
"""Remove all hooks."""
for h in self.handles:
h.remove()
self.handles.clear()
# ---------------------------------------------------------------------------
# Build dummy inputs matching the pipeline's expected shapes
# ---------------------------------------------------------------------------
def build_dummy_inputs(model, noise_shape):
"""Create synthetic observation dict and prompts for profiling."""
device = next(model.parameters()).device
B, C, T, H, W = noise_shape
dtype = torch.bfloat16
# observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline)
O = 2
obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype)
obs_state = torch.randn(B, O, 16, device=device, dtype=dtype)
action = torch.randn(B, 16, 16, device=device, dtype=dtype)
observation = {
'observation.images.top': obs_images,
'observation.state': obs_state,
'action': action,
}
prompts = ["a robot arm performing a task"] * B
return observation, prompts
# ---------------------------------------------------------------------------
# Run one full pipeline pass with per-stage timing
# ---------------------------------------------------------------------------
def run_pipeline(model, observation, prompts, noise_shape, ddim_steps,
cfg_scale, hook_profiler):
"""Execute the full 7-stage pipeline, returning per-stage timing dict."""
records = defaultdict(list)
device = next(model.parameters()).device
B, C, T, H, W = noise_shape
dtype = torch.bfloat16
fs = torch.tensor([1] * B, dtype=torch.long, device=device)
# --- Stage 1: Image Embedding ---
with CudaTimer("1_Image_Embedding", records):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype)
with torch.autocast('cuda', dtype=torch.bfloat16):
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
# --- Stage 2: VAE Encode ---
with CudaTimer("2_VAE_Encode", records):
videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W]
b_v, c_v, t_v, h_v, w_v = videos.shape
vae_dtype = next(model.first_stage_model.parameters()).dtype
x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype)
z = model.encode_first_stage(x_vae)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v)
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w', repeat=T)
cond = {"c_concat": [img_cat_cond]}
vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size()
vae_enc_output_bytes = z.nelement() * z.element_size()
# --- Stage 3: Text Conditioning ---
with CudaTimer("3_Text_Conditioning", records):
cond_ins_emb = model.get_learned_conditioning(prompts)
# --- Stage 4: State/Action Projectors ---
with CudaTimer("4_Projectors", records):
projector_dtype = next(model.state_projector.parameters()).dtype
with torch.autocast('cuda', dtype=torch.bfloat16):
cond_state_emb = model.state_projector(
observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(
observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
# Assemble cross-attention conditioning
cond["c_crossattn"] = [
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
n_obs_acting = getattr(model, 'n_obs_steps_acting', 2)
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :, -n_obs_acting:],
observation['observation.state'][:, -n_obs_acting:],
True, # sim_mode
False,
]
# CFG: build unconditional conditioning if needed
uc = None
if cfg_scale != 1.0:
uc_crossattn = torch.zeros_like(cond["c_crossattn"][0])
uc = {
"c_concat": cond["c_concat"],
"c_crossattn": [uc_crossattn],
"c_crossattn_action": cond["c_crossattn_action"],
}
# --- Stage 5: DDIM Loop ---
ddim_sampler = DDIMSampler(model)
hook_profiler.reset()
with CudaTimer("5_DDIM_Loop", records):
with torch.autocast('cuda', dtype=torch.bfloat16):
samples, actions, states, _ = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=B,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=1.0,
cfg_img=None,
mask=None,
x0=None,
fs=fs,
timestep_spacing='uniform',
guidance_rescale=0.0,
unconditional_conditioning_img_nonetext=None,
)
hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect()
# --- Stage 6: VAE Decode ---
with CudaTimer("6_VAE_Decode", records):
batch_images = model.decode_first_stage(samples)
vae_dec_input_bytes = samples.nelement() * samples.element_size()
vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size()
# --- Stage 7: Post-process ---
with CudaTimer("7_Post_Process", records):
batch_images_cpu = batch_images.cpu()
actions_cpu = actions.cpu()
states_cpu = states.cpu()
# Simulate video save overhead: clamp + uint8 conversion
_ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8)
# Flatten single-element lists
stage_times = {k: v[0] for k, v in records.items()}
bandwidth_info = {
"vae_enc_input_bytes": vae_enc_input_bytes,
"vae_enc_output_bytes": vae_enc_output_bytes,
"vae_dec_input_bytes": vae_dec_input_bytes,
"vae_dec_output_bytes": vae_dec_output_bytes,
}
return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info
# ---------------------------------------------------------------------------
# Reporting
# ---------------------------------------------------------------------------
def print_stage_timing(all_runs_stages):
"""Table 1: Stage Timing — name | mean(ms) | std | percent."""
import numpy as np
stage_names = list(all_runs_stages[0].keys())
means = {}
stds = {}
for name in stage_names:
vals = [run[name] for run in all_runs_stages]
means[name] = np.mean(vals)
stds[name] = np.std(vals)
total = sum(means.values())
print()
print("=" * 72)
print("TABLE 1: STAGE TIMING")
print("=" * 72)
print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}")
print("-" * 72)
for name in stage_names:
pct = means[name] / total * 100 if total > 0 else 0
print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%")
print("-" * 72)
print(f"{'TOTAL':<25} {total:>10.1f}")
print()
def print_unet_breakdown(all_runs_hooks):
"""Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent."""
import numpy as np
# Aggregate across runs
agg = defaultdict(lambda: {"totals": [], "counts": []})
for hook_by_type in all_runs_hooks:
for tag, data in hook_by_type.items():
agg[tag]["totals"].append(data["total_ms"])
agg[tag]["counts"].append(data["count"])
print("=" * 80)
print("TABLE 2: UNET SUB-MODULE BREAKDOWN")
print("=" * 80)
print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}")
print("-" * 80)
grand_total = 0
rows = []
for tag, d in agg.items():
mean_total = np.mean(d["totals"])
mean_count = np.mean(d["counts"])
per_call = mean_total / mean_count if mean_count > 0 else 0
grand_total += mean_total
rows.append((tag, mean_total, mean_count, per_call))
rows.sort(key=lambda r: r[1], reverse=True)
for tag, mean_total, mean_count, per_call in rows:
pct = mean_total / grand_total * 100 if grand_total > 0 else 0
print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%")
print("-" * 80)
print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}")
print()
def print_block_timing(all_runs_blocks):
"""Table 2b: Per-UNet-block timing — which blocks are hottest."""
import numpy as np
# Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}}
agg = defaultdict(lambda: defaultdict(list))
for by_block in all_runs_blocks:
for block_loc, tag_dict in by_block.items():
for tag, data in tag_dict.items():
agg[block_loc][tag].append(data["total_ms"])
# Compute per-block totals
block_totals = {}
for block_loc, tag_dict in agg.items():
block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values())
grand_total = sum(block_totals.values())
# Sort blocks in logical order
def block_sort_key(name):
if name.startswith("input_blocks."):
return (0, int(name.split('.')[1]))
elif name == "middle_block":
return (1, 0)
elif name.startswith("output_blocks."):
return (2, int(name.split('.')[1]))
elif name == "out":
return (3, 0)
elif name == "action_unet":
return (4, 0)
elif name == "state_unet":
return (5, 0)
return (9, 0)
sorted_blocks = sorted(block_totals.keys(), key=block_sort_key)
print("=" * 90)
print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)")
print("=" * 90)
print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown")
print("-" * 90)
for block_loc in sorted_blocks:
total = block_totals[block_loc]
pct = total / grand_total * 100 if grand_total > 0 else 0
# Build breakdown string
parts = []
for tag, vals in sorted(agg[block_loc].items(),
key=lambda x: np.mean(x[1]), reverse=True):
parts.append(f"{tag}={np.mean(vals):.0f}")
breakdown = ", ".join(parts)
print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}")
print("-" * 90)
print(f"{'TOTAL':<22} {grand_total:>10.1f}")
print()
def print_attn_ff_breakdown(all_runs_hooks):
"""Table 2c: CrossAttention vs FeedForward breakdown (--deep mode)."""
import numpy as np
agg = defaultdict(list)
for hook_by_type in all_runs_hooks:
for tag, data in hook_by_type.items():
if tag in ("CrossAttention", "FeedForward"):
agg[tag].append(data["total_ms"])
if not agg:
return
print("=" * 70)
print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)")
print("=" * 70)
print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}")
print("-" * 70)
grand = 0
rows = []
for tag in ("CrossAttention", "FeedForward"):
if tag in agg:
mean_t = np.mean(agg[tag])
grand += mean_t
rows.append((tag, mean_t))
for tag, mean_t in rows:
pct = mean_t / grand * 100 if grand > 0 else 0
print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%")
print("-" * 70)
print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}")
print()
def print_unet_detailed(all_runs_instances):
"""Print per-instance UNet sub-module detail (--detailed mode)."""
import numpy as np
# Use last run's data
by_instance = all_runs_instances[-1]
print("=" * 100)
print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)")
print("=" * 100)
print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}")
print("-" * 100)
rows = []
for (tag, mod_name), times in by_instance.items():
if len(times) == 0:
continue
total = sum(times)
mean = np.mean(times)
rows.append((tag, mod_name, len(times), total, mean))
rows.sort(key=lambda r: r[3], reverse=True)
for tag, mod_name, count, total, mean in rows:
short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name
print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}")
print()
def print_memory_summary(mem_before, mem_peak):
"""Table 3: Memory Summary."""
delta = mem_peak - mem_before
print("=" * 50)
print("TABLE 3: MEMORY SUMMARY")
print("=" * 50)
print(f" Initial allocated: {mem_before / 1e9:.2f} GB")
print(f" Peak allocated: {mem_peak / 1e9:.2f} GB")
print(f" Delta (pipeline): {delta / 1e9:.2f} GB")
print()
def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale):
"""Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth."""
import numpy as np
n_runs = len(all_runs_stages)
# Total latency
totals = []
for run in all_runs_stages:
totals.append(sum(run.values()))
mean_total = np.mean(totals)
# DDIM loop time
ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages]
mean_ddim = np.mean(ddim_times)
unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2
per_step = mean_ddim / ddim_steps
per_unet = mean_ddim / unet_calls
# VAE bandwidth
mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages])
mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages])
bw = all_bw[-1] # use last run's byte counts
enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"]
dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"]
enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0
dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0
print("=" * 60)
print("TABLE 4: THROUGHPUT")
print("=" * 60)
print(f" Total pipeline latency: {mean_total:.1f} ms")
print(f" DDIM loop latency: {mean_ddim:.1f} ms")
print(f" DDIM steps: {ddim_steps}")
print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})")
print(f" UNet forward calls: {unet_calls}")
print(f" Per DDIM step: {per_step:.1f} ms")
print(f" Per UNet forward: {per_unet:.1f} ms")
print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS")
print()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
patch_norm_bypass_autocast()
parser = argparse.ArgumentParser(
description="Profile the full inference pipeline")
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--ddim_steps", type=int, default=50)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--n_runs", type=int, default=3)
parser.add_argument("--warmup", type=int, default=1)
parser.add_argument("--detailed", action="store_true",
help="Print per-instance UNet sub-module detail")
parser.add_argument("--deep", action="store_true",
help="Enable deep DDIM analysis: per-block, attn vs ff")
args = parser.parse_args()
noise_shape = [1, 4, 16, 40, 64]
# --- Load model ---
print("Loading model...")
model = load_model(args)
observation, prompts = build_dummy_inputs(model, noise_shape)
# --- Setup hook profiler ---
unet = model.model.diffusion_model
hook_profiler = HookProfiler(unet, deep=args.deep)
hook_profiler.register()
print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules")
# --- Warmup ---
print(f"Warmup: {args.warmup} run(s)...")
with torch.no_grad():
for i in range(args.warmup):
run_pipeline(model, observation, prompts, noise_shape,
args.ddim_steps, args.cfg_scale, hook_profiler)
print(f" warmup {i+1}/{args.warmup} done")
# --- Measurement runs ---
print(f"Measuring: {args.n_runs} run(s)...")
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
all_stages = []
all_hooks = []
all_instances = []
all_blocks = []
all_bw = []
with torch.no_grad():
for i in range(args.n_runs):
stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline(
model, observation, prompts, noise_shape,
args.ddim_steps, args.cfg_scale, hook_profiler)
all_stages.append(stage_times)
all_hooks.append(hook_by_type)
all_instances.append(hook_by_instance)
all_blocks.append(hook_by_block)
all_bw.append(bw)
total = sum(stage_times.values())
print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total")
mem_peak = torch.cuda.max_memory_allocated()
# --- Reports ---
print_stage_timing(all_stages)
print_unet_breakdown(all_hooks)
print_block_timing(all_blocks)
if args.deep:
print_attn_ff_breakdown(all_hooks)
if args.detailed:
print_unet_detailed(all_instances)
print_memory_summary(mem_before, mem_peak)
print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale)
hook_profiler.remove()
print("Done.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,287 @@
"""
Profile one DDIM sampling iteration to capture all matmul/attention ops,
their matrix sizes, wall time, and compute utilization.
Uses torch.profiler for CUDA timing and FlopCounterMode for accurate
FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS).
Usage:
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
python scripts/evaluation/profile_unet.py \
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
--config configs/inference/world_model_interaction.yaml
"""
import argparse
import torch
import torch.nn as nn
from collections import OrderedDict, defaultdict
from omegaconf import OmegaConf
from torch.utils.flop_counter import FlopCounterMode
from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
# --- W7900D theoretical peak (TFLOPS) ---
PEAK_BF16_TFLOPS = 61.0
PEAK_FP32_TFLOPS = 30.5
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=True)
model.eval()
model.model.to(torch.bfloat16)
apply_torch_compile(model)
model = model.cuda()
return model
def build_call_kwargs(model, noise_shape):
"""Build dummy inputs matching the hybrid conditioning forward signature."""
device = next(model.parameters()).device
B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64]
dtype = torch.bfloat16
x_action = torch.randn(B, 16, 16, device=device, dtype=dtype)
x_state = torch.randn(B, 16, 16, device=device, dtype=dtype)
timesteps = torch.tensor([500], device=device, dtype=torch.long)
context = torch.randn(B, 351, 1024, device=device, dtype=dtype)
obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype)
obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype)
context_action = [obs_images, obs_state, True, False]
fps = torch.tensor([1], device=device, dtype=torch.long)
x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype)
c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)]
return dict(
x=x_raw, x_action=x_action, x_state=x_state, t=timesteps,
c_concat=c_concat, c_crossattn=[context],
c_crossattn_action=context_action, s=fps,
)
def profile_one_step(model, noise_shape):
"""Run one UNet forward pass under torch.profiler for CUDA timing."""
diff_wrapper = model.model
call_kwargs = build_call_kwargs(model, noise_shape)
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
# Warmup
for _ in range(2):
_ = diff_wrapper(**call_kwargs)
torch.cuda.synchronize()
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_flops=True,
) as prof:
_ = diff_wrapper(**call_kwargs)
torch.cuda.synchronize()
return prof
def count_flops(model, noise_shape):
"""Run one UNet forward pass under FlopCounterMode for accurate FLOPS."""
diff_wrapper = model.model
call_kwargs = build_call_kwargs(model, noise_shape)
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
_ = diff_wrapper(**call_kwargs)
torch.cuda.synchronize()
return flop_counter
def print_report(prof, flop_counter):
"""Parse profiler results and print a structured report with accurate FLOPS."""
events = prof.key_averages()
# --- Extract per-operator FLOPS from FlopCounterMode ---
# flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting
flop_by_op = {}
flop_by_module = {}
if hasattr(flop_counter, 'flop_counts'):
# Per-op: only from top-level "Global" entry (no parent/child duplication)
global_ops = flop_counter.flop_counts.get("Global", {})
for op_name, flop_count in global_ops.items():
key = str(op_name).split('.')[-1]
flop_by_op[key] = flop_by_op.get(key, 0) + flop_count
# Per-module: collect all, skip "Global" and top-level wrapper duplicates
for module_name, op_dict in flop_counter.flop_counts.items():
module_total = sum(op_dict.values())
if module_total > 0:
flop_by_module[module_name] = module_total
total_counted_flops = flop_counter.get_total_flops()
# Collect matmul-like ops
matmul_ops = []
other_ops = []
for evt in events:
if evt.device_time_total <= 0:
continue
name = evt.key
is_matmul = any(k in name.lower() for k in
['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear'])
entry = {
'name': name,
'input_shapes': str(evt.input_shapes) if evt.input_shapes else '',
'cuda_time_ms': evt.device_time_total / 1000.0,
'count': evt.count,
'flops': evt.flops if evt.flops else 0,
}
if is_matmul:
matmul_ops.append(entry)
else:
other_ops.append(entry)
# Sort by CUDA time
matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops)
total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops)
# --- Print matmul ops ---
print("=" * 130)
print("MATMUL / LINEAR OPS (sorted by CUDA time)")
print("=" * 130)
print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
print("-" * 130)
for op in matmul_ops:
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
# --- Print top non-matmul ops ---
print()
print("=" * 130)
print("TOP NON-MATMUL OPS (sorted by CUDA time)")
print("=" * 130)
print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
print("-" * 130)
for op in other_ops[:20]:
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
# --- FlopCounterMode per-operator breakdown ---
print()
print("=" * 130)
print("FLOPS BY ATen OPERATOR (FlopCounterMode)")
print("=" * 130)
print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}")
print("-" * 55)
sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True)
for op_name, flops in sorted_flop_ops:
gflops = flops / 1e9
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%")
# --- FlopCounterMode per-module breakdown ---
if flop_by_module:
print()
print("=" * 130)
print("FLOPS BY MODULE (FlopCounterMode)")
print("=" * 130)
print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}")
print("-" * 90)
sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True)
for mod_name, flops in sorted_modules[:30]:
gflops = flops / 1e9
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name
print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%")
# --- Summary ---
print()
print("=" * 130)
print("SUMMARY")
print("=" * 130)
print(f" Total CUDA time: {total_cuda_ms:.1f} ms")
print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)")
print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)")
print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS")
if total_matmul_ms > 0 and total_counted_flops > 0:
avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12
avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100
overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12
overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100
print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)")
print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)")
print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS")
if __name__ == '__main__':
patch_norm_bypass_autocast()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
print("Loading model...")
model = load_model(args)
noise_shape = [1, 4, 16, 40, 64]
print(f"Profiling UNet forward pass with shape {noise_shape}...")
prof = profile_one_step(model, noise_shape)
print("Counting FLOPS with FlopCounterMode...")
flop_counter = count_flops(model, noise_shape)
print_report(prof, flop_counter)

View File

@@ -1,4 +1,7 @@
import argparse, os, glob
from contextlib import nullcontext
import atexit
from concurrent.futures import ThreadPoolExecutor
import pandas as pd
import random
import torch
@@ -10,13 +13,15 @@ import einops
import warnings
import imageio
from typing import Optional, List, Any
from pytorch_lightning import seed_everything
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from torch import nn
from eval_utils import populate_queues, log_to_tensorboard
from eval_utils import populate_queues
from collections import deque
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
@@ -24,6 +29,105 @@ from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
# ========== Async I/O utilities ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
def _get_io_executor() -> ThreadPoolExecutor:
global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try:
fut.result()
except Exception as e:
print(f">>> [async I/O] error: {e}")
_io_futures.clear()
atexit.register(_flush_io)
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Synchronous save on CPU tensor (runs in background thread)."""
video = torch.clamp(video_cpu.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
def _log_to_tb_sync(video_cpu: Tensor, writer: SummaryWriter, tag: str, fps: int) -> None:
"""Synchronous tensorboard logging on CPU tensor (runs in background thread)."""
video = video_cpu.float()
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)
def log_to_tensorboard_async(writer: SummaryWriter, video: Tensor, tag: str, fps: int = 10) -> None:
"""Submit tensorboard logging to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_log_to_tb_sync, video_cpu, writer, tag, fps)
_io_futures.append(fut)
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
def get_device_from_parameters(module: nn.Module) -> torch.device:
@@ -38,6 +142,92 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
return next(iter(module.parameters())).device
def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module:
"""Apply precision settings to model components based on command-line arguments.
Args:
model (nn.Module): The model to apply precision settings to.
args (argparse.Namespace): Parsed command-line arguments containing precision settings.
Returns:
nn.Module: Model with precision settings applied.
"""
print(f">>> Applying precision settings:")
print(f" - Diffusion dtype: {args.diffusion_dtype}")
print(f" - Projector mode: {args.projector_mode}")
print(f" - Encoder mode: {args.encoder_mode}")
print(f" - VAE dtype: {args.vae_dtype}")
# 1. Set Diffusion backbone precision
if args.diffusion_dtype == "bf16":
# Convert diffusion model weights to bf16
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
print(" ✓ Diffusion model weights converted to bfloat16")
else:
model.diffusion_autocast_dtype = torch.bfloat16
print(" ✓ Diffusion model using fp32")
# 2. Set Projector precision
if args.projector_mode == "bf16_full":
model.state_projector.to(torch.bfloat16)
model.action_projector.to(torch.bfloat16)
model.projector_autocast_dtype = None
print(" ✓ Projectors converted to bfloat16")
elif args.projector_mode == "autocast":
model.projector_autocast_dtype = torch.bfloat16
print(" ✓ Projectors will use autocast (weights fp32, compute bf16)")
else:
model.projector_autocast_dtype = None
# fp32 mode: do nothing, keep original precision
# 3. Set Encoder precision
if args.encoder_mode == "bf16_full":
model.embedder.to(torch.bfloat16)
model.image_proj_model.to(torch.bfloat16)
model.encoder_autocast_dtype = None
print(" ✓ Encoders converted to bfloat16")
elif args.encoder_mode == "autocast":
model.encoder_autocast_dtype = torch.bfloat16
print(" ✓ Encoders will use autocast (weights fp32, compute bf16)")
else:
model.encoder_autocast_dtype = None
# fp32 mode: do nothing, keep original precision
# 4. Set VAE precision
if args.vae_dtype == "bf16":
model.first_stage_model.to(torch.bfloat16)
print(" ✓ VAE converted to bfloat16")
else:
print(" ✓ VAE kept in fp32 for best quality")
# 5. Safety net: ensure no fp32 parameters remain when all components are bf16
if args.diffusion_dtype == "bf16":
fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32]
if fp32_params:
print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16")
for name, param in fp32_params:
param.data = param.data.to(torch.bfloat16)
print(" ✓ All parameters converted to bfloat16")
return model
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
return model
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file.
@@ -73,17 +263,18 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
def load_model_checkpoint(model: nn.Module, ckpt: str, device: str = "cpu") -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file.
device (str): Target device for loaded tensors.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
state_dict = torch.load(ckpt, map_location=device)
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
@@ -262,6 +453,11 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
# Auto-detect VAE dtype and convert input
vae_dtype = next(model.first_stage_model.parameters()).dtype
x = x.to(dtype=vae_dtype)
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
@@ -327,7 +523,8 @@ def image_guided_synthesis_sim_mode(
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
sim_mode: bool = True,
**kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
decode_video: bool = True,
**kwargs) -> tuple[torch.Tensor | None, torch.Tensor, torch.Tensor]:
"""
Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
@@ -350,10 +547,13 @@ def image_guided_synthesis_sim_mode(
timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
decode_video (bool): Whether to decode latent samples to pixel-space video.
Set to False to skip VAE decode for speed when only actions/states are needed.
**kwargs: Additional arguments passed to the DDIM sampler.
Returns:
batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
batch_variants (torch.Tensor | None): Predicted pixel-space video frames [B, C, T, H, W],
or None when decode_video=False.
actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
"""
@@ -363,10 +563,22 @@ def image_guided_synthesis_sim_mode(
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
# Auto-detect model dtype and convert inputs accordingly
model_dtype = next(model.embedder.parameters()).dtype
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
# Encoder autocast: weights stay fp32, compute in bf16
enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None)
if enc_ac_dtype is not None and model.device.type == 'cuda':
enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype)
else:
enc_ctx = nullcontext()
with enc_ctx:
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
@@ -380,11 +592,22 @@ def image_guided_synthesis_sim_mode(
prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state_emb = model.state_projector(observation['observation.state'])
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
# Auto-detect projector dtype and convert inputs
projector_dtype = next(model.state_projector.parameters()).dtype
cond_action_emb = model.action_projector(observation['action'])
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
# Projector autocast: weights stay fp32, compute in bf16
proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None)
if proj_ac_dtype is not None and model.device.type == 'cuda':
proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype)
else:
proj_ctx = nullcontext()
with proj_ctx:
cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
@@ -406,8 +629,18 @@ def image_guided_synthesis_sim_mode(
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
# Setup autocast context for diffusion sampling
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
if autocast_dtype is not None and model.device.type == 'cuda':
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
else:
autocast_ctx = nullcontext()
batch_variants = None
if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample(
with autocast_ctx:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
@@ -424,9 +657,10 @@ def image_guided_synthesis_sim_mode(
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
if decode_video:
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
@@ -455,22 +689,63 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
# Load config
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
print(f'>>> Load pre-trained model ...')
# Build unnomalizer
prepared_path = args.ckpt_path + ".prepared.pt"
if os.path.exists(prepared_path):
# ---- Fast path: load the fully-prepared model ----
print(f">>> Loading prepared model from {prepared_path} ...")
model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}",
weights_only=False)
model.eval()
# Restore autocast attributes (weights already cast, just need contexts)
model.diffusion_autocast_dtype = torch.bfloat16 if args.diffusion_dtype == "bf16" else torch.bfloat16
model.projector_autocast_dtype = torch.bfloat16 if args.projector_mode == "autocast" else None
model.encoder_autocast_dtype = torch.bfloat16 if args.encoder_mode == "autocast" else None
# Compile hot ResBlocks for operator fusion
apply_torch_compile(model)
print(f">>> Prepared model loaded.")
else:
# ---- Normal path: construct + checkpoint + casting ----
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path,
device=f"cuda:{gpu_no}")
model.eval()
print(f'>>> Load pre-trained model ...')
# Apply precision settings before moving to GPU
model = apply_precision_settings(model, args)
# Export precision-converted checkpoint if requested
if args.export_precision_ckpt:
export_path = args.export_precision_ckpt
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
torch.save({"state_dict": model.state_dict()}, export_path)
print(f">>> Precision-converted checkpoint saved to: {export_path}")
return
model = model.cuda(gpu_no)
# Save prepared model for fast loading next time (before torch.compile)
print(f">>> Saving prepared model to {prepared_path} ...")
torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# Compile hot ResBlocks for operator fusion (after save, compiled objects can't be pickled)
apply_torch_compile(model)
# Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
model = model.cuda(gpu_no)
device = get_device_from_parameters(model)
# Run over data
@@ -587,7 +862,8 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
fs=model_input_fs,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False)
sim_mode=False,
decode_video=not args.fast_policy_no_decode)
# Update future actions in the observation queues
for idx in range(len(pred_actions[0])):
@@ -645,28 +921,30 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
observation)
# Save the imagen videos for decision-making
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
if pred_videos_0 is not None:
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard_async(writer,
pred_videos_0,
sample_tag,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
log_to_tensorboard(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
log_to_tensorboard_async(writer,
pred_videos_1,
sample_tag,
fps=args.save_fps)
# Save the imagen videos for decision-making
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_0.cpu(),
sample_video_file,
fps=args.save_fps)
if pred_videos_0 is not None:
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
save_results_async(pred_videos_0,
sample_video_file,
fps=args.save_fps)
# Save videos environment changes via world-model interaction
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
save_results(pred_videos_1.cpu(),
sample_video_file,
fps=args.save_fps)
save_results_async(pred_videos_1,
sample_video_file,
fps=args.save_fps)
print('>' * 24)
# Collect the result of world-model interactions
@@ -674,12 +952,15 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
full_video = torch.cat(wm_video, dim=2)
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
log_to_tensorboard(writer,
full_video,
sample_tag,
fps=args.save_fps)
log_to_tensorboard_async(writer,
full_video,
sample_tag,
fps=args.save_fps)
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
save_results(full_video, sample_full_video_file, fps=args.save_fps)
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
# Wait for all async I/O to complete
_flush_io()
def get_parser():
@@ -794,14 +1075,49 @@ def get_parser():
action='store_true',
default=False,
help="not using the predicted states as comparison")
parser.add_argument(
"--fast_policy_no_decode",
action='store_true',
default=False,
help="Speed mode: policy pass only predicts actions, skip policy video decode/log/save.")
parser.add_argument("--save_fps",
type=int,
default=8,
help="fps for the saving video")
parser.add_argument(
"--diffusion_dtype",
type=str,
choices=["fp32", "bf16"],
default="bf16",
help="Diffusion backbone precision (fp32/bf16)")
parser.add_argument(
"--projector_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="bf16_full",
help="Projector precision mode (fp32/autocast/bf16_full)")
parser.add_argument(
"--encoder_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="bf16_full",
help="Encoder precision mode (fp32/autocast/bf16_full)")
parser.add_argument(
"--vae_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="VAE precision (fp32/bf16, most affects image quality)")
parser.add_argument(
"--export_precision_ckpt",
type=str,
default=None,
help="Export precision-converted checkpoint to this path, then exit.")
return parser
if __name__ == '__main__':
patch_norm_bypass_autocast()
parser = get_parser()
args = parser.parse_args()
seed = args.seed

View File

@@ -99,6 +99,8 @@ class AutoencoderKL(pl.LightningModule):
print(f"Restored from {path}")
def encode(self, x, **kwargs):
if getattr(self, '_channels_last', False):
x = x.to(memory_format=torch.channels_last)
h = self.encoder(x)
moments = self.quant_conv(h)
@@ -106,6 +108,8 @@ class AutoencoderKL(pl.LightningModule):
return posterior
def decode(self, z, **kwargs):
if getattr(self, '_channels_last', False):
z = z.to(memory_format=torch.channels_last)
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec

View File

@@ -1074,10 +1074,10 @@ class LatentDiffusion(DDPM):
encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach()
else: ## Consume less GPU memory but slower
bs = getattr(self, 'vae_encode_bs', 1)
results = []
for index in range(x.shape[0]):
frame_batch = self.first_stage_model.encode(x[index:index +
1, :, :, :])
for i in range(0, x.shape[0], bs):
frame_batch = self.first_stage_model.encode(x[i:i + bs])
frame_result = self.get_first_stage_encoding(
frame_batch).detach()
results.append(frame_result)
@@ -1105,14 +1105,18 @@ class LatentDiffusion(DDPM):
else:
reshape_back = False
# Align input dtype with VAE weights (e.g. fp32 samples → bf16 VAE)
vae_dtype = next(self.first_stage_model.parameters()).dtype
z = z.to(dtype=vae_dtype)
z = 1. / self.scale_factor * z
if not self.perframe_ae:
z = 1. / self.scale_factor * z
results = self.first_stage_model.decode(z, **kwargs)
else:
bs = getattr(self, 'vae_decode_bs', 1)
results = []
for index in range(z.shape[0]):
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
for i in range(0, z.shape[0], bs):
frame_result = self.first_stage_model.decode(z[i:i + bs], **kwargs)
results.append(frame_result)
results = torch.cat(results, dim=0)
@@ -1799,7 +1803,9 @@ class LatentDiffusion(DDPM):
"""
if ddim:
ddim_sampler = DDIMSampler(self)
if not hasattr(self, '_ddim_sampler') or self._ddim_sampler is None:
self._ddim_sampler = DDIMSampler(self)
ddim_sampler = self._ddim_sampler
shape = (self.channels, self.temporal_length, *self.image_size)
samples, actions, states, intermediates = ddim_sampler.sample(
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
@@ -2457,7 +2463,6 @@ class DiffusionWrapper(pl.LightningModule):
Returns:
Output from the inner diffusion model (tensor or tuple, depending on the model).
"""
if self.conditioning_key is None:
out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat':

View File

@@ -501,6 +501,10 @@ class ConditionalUnet1D(nn.Module):
self.last_frame_only = last_frame_only
self.horizon = horizon
# Context precomputation cache
self._global_cond_cache_enabled = False
self._global_cond_cache = {}
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
@@ -530,14 +534,20 @@ class ConditionalUnet1D(nn.Module):
B, T, D = sample.shape
if self.use_linear_act_proj:
sample = self.proj_in_action(sample.unsqueeze(-1))
global_cond = self.obs_encoder(cond)
global_cond = rearrange(global_cond,
'(b t) d -> b 1 (t d)',
b=B,
t=self.n_obs_steps)
global_cond = repeat(global_cond,
'b c d -> b (repeat c) d',
repeat=T)
_gc_key = (cond['image'].data_ptr(), cond['agent_pos'].data_ptr())
if self._global_cond_cache_enabled and _gc_key in self._global_cond_cache:
global_cond = self._global_cond_cache[_gc_key]
else:
global_cond = self.obs_encoder(cond)
global_cond = rearrange(global_cond,
'(b t) d -> b 1 (t d)',
b=B,
t=self.n_obs_steps)
global_cond = repeat(global_cond,
'b c d -> b (repeat c) d',
repeat=T)
if self._global_cond_cache_enabled:
self._global_cond_cache[_gc_key] = global_cond
else:
sample = einops.rearrange(sample, 'b h t -> b t h')
sample = self.proj_in_horizon(sample)

View File

@@ -8,12 +8,14 @@ class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
# Dummy buffer so .to(dtype) propagates to this module
self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False)
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = x.float()[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
return emb.to(self._dtype_buf.dtype)

View File

@@ -0,0 +1,7 @@
{
"permissions": {
"allow": [
"Bash(python3:*)"
]
}
}

View File

@@ -6,6 +6,8 @@ from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim
from unifolm_wma.utils.common import noise_like
from unifolm_wma.utils.common import extract_into_tensor
from tqdm import tqdm
from unifolm_wma.modules.attention import enable_cross_attn_kv_cache, disable_cross_attn_kv_cache
from unifolm_wma.modules.networks.wma_model import enable_ctx_cache, disable_ctx_cache
class DDIMSampler(object):
@@ -16,6 +18,7 @@ class DDIMSampler(object):
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.counter = 0
self._schedule_key = None # (ddim_num_steps, ddim_discretize, ddim_eta)
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
@@ -28,6 +31,11 @@ class DDIMSampler(object):
ddim_discretize="uniform",
ddim_eta=0.,
verbose=True):
key = (ddim_num_steps, ddim_discretize, ddim_eta)
if self._schedule_key == key:
return
self._schedule_key = key
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
@@ -36,7 +44,7 @@ class DDIMSampler(object):
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model
.device)
if self.model.use_dynamic_rescale:
@@ -67,11 +75,12 @@ class DDIMSampler(object):
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
# Ensure tensors are on correct device for efficient indexing
self.register_buffer('ddim_sigmas', to_torch(torch.as_tensor(ddim_sigmas)))
self.register_buffer('ddim_alphas', to_torch(torch.as_tensor(ddim_alphas)))
self.register_buffer('ddim_alphas_prev', to_torch(torch.as_tensor(ddim_alphas_prev)))
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
to_torch(torch.as_tensor(np.sqrt(1. - ddim_alphas))))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
@@ -208,9 +217,9 @@ class DDIMSampler(object):
if precision is not None:
if precision == 16:
img = img.to(dtype=torch.float16)
action = action.to(dtype=torch.float16)
state = state.to(dtype=torch.float16)
img = img.to(dtype=torch.bfloat16)
action = action.to(dtype=torch.bfloat16)
state = state.to(dtype=torch.bfloat16)
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
@@ -241,63 +250,70 @@ class DDIMSampler(object):
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
ts = torch.empty((b, ), device=device, dtype=torch.long)
enable_cross_attn_kv_cache(self.model)
enable_ctx_cache(self.model)
try:
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts.fill_(step)
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
assert x0 is not None
if clean_cond:
img_orig = x0
else:
img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
assert x0 is not None
if clean_cond:
img_orig = x0
else:
img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(
img,
action,
state,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask=mask,
x0=x0,
fs=fs,
guidance_rescale=guidance_rescale,
**kwargs)
outs = self.p_sample_ddim(
img,
action,
state,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask=mask,
x0=x0,
fs=fs,
guidance_rescale=guidance_rescale,
**kwargs)
img, pred_x0, model_output_action, model_output_state = outs
img, pred_x0, model_output_action, model_output_state = outs
action = dp_ddim_scheduler_action.step(
model_output_action,
step,
action,
generator=None,
).prev_sample
state = dp_ddim_scheduler_state.step(
model_output_state,
step,
state,
generator=None,
).prev_sample
action = dp_ddim_scheduler_action.step(
model_output_action,
step,
action,
generator=None,
).prev_sample
state = dp_ddim_scheduler_state.step(
model_output_state,
step,
state,
generator=None,
).prev_sample
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
intermediates['x_inter_action'].append(action)
intermediates['x_inter_state'].append(state)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
intermediates['x_inter_action'].append(action)
intermediates['x_inter_state'].append(state)
finally:
disable_cross_attn_kv_cache(self.model)
disable_ctx_cache(self.model)
return img, action, state, intermediates
@@ -325,10 +341,6 @@ class DDIMSampler(object):
guidance_rescale=0.0,
**kwargs):
b, *_, device = *x.shape, x.device
if x.dim() == 5:
is_video = True
else:
is_video = False
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output, model_output_action, model_output_state = self.model.apply_model(
@@ -377,17 +389,11 @@ class DDIMSampler(object):
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video:
size = (b, 1, 1, 1, 1)
else:
size = (b, 1, 1, 1)
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(size,
sqrt_one_minus_alphas[index],
device=device)
# Use 0-d tensors directly (already on device); broadcasting handles shape
a_t = alphas[index].to(x.dtype)
a_prev = alphas_prev[index].to(x.dtype)
sigma_t = sigmas[index].to(x.dtype)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
@@ -395,12 +401,8 @@ class DDIMSampler(object):
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale:
scale_t = torch.full(size,
self.ddim_scale_arr[index],
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
scale_t = self.ddim_scale_arr[index]
prev_scale_t = self.ddim_scale_arr_prev[index]
rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale

View File

@@ -86,9 +86,8 @@ class CrossAttention(nn.Module):
self.relative_position_v = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length)
else:
## only used for spatial attention, while NOT for temporal attention
if XFORMERS_IS_AVAILBLE and temporal_length is None:
self.forward = self.efficient_forward
## bmm fused-scale attention for all non-relative-position cases
self.forward = self.bmm_forward
self.video_length = video_length
self.image_cross_attention = image_cross_attention
@@ -98,6 +97,9 @@ class CrossAttention(nn.Module):
self.text_context_len = text_context_len
self.agent_state_context_len = agent_state_context_len
self.agent_action_context_len = agent_action_context_len
self._kv_cache = {}
self._kv_cache_enabled = False
self.cross_attention_scale_learnable = cross_attention_scale_learnable
if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
@@ -125,7 +127,7 @@ class CrossAttention(nn.Module):
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
# assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:,
self.agent_state_context_len:self.
@@ -173,7 +175,8 @@ class CrossAttention(nn.Module):
sim.masked_fill_(~(mask > 0.5), max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
with torch.amp.autocast('cuda', enabled=False):
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v)
if self.relative_position:
@@ -190,7 +193,8 @@ class CrossAttention(nn.Module):
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
k_ip) * self.scale
del k_ip
sim_ip = sim_ip.softmax(dim=-1)
with torch.amp.autocast('cuda', enabled=False):
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
@@ -201,7 +205,8 @@ class CrossAttention(nn.Module):
sim_as = torch.einsum('b i d, b j d -> b i j', q,
k_as) * self.scale
del k_as
sim_as = sim_as.softmax(dim=-1)
with torch.amp.autocast('cuda', enabled=False):
sim_as = sim_as.softmax(dim=-1)
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
@@ -212,7 +217,8 @@ class CrossAttention(nn.Module):
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
k_aa) * self.scale
del k_aa
sim_aa = sim_aa.softmax(dim=-1)
with torch.amp.autocast('cuda', enabled=False):
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
@@ -230,6 +236,141 @@ class CrossAttention(nn.Module):
return self.to_out(out)
def bmm_forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None
h = self.heads
q = self.to_q(x)
context = default(context, x)
use_cache = self._kv_cache_enabled and not spatial_self_attn
cache_hit = use_cache and len(self._kv_cache) > 0
if cache_hit:
# Reuse cached K/V (already in (b*h, n, d) shape)
k = self._kv_cache['k']
v = self._kv_cache['v']
if 'k_ip' in self._kv_cache:
k_ip = self._kv_cache['k_ip']
v_ip = self._kv_cache['v_ip']
k_as = self._kv_cache['k_as']
v_as = self._kv_cache['v_as']
k_aa = self._kv_cache['k_aa']
v_aa = self._kv_cache['v_aa']
q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
elif self.image_cross_attention and not spatial_self_attn:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:,
self.agent_state_context_len:self.
agent_state_context_len +
self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len +
self.agent_action_context_len:self.
agent_state_context_len +
self.agent_action_context_len +
self.text_context_len, :]
context_image = context[:, self.agent_state_context_len +
self.agent_action_context_len +
self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_ip, v_ip))
k_as, v_as = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_as, v_as))
k_aa, v_aa = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_aa, v_aa))
if use_cache:
self._kv_cache = {
'k': k, 'v': v,
'k_ip': k_ip, 'v_ip': v_ip,
'k_as': k_as, 'v_as': v_as,
'k_aa': k_aa, 'v_aa': v_aa,
}
else:
if not spatial_self_attn:
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
if use_cache:
self._kv_cache = {'k': k, 'v': v}
# baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul
sim = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
q, k.transpose(-1, -2), beta=0, alpha=self.scale)
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
sim.masked_fill_(~(mask > 0.5), max_neg_value)
with torch.amp.autocast('cuda', enabled=False):
sim = sim.softmax(dim=-1)
out = torch.bmm(sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
if k_ip is not None and k_as is not None and k_aa is not None:
## image cross-attention (k_ip/v_ip already in (b*h, n, d) shape)
sim_ip = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device),
q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale)
with torch.amp.autocast('cuda', enabled=False):
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.bmm(sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
## agent state cross-attention (k_as/v_as already in (b*h, n, d) shape)
sim_as = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device),
q, k_as.transpose(-1, -2), beta=0, alpha=self.scale)
with torch.amp.autocast('cuda', enabled=False):
sim_as = sim_as.softmax(dim=-1)
out_as = torch.bmm(sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
## agent action cross-attention (k_aa/v_aa already in (b*h, n, d) shape)
sim_aa = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device),
q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale)
with torch.amp.autocast('cuda', enabled=False):
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.bmm(sim_aa, v_aa)
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
if out_ip is not None and out_as is not None and out_aa is not None:
if self.cross_attention_scale_learnable:
out = out + \
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
else:
out = out + \
self.image_cross_attention_scale * out_ip + \
self.agent_state_cross_attention_scale * out_as + \
self.agent_action_cross_attention_scale * out_aa
return self.to_out(out)
def efficient_forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k, v, out = None, None, None
@@ -275,7 +416,8 @@ class CrossAttention(nn.Module):
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
q.shape[1],
k_aa.shape[1],
block_size=16).to(k_aa.device)
block_size=16,
device=k_aa.device)
else:
if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..."
@@ -386,17 +528,43 @@ class CrossAttention(nn.Module):
return self.to_out(out)
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
def _get_attn_mask_aa(self, b, l1, l2, block_size=16, device=None):
cache_key = (b, l1, l2, block_size)
if hasattr(self, '_attn_mask_aa_cache_key') and self._attn_mask_aa_cache_key == cache_key:
cached = self._attn_mask_aa_cache
if device is not None and cached.device != torch.device(device):
cached = cached.to(device)
self._attn_mask_aa_cache = cached
return cached
target_device = device if device is not None else 'cpu'
num_token = l2 // block_size
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
col_indices = torch.arange(l2)
start_positions = ((torch.arange(b, device=target_device) % block_size) + 1) * num_token
col_indices = torch.arange(l2, device=target_device)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros_like(mask, dtype=torch.float)
attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device)
attn_mask[mask] = float('-inf')
self._attn_mask_aa_cache_key = cache_key
self._attn_mask_aa_cache = attn_mask
return attn_mask
def enable_cross_attn_kv_cache(module):
for m in module.modules():
if isinstance(m, CrossAttention):
m._kv_cache_enabled = True
m._kv_cache = {}
def disable_cross_attn_kv_cache(module):
for m in module.modules():
if isinstance(m, CrossAttention):
m._kv_cache_enabled = False
m._kv_cache = {}
class BasicTransformerBlock(nn.Module):
def __init__(self,

View File

@@ -11,7 +11,7 @@ from unifolm_wma.utils.utils import instantiate_from_config
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -422,7 +422,7 @@ class WMAModel(nn.Module):
self.temporal_attention = temporal_attention
time_embed_dim = model_channels * 4
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
self.dtype = torch.float16 if use_fp16 else torch.bfloat16
temporal_self_att_only = True
self.addition_attention = addition_attention
self.temporal_length = temporal_length
@@ -685,6 +685,12 @@ class WMAModel(nn.Module):
self.action_token_projector = instantiate_from_config(
stem_process_config)
# Context precomputation cache
self._ctx_cache_enabled = False
self._ctx_cache = {}
# fs_embed cache
self._fs_embed_cache = None
def forward(self,
x: Tensor,
x_action: Tensor,
@@ -720,58 +726,64 @@ class WMAModel(nn.Module):
repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb)
bt, l_context, _ = context.shape
if self.base_model_gen_only:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
_ctx_key = context.data_ptr()
if self._ctx_cache_enabled and _ctx_key in self._ctx_cache:
context = self._ctx_cache[_ctx_key]
else:
if l_context == self.n_obs_steps + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
77, :]
context_img = context[:, self.n_obs_steps + 77:, :]
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context = torch.cat(
[context_agent_state, context_text, context_img], dim=1)
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_agent_action = context[:, self.
n_obs_steps:self.n_obs_steps +
16, :]
context_agent_action = rearrange(
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
context_agent_action = self.action_token_projector(
context_agent_action)
context_agent_action = rearrange(context_agent_action,
'(b o) l d -> b o l d',
o=t)
context_agent_action = rearrange(context_agent_action,
'b o (t l) d -> b o t l d',
t=t)
context_agent_action = context_agent_action.permute(
0, 2, 1, 3, 4)
context_agent_action = rearrange(context_agent_action,
'b t o l d -> (b t) (o l) d')
bt, l_context, _ = context.shape
if self.base_model_gen_only:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
else:
if l_context == self.n_obs_steps + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
77, :]
context_img = context[:, self.n_obs_steps + 77:, :]
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context = torch.cat(
[context_agent_state, context_text, context_img], dim=1)
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_agent_action = context[:, self.
n_obs_steps:self.n_obs_steps +
16, :]
context_agent_action = rearrange(
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
context_agent_action = self.action_token_projector(
context_agent_action)
context_agent_action = rearrange(context_agent_action,
'(b o) l d -> b o l d',
o=t)
context_agent_action = rearrange(context_agent_action,
'b o (t l) d -> b o t l d',
t=t)
context_agent_action = context_agent_action.permute(
0, 2, 1, 3, 4)
context_agent_action = rearrange(context_agent_action,
'b t o l d -> (b t) (o l) d')
context_text = context[:, self.n_obs_steps +
16:self.n_obs_steps + 16 + 77, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_text = context[:, self.n_obs_steps +
16:self.n_obs_steps + 16 + 77, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context = torch.cat([
context_agent_state, context_agent_action, context_text,
context_img
],
dim=1)
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context = torch.cat([
context_agent_state, context_agent_action, context_text,
context_img
],
dim=1)
if self._ctx_cache_enabled:
self._ctx_cache[_ctx_key] = context
emb = emb.repeat_interleave(repeats=t, dim=0)
@@ -779,16 +791,20 @@ class WMAModel(nn.Module):
# Combine emb
if self.fs_condition:
if fs is None:
fs = torch.tensor([self.default_fs] * b,
dtype=torch.long,
device=x.device)
fs_emb = timestep_embedding(fs,
self.model_channels,
repeat_only=False).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
if self._ctx_cache_enabled and self._fs_embed_cache is not None:
fs_embed = self._fs_embed_cache
else:
if fs is None:
fs = torch.tensor([self.default_fs] * b,
dtype=torch.long,
device=x.device)
fs_emb = timestep_embedding(fs,
self.model_channels,
repeat_only=False).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
if self._ctx_cache_enabled:
self._fs_embed_cache = fs_embed
emb = emb + fs_embed
h = x.type(self.dtype)
@@ -846,3 +862,32 @@ class WMAModel(nn.Module):
s_y = torch.zeros_like(x_state)
return y, a_y, s_y
def enable_ctx_cache(model):
"""Enable context precomputation cache on WMAModel and its action/state UNets."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = True
m._ctx_cache = {}
m._fs_embed_cache = None
# conditional_unet1d cache
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = True
m._global_cond_cache = {}
def disable_ctx_cache(model):
"""Disable and clear context precomputation cache."""
for m in model.modules():
if isinstance(m, WMAModel):
m._ctx_cache_enabled = False
m._ctx_cache = {}
m._fs_embed_cache = None
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules():
if isinstance(m, ConditionalUnet1D):
m._global_cond_cache_enabled = False
m._global_cond_cache = {}

View File

@@ -7,7 +7,9 @@
#
# thanks!
import torch
import torch.nn as nn
import torch.nn.functional as F
from unifolm_wma.utils.utils import instantiate_from_config
@@ -78,7 +80,11 @@ def nonlinearity(type='silu'):
class GroupNormSpecific(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def normalization(channels, num_groups=32):

View File

@@ -0,0 +1,144 @@
2026-02-08 05:20:49.828675: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 05:20:49.831563: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:20:49.861366: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 05:20:49.861402: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 05:20:49.862974: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 05:20:49.870402: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:20:49.870647: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 05:20:50.486843: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:38<16:25, 98.56s/it]
18%|█▊ | 2/11 [03:16<14:44, 98.31s/it]
27%|██▋ | 3/11 [04:55<13:06, 98.33s/it]
36%|███▋ | 4/11 [06:36<11:37, 99.66s/it]
45%|████▌ | 5/11 [08:31<10:29, 104.96s/it]
55%|█████▍ | 6/11 [10:10<08:35, 103.07s/it]
64%|██████▎ | 7/11 [11:48<06:46, 101.50s/it]
73%|███████▎ | 8/11 [13:27<05:01, 100.52s/it]
82%|████████▏ | 9/11 [15:05<03:19, 99.79s/it]
91%|█████████ | 10/11 [16:43<01:39, 99.30s/it]
100%|██████████| 11/11 [18:21<00:00, 98.97s/it]
100%|██████████| 11/11 [18:21<00:00, 100.16s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4",
"pred_video": "unitree_g1_pack_camera/case1/output/inference/unitree_g1_pack_camera_case1_amd.mp4",
"psnr": 16.415668383379177
}

View File

@@ -0,0 +1,158 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 18:28:48.960238: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 18:28:48.963331: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:28:48.995688: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 18:28:48.995732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 18:28:48.997547: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 18:28:49.005673: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:28:49.005948: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 18:28:50.009660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
>>> Applying precision settings:
- Diffusion dtype: bf16
- Projector mode: bf16_full
- Encoder mode: bf16_full
- VAE dtype: fp32
✓ Diffusion model weights converted to bfloat16
✓ Projectors converted to bfloat16
✓ Encoders converted to bfloat16
✓ VAE kept in fp32 for best quality
⚠ Found 849 fp32 params, converting to bf16
✓ All parameters converted to bfloat16
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:14<12:29, 74.95s/it]
18%|█▊ | 2/11 [02:23<10:40, 71.18s/it]
27%|██▋ | 3/11 [03:32<09:20, 70.05s/it]
36%|███▋ | 4/11 [04:40<08:06, 69.51s/it]
45%|████▌ | 5/11 [05:49<06:55, 69.19s/it]
55%|█████▍ | 6/11 [06:57<05:44, 68.95s/it]
64%|██████▎ | 7/11 [08:06<04:35, 68.79s/it]
73%|███████▎ | 8/11 [09:14<03:26, 68.70s/it]
82%|████████▏ | 9/11 [10:23<02:17, 68.65s/it]
91%|█████████ | 10/11 [11:31<01:08, 68.58s/it]
100%|██████████| 11/11 [12:40<00:00, 68.51s/it]
100%|██████████| 11/11 [12:40<00:00, 69.11s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_g1_pack_camera/case2/unitree_g1_pack_camera_case2.mp4",
"pred_video": "unitree_g1_pack_camera/case2/output/inference/unitree_g1_pack_camera_case2_amd.mp4",
"psnr": 19.515250190529375
}

View File

@@ -0,0 +1,144 @@
2026-02-08 05:08:32.803904: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 05:08:32.807010: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:08:32.837936: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 05:08:32.837978: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 05:08:32.839785: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 05:08:32.847835: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:08:32.848223: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 05:08:34.120114: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:39<16:34, 99.46s/it]
18%|█▊ | 2/11 [03:18<14:55, 99.48s/it]
27%|██▋ | 3/11 [04:58<13:16, 99.60s/it]
36%|███▋ | 4/11 [06:38<11:37, 99.69s/it]
45%|████▌ | 5/11 [08:18<09:58, 99.68s/it]
55%|█████▍ | 6/11 [09:57<08:18, 99.66s/it]
64%|██████▎ | 7/11 [11:37<06:38, 99.62s/it]
73%|███████▎ | 8/11 [13:16<04:58, 99.55s/it]
82%|████████▏ | 9/11 [14:56<03:19, 99.50s/it]
91%|█████████ | 10/11 [16:35<01:39, 99.43s/it]
100%|██████████| 11/11 [18:14<00:00, 99.36s/it]
100%|██████████| 11/11 [18:14<00:00, 99.51s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_g1_pack_camera/case3/unitree_g1_pack_camera_case3.mp4",
"pred_video": "unitree_g1_pack_camera/case3/output/inference/unitree_g1_pack_camera_case3_amd.mp4",
"psnr": 19.429578160315536
}

View File

@@ -0,0 +1,144 @@
2026-02-08 05:29:19.728303: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 05:29:19.731620: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:29:19.761276: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 05:29:19.761301: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 05:29:19.762880: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 05:29:19.770578: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:29:19.771072: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 05:29:21.043661: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:37<16:18, 97.81s/it]
18%|█▊ | 2/11 [03:15<14:38, 97.56s/it]
27%|██▋ | 3/11 [04:52<12:59, 97.48s/it]
36%|███▋ | 4/11 [06:29<11:21, 97.38s/it]
45%|████▌ | 5/11 [08:06<09:43, 97.28s/it]
55%|█████▍ | 6/11 [09:44<08:06, 97.35s/it]
64%|██████▎ | 7/11 [11:21<06:29, 97.36s/it]
73%|███████▎ | 8/11 [12:59<04:52, 97.38s/it]
82%|████████▏ | 9/11 [14:36<03:14, 97.39s/it]
91%|█████████ | 10/11 [16:14<01:37, 97.42s/it]
100%|██████████| 11/11 [17:51<00:00, 97.42s/it]
100%|██████████| 11/11 [17:51<00:00, 97.41s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_g1_pack_camera/case4/unitree_g1_pack_camera_case4.mp4",
"pred_video": "unitree_g1_pack_camera/case4/output/inference/unitree_g1_pack_camera_case4_amd.mp4",
"psnr": 17.80386833747375
}

View File

@@ -0,0 +1,145 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-09 18:39:50.119842: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-09 18:39:50.123128: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 18:39:50.156652: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-09 18:39:50.156708: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-09 18:39:50.158926: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-09 18:39:50.167779: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 18:39:50.168073: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-09 18:39:50.915144: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
>>> Applying precision settings:
- Diffusion dtype: bf16
- Projector mode: bf16_full
- Encoder mode: bf16_full
- VAE dtype: bf16
✓ Diffusion model weights converted to bfloat16
✓ Projectors converted to bfloat16
✓ Encoders converted to bfloat16
✓ VAE converted to bfloat16
⚠ Found 601 fp32 params, converting to bf16
✓ All parameters converted to bfloat16
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:08<07:58, 68.38s/it]
25%|██▌ | 2/8 [02:13<06:38, 66.48s/it]
38%|███▊ | 3/8 [03:18<05:29, 65.83s/it]
50%|█████ | 4/8 [04:23<04:22, 65.52s/it]
62%|██████▎ | 5/8 [05:28<03:15, 65.33s/it]
75%|███████▌ | 6/8 [06:33<02:10, 65.23s/it]
88%|████████▊ | 7/8 [07:38<01:05, 65.12s/it]
100%|██████████| 8/8 [08:43<00:00, 65.08s/it]
100%|██████████| 8/8 [08:43<00:00, 65.44s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4",
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
"psnr": 19.586376345676264
}

View File

@@ -0,0 +1,5 @@
{
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
"psnr": 31.802224855380352
}

View File

@@ -0,0 +1,5 @@
#\!/bin/bash
res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
dataset="unitree_z1_dual_arm_cleanup_pencils"
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/profile_iteration.py --seed 123 --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --savedir "${res_dir}/profile_output" --prompt_dir "${res_dir}/world_model_interaction_prompts" --dataset ${dataset} --bs 1 --height 320 --width 512 --unconditional_guidance_scale 1.0 --ddim_steps 50 --ddim_eta 1.0 --video_length 16 --frame_stride 4 --exe_steps 16 --n_iter 5 --warmup 1 --timestep_spacing uniform_trailing --guidance_rescale 0.7 --perframe_ae --vae_dtype bf16 --fast_policy_no_decode --csv "${res_dir}/profile_output/baseline.csv" 2>&1 | tee "${res_dir}/profile_output/profile.log"

View File

@@ -2,9 +2,9 @@ res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
dataset="unitree_z1_dual_arm_cleanup_pencils"
{
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
--config configs/inference/world_model_interaction.yaml \
--savedir "${res_dir}/output" \
--bs 1 --height 320 --width 512 \
@@ -20,5 +20,10 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
--n_iter 8 \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae
--perframe_ae \
--diffusion_dtype fp32 \
--projector_mode fp32 \
--encoder_mode fp32 \
--vae_dtype fp32 \
--fast_policy_no_decode
} 2>&1 | tee "${res_dir}/output.log"

View File

@@ -0,0 +1,137 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 06:59:34.465946: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 06:59:34.469367: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 06:59:34.500805: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 06:59:34.500837: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 06:59:34.502917: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 06:59:34.511434: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 06:59:34.511678: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 06:59:35.478194: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:37<11:23, 97.57s/it]
25%|██▌ | 2/8 [03:14<09:44, 97.48s/it]
38%|███▊ | 3/8 [04:52<08:07, 97.47s/it]
50%|█████ | 4/8 [06:29<06:29, 97.49s/it]
62%|██████▎ | 5/8 [08:07<04:52, 97.42s/it]
75%|███████▌ | 6/8 [09:44<03:14, 97.32s/it]
88%|████████▊ | 7/8 [11:21<01:37, 97.34s/it]
100%|██████████| 8/8 [12:59<00:00, 97.36s/it]
100%|██████████| 8/8 [12:59<00:00, 97.40s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case2/unitree_z1_dual_arm_cleanup_pencils_case2.mp4",
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case2/output/inference/unitree_z1_dual_arm_cleanup_pencils_case2_amd.mp4",
"psnr": 20.484298972158296
}

View File

@@ -0,0 +1,137 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:18:52.629976: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:18:52.633025: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:18:52.663985: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:18:52.664018: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:18:52.665837: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:18:52.673889: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:18:52.674218: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:18:53.298338: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:40<11:43, 100.54s/it]
25%|██▌ | 2/8 [03:20<10:02, 100.36s/it]
38%|███▊ | 3/8 [05:01<08:21, 100.32s/it]
50%|█████ | 4/8 [06:41<06:41, 100.36s/it]
62%|██████▎ | 5/8 [08:21<05:00, 100.30s/it]
75%|███████▌ | 6/8 [10:01<03:20, 100.28s/it]
88%|████████▊ | 7/8 [11:42<01:40, 100.34s/it]
100%|██████████| 8/8 [13:22<00:00, 100.36s/it]
100%|██████████| 8/8 [13:22<00:00, 100.34s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case3/unitree_z1_dual_arm_cleanup_pencils_case3.mp4",
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case3/output/inference/unitree_z1_dual_arm_cleanup_pencils_case3_amd.mp4",
"psnr": 21.20205061239349
}

View File

@@ -0,0 +1,137 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:22:15.333099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:22:15.336215: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:22:15.366489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:22:15.366522: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:22:15.368294: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:22:15.376202: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:22:15.376444: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:22:15.995383: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:37<11:23, 97.68s/it]
25%|██▌ | 2/8 [03:15<09:47, 97.83s/it]
38%|███▊ | 3/8 [04:53<08:09, 97.91s/it]
50%|█████ | 4/8 [06:31<06:32, 98.03s/it]
62%|██████▎ | 5/8 [08:10<04:54, 98.11s/it]
75%|███████▌ | 6/8 [09:48<03:16, 98.18s/it]
88%|████████▊ | 7/8 [11:26<01:38, 98.24s/it]
100%|██████████| 8/8 [13:04<00:00, 98.16s/it]
100%|██████████| 8/8 [13:04<00:00, 98.09s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case4/unitree_z1_dual_arm_cleanup_pencils_case4.mp4",
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case4/output/inference/unitree_z1_dual_arm_cleanup_pencils_case4_amd.mp4",
"psnr": 21.130122583788612
}

View File

@@ -0,0 +1,134 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:24:40.357099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:24:40.360365: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:24:40.391744: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:24:40.391772: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:24:40.393608: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:24:40.401837: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:24:40.402077: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:24:41.022382: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:41<10:09, 101.63s/it]
29%|██▊ | 2/7 [03:20<08:18, 99.78s/it]
43%|████▎ | 3/7 [04:58<06:36, 99.24s/it]
57%|█████▋ | 4/7 [06:37<04:57, 99.05s/it]
71%|███████▏ | 5/7 [08:16<03:17, 98.90s/it]
86%|████████▌ | 6/7 [09:54<01:38, 98.80s/it]
100%|██████████| 7/7 [11:33<00:00, 98.70s/it]
100%|██████████| 7/7 [11:33<00:00, 99.03s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox/case1/unitree_z1_dual_arm_stackbox_case1.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox/case1/output/inference/unitree_z1_dual_arm_stackbox_case1_amd.mp4",
"psnr": 21.258130518117493
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case1"
dataset="unitree_z1_dual_arm_stackbox"
{
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \

View File

@@ -0,0 +1,134 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:25:18.653033: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:25:18.656060: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:25:18.687077: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:25:18.687119: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:25:18.688915: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:25:18.697008: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:25:18.697255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:25:19.338303: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:39<09:56, 99.35s/it]
29%|██▊ | 2/7 [03:18<08:17, 99.50s/it]
43%|████▎ | 3/7 [04:58<06:38, 99.54s/it]
57%|█████▋ | 4/7 [06:38<04:58, 99.52s/it]
71%|███████▏ | 5/7 [08:17<03:19, 99.55s/it]
86%|████████▌ | 6/7 [09:57<01:39, 99.53s/it]
100%|██████████| 7/7 [11:36<00:00, 99.50s/it]
100%|██████████| 7/7 [11:36<00:00, 99.51s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox/case2/unitree_z1_dual_arm_stackbox_case2.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox/case2/output/inference/unitree_z1_dual_arm_stackbox_case2_amd.mp4",
"psnr": 23.878153424077645
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case2"
dataset="unitree_z1_dual_arm_stackbox"
{
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \

View File

@@ -0,0 +1,134 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:35:33.682231: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:35:33.685275: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:35:33.716682: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:35:33.716728: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:35:33.718523: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:35:33.726756: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:35:33.727105: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:35:34.356722: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:41<10:06, 101.02s/it]
29%|██▊ | 2/7 [03:23<08:29, 101.84s/it]
43%|████▎ | 3/7 [05:04<06:45, 101.43s/it]
57%|█████▋ | 4/7 [06:45<05:04, 101.42s/it]
71%|███████▏ | 5/7 [08:27<03:22, 101.40s/it]
86%|████████▌ | 6/7 [10:08<01:41, 101.39s/it]
100%|██████████| 7/7 [11:49<00:00, 101.33s/it]
100%|██████████| 7/7 [11:49<00:00, 101.39s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox/case3/unitree_z1_dual_arm_stackbox_case3.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox/case3/output/inference/unitree_z1_dual_arm_stackbox_case3_amd.mp4",
"psnr": 25.400458754751128
}

View File

@@ -0,0 +1,134 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:38<09:52, 98.73s/it]
29%|██▊ | 2/7 [03:17<08:14, 98.85s/it]
43%|████▎ | 3/7 [04:56<06:35, 98.80s/it]
57%|█████▋ | 4/7 [06:35<04:56, 98.94s/it]
71%|███████▏ | 5/7 [08:14<03:17, 98.93s/it]
86%|████████▌ | 6/7 [09:53<01:38, 98.89s/it]
100%|██████████| 7/7 [11:31<00:00, 98.81s/it]
100%|██████████| 7/7 [11:31<00:00, 98.85s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox/case4/unitree_z1_dual_arm_stackbox_case4.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox/case4/output/inference/unitree_z1_dual_arm_stackbox_case4_amd.mp4",
"psnr": 24.098958457373858
}

View File

@@ -0,0 +1,146 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:51:23.961486: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:51:24.200063: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:51:24.522299: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:51:24.522350: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:51:24.528237: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:51:24.579400: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:51:24.579644: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:51:25.781311: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:38<16:20, 98.04s/it]
18%|█▊ | 2/11 [03:15<14:40, 97.81s/it]
27%|██▋ | 3/11 [04:53<13:01, 97.72s/it]
36%|███▋ | 4/11 [06:31<11:24, 97.71s/it]
45%|████▌ | 5/11 [08:08<09:46, 97.71s/it]
55%|█████▍ | 6/11 [09:46<08:08, 97.65s/it]
64%|██████▎ | 7/11 [11:23<06:30, 97.65s/it]
73%|███████▎ | 8/11 [13:02<04:54, 98.09s/it]
82%|████████▏ | 9/11 [14:40<03:15, 97.83s/it]
91%|█████████ | 10/11 [16:17<01:37, 97.73s/it]
100%|██████████| 11/11 [17:55<00:00, 97.64s/it]
100%|██████████| 11/11 [17:55<00:00, 97.74s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/unitree_z1_dual_arm_stackbox_v2_case1_amd.mp4",
"psnr": 18.126776535969576
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case1"
dataset="unitree_z1_dual_arm_stackbox_v2"
{
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \

View File

@@ -0,0 +1,146 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:56:31.144789: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:56:31.148256: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:56:31.178870: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:56:31.178898: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:56:31.180683: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:56:31.188800: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:56:31.189142: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:56:31.810098: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:40<16:41, 100.16s/it]
18%|█▊ | 2/11 [03:20<15:04, 100.47s/it]
27%|██▋ | 3/11 [05:01<13:24, 100.62s/it]
36%|███▋ | 4/11 [06:42<11:44, 100.69s/it]
45%|████▌ | 5/11 [08:22<10:02, 100.48s/it]
55%|█████▍ | 6/11 [10:02<08:21, 100.33s/it]
64%|██████▎ | 7/11 [11:42<06:40, 100.23s/it]
73%|███████▎ | 8/11 [13:22<05:00, 100.23s/it]
82%|████████▏ | 9/11 [15:03<03:20, 100.23s/it]
91%|█████████ | 10/11 [16:43<01:40, 100.33s/it]
100%|██████████| 11/11 [18:24<00:00, 100.41s/it]
100%|██████████| 11/11 [18:24<00:00, 100.39s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case2/unitree_z1_dual_arm_stackbox_v2_case2.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case2/output/inference/unitree_z1_dual_arm_stackbox_v2_case2_amd.mp4",
"psnr": 19.38130614773096
}

View File

@@ -0,0 +1,146 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:56:04.467082: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:56:04.470145: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:56:04.502248: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:56:04.502277: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:56:04.504088: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:56:04.512557: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:56:04.512830: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:56:05.259641: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:38<16:20, 98.03s/it]
18%|█▊ | 2/11 [03:16<14:43, 98.19s/it]
27%|██▋ | 3/11 [04:55<13:08, 98.54s/it]
36%|███▋ | 4/11 [06:33<11:29, 98.52s/it]
45%|████▌ | 5/11 [08:11<09:50, 98.38s/it]
55%|█████▍ | 6/11 [09:49<08:10, 98.11s/it]
64%|██████▎ | 7/11 [11:27<06:31, 97.97s/it]
73%|███████▎ | 8/11 [13:04<04:53, 97.83s/it]
82%|████████▏ | 9/11 [14:42<03:15, 97.72s/it]
91%|█████████ | 10/11 [16:19<01:37, 97.71s/it]
100%|██████████| 11/11 [17:57<00:00, 97.74s/it]
100%|██████████| 11/11 [17:57<00:00, 97.97s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case3/unitree_z1_dual_arm_stackbox_v2_case3.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case3/output/inference/unitree_z1_dual_arm_stackbox_v2_case3_amd.mp4",
"psnr": 18.74462122425683
}

View File

@@ -0,0 +1,146 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:04:16.104516: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:04:16.109112: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:04:16.138703: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:04:16.138737: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:04:16.140302: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:04:16.147672: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:04:16.147903: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:04:17.363218: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:39<16:32, 99.26s/it]
18%|█▊ | 2/11 [03:17<14:49, 98.81s/it]
27%|██▋ | 3/11 [04:56<13:10, 98.76s/it]
36%|███▋ | 4/11 [06:35<11:31, 98.80s/it]
45%|████▌ | 5/11 [08:14<09:53, 98.85s/it]
55%|█████▍ | 6/11 [09:53<08:14, 98.87s/it]
64%|██████▎ | 7/11 [11:31<06:34, 98.68s/it]
73%|███████▎ | 8/11 [13:09<04:55, 98.49s/it]
82%|████████▏ | 9/11 [14:47<03:16, 98.38s/it]
91%|█████████ | 10/11 [16:25<01:38, 98.29s/it]
100%|██████████| 11/11 [18:03<00:00, 98.26s/it]
100%|██████████| 11/11 [18:03<00:00, 98.54s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case4/unitree_z1_dual_arm_stackbox_v2_case4.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case4/output/inference/unitree_z1_dual_arm_stackbox_v2_case4_amd.mp4",
"psnr": 19.526448380726254
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case4"
dataset="unitree_z1_dual_arm_stackbox_v2"
{
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \

View File

@@ -0,0 +1,149 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:12:47.424053: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:12:47.427280: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:12:47.458253: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:12:47.458288: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:12:47.462758: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:12:47.518283: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:12:47.518566: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:12:48.593011: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:38<18:08, 98.94s/it]
17%|█▋ | 2/12 [03:18<16:30, 99.01s/it]
25%|██▌ | 3/12 [04:57<14:51, 99.07s/it]
33%|███▎ | 4/12 [06:36<13:12, 99.04s/it]
42%|████▏ | 5/12 [08:15<11:33, 99.00s/it]
50%|█████ | 6/12 [09:54<09:54, 99.10s/it]
58%|█████▊ | 7/12 [11:33<08:14, 99.00s/it]
67%|██████▋ | 8/12 [13:13<06:38, 99.58s/it]
75%|███████▌ | 9/12 [14:54<04:59, 99.88s/it]
83%|████████▎ | 10/12 [16:33<03:19, 99.58s/it]
92%|█████████▏| 11/12 [18:12<01:39, 99.39s/it]
100%|██████████| 12/12 [19:51<00:00, 99.25s/it]
100%|██████████| 12/12 [19:51<00:00, 99.28s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_stackbox/case1/unitree_z1_stackbox_case1.mp4",
"pred_video": "unitree_z1_stackbox/case1/output/inference/unitree_z1_stackbox_case1_amd.mp4",
"psnr": 19.81391789862606
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case1"
dataset="unitree_z1_stackbox"
{
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
time CUDA_VISIBLE_DEVICES=5 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \

View File

@@ -0,0 +1,149 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:37<17:51, 97.37s/it]
17%|█▋ | 2/12 [03:14<16:13, 97.31s/it]
25%|██▌ | 3/12 [04:51<14:35, 97.26s/it]
33%|███▎ | 4/12 [06:29<12:58, 97.25s/it]
42%|████▏ | 5/12 [08:06<11:20, 97.24s/it]
50%|█████ | 6/12 [09:43<09:43, 97.24s/it]
58%|█████▊ | 7/12 [11:20<08:06, 97.27s/it]
67%|██████▋ | 8/12 [12:58<06:29, 97.36s/it]
75%|███████▌ | 9/12 [14:36<04:52, 97.49s/it]
83%|████████▎ | 10/12 [16:13<03:15, 97.52s/it]
92%|█████████▏| 11/12 [17:51<01:37, 97.47s/it]
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_stackbox/case2/unitree_z1_stackbox_case2.mp4",
"pred_video": "unitree_z1_stackbox/case2/output/inference/unitree_z1_stackbox_case2_amd.mp4",
"psnr": 21.083821459054743
}

View File

@@ -0,0 +1,149 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:16:22.299521: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:16:22.302545: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:16:22.335354: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:16:22.335389: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:16:22.337179: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:16:22.345296: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:16:22.345548: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:16:23.008743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:39<18:16, 99.64s/it]
17%|█▋ | 2/12 [03:19<16:35, 99.56s/it]
25%|██▌ | 3/12 [04:58<14:55, 99.53s/it]
33%|███▎ | 4/12 [06:38<13:16, 99.53s/it]
42%|████▏ | 5/12 [08:17<11:36, 99.54s/it]
50%|█████ | 6/12 [09:57<09:57, 99.57s/it]
58%|█████▊ | 7/12 [11:37<08:18, 99.66s/it]
67%|██████▋ | 8/12 [13:17<06:39, 99.83s/it]
75%|███████▌ | 9/12 [14:57<04:59, 99.93s/it]
83%|████████▎ | 10/12 [16:37<03:19, 99.97s/it]
92%|█████████▏| 11/12 [18:17<01:39, 99.85s/it]
100%|██████████| 12/12 [19:56<00:00, 99.71s/it]
100%|██████████| 12/12 [19:56<00:00, 99.71s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_stackbox/case3/unitree_z1_stackbox_case3.mp4",
"pred_video": "unitree_z1_stackbox/case3/output/inference/unitree_z1_stackbox_case3_amd.mp4",
"psnr": 21.322784880212172
}

View File

@@ -0,0 +1,149 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:25:54.657305: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:25:54.660628: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:25:54.691237: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:25:54.691275: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:25:54.693046: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:25:54.701142: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:25:54.701413: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:25:55.801367: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:37<17:51, 97.38s/it]
17%|█▋ | 2/12 [03:14<16:12, 97.24s/it]
25%|██▌ | 3/12 [04:51<14:35, 97.28s/it]
33%|███▎ | 4/12 [06:29<12:59, 97.40s/it]
42%|████▏ | 5/12 [08:06<11:21, 97.30s/it]
50%|█████ | 6/12 [09:43<09:43, 97.17s/it]
58%|█████▊ | 7/12 [11:20<08:05, 97.07s/it]
67%|██████▋ | 8/12 [12:57<06:28, 97.02s/it]
75%|███████▌ | 9/12 [14:34<04:50, 96.98s/it]
83%|████████▎ | 10/12 [16:11<03:14, 97.00s/it]
92%|█████████▏| 11/12 [17:48<01:37, 97.06s/it]
100%|██████████| 12/12 [19:25<00:00, 97.13s/it]
100%|██████████| 12/12 [19:25<00:00, 97.14s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_stackbox/case4/unitree_z1_stackbox_case4.mp4",
"pred_video": "unitree_z1_stackbox/case4/output/inference/unitree_z1_stackbox_case4_amd.mp4",
"psnr": 25.32928948331741
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case4"
dataset="unitree_z1_stackbox"
{
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \