14 Commits

Author SHA1 Message Date
qhy
3069666a15 脚本修改 2026-02-10 14:49:26 +08:00
qhy
68369cc15f 合并后测试 2026-02-10 14:45:14 +08:00
b0ebb7006e 添加三层迭代级性能分析工具 profile_iteration.py
Layer1: CUDA Events 精确测量每个itr内10个阶段耗时
Layer2: torch.profiler GPU timeline trace
Layer3: CSV输出支持A/B对比

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-02-10 05:42:11 +00:00
125b85ce68 实现fs_embed 缓存,收益不明显,精度不降低 2026-02-09 18:49:44 +00:00
0b3b0e534a 复用 DDIMSampler + make_schedule微弱提升 2026-02-09 18:26:39 +00:00
6dca3696d8 实现了Context 预计算和缓存功能,提升了采样效率。 psnr不下降 2026-02-09 17:42:47 +00:00
f192c8aca9 添加CrossAttention kv缓存,减少重复计算,提升性能,psnr=31.8022 dB 2026-02-09 17:04:23 +00:00
4288c9d8c9 减少了一路视频vae解码 2026-02-09 16:48:16 +00:00
a2cd34dd51 1. einsum('b i d, b j d -> b i j') → torch.bmm(q, k.transpose(-1,-2)) — 直接映射 rocBLAS batched GEMM
2. baddbmm 把 scale 融合进 GEMM,少一次 kernel launch
3. 第二个 einsum 同理换torch.bm
每一轮加速1到两秒
2026-02-08 18:54:48 +00:00
7338cc384a ddim.py — torch.float16 → torch.bfloat16,修复 dtype 不匹配
attention.py — 4 处 softmax 都包裹了 torch.amp.autocast('cuda', enabled=False),阻止 autocast 将 bf16 提升到 fp32
2026-02-08 17:02:05 +00:00
f86ab51a04 全链路 bf16 混合精度修正与 UNet FLOPS profiling
- GroupNorm/LayerNorm bypass autocast,消除 bf16→fp32→bf16 转换开销
  - DDIM 调度系数 cast 到输入 dtype,attention mask 直接用 bf16 分配
  - alphas_cumprod 提升到 float64 保证数值精度
  - SinusoidalPosEmb 输出 dtype跟随模型精度
  - 新增 profile_unet.py 脚本及FLOPS 分析结果
  - 启用 TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL
  - case1 PSNR: 30.45 → 30.24(bf16 精度预期内波动)
2026-02-08 16:01:30 +00:00
75c798ded0 DDIM loop 内小张量分配优化,attention mask 缓存到 GPU 2026-02-08 14:20:48 +00:00
e588182642 修复混合精度vae相关的配置错误,确保在推理阶段正确使用了混合精度模型,并且导出了正确精度的检查点文件。 2026-02-08 12:35:59 +00:00
e6c55a648c 所有case的baseline,amd版本的ground truth都上传了 2026-02-08 09:42:14 +00:00
73 changed files with 5899 additions and 384 deletions

View File

@@ -1,15 +0,0 @@
{
"permissions": {
"allow": [
"Bash(conda env list:*)",
"Bash(mamba env:*)",
"Bash(micromamba env list:*)",
"Bash(echo:*)",
"Bash(git show:*)",
"Bash(nvidia-smi:*)",
"Bash(conda activate unifolm-wma)",
"Bash(conda info:*)",
"Bash(direnv allow:*)"
]
}
}

2
.envrc
View File

@@ -1,2 +0,0 @@
eval "$(conda shell.bash hook 2>/dev/null)"
conda activate unifolm-wma

4
.gitignore vendored
View File

@@ -55,6 +55,7 @@ coverage.xml
*.pot *.pot
# Django stuff: # Django stuff:
local_settings.py local_settings.py
db.sqlite3 db.sqlite3
@@ -120,7 +121,6 @@ localTest/
fig/ fig/
figure/ figure/
*.mp4 *.mp4
Data/ControlVAE.yml Data/ControlVAE.yml
Data/Misc Data/Misc
Data/Pretrained Data/Pretrained
@@ -129,6 +129,4 @@ Experiment/checkpoint
Experiment/log Experiment/log
*.ckpt *.ckpt
*.0 *.0
ckpts/unifolm_wma_dual.ckpt.prepared.pt

135
case4_run.log Normal file
View File

@@ -0,0 +1,135 @@
nohup: ignoring input
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:38<09:52, 98.73s/it]
29%|██▊ | 2/7 [03:17<08:14, 98.85s/it]
43%|████▎ | 3/7 [04:56<06:35, 98.80s/it]
57%|█████▋ | 4/7 [06:35<04:56, 98.94s/it]
71%|███████▏ | 5/7 [08:14<03:17, 98.93s/it]
86%|████████▌ | 6/7 [09:53<01:38, 98.89s/it]
100%|██████████| 7/7 [11:31<00:00, 98.81s/it]
100%|██████████| 7/7 [11:31<00:00, 98.85s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

1
ckpts/configuration.json Normal file
View File

@@ -0,0 +1 @@
{"framework": "pytorch", "task": "robotics", "allow_remote": true}

21
env.sh Normal file
View File

@@ -0,0 +1,21 @@
# Note: This script should be sourced, not executed
# Usage: source env.sh
#
# If you need render group permissions, run this first:
# newgrp render
# Then source this script:
# source env.sh
# Initialize conda
source /mnt/ASC1637/miniconda3/etc/profile.d/conda.sh
# Activate conda environment
conda activate unifolm-wma-o
# Set HuggingFace cache directories
export HF_HOME=/mnt/ASC1637/hf_home
export HUGGINGFACE_HUB_CACHE=/mnt/ASC1637/hf_home/hub
echo "Environment configured successfully"
echo "Conda environment: unifolm-wma-o"
echo "HF_HOME: $HF_HOME"

217
profile_unet_flops.md Normal file
View File

@@ -0,0 +1,217 @@
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_unet.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml
==================================================================================================================================
FLOPS BY ATen OPERATOR (FlopCounterMode)
==================================================================================================================================
ATen Op | GFLOPS | % of Total
-------------------------------------------------------
convolution | 6185.17 | 46.4%
addmm | 4411.17 | 33.1%
mm | 1798.34 | 13.5%
bmm | 949.54 | 7.1%
==================================================================================================================================
FLOPS BY MODULE (FlopCounterMode)
==================================================================================================================================
Module | GFLOPS | % of Total
------------------------------------------------------------------------------------------
Global | 13344.23 | 100.0%
DiffusionWrapper | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
==================================================================================================================================
SUMMARY
==================================================================================================================================
Total CUDA time: 761.4 ms
Matmul CUDA time: 404.2 ms (53.1%)
Non-matmul CUDA time: 357.1 ms (46.9%)
Total FLOPS (FlopCounter): 13344.23 GFLOPS
Matmul throughput: 33.01 TFLOPS/s (54.1% of BF16 peak)
Overall throughput: 17.53 TFLOPS/s (28.7% of BF16 peak)
GPU peak (BF16): 61.0 TFLOPS
==================================================================================================================================
FLOPS BY ATen OPERATOR (FlopCounterMode)
==================================================================================================================================
ATen Op | GFLOPS | % of Total
-------------------------------------------------------
convolution | 6185.17 | 46.4%
addmm | 4411.17 | 33.1%
mm | 1798.34 | 13.5%
bmm | 949.54 | 7.1%
==================================================================================================================================
FLOPS BY MODULE (FlopCounterMode)
==================================================================================================================================
Module | GFLOPS | % of Total
------------------------------------------------------------------------------------------
DiffusionWrapper | 13344.23 | 100.0%
Global | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
==================================================================================================================================
SUMMARY
==================================================================================================================================
Total CUDA time: 707.1 ms
Matmul CUDA time: 403.1 ms (57.0%)
Non-matmul CUDA time: 304.0 ms (43.0%)
Total FLOPS (FlopCounter): 13344.23 GFLOPS
Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak)
Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak)
GPU peak (BF16): 61.0 TFLOPS
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
========================================================================
TABLE 1: STAGE TIMING
========================================================================
Stage Mean(ms) Std %
------------------------------------------------------------------------
1_Image_Embedding 29.5 0.16 0.1%
2_VAE_Encode 51.3 0.06 0.1%
3_Text_Conditioning 14.7 0.18 0.0%
4_Projectors 0.2 0.03 0.0%
5_DDIM_Loop 33392.5 3.21 97.3%
6_VAE_Decode 808.4 1.00 2.4%
7_Post_Process 15.8 0.56 0.0%
------------------------------------------------------------------------
TOTAL 34312.4
================================================================================
TABLE 2: UNET SUB-MODULE BREAKDOWN
================================================================================
Module Type Total(ms) Count Per-call %
--------------------------------------------------------------------------------
ResBlock 10256.3 1100 9.32 23.2%
SpatialTransformer 9228.2 800 11.54 20.9%
CrossAttention 8105.8 3300 2.46 18.3%
ConditionalUnet1D 6409.5 100 64.10 14.5%
TemporalTransformer 5847.0 850 6.88 13.2%
FeedForward 4338.1 1650 2.63 9.8%
UNet.out 73.8 50 1.48 0.2%
--------------------------------------------------------------------------------
TOTAL (hooked) 44258.7
==========================================================================================
TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)
==========================================================================================
Block Total(ms) % Breakdown
------------------------------------------------------------------------------------------
input_blocks.1 3376.2 7.6% SpatialTransformer=1101, CrossAttention=990, ResBlock=543, TemporalTransformer=454, FeedForward=288
input_blocks.2 3374.0 7.6% SpatialTransformer=1100, CrossAttention=991, ResBlock=540, TemporalTransformer=455, FeedForward=288
input_blocks.4 1592.4 3.6% SpatialTransformer=394, ResBlock=374, CrossAttention=303, TemporalTransformer=272, FeedForward=249
input_blocks.5 1642.5 3.7% ResBlock=425, SpatialTransformer=397, CrossAttention=303, TemporalTransformer=271, FeedForward=247
input_blocks.7 1469.0 3.3% ResBlock=416, SpatialTransformer=324, FeedForward=251, CrossAttention=240, TemporalTransformer=237
input_blocks.8 1543.7 3.5% ResBlock=491, SpatialTransformer=325, FeedForward=250, CrossAttention=240, TemporalTransformer=238
input_blocks.10 217.5 0.5% ResBlock=218
input_blocks.11 216.8 0.5% ResBlock=217
middle_block 848.9 1.9% ResBlock=434, SpatialTransformer=151, CrossAttention=134, TemporalTransformer=69, FeedForward=61
output_blocks.0 303.2 0.7% ResBlock=303
output_blocks.1 303.1 0.7% ResBlock=303
output_blocks.2 302.8 0.7% ResBlock=303
output_blocks.3 1734.8 3.9% ResBlock=687, SpatialTransformer=322, FeedForward=249, CrossAttention=239, TemporalTransformer=237
output_blocks.4 1739.8 3.9% ResBlock=688, SpatialTransformer=323, FeedForward=251, CrossAttention=239, TemporalTransformer=238
output_blocks.5 1622.3 3.7% ResBlock=570, SpatialTransformer=324, FeedForward=251, CrossAttention=239, TemporalTransformer=238
output_blocks.6 1881.0 4.3% ResBlock=664, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=250
output_blocks.7 1768.0 4.0% ResBlock=554, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
output_blocks.8 1688.7 3.8% ResBlock=474, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
output_blocks.9 3558.6 8.0% SpatialTransformer=1096, CrossAttention=992, ResBlock=727, TemporalTransformer=454, FeedForward=290
output_blocks.10 3492.8 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
output_blocks.11 3493.3 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
out 73.8 0.2% UNet.out=74
action_unet 3212.0 7.3% ConditionalUnet1D=3212
state_unet 3197.6 7.2% ConditionalUnet1D=3198
other 1606.2 3.6% TemporalTransformer=960, FeedForward=337, CrossAttention=309
------------------------------------------------------------------------------------------
TOTAL 44258.7
======================================================================
TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)
======================================================================
Component Total(ms) %
----------------------------------------------------------------------
CrossAttention 8105.8 65.1%
FeedForward 4338.1 34.9%
----------------------------------------------------------------------
TOTAL (attn+ff) 12443.9
==================================================
TABLE 3: MEMORY SUMMARY
==================================================
Initial allocated: 11.82 GB
Peak allocated: 14.43 GB
Delta (pipeline): 2.61 GB
============================================================
TABLE 4: THROUGHPUT
============================================================
Total pipeline latency: 34312.4 ms
DDIM loop latency: 33392.5 ms
DDIM steps: 50
CFG scale: 1.0 (1x UNet/step)
UNet forward calls: 50
Per DDIM step: 667.9 ms
Per UNet forward: 667.9 ms
VAE encode bandwidth: 0.1 GB/s (peak HBM: 864.0 GB/s)
VAE decode bandwidth: 0.0 GB/s (peak HBM: 864.0 GB/s)
GPU BF16 peak: 61.0 TFLOPS
Done.

150
run.log Normal file
View File

@@ -0,0 +1,150 @@
nohup: ignoring input
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:37<17:51, 97.37s/it]
17%|█▋ | 2/12 [03:14<16:13, 97.31s/it]
25%|██▌ | 3/12 [04:51<14:35, 97.26s/it]
33%|███▎ | 4/12 [06:29<12:58, 97.25s/it]
42%|████▏ | 5/12 [08:06<11:20, 97.24s/it]
50%|█████ | 6/12 [09:43<09:43, 97.24s/it]
58%|█████▊ | 7/12 [11:20<08:06, 97.27s/it]
67%|██████▋ | 8/12 [12:58<06:29, 97.36s/it]
75%|███████▌ | 9/12 [14:36<04:52, 97.49s/it]
83%|████████▎ | 10/12 [16:13<03:15, 97.52s/it]
92%|█████████▏| 11/12 [17:51<01:37, 97.47s/it]
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -16,9 +16,6 @@ from collections import OrderedDict
from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.utils.utils import instantiate_from_config
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]: def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
""" """

View File

@@ -0,0 +1,975 @@
"""
Profile the full iteration loop of world model interaction.
Three layers of profiling:
Layer 1: Iteration-level wall-clock breakdown (CUDA events)
Layer 2: GPU timeline trace (torch.profiler → Chrome trace)
Layer 3: A/B comparison (standardized CSV output)
Usage:
# Layer 1 only (fast, default):
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
python scripts/evaluation/profile_iteration.py \
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
--config configs/inference/world_model_interaction.yaml \
--prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \
--dataset unitree_z1_dual_arm_cleanup_pencils \
--frame_stride 4 --n_iter 5
# Layer 1 + Layer 2 (GPU trace):
... --trace --trace_dir ./profile_traces
# Layer 3 (A/B comparison): run twice, diff the CSVs
... --csv baseline.csv
... --csv optimized.csv
python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv
"""
import argparse
import csv
import os
import sys
import time
from collections import defaultdict, deque
from contextlib import nullcontext
import h5py
import numpy as np
import pandas as pd
import torch
import torchvision
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from torch import Tensor
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
# ──────────────────────────────────────────────────────────────────────
# Constants
# ──────────────────────────────────────────────────────────────────────
STAGE_NAMES = [
"stack_to_device_1",
"synth_policy",
"update_action_queue",
"stack_to_device_2",
"synth_world_model",
"update_obs_queue",
"tensorboard_log",
"save_results",
"cpu_transfer",
"itr_total",
]
# Sub-stages inside image_guided_synthesis_sim_mode
SYNTH_SUB_STAGES = [
"ddim_sampler_init",
"image_embedding",
"vae_encode",
"text_conditioning",
"projectors",
"cond_assembly",
"ddim_sampling",
"vae_decode",
]
# ──────────────────────────────────────────────────────────────────────
# CudaTimer — GPU-precise timing via CUDA events
# ──────────────────────────────────────────────────────────────────────
class CudaTimer:
"""Context manager that records GPU time between enter/exit using CUDA events."""
def __init__(self, name, records):
self.name = name
self.records = records
def __enter__(self):
torch.cuda.synchronize()
self._start = torch.cuda.Event(enable_timing=True)
self._end = torch.cuda.Event(enable_timing=True)
self._start.record()
return self
def __exit__(self, *args):
self._end.record()
torch.cuda.synchronize()
elapsed_ms = self._start.elapsed_time(self._end)
self.records[self.name].append(elapsed_ms)
class WallTimer:
"""Context manager that records CPU wall-clock time (for pure-CPU stages)."""
def __init__(self, name, records):
self.name = name
self.records = records
def __enter__(self):
torch.cuda.synchronize()
self._t0 = time.perf_counter()
return self
def __exit__(self, *args):
torch.cuda.synchronize()
elapsed_ms = (time.perf_counter() - self._t0) * 1000.0
self.records[self.name].append(elapsed_ms)
# ──────────────────────────────────────────────────────────────────────
# Model loading (reused from world_model_interaction.py)
# ──────────────────────────────────────────────────────────────────────
def patch_norm_bypass_autocast():
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
from collections import OrderedDict
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=True)
except Exception:
new_sd = OrderedDict()
for k, v in state_dict.items():
new_sd[k] = v
for k in list(new_sd.keys()):
if "framestride_embed" in k:
new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k)
model.load_state_dict(new_sd, strict=True)
model.eval()
# Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
model.embedder.to(torch.bfloat16)
model.image_proj_model.to(torch.bfloat16)
model.encoder_autocast_dtype = None
model.state_projector.to(torch.bfloat16)
model.action_projector.to(torch.bfloat16)
model.projector_autocast_dtype = None
if args.vae_dtype == "bf16":
model.first_stage_model.to(torch.bfloat16)
# Compile hot ResBlocks
apply_torch_compile(model)
model = model.cuda()
print(">>> Model loaded and ready.")
return model, config
# ──────────────────────────────────────────────────────────────────────
# Data preparation (reused from world_model_interaction.py)
# ──────────────────────────────────────────────────────────────────────
def get_init_frame_path(data_dir, sample):
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png')
return os.path.join(data_dir, 'images', rel)
def get_transition_path(data_dir, sample):
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5')
return os.path.join(data_dir, 'transitions', rel)
def prepare_init_input(start_idx, init_frame_path, transition_dict,
frame_stride, wma_data, video_length=16, n_obs_steps=2):
indices = [start_idx + frame_stride * i for i in range(video_length)]
init_frame = Image.open(init_frame_path).convert('RGB')
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float()
if start_idx < n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = n_obs_steps - 1 - start_idx
padding = states[0:1, :].repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
actions = transition_dict['action'][indices, :]
ori_state_dim = states.shape[-1]
ori_action_dim = actions.shape[-1]
frames_action_state_dict = {
'action': actions,
'observation.state': states,
}
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
frames_action_state_dict = wma_data.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
if wma_data.spatial_transform is not None:
init_frame = wma_data.spatial_transform(init_frame)
init_frame = (init_frame / 255 - 0.5) * 2
data = {'observation.image': init_frame}
data.update(frames_action_state_dict)
return data, ori_state_dim, ori_action_dim
def populate_queues(queues, batch):
for key in batch:
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen:
while len(queues[key]) != queues[key].maxlen:
queues[key].append(batch[key])
else:
queues[key].append(batch[key])
return queues
# ──────────────────────────────────────────────────────────────────────
# Instrumented image_guided_synthesis_sim_mode with sub-stage timing
# ──────────────────────────────────────────────────────────────────────
def get_latent_z(model, videos):
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
vae_dtype = next(model.first_stage_model.parameters()).dtype
x = x.to(dtype=vae_dtype)
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def save_results(video, filename, fps=8):
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename, grid, fps=fps,
video_codec='h264', options={'crf': '10'})
def profiled_synthesis(model, prompts, observation, noise_shape,
ddim_steps, ddim_eta, unconditional_guidance_scale,
fs, text_input, timestep_spacing, guidance_rescale,
sim_mode, decode_video, records, prefix):
"""image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing.
Args:
prefix: "policy" or "wm" — prepended to sub-stage names in records.
"""
b, _, t, _, _ = noise_shape
batch_size = noise_shape[0]
device = next(model.parameters()).device
# --- sub-stage: ddim_sampler_init ---
with CudaTimer(f"{prefix}/ddim_sampler_init", records):
ddim_sampler = DDIMSampler(model)
fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device)
# --- sub-stage: image_embedding ---
with CudaTimer(f"{prefix}/image_embedding", records):
model_dtype = next(model.embedder.parameters()).dtype
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
# --- sub-stage: vae_encode ---
with CudaTimer(f"{prefix}/vae_encode", records):
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
# --- sub-stage: text_conditioning ---
with CudaTimer(f"{prefix}/text_conditioning", records):
if not text_input:
prompts_use = [""] * batch_size
else:
prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts_use)
# --- sub-stage: projectors ---
with CudaTimer(f"{prefix}/projectors", records):
projector_dtype = next(model.state_projector.parameters()).dtype
cond_state_emb = model.state_projector(
observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(
observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb)
# --- sub-stage: cond_assembly ---
with CudaTimer(f"{prefix}/cond_assembly", records):
cond["c_crossattn"] = [
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :, -model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:],
sim_mode,
False,
]
# --- sub-stage: ddim_sampling ---
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
if autocast_dtype is not None and device.type == 'cuda':
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
else:
autocast_ctx = nullcontext()
with CudaTimer(f"{prefix}/ddim_sampling", records):
with autocast_ctx:
samples, actions, states, _ = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=None,
eta=ddim_eta,
cfg_img=None,
mask=None,
x0=None,
fs=fs_t,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
unconditional_conditioning_img_nonetext=None,
)
# --- sub-stage: vae_decode ---
batch_variants = None
if decode_video:
with CudaTimer(f"{prefix}/vae_decode", records):
batch_variants = model.decode_first_stage(samples)
else:
records[f"{prefix}/vae_decode"].append(0.0)
return batch_variants, actions, states
# ──────────────────────────────────────────────────────────────────────
# Instrumented iteration loop
# ──────────────────────────────────────────────────────────────────────
def run_profiled_iterations(model, args, config, noise_shape, device):
"""Run the full iteration loop with per-stage timing.
Returns:
all_records: list of dicts, one per itr, {stage_name: ms}
"""
# Load data
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
sample = df.iloc[0]
data_module = instantiate_from_config(config.data)
data_module.setup()
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
fs = args.frame_stride
model_input_fs = ori_fps // fs
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# Prepare initial observation
batch, ori_state_dim, ori_action_dim = prepare_init_input(
0, init_frame_path, transition_dict, fs,
data_module.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
cond_obs_queues = {
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
cond_obs_queues = populate_queues(cond_obs_queues, observation)
# Temp dir for save_results profiling
tmp_dir = os.path.join(args.savedir, "profile_tmp")
os.makedirs(tmp_dir, exist_ok=True)
prompt_text = sample['instruction']
all_records = []
print(f">>> Running {args.n_iter} profiled iterations ...")
for itr in range(args.n_iter):
rec = defaultdict(list)
# ── itr_total start ──
torch.cuda.synchronize()
itr_start = torch.cuda.Event(enable_timing=True)
itr_end = torch.cuda.Event(enable_timing=True)
itr_start.record()
# ① stack_to_device_1
with CudaTimer("stack_to_device_1", rec):
observation = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
# ② synth_policy
with CudaTimer("synth_policy", rec):
pred_videos_0, pred_actions, _ = profiled_synthesis(
model, prompt_text, observation, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=True,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
decode_video=not args.fast_policy_no_decode,
records=rec, prefix="policy")
# ③ update_action_queue
with WallTimer("update_action_queue", rec):
for idx in range(len(pred_actions[0])):
obs_a = {'action': pred_actions[0][idx:idx + 1]}
obs_a['action'][:, ori_action_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues, obs_a)
# ④ stack_to_device_2
with CudaTimer("stack_to_device_2", rec):
observation = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
# ⑤ synth_world_model
with CudaTimer("synth_world_model", rec):
pred_videos_1, _, pred_states = profiled_synthesis(
model, "", observation, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=True, decode_video=True,
records=rec, prefix="wm")
# ⑥ update_obs_queue
with WallTimer("update_obs_queue", rec):
for idx in range(args.exe_steps):
obs_u = {
'observation.images.top':
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state':
pred_states[0][idx:idx + 1],
'action':
torch.zeros_like(pred_actions[0][-1:]),
}
obs_u['observation.state'][:, ori_state_dim:] = 0.0
cond_obs_queues = populate_queues(cond_obs_queues, obs_u)
# ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost)
with WallTimer("tensorboard_log", rec):
for vid in [pred_videos_0, pred_videos_1]:
if vid is not None and vid.dim() == 5:
v = vid.permute(2, 0, 1, 3, 4)
grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v]
_ = torch.stack(grids, dim=0)
# ⑧ save_results
with WallTimer("save_results", rec):
if pred_videos_0 is not None:
save_results(pred_videos_0.cpu(),
os.path.join(tmp_dir, f"dm_{itr}.mp4"),
fps=args.save_fps)
save_results(pred_videos_1.cpu(),
os.path.join(tmp_dir, f"wm_{itr}.mp4"),
fps=args.save_fps)
# ⑨ cpu_transfer
with CudaTimer("cpu_transfer", rec):
_ = pred_videos_1[:, :, :args.exe_steps].cpu()
# ── itr_total end ──
itr_end.record()
torch.cuda.synchronize()
itr_total_ms = itr_start.elapsed_time(itr_end)
rec["itr_total"].append(itr_total_ms)
# Flatten: each stage has exactly one entry per itr
itr_rec = {k: v[0] for k, v in rec.items()}
all_records.append(itr_rec)
# Print live progress
print(f" itr {itr}: {itr_total_ms:.0f} ms total | "
f"policy={itr_rec.get('synth_policy', 0):.0f} | "
f"wm={itr_rec.get('synth_world_model', 0):.0f} | "
f"save={itr_rec.get('save_results', 0):.0f} | "
f"tb={itr_rec.get('tensorboard_log', 0):.0f}")
return all_records
# ──────────────────────────────────────────────────────────────────────
# Layer 1: Console report
# ──────────────────────────────────────────────────────────────────────
def print_iteration_report(all_records, warmup=1):
"""Print a structured table of per-stage timing across iterations."""
if len(all_records) <= warmup:
records = all_records
else:
records = all_records[warmup:]
print(f"\n(Skipping first {warmup} itr(s) as warmup)\n")
# Collect all stage keys in a stable order
all_keys = []
seen = set()
for rec in records:
for k in rec:
if k not in seen:
all_keys.append(k)
seen.add(k)
# Separate top-level stages from sub-stages
top_keys = [k for k in all_keys if '/' not in k]
sub_keys = [k for k in all_keys if '/' in k]
def _print_table(keys, title):
if not keys:
return
print("=" * 82)
print(title)
print("=" * 82)
print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}")
print("-" * 82)
total_mean = np.mean([rec.get("itr_total", 0) for rec in records])
for k in keys:
vals = [rec.get(k, 0) for rec in records]
mean = np.mean(vals)
std = np.std(vals)
mn = np.min(vals)
mx = np.max(vals)
pct = mean / total_mean * 100 if total_mean > 0 else 0
print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%")
print("-" * 82)
print()
_print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN")
_print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN")
# ──────────────────────────────────────────────────────────────────────
# Layer 3: CSV output for A/B comparison
# ──────────────────────────────────────────────────────────────────────
def write_csv(all_records, csv_path, warmup=1):
"""Write per-iteration timing to CSV for later comparison."""
records = all_records[warmup:] if len(all_records) > warmup else all_records
# Collect all keys
all_keys = []
seen = set()
for rec in records:
for k in rec:
if k not in seen:
all_keys.append(k)
seen.add(k)
with open(csv_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys)
writer.writeheader()
for i, rec in enumerate(records):
row = {'itr': i}
row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys})
writer.writerow(row)
# Also write a summary row
summary_path = csv_path.replace('.csv', '_summary.csv')
with open(summary_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys)
writer.writeheader()
for stat_name, stat_fn in [('mean', np.mean), ('std', np.std),
('min', np.min), ('max', np.max)]:
row = {'stat': stat_name}
row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}"
for k in all_keys})
writer.writerow(row)
print(f">>> CSV written to: {csv_path}")
print(f">>> Summary written to: {summary_path}")
def compare_csvs(path_a, path_b):
"""Compare two summary CSVs and print a diff table."""
df_a = pd.read_csv(path_a, index_col='stat')
df_b = pd.read_csv(path_b, index_col='stat')
# Use mean row for comparison
mean_a = df_a.loc['mean'].astype(float)
mean_b = df_b.loc['mean'].astype(float)
print("=" * 90)
print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}")
print("=" * 90)
print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}")
print("-" * 90)
for col in mean_a.index:
if col not in mean_b.index:
continue
a_val = mean_a[col]
b_val = mean_b[col]
diff = b_val - a_val
speedup = a_val / b_val if b_val > 0 else float('inf')
marker = " <<<" if abs(diff) > 50 else ""
print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}")
print("-" * 90)
total_a = mean_a.get('itr_total', 0)
total_b = mean_b.get('itr_total', 0)
print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} "
f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x")
print()
# ──────────────────────────────────────────────────────────────────────
# Layer 2: GPU timeline trace wrapper
# ──────────────────────────────────────────────────────────────────────
def run_with_trace(model, args, config, noise_shape, device):
"""Run iterations under torch.profiler to generate Chrome/TensorBoard traces."""
trace_dir = args.trace_dir
os.makedirs(trace_dir, exist_ok=True)
# We need the same data setup as run_profiled_iterations
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path)
sample = df.iloc[0]
data_module = instantiate_from_config(config.data)
data_module.setup()
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
ori_fps = float(sample['fps'])
fs = args.frame_stride
model_input_fs = ori_fps // fs
transition_path = get_transition_path(args.prompt_dir, sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
batch, ori_state_dim, ori_action_dim = prepare_init_input(
0, init_frame_path, transition_dict, fs,
data_module.test_datasets[args.dataset],
n_obs_steps=model.n_obs_steps_imagen)
observation = {
'observation.images.top':
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
'observation.state':
batch['observation.state'][-1].unsqueeze(0),
'action':
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
}
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
cond_obs_queues = {
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
"action": deque(maxlen=args.video_length),
}
cond_obs_queues = populate_queues(cond_obs_queues, observation)
tmp_dir = os.path.join(args.savedir, "profile_tmp")
os.makedirs(tmp_dir, exist_ok=True)
prompt_text = sample['instruction']
# Total iterations: warmup + active
n_warmup = 1
n_active = min(args.n_iter, 2) # trace 2 active iterations max
n_total = n_warmup + n_active
print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations")
print(f">>> Trace output: {trace_dir}")
with torch.no_grad(), torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(
wait=0, warmup=n_warmup, active=n_active, repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
record_shapes=True,
with_stack=True,
) as prof:
for itr_idx in range(n_total):
phase = "warmup" if itr_idx < n_warmup else "active"
print(f" trace itr {itr_idx} ({phase})...")
# ── One full iteration (same logic as run_inference) ──
obs_loc = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
obs_loc = {k: v.to(device) for k, v in obs_loc.items()}
# Policy pass
dummy_rec = defaultdict(list)
pv0, pa, _ = profiled_synthesis(
model, prompt_text, obs_loc, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=True,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=False,
decode_video=not args.fast_policy_no_decode,
records=dummy_rec, prefix="policy")
for idx in range(len(pa[0])):
oa = {'action': pa[0][idx:idx + 1]}
oa['action'][:, ori_action_dim:] = 0.0
populate_queues(cond_obs_queues, oa)
# Re-stack for world model
obs_loc2 = {
'observation.images.top':
torch.stack(list(cond_obs_queues['observation.images.top']),
dim=1).permute(0, 2, 1, 3, 4),
'observation.state':
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
'action':
torch.stack(list(cond_obs_queues['action']), dim=1),
}
obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()}
# World model pass
pv1, _, ps = profiled_synthesis(
model, "", obs_loc2, noise_shape,
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=model_input_fs, text_input=False,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale,
sim_mode=True, decode_video=True,
records=dummy_rec, prefix="wm")
# Update obs queue
for idx in range(args.exe_steps):
ou = {
'observation.images.top':
pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
'observation.state': ps[0][idx:idx + 1],
'action': torch.zeros_like(pa[0][-1:]),
}
ou['observation.state'][:, ori_state_dim:] = 0.0
populate_queues(cond_obs_queues, ou)
# Save results (captures CPU stall in trace)
if pv0 is not None:
save_results(pv0.cpu(),
os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"),
fps=args.save_fps)
save_results(pv1.cpu(),
os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"),
fps=args.save_fps)
prof.step()
print(f">>> Trace saved to {trace_dir}")
print(" View with: tensorboard --logdir", trace_dir)
print(" Or open the .json file in chrome://tracing")
# ──────────────────────────────────────────────────────────────────────
# Argument parser
# ──────────────────────────────────────────────────────────────────────
def get_parser():
p = argparse.ArgumentParser(description="Profile full iteration loop")
# Compare mode (no model needed)
p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"),
help="Compare two summary CSVs and exit")
# Model / data
p.add_argument("--ckpt_path", type=str, default=None)
p.add_argument("--config", type=str, default=None)
p.add_argument("--prompt_dir", type=str, default=None)
p.add_argument("--dataset", type=str, default=None)
p.add_argument("--savedir", type=str, default="profile_output")
# Inference params (match world_model_interaction.py)
p.add_argument("--ddim_steps", type=int, default=50)
p.add_argument("--ddim_eta", type=float, default=1.0)
p.add_argument("--bs", type=int, default=1)
p.add_argument("--height", type=int, default=320)
p.add_argument("--width", type=int, default=512)
p.add_argument("--frame_stride", type=int, default=4)
p.add_argument("--unconditional_guidance_scale", type=float, default=1.0)
p.add_argument("--video_length", type=int, default=16)
p.add_argument("--timestep_spacing", type=str, default="uniform_trailing")
p.add_argument("--guidance_rescale", type=float, default=0.7)
p.add_argument("--exe_steps", type=int, default=16)
p.add_argument("--n_iter", type=int, default=5)
p.add_argument("--save_fps", type=int, default=8)
p.add_argument("--seed", type=int, default=123)
p.add_argument("--perframe_ae", action='store_true', default=False)
p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16")
p.add_argument("--fast_policy_no_decode", action='store_true', default=False)
# Profiling control
p.add_argument("--warmup", type=int, default=1,
help="Number of warmup iterations to skip in statistics")
p.add_argument("--csv", type=str, default=None,
help="Write per-iteration timing to this CSV file")
p.add_argument("--trace", action='store_true', default=False,
help="Enable Layer 2: GPU timeline trace")
p.add_argument("--trace_dir", type=str, default="./profile_traces",
help="Directory for trace output")
return p
# ──────────────────────────────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────────────────────────────
def main():
patch_norm_bypass_autocast()
parser = get_parser()
args = parser.parse_args()
# ── Compare mode: no model needed ──
if args.compare:
compare_csvs(args.compare[0], args.compare[1])
return
# ── Validate required args ──
for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']:
if getattr(args, required) is None:
parser.error(f"--{required} is required for profiling mode")
seed_everything(args.seed)
os.makedirs(args.savedir, exist_ok=True)
# ── Load model ──
print("=" * 60)
print("PROFILE ITERATION — Loading model...")
print("=" * 60)
model, config = load_model(args)
device = next(model.parameters()).device
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
noise_shape = [args.bs, channels, args.video_length, h, w]
print(f">>> Noise shape: {noise_shape}")
print(f">>> DDIM steps: {args.ddim_steps}")
print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}")
# ── Layer 2: GPU trace (optional) ──
if args.trace:
with torch.no_grad():
run_with_trace(model, args, config, noise_shape, device)
print()
# ── Layer 1: Iteration-level breakdown ──
print("=" * 60)
print("LAYER 1: ITERATION-LEVEL PROFILING")
print("=" * 60)
with torch.no_grad():
all_records = run_profiled_iterations(
model, args, config, noise_shape, device)
# Print report
print_iteration_report(all_records, warmup=args.warmup)
# ── Layer 3: CSV output for A/B comparison ──
if args.csv:
write_csv(all_records, args.csv, warmup=args.warmup)
print("Done.")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,733 @@
"""
Profile the full inference pipeline of the world model, covering all 7 stages:
1. Image Embedding
2. VAE Encode
3. Text Conditioning
4. State/Action Projectors
5. DDIM Loop
6. VAE Decode
7. Post-process
Reports stage-level timing, UNet sub-module breakdown, memory summary,
and throughput analysis.
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep
Usage:
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3
"""
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common
from contextlib import nullcontext, contextmanager
from collections import defaultdict
from omegaconf import OmegaConf
from einops import rearrange, repeat
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.modules.attention import (
SpatialTransformer, TemporalTransformer,
BasicTransformerBlock, CrossAttention, FeedForward,
)
from unifolm_wma.modules.networks.wma_model import ResBlock
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
# --- W7900D theoretical peak ---
PEAK_BF16_TFLOPS = 61.0
MEM_BW_GBS = 864.0
# ---------------------------------------------------------------------------
# Utility: patch norms to bypass autocast fp32 promotion
# ---------------------------------------------------------------------------
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
# ---------------------------------------------------------------------------
# Utility: torch.compile hot ResBlocks
# ---------------------------------------------------------------------------
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=True)
model.eval()
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
apply_torch_compile(model)
model = model.cuda()
return model
# ---------------------------------------------------------------------------
# CudaTimer — precise GPU timing via CUDA events
# ---------------------------------------------------------------------------
class CudaTimer:
"""Context manager for GPU-precise stage timing using CUDA events."""
def __init__(self, name, records):
self.name = name
self.records = records
self.start = torch.cuda.Event(enable_timing=True)
self.end = torch.cuda.Event(enable_timing=True)
def __enter__(self):
torch.cuda.synchronize()
self.start.record()
return self
def __exit__(self, *args):
self.end.record()
torch.cuda.synchronize()
elapsed = self.start.elapsed_time(self.end)
self.records[self.name].append(elapsed)
# ---------------------------------------------------------------------------
# HookProfiler — sub-module level timing inside UNet via hooks
# ---------------------------------------------------------------------------
class HookProfiler:
"""Register forward hooks on UNet sub-modules to collect per-call timing."""
# Coarse-grained targets (original)
COARSE_CLASSES = (
SpatialTransformer,
TemporalTransformer,
ResBlock,
ConditionalUnet1D,
)
# Fine-grained targets for deep DDIM analysis
FINE_CLASSES = (
CrossAttention,
FeedForward,
)
def __init__(self, unet, deep=False):
self.unet = unet
self.deep = deep
self.handles = []
# per-instance data: {instance_id: [(start_event, end_event), ...]}
self._events = defaultdict(list)
# tag mapping: {instance_id: (class_name, module_name)}
self._tags = {}
# block location: {instance_id: block_location_str}
self._block_loc = {}
@staticmethod
def _get_block_location(name):
"""Derive UNet block location from module name, e.g. 'input_blocks.3.1'."""
parts = name.split('.')
if len(parts) >= 2 and parts[0] == 'input_blocks':
return f"input_blocks.{parts[1]}"
elif len(parts) >= 1 and parts[0] == 'middle_block':
return "middle_block"
elif len(parts) >= 2 and parts[0] == 'output_blocks':
return f"output_blocks.{parts[1]}"
elif 'action_unet' in name:
return "action_unet"
elif 'state_unet' in name:
return "state_unet"
elif name == 'out' or name.startswith('out.'):
return "out"
return "other"
def register(self):
"""Attach pre/post forward hooks to target sub-modules + unet.out."""
target_classes = self.COARSE_CLASSES
if self.deep:
target_classes = target_classes + self.FINE_CLASSES
for name, mod in self.unet.named_modules():
if isinstance(mod, target_classes):
tag = type(mod).__name__
inst_id = id(mod)
self._tags[inst_id] = (tag, name)
self._block_loc[inst_id] = self._get_block_location(name)
self.handles.append(
mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
self.handles.append(
mod.register_forward_hook(self._make_post_hook(inst_id)))
# Also hook unet.out (nn.Sequential)
out_mod = self.unet.out
inst_id = id(out_mod)
self._tags[inst_id] = ("UNet.out", "out")
self._block_loc[inst_id] = "out"
self.handles.append(
out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
self.handles.append(
out_mod.register_forward_hook(self._make_post_hook(inst_id)))
def _make_pre_hook(self, inst_id):
events = self._events
def hook(module, input):
start = torch.cuda.Event(enable_timing=True)
start.record()
events[inst_id].append([start, None])
return hook
def _make_post_hook(self, inst_id):
events = self._events
def hook(module, input, output):
end = torch.cuda.Event(enable_timing=True)
end.record()
events[inst_id][-1][1] = end
return hook
def reset(self):
"""Clear collected events for a fresh run."""
self._events.clear()
def synchronize_and_collect(self):
"""Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block)."""
torch.cuda.synchronize()
by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []})
by_instance = {}
# by_block: {block_loc: {tag: {"total_ms", "count"}}}
by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0}))
for inst_id, pairs in self._events.items():
tag, mod_name = self._tags[inst_id]
block_loc = self._block_loc.get(inst_id, "other")
inst_times = []
for start_evt, end_evt in pairs:
if end_evt is not None:
ms = start_evt.elapsed_time(end_evt)
inst_times.append(ms)
by_type[tag]["total_ms"] += ms
by_type[tag]["count"] += 1
by_type[tag]["calls"].append(ms)
by_block[block_loc][tag]["total_ms"] += ms
by_block[block_loc][tag]["count"] += 1
by_instance[(tag, mod_name)] = inst_times
return dict(by_type), by_instance, dict(by_block)
def remove(self):
"""Remove all hooks."""
for h in self.handles:
h.remove()
self.handles.clear()
# ---------------------------------------------------------------------------
# Build dummy inputs matching the pipeline's expected shapes
# ---------------------------------------------------------------------------
def build_dummy_inputs(model, noise_shape):
"""Create synthetic observation dict and prompts for profiling."""
device = next(model.parameters()).device
B, C, T, H, W = noise_shape
dtype = torch.bfloat16
# observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline)
O = 2
obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype)
obs_state = torch.randn(B, O, 16, device=device, dtype=dtype)
action = torch.randn(B, 16, 16, device=device, dtype=dtype)
observation = {
'observation.images.top': obs_images,
'observation.state': obs_state,
'action': action,
}
prompts = ["a robot arm performing a task"] * B
return observation, prompts
# ---------------------------------------------------------------------------
# Run one full pipeline pass with per-stage timing
# ---------------------------------------------------------------------------
def run_pipeline(model, observation, prompts, noise_shape, ddim_steps,
cfg_scale, hook_profiler):
"""Execute the full 7-stage pipeline, returning per-stage timing dict."""
records = defaultdict(list)
device = next(model.parameters()).device
B, C, T, H, W = noise_shape
dtype = torch.bfloat16
fs = torch.tensor([1] * B, dtype=torch.long, device=device)
# --- Stage 1: Image Embedding ---
with CudaTimer("1_Image_Embedding", records):
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype)
with torch.autocast('cuda', dtype=torch.bfloat16):
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
# --- Stage 2: VAE Encode ---
with CudaTimer("2_VAE_Encode", records):
videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W]
b_v, c_v, t_v, h_v, w_v = videos.shape
vae_dtype = next(model.first_stage_model.parameters()).dtype
x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype)
z = model.encode_first_stage(x_vae)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v)
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w', repeat=T)
cond = {"c_concat": [img_cat_cond]}
vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size()
vae_enc_output_bytes = z.nelement() * z.element_size()
# --- Stage 3: Text Conditioning ---
with CudaTimer("3_Text_Conditioning", records):
cond_ins_emb = model.get_learned_conditioning(prompts)
# --- Stage 4: State/Action Projectors ---
with CudaTimer("4_Projectors", records):
projector_dtype = next(model.state_projector.parameters()).dtype
with torch.autocast('cuda', dtype=torch.bfloat16):
cond_state_emb = model.state_projector(
observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(
observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
# Assemble cross-attention conditioning
cond["c_crossattn"] = [
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
dim=1)
]
n_obs_acting = getattr(model, 'n_obs_steps_acting', 2)
cond["c_crossattn_action"] = [
observation['observation.images.top'][:, :, -n_obs_acting:],
observation['observation.state'][:, -n_obs_acting:],
True, # sim_mode
False,
]
# CFG: build unconditional conditioning if needed
uc = None
if cfg_scale != 1.0:
uc_crossattn = torch.zeros_like(cond["c_crossattn"][0])
uc = {
"c_concat": cond["c_concat"],
"c_crossattn": [uc_crossattn],
"c_crossattn_action": cond["c_crossattn_action"],
}
# --- Stage 5: DDIM Loop ---
ddim_sampler = DDIMSampler(model)
hook_profiler.reset()
with CudaTimer("5_DDIM_Loop", records):
with torch.autocast('cuda', dtype=torch.bfloat16):
samples, actions, states, _ = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=B,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
eta=1.0,
cfg_img=None,
mask=None,
x0=None,
fs=fs,
timestep_spacing='uniform',
guidance_rescale=0.0,
unconditional_conditioning_img_nonetext=None,
)
hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect()
# --- Stage 6: VAE Decode ---
with CudaTimer("6_VAE_Decode", records):
batch_images = model.decode_first_stage(samples)
vae_dec_input_bytes = samples.nelement() * samples.element_size()
vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size()
# --- Stage 7: Post-process ---
with CudaTimer("7_Post_Process", records):
batch_images_cpu = batch_images.cpu()
actions_cpu = actions.cpu()
states_cpu = states.cpu()
# Simulate video save overhead: clamp + uint8 conversion
_ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8)
# Flatten single-element lists
stage_times = {k: v[0] for k, v in records.items()}
bandwidth_info = {
"vae_enc_input_bytes": vae_enc_input_bytes,
"vae_enc_output_bytes": vae_enc_output_bytes,
"vae_dec_input_bytes": vae_dec_input_bytes,
"vae_dec_output_bytes": vae_dec_output_bytes,
}
return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info
# ---------------------------------------------------------------------------
# Reporting
# ---------------------------------------------------------------------------
def print_stage_timing(all_runs_stages):
"""Table 1: Stage Timing — name | mean(ms) | std | percent."""
import numpy as np
stage_names = list(all_runs_stages[0].keys())
means = {}
stds = {}
for name in stage_names:
vals = [run[name] for run in all_runs_stages]
means[name] = np.mean(vals)
stds[name] = np.std(vals)
total = sum(means.values())
print()
print("=" * 72)
print("TABLE 1: STAGE TIMING")
print("=" * 72)
print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}")
print("-" * 72)
for name in stage_names:
pct = means[name] / total * 100 if total > 0 else 0
print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%")
print("-" * 72)
print(f"{'TOTAL':<25} {total:>10.1f}")
print()
def print_unet_breakdown(all_runs_hooks):
"""Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent."""
import numpy as np
# Aggregate across runs
agg = defaultdict(lambda: {"totals": [], "counts": []})
for hook_by_type in all_runs_hooks:
for tag, data in hook_by_type.items():
agg[tag]["totals"].append(data["total_ms"])
agg[tag]["counts"].append(data["count"])
print("=" * 80)
print("TABLE 2: UNET SUB-MODULE BREAKDOWN")
print("=" * 80)
print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}")
print("-" * 80)
grand_total = 0
rows = []
for tag, d in agg.items():
mean_total = np.mean(d["totals"])
mean_count = np.mean(d["counts"])
per_call = mean_total / mean_count if mean_count > 0 else 0
grand_total += mean_total
rows.append((tag, mean_total, mean_count, per_call))
rows.sort(key=lambda r: r[1], reverse=True)
for tag, mean_total, mean_count, per_call in rows:
pct = mean_total / grand_total * 100 if grand_total > 0 else 0
print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%")
print("-" * 80)
print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}")
print()
def print_block_timing(all_runs_blocks):
"""Table 2b: Per-UNet-block timing — which blocks are hottest."""
import numpy as np
# Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}}
agg = defaultdict(lambda: defaultdict(list))
for by_block in all_runs_blocks:
for block_loc, tag_dict in by_block.items():
for tag, data in tag_dict.items():
agg[block_loc][tag].append(data["total_ms"])
# Compute per-block totals
block_totals = {}
for block_loc, tag_dict in agg.items():
block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values())
grand_total = sum(block_totals.values())
# Sort blocks in logical order
def block_sort_key(name):
if name.startswith("input_blocks."):
return (0, int(name.split('.')[1]))
elif name == "middle_block":
return (1, 0)
elif name.startswith("output_blocks."):
return (2, int(name.split('.')[1]))
elif name == "out":
return (3, 0)
elif name == "action_unet":
return (4, 0)
elif name == "state_unet":
return (5, 0)
return (9, 0)
sorted_blocks = sorted(block_totals.keys(), key=block_sort_key)
print("=" * 90)
print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)")
print("=" * 90)
print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown")
print("-" * 90)
for block_loc in sorted_blocks:
total = block_totals[block_loc]
pct = total / grand_total * 100 if grand_total > 0 else 0
# Build breakdown string
parts = []
for tag, vals in sorted(agg[block_loc].items(),
key=lambda x: np.mean(x[1]), reverse=True):
parts.append(f"{tag}={np.mean(vals):.0f}")
breakdown = ", ".join(parts)
print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}")
print("-" * 90)
print(f"{'TOTAL':<22} {grand_total:>10.1f}")
print()
def print_attn_ff_breakdown(all_runs_hooks):
"""Table 2c: CrossAttention vs FeedForward breakdown (--deep mode)."""
import numpy as np
agg = defaultdict(list)
for hook_by_type in all_runs_hooks:
for tag, data in hook_by_type.items():
if tag in ("CrossAttention", "FeedForward"):
agg[tag].append(data["total_ms"])
if not agg:
return
print("=" * 70)
print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)")
print("=" * 70)
print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}")
print("-" * 70)
grand = 0
rows = []
for tag in ("CrossAttention", "FeedForward"):
if tag in agg:
mean_t = np.mean(agg[tag])
grand += mean_t
rows.append((tag, mean_t))
for tag, mean_t in rows:
pct = mean_t / grand * 100 if grand > 0 else 0
print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%")
print("-" * 70)
print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}")
print()
def print_unet_detailed(all_runs_instances):
"""Print per-instance UNet sub-module detail (--detailed mode)."""
import numpy as np
# Use last run's data
by_instance = all_runs_instances[-1]
print("=" * 100)
print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)")
print("=" * 100)
print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}")
print("-" * 100)
rows = []
for (tag, mod_name), times in by_instance.items():
if len(times) == 0:
continue
total = sum(times)
mean = np.mean(times)
rows.append((tag, mod_name, len(times), total, mean))
rows.sort(key=lambda r: r[3], reverse=True)
for tag, mod_name, count, total, mean in rows:
short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name
print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}")
print()
def print_memory_summary(mem_before, mem_peak):
"""Table 3: Memory Summary."""
delta = mem_peak - mem_before
print("=" * 50)
print("TABLE 3: MEMORY SUMMARY")
print("=" * 50)
print(f" Initial allocated: {mem_before / 1e9:.2f} GB")
print(f" Peak allocated: {mem_peak / 1e9:.2f} GB")
print(f" Delta (pipeline): {delta / 1e9:.2f} GB")
print()
def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale):
"""Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth."""
import numpy as np
n_runs = len(all_runs_stages)
# Total latency
totals = []
for run in all_runs_stages:
totals.append(sum(run.values()))
mean_total = np.mean(totals)
# DDIM loop time
ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages]
mean_ddim = np.mean(ddim_times)
unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2
per_step = mean_ddim / ddim_steps
per_unet = mean_ddim / unet_calls
# VAE bandwidth
mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages])
mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages])
bw = all_bw[-1] # use last run's byte counts
enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"]
dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"]
enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0
dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0
print("=" * 60)
print("TABLE 4: THROUGHPUT")
print("=" * 60)
print(f" Total pipeline latency: {mean_total:.1f} ms")
print(f" DDIM loop latency: {mean_ddim:.1f} ms")
print(f" DDIM steps: {ddim_steps}")
print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})")
print(f" UNet forward calls: {unet_calls}")
print(f" Per DDIM step: {per_step:.1f} ms")
print(f" Per UNet forward: {per_unet:.1f} ms")
print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS")
print()
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
patch_norm_bypass_autocast()
parser = argparse.ArgumentParser(
description="Profile the full inference pipeline")
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--ddim_steps", type=int, default=50)
parser.add_argument("--cfg_scale", type=float, default=1.0)
parser.add_argument("--n_runs", type=int, default=3)
parser.add_argument("--warmup", type=int, default=1)
parser.add_argument("--detailed", action="store_true",
help="Print per-instance UNet sub-module detail")
parser.add_argument("--deep", action="store_true",
help="Enable deep DDIM analysis: per-block, attn vs ff")
args = parser.parse_args()
noise_shape = [1, 4, 16, 40, 64]
# --- Load model ---
print("Loading model...")
model = load_model(args)
observation, prompts = build_dummy_inputs(model, noise_shape)
# --- Setup hook profiler ---
unet = model.model.diffusion_model
hook_profiler = HookProfiler(unet, deep=args.deep)
hook_profiler.register()
print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules")
# --- Warmup ---
print(f"Warmup: {args.warmup} run(s)...")
with torch.no_grad():
for i in range(args.warmup):
run_pipeline(model, observation, prompts, noise_shape,
args.ddim_steps, args.cfg_scale, hook_profiler)
print(f" warmup {i+1}/{args.warmup} done")
# --- Measurement runs ---
print(f"Measuring: {args.n_runs} run(s)...")
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated()
all_stages = []
all_hooks = []
all_instances = []
all_blocks = []
all_bw = []
with torch.no_grad():
for i in range(args.n_runs):
stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline(
model, observation, prompts, noise_shape,
args.ddim_steps, args.cfg_scale, hook_profiler)
all_stages.append(stage_times)
all_hooks.append(hook_by_type)
all_instances.append(hook_by_instance)
all_blocks.append(hook_by_block)
all_bw.append(bw)
total = sum(stage_times.values())
print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total")
mem_peak = torch.cuda.max_memory_allocated()
# --- Reports ---
print_stage_timing(all_stages)
print_unet_breakdown(all_hooks)
print_block_timing(all_blocks)
if args.deep:
print_attn_ff_breakdown(all_hooks)
if args.detailed:
print_unet_detailed(all_instances)
print_memory_summary(mem_before, mem_peak)
print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale)
hook_profiler.remove()
print("Done.")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,287 @@
"""
Profile one DDIM sampling iteration to capture all matmul/attention ops,
their matrix sizes, wall time, and compute utilization.
Uses torch.profiler for CUDA timing and FlopCounterMode for accurate
FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS).
Usage:
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
python scripts/evaluation/profile_unet.py \
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
--config configs/inference/world_model_interaction.yaml
"""
import argparse
import torch
import torch.nn as nn
from collections import OrderedDict, defaultdict
from omegaconf import OmegaConf
from torch.utils.flop_counter import FlopCounterMode
from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
# --- W7900D theoretical peak (TFLOPS) ---
PEAK_BF16_TFLOPS = 61.0
PEAK_FP32_TFLOPS = 30.5
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
def load_model(args):
config = OmegaConf.load(args.config)
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
state_dict = torch.load(args.ckpt_path, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
model.load_state_dict(state_dict, strict=True)
model.eval()
model.model.to(torch.bfloat16)
apply_torch_compile(model)
model = model.cuda()
return model
def build_call_kwargs(model, noise_shape):
"""Build dummy inputs matching the hybrid conditioning forward signature."""
device = next(model.parameters()).device
B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64]
dtype = torch.bfloat16
x_action = torch.randn(B, 16, 16, device=device, dtype=dtype)
x_state = torch.randn(B, 16, 16, device=device, dtype=dtype)
timesteps = torch.tensor([500], device=device, dtype=torch.long)
context = torch.randn(B, 351, 1024, device=device, dtype=dtype)
obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype)
obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype)
context_action = [obs_images, obs_state, True, False]
fps = torch.tensor([1], device=device, dtype=torch.long)
x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype)
c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)]
return dict(
x=x_raw, x_action=x_action, x_state=x_state, t=timesteps,
c_concat=c_concat, c_crossattn=[context],
c_crossattn_action=context_action, s=fps,
)
def profile_one_step(model, noise_shape):
"""Run one UNet forward pass under torch.profiler for CUDA timing."""
diff_wrapper = model.model
call_kwargs = build_call_kwargs(model, noise_shape)
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
# Warmup
for _ in range(2):
_ = diff_wrapper(**call_kwargs)
torch.cuda.synchronize()
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CUDA],
record_shapes=True,
with_flops=True,
) as prof:
_ = diff_wrapper(**call_kwargs)
torch.cuda.synchronize()
return prof
def count_flops(model, noise_shape):
"""Run one UNet forward pass under FlopCounterMode for accurate FLOPS."""
diff_wrapper = model.model
call_kwargs = build_call_kwargs(model, noise_shape)
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
_ = diff_wrapper(**call_kwargs)
torch.cuda.synchronize()
return flop_counter
def print_report(prof, flop_counter):
"""Parse profiler results and print a structured report with accurate FLOPS."""
events = prof.key_averages()
# --- Extract per-operator FLOPS from FlopCounterMode ---
# flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting
flop_by_op = {}
flop_by_module = {}
if hasattr(flop_counter, 'flop_counts'):
# Per-op: only from top-level "Global" entry (no parent/child duplication)
global_ops = flop_counter.flop_counts.get("Global", {})
for op_name, flop_count in global_ops.items():
key = str(op_name).split('.')[-1]
flop_by_op[key] = flop_by_op.get(key, 0) + flop_count
# Per-module: collect all, skip "Global" and top-level wrapper duplicates
for module_name, op_dict in flop_counter.flop_counts.items():
module_total = sum(op_dict.values())
if module_total > 0:
flop_by_module[module_name] = module_total
total_counted_flops = flop_counter.get_total_flops()
# Collect matmul-like ops
matmul_ops = []
other_ops = []
for evt in events:
if evt.device_time_total <= 0:
continue
name = evt.key
is_matmul = any(k in name.lower() for k in
['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear'])
entry = {
'name': name,
'input_shapes': str(evt.input_shapes) if evt.input_shapes else '',
'cuda_time_ms': evt.device_time_total / 1000.0,
'count': evt.count,
'flops': evt.flops if evt.flops else 0,
}
if is_matmul:
matmul_ops.append(entry)
else:
other_ops.append(entry)
# Sort by CUDA time
matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops)
total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops)
# --- Print matmul ops ---
print("=" * 130)
print("MATMUL / LINEAR OPS (sorted by CUDA time)")
print("=" * 130)
print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
print("-" * 130)
for op in matmul_ops:
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
# --- Print top non-matmul ops ---
print()
print("=" * 130)
print("TOP NON-MATMUL OPS (sorted by CUDA time)")
print("=" * 130)
print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
print("-" * 130)
for op in other_ops[:20]:
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
# --- FlopCounterMode per-operator breakdown ---
print()
print("=" * 130)
print("FLOPS BY ATen OPERATOR (FlopCounterMode)")
print("=" * 130)
print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}")
print("-" * 55)
sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True)
for op_name, flops in sorted_flop_ops:
gflops = flops / 1e9
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%")
# --- FlopCounterMode per-module breakdown ---
if flop_by_module:
print()
print("=" * 130)
print("FLOPS BY MODULE (FlopCounterMode)")
print("=" * 130)
print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}")
print("-" * 90)
sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True)
for mod_name, flops in sorted_modules[:30]:
gflops = flops / 1e9
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name
print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%")
# --- Summary ---
print()
print("=" * 130)
print("SUMMARY")
print("=" * 130)
print(f" Total CUDA time: {total_cuda_ms:.1f} ms")
print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)")
print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)")
print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS")
if total_matmul_ms > 0 and total_counted_flops > 0:
avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12
avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100
overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12
overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100
print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)")
print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)")
print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS")
if __name__ == '__main__':
patch_norm_bypass_autocast()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
print("Loading model...")
model = load_model(args)
noise_shape = [1, 4, 16, 40, 64]
print(f"Profiling UNet forward pass with shape {noise_shape}...")
prof = profile_one_step(model, noise_shape)
print("Counting FLOPS with FlopCounterMode...")
flop_counter = count_flops(model, noise_shape)
print_report(prof, flop_counter)

View File

@@ -19,9 +19,6 @@ from fastapi.responses import JSONResponse
from typing import Any, Dict, Optional, Tuple, List from typing import Any, Dict, Optional, Tuple, List
from datetime import datetime from datetime import datetime
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.models.samplers.ddim import DDIMSampler

View File

@@ -1,4 +1,7 @@
import argparse, os, glob import argparse, os, glob
from contextlib import nullcontext
import atexit
from concurrent.futures import ThreadPoolExecutor
import pandas as pd import pandas as pd
import random import random
import torch import torch
@@ -9,8 +12,8 @@ import logging
import einops import einops
import warnings import warnings
import imageio import imageio
import atexit
from concurrent.futures import ThreadPoolExecutor from typing import Optional, List, Any
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from omegaconf import OmegaConf from omegaconf import OmegaConf
@@ -20,16 +23,111 @@ from collections import OrderedDict
from torch import nn from torch import nn
from eval_utils import populate_queues from eval_utils import populate_queues
from collections import deque from collections import deque
from typing import Optional, List, Any
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torch import Tensor from torch import Tensor
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from PIL import Image from PIL import Image
from unifolm_wma.models.samplers.ddim import DDIMSampler from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.utils.utils import instantiate_from_config
import torch.nn.functional as F
# ========== Async I/O utilities ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
def _get_io_executor() -> ThreadPoolExecutor:
global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try:
fut.result()
except Exception as e:
print(f">>> [async I/O] error: {e}")
_io_futures.clear()
atexit.register(_flush_io)
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Synchronous save on CPU tensor (runs in background thread)."""
video = torch.clamp(video_cpu.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
def _log_to_tb_sync(video_cpu: Tensor, writer: SummaryWriter, tag: str, fps: int) -> None:
"""Synchronous tensorboard logging on CPU tensor (runs in background thread)."""
video = video_cpu.float()
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)
def log_to_tensorboard_async(writer: SummaryWriter, video: Tensor, tag: str, fps: int = 10) -> None:
"""Submit tensorboard logging to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_log_to_tb_sync, video_cpu, writer, tag, fps)
_io_futures.append(fut)
def patch_norm_bypass_autocast():
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
def _group_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(
x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def _layer_norm_forward(self, x):
with torch.amp.autocast('cuda', enabled=False):
return F.layer_norm(
x, self.normalized_shape,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
torch.nn.GroupNorm.forward = _group_norm_forward
torch.nn.LayerNorm.forward = _layer_norm_forward
def get_device_from_parameters(module: nn.Module) -> torch.device: def get_device_from_parameters(module: nn.Module) -> torch.device:
@@ -44,6 +142,92 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
return next(iter(module.parameters())).device return next(iter(module.parameters())).device
def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module:
"""Apply precision settings to model components based on command-line arguments.
Args:
model (nn.Module): The model to apply precision settings to.
args (argparse.Namespace): Parsed command-line arguments containing precision settings.
Returns:
nn.Module: Model with precision settings applied.
"""
print(f">>> Applying precision settings:")
print(f" - Diffusion dtype: {args.diffusion_dtype}")
print(f" - Projector mode: {args.projector_mode}")
print(f" - Encoder mode: {args.encoder_mode}")
print(f" - VAE dtype: {args.vae_dtype}")
# 1. Set Diffusion backbone precision
if args.diffusion_dtype == "bf16":
# Convert diffusion model weights to bf16
model.model.to(torch.bfloat16)
model.diffusion_autocast_dtype = torch.bfloat16
print(" ✓ Diffusion model weights converted to bfloat16")
else:
model.diffusion_autocast_dtype = torch.bfloat16
print(" ✓ Diffusion model using fp32")
# 2. Set Projector precision
if args.projector_mode == "bf16_full":
model.state_projector.to(torch.bfloat16)
model.action_projector.to(torch.bfloat16)
model.projector_autocast_dtype = None
print(" ✓ Projectors converted to bfloat16")
elif args.projector_mode == "autocast":
model.projector_autocast_dtype = torch.bfloat16
print(" ✓ Projectors will use autocast (weights fp32, compute bf16)")
else:
model.projector_autocast_dtype = None
# fp32 mode: do nothing, keep original precision
# 3. Set Encoder precision
if args.encoder_mode == "bf16_full":
model.embedder.to(torch.bfloat16)
model.image_proj_model.to(torch.bfloat16)
model.encoder_autocast_dtype = None
print(" ✓ Encoders converted to bfloat16")
elif args.encoder_mode == "autocast":
model.encoder_autocast_dtype = torch.bfloat16
print(" ✓ Encoders will use autocast (weights fp32, compute bf16)")
else:
model.encoder_autocast_dtype = None
# fp32 mode: do nothing, keep original precision
# 4. Set VAE precision
if args.vae_dtype == "bf16":
model.first_stage_model.to(torch.bfloat16)
print(" ✓ VAE converted to bfloat16")
else:
print(" ✓ VAE kept in fp32 for best quality")
# 5. Safety net: ensure no fp32 parameters remain when all components are bf16
if args.diffusion_dtype == "bf16":
fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32]
if fp32_params:
print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16")
for name, param in fp32_params:
param.data = param.data.to(torch.bfloat16)
print(" ✓ All parameters converted to bfloat16")
return model
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
from unifolm_wma.modules.networks.wma_model import ResBlock
unet = model.model.diffusion_model
compiled = 0
for idx in hot_indices:
block = unet.output_blocks[idx]
for layer in block:
if isinstance(layer, ResBlock):
layer._forward = torch.compile(layer._forward, mode="default")
compiled += 1
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
return model
def write_video(video_path: str, stacked_frames: list, fps: int) -> None: def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
"""Save a list of frames to a video file. """Save a list of frames to a video file.
@@ -79,17 +263,18 @@ def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
return file_list return file_list
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module: def load_model_checkpoint(model: nn.Module, ckpt: str, device: str = "cpu") -> nn.Module:
"""Load model weights from checkpoint file. """Load model weights from checkpoint file.
Args: Args:
model (nn.Module): Model instance. model (nn.Module): Model instance.
ckpt (str): Path to the checkpoint file. ckpt (str): Path to the checkpoint file.
device (str): Target device for loaded tensors.
Returns: Returns:
nn.Module: Model with loaded weights. nn.Module: Model with loaded weights.
""" """
state_dict = torch.load(ckpt, map_location="cpu") state_dict = torch.load(ckpt, map_location=device)
if "state_dict" in list(state_dict.keys()): if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"] state_dict = state_dict["state_dict"]
try: try:
@@ -156,81 +341,6 @@ def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
options={'crf': '10'}) options={'crf': '10'})
# ========== Async I/O ==========
_io_executor: Optional[ThreadPoolExecutor] = None
_io_futures: List[Any] = []
def _get_io_executor() -> ThreadPoolExecutor:
global _io_executor
if _io_executor is None:
_io_executor = ThreadPoolExecutor(max_workers=2)
return _io_executor
def _flush_io():
"""Wait for all pending async I/O to finish."""
global _io_futures
for fut in _io_futures:
try:
fut.result()
except Exception as e:
print(f">>> [async I/O] error: {e}")
_io_futures.clear()
atexit.register(_flush_io)
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
"""Synchronous save on CPU tensor (runs in background thread)."""
video = torch.clamp(video_cpu.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
"""Submit video saving to background thread pool."""
video_cpu = video.detach().cpu()
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
_io_futures.append(fut)
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
if video_cpu.dim() == 5:
n = video_cpu.shape[0]
video = video_cpu.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
"""Submit TensorBoard logging to background thread pool."""
if isinstance(data, torch.Tensor) and data.dim() == 5:
data_cpu = data.detach().cpu()
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
_io_futures.append(fut)
def get_init_frame_path(data_dir: str, sample: dict) -> str: def get_init_frame_path(data_dir: str, sample: dict) -> str:
"""Construct the init_frame path from directory and sample metadata. """Construct the init_frame path from directory and sample metadata.
@@ -343,6 +453,11 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
""" """
b, c, t, h, w = videos.shape b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w') x = rearrange(videos, 'b c t h w -> (b t) c h w')
# Auto-detect VAE dtype and convert input
vae_dtype = next(model.first_stage_model.parameters()).dtype
x = x.to(dtype=vae_dtype)
z = model.encode_first_stage(x) z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t) z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z return z
@@ -448,10 +563,22 @@ def image_guided_synthesis_sim_mode(
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
# Auto-detect model dtype and convert inputs accordingly
model_dtype = next(model.embedder.parameters()).dtype
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4) img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:] cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb) # Encoder autocast: weights stay fp32, compute in bf16
enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None)
if enc_ac_dtype is not None and model.device.type == 'cuda':
enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype)
else:
enc_ctx = nullcontext()
with enc_ctx:
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid': if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4)) z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
@@ -465,11 +592,22 @@ def image_guided_synthesis_sim_mode(
prompts = [""] * batch_size prompts = [""] * batch_size
cond_ins_emb = model.get_learned_conditioning(prompts) cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state_emb = model.state_projector(observation['observation.state']) # Auto-detect projector dtype and convert inputs
cond_state_emb = cond_state_emb + model.agent_state_pos_emb projector_dtype = next(model.state_projector.parameters()).dtype
cond_action_emb = model.action_projector(observation['action']) # Projector autocast: weights stay fp32, compute in bf16
cond_action_emb = cond_action_emb + model.agent_action_pos_emb proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None)
if proj_ac_dtype is not None and model.device.type == 'cuda':
proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype)
else:
proj_ctx = nullcontext()
with proj_ctx:
cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype))
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype))
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
if not sim_mode: if not sim_mode:
cond_action_emb = torch.zeros_like(cond_action_emb) cond_action_emb = torch.zeros_like(cond_action_emb)
@@ -491,9 +629,18 @@ def image_guided_synthesis_sim_mode(
kwargs.update({"unconditional_conditioning_img_nonetext": None}) kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None cond_mask = None
cond_z0 = None cond_z0 = None
# Setup autocast context for diffusion sampling
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
if autocast_dtype is not None and model.device.type == 'cuda':
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
else:
autocast_ctx = nullcontext()
batch_variants = None batch_variants = None
if ddim_sampler is not None: if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample( with autocast_ctx:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps, S=ddim_steps,
conditioning=cond, conditioning=cond,
batch_size=batch_size, batch_size=batch_size,
@@ -540,7 +687,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv") csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
df = pd.read_csv(csv_path) df = pd.read_csv(csv_path)
# Load config (always needed for data setup) # Load config
config = OmegaConf.load(args.config) config = OmegaConf.load(args.config)
prepared_path = args.ckpt_path + ".prepared.pt" prepared_path = args.ckpt_path + ".prepared.pt"
@@ -549,42 +696,58 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
print(f">>> Loading prepared model from {prepared_path} ...") print(f">>> Loading prepared model from {prepared_path} ...")
model = torch.load(prepared_path, model = torch.load(prepared_path,
map_location=f"cuda:{gpu_no}", map_location=f"cuda:{gpu_no}",
weights_only=False, weights_only=False)
mmap=True)
model.eval() model.eval()
# Restore autocast attributes (weights already cast, just need contexts)
model.diffusion_autocast_dtype = torch.bfloat16 if args.diffusion_dtype == "bf16" else torch.bfloat16
model.projector_autocast_dtype = torch.bfloat16 if args.projector_mode == "autocast" else None
model.encoder_autocast_dtype = torch.bfloat16 if args.encoder_mode == "autocast" else None
# Compile hot ResBlocks for operator fusion
apply_torch_compile(model)
print(f">>> Prepared model loaded.") print(f">>> Prepared model loaded.")
else: else:
# ---- Normal path: construct + load checkpoint ---- # ---- Normal path: construct + checkpoint + casting ----
config['model']['params']['wma_config']['params'][ config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False 'use_checkpoint'] = False
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!" assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path) model = load_model_checkpoint(model, args.ckpt_path,
device=f"cuda:{gpu_no}")
model.eval() model.eval()
model = model.cuda(gpu_no)
print(f'>>> Load pre-trained model ...') print(f'>>> Load pre-trained model ...')
# Save prepared model for fast loading next time # Apply precision settings before moving to GPU
model = apply_precision_settings(model, args)
# Export precision-converted checkpoint if requested
if args.export_precision_ckpt:
export_path = args.export_precision_ckpt
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
torch.save({"state_dict": model.state_dict()}, export_path)
print(f">>> Precision-converted checkpoint saved to: {export_path}")
return
model = model.cuda(gpu_no)
# Save prepared model for fast loading next time (before torch.compile)
print(f">>> Saving prepared model to {prepared_path} ...") print(f">>> Saving prepared model to {prepared_path} ...")
torch.save(model, prepared_path) torch.save(model, prepared_path)
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).") print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
# Compile hot ResBlocks for operator fusion (after save, compiled objects can't be pickled)
apply_torch_compile(model)
# Build normalizer (always needed, independent of model loading path) # Build normalizer (always needed, independent of model loading path)
logging.info("***** Configing Data *****") logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data) data = instantiate_from_config(config.data)
data.setup() data.setup()
print(">>> Dataset is successfully loaded ...") print(">>> Dataset is successfully loaded ...")
device = get_device_from_parameters(model) device = get_device_from_parameters(model)
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
from unifolm_wma.modules.attention import CrossAttention
kv_count = sum(1 for m in model.modules()
if isinstance(m, CrossAttention) and m.fuse_kv())
print(f" ✓ KV fused: {kv_count} attention layers")
# Run over data # Run over data
assert (args.height % 16 == 0) and ( assert (args.height % 16 == 0) and (
args.width % 16 args.width % 16
@@ -757,7 +920,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
cond_obs_queues = populate_queues(cond_obs_queues, cond_obs_queues = populate_queues(cond_obs_queues,
observation) observation)
# Save the imagen videos for decision-making (async) # Save the imagen videos for decision-making
if pred_videos_0 is not None: if pred_videos_0 is not None:
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}" sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
log_to_tensorboard_async(writer, log_to_tensorboard_async(writer,
@@ -921,10 +1084,40 @@ def get_parser():
type=int, type=int,
default=8, default=8,
help="fps for the saving video") help="fps for the saving video")
parser.add_argument(
"--diffusion_dtype",
type=str,
choices=["fp32", "bf16"],
default="bf16",
help="Diffusion backbone precision (fp32/bf16)")
parser.add_argument(
"--projector_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="bf16_full",
help="Projector precision mode (fp32/autocast/bf16_full)")
parser.add_argument(
"--encoder_mode",
type=str,
choices=["fp32", "autocast", "bf16_full"],
default="bf16_full",
help="Encoder precision mode (fp32/autocast/bf16_full)")
parser.add_argument(
"--vae_dtype",
type=str,
choices=["fp32", "bf16"],
default="fp32",
help="VAE precision (fp32/bf16, most affects image quality)")
parser.add_argument(
"--export_precision_ckpt",
type=str,
default=None,
help="Export precision-converted checkpoint to this path, then exit.")
return parser return parser
if __name__ == '__main__': if __name__ == '__main__':
patch_norm_bypass_autocast()
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
seed = args.seed seed = args.seed

View File

@@ -11,9 +11,6 @@ from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
def get_parser(**parser_kwargs): def get_parser(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs) parser = argparse.ArgumentParser(**parser_kwargs)

View File

@@ -99,6 +99,8 @@ class AutoencoderKL(pl.LightningModule):
print(f"Restored from {path}") print(f"Restored from {path}")
def encode(self, x, **kwargs): def encode(self, x, **kwargs):
if getattr(self, '_channels_last', False):
x = x.to(memory_format=torch.channels_last)
h = self.encoder(x) h = self.encoder(x)
moments = self.quant_conv(h) moments = self.quant_conv(h)
@@ -106,6 +108,8 @@ class AutoencoderKL(pl.LightningModule):
return posterior return posterior
def decode(self, z, **kwargs): def decode(self, z, **kwargs):
if getattr(self, '_channels_last', False):
z = z.to(memory_format=torch.channels_last)
z = self.post_quant_conv(z) z = self.post_quant_conv(z)
dec = self.decoder(z) dec = self.decoder(z)
return dec return dec

View File

@@ -1074,10 +1074,10 @@ class LatentDiffusion(DDPM):
encoder_posterior = self.first_stage_model.encode(x) encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach() results = self.get_first_stage_encoding(encoder_posterior).detach()
else: ## Consume less GPU memory but slower else: ## Consume less GPU memory but slower
bs = getattr(self, 'vae_encode_bs', 1)
results = [] results = []
for index in range(x.shape[0]): for i in range(0, x.shape[0], bs):
frame_batch = self.first_stage_model.encode(x[index:index + frame_batch = self.first_stage_model.encode(x[i:i + bs])
1, :, :, :])
frame_result = self.get_first_stage_encoding( frame_result = self.get_first_stage_encoding(
frame_batch).detach() frame_batch).detach()
results.append(frame_result) results.append(frame_result)
@@ -1105,14 +1105,18 @@ class LatentDiffusion(DDPM):
else: else:
reshape_back = False reshape_back = False
# Align input dtype with VAE weights (e.g. fp32 samples → bf16 VAE)
vae_dtype = next(self.first_stage_model.parameters()).dtype
z = z.to(dtype=vae_dtype)
z = 1. / self.scale_factor * z
if not self.perframe_ae: if not self.perframe_ae:
z = 1. / self.scale_factor * z
results = self.first_stage_model.decode(z, **kwargs) results = self.first_stage_model.decode(z, **kwargs)
else: else:
bs = getattr(self, 'vae_decode_bs', 1)
results = [] results = []
for index in range(z.shape[0]): for i in range(0, z.shape[0], bs):
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :] frame_result = self.first_stage_model.decode(z[i:i + bs], **kwargs)
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
results.append(frame_result) results.append(frame_result)
results = torch.cat(results, dim=0) results = torch.cat(results, dim=0)
@@ -1799,7 +1803,9 @@ class LatentDiffusion(DDPM):
""" """
if ddim: if ddim:
ddim_sampler = DDIMSampler(self) if not hasattr(self, '_ddim_sampler') or self._ddim_sampler is None:
self._ddim_sampler = DDIMSampler(self)
ddim_sampler = self._ddim_sampler
shape = (self.channels, self.temporal_length, *self.image_size) shape = (self.channels, self.temporal_length, *self.image_size)
samples, actions, states, intermediates = ddim_sampler.sample( samples, actions, states, intermediates = ddim_sampler.sample(
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs) ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
@@ -2457,7 +2463,6 @@ class DiffusionWrapper(pl.LightningModule):
Returns: Returns:
Output from the inner diffusion model (tensor or tuple, depending on the model). Output from the inner diffusion model (tensor or tuple, depending on the model).
""" """
if self.conditioning_key is None: if self.conditioning_key is None:
out = self.diffusion_model(x, t) out = self.diffusion_model(x, t)
elif self.conditioning_key == 'concat': elif self.conditioning_key == 'concat':

View File

@@ -8,12 +8,14 @@ class SinusoidalPosEmb(nn.Module):
def __init__(self, dim): def __init__(self, dim):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
# Dummy buffer so .to(dtype) propagates to this module
self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False)
def forward(self, x): def forward(self, x):
device = x.device device = x.device
half_dim = self.dim // 2 half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :] emb = x.float()[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1) emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb return emb.to(self._dtype_buf.dtype)

View File

@@ -0,0 +1,7 @@
{
"permissions": {
"allow": [
"Bash(python3:*)"
]
}
}

View File

@@ -18,6 +18,7 @@ class DDIMSampler(object):
self.ddpm_num_timesteps = model.num_timesteps self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule self.schedule = schedule
self.counter = 0 self.counter = 0
self._schedule_key = None # (ddim_num_steps, ddim_discretize, ddim_eta)
def register_buffer(self, name, attr): def register_buffer(self, name, attr):
if type(attr) == torch.Tensor: if type(attr) == torch.Tensor:
@@ -30,6 +31,11 @@ class DDIMSampler(object):
ddim_discretize="uniform", ddim_discretize="uniform",
ddim_eta=0., ddim_eta=0.,
verbose=True): verbose=True):
key = (ddim_num_steps, ddim_discretize, ddim_eta)
if self._schedule_key == key:
return
self._schedule_key = key
self.ddim_timesteps = make_ddim_timesteps( self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize, ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps, num_ddim_timesteps=ddim_num_steps,
@@ -38,7 +44,7 @@ class DDIMSampler(object):
alphas_cumprod = self.model.alphas_cumprod alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[ assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model
.device) .device)
if self.model.use_dynamic_rescale: if self.model.use_dynamic_rescale:
@@ -211,9 +217,9 @@ class DDIMSampler(object):
if precision is not None: if precision is not None:
if precision == 16: if precision == 16:
img = img.to(dtype=torch.float16) img = img.to(dtype=torch.bfloat16)
action = action.to(dtype=torch.float16) action = action.to(dtype=torch.bfloat16)
state = state.to(dtype=torch.float16) state = state.to(dtype=torch.bfloat16)
if timesteps is None: if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
@@ -384,10 +390,10 @@ class DDIMSampler(object):
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# Use 0-d tensors directly (already on device); broadcasting handles shape # Use 0-d tensors directly (already on device); broadcasting handles shape
a_t = alphas[index] a_t = alphas[index].to(x.dtype)
a_prev = alphas_prev[index] a_prev = alphas_prev[index].to(x.dtype)
sigma_t = sigmas[index] sigma_t = sigmas[index].to(x.dtype)
sqrt_one_minus_at = sqrt_one_minus_alphas[index] sqrt_one_minus_at = sqrt_one_minus_alphas[index].to(x.dtype)
if self.model.parameterization != "v": if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()

View File

@@ -86,9 +86,8 @@ class CrossAttention(nn.Module):
self.relative_position_v = RelativePosition( self.relative_position_v = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length) num_units=dim_head, max_relative_position=temporal_length)
else: else:
## only used for spatial attention, while NOT for temporal attention ## bmm fused-scale attention for all non-relative-position cases
if XFORMERS_IS_AVAILBLE and temporal_length is None: self.forward = self.bmm_forward
self.forward = self.efficient_forward
self.video_length = video_length self.video_length = video_length
self.image_cross_attention = image_cross_attention self.image_cross_attention = image_cross_attention
@@ -100,7 +99,6 @@ class CrossAttention(nn.Module):
self.agent_action_context_len = agent_action_context_len self.agent_action_context_len = agent_action_context_len
self._kv_cache = {} self._kv_cache = {}
self._kv_cache_enabled = False self._kv_cache_enabled = False
self._kv_fused = False
self.cross_attention_scale_learnable = cross_attention_scale_learnable self.cross_attention_scale_learnable = cross_attention_scale_learnable
if self.image_cross_attention: if self.image_cross_attention:
@@ -118,27 +116,6 @@ class CrossAttention(nn.Module):
self.register_parameter('alpha_caa', self.register_parameter('alpha_caa',
nn.Parameter(torch.tensor(0.))) nn.Parameter(torch.tensor(0.)))
def fuse_kv(self):
"""Fuse to_k/to_v into to_kv (2 Linear → 1). Works for all layers."""
k_w = self.to_k.weight # (inner_dim, context_dim)
v_w = self.to_v.weight
self.to_kv = nn.Linear(k_w.shape[1], k_w.shape[0] * 2, bias=False)
self.to_kv.weight = nn.Parameter(torch.cat([k_w, v_w], dim=0))
del self.to_k, self.to_v
if self.image_cross_attention:
for suffix in ('_ip', '_as', '_aa'):
k_attr = f'to_k{suffix}'
v_attr = f'to_v{suffix}'
kw = getattr(self, k_attr).weight
vw = getattr(self, v_attr).weight
fused = nn.Linear(kw.shape[1], kw.shape[0] * 2, bias=False)
fused.weight = nn.Parameter(torch.cat([kw, vw], dim=0))
setattr(self, f'to_kv{suffix}', fused)
delattr(self, k_attr)
delattr(self, v_attr)
self._kv_fused = True
return True
def forward(self, x, context=None, mask=None): def forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None) spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None k_ip, v_ip, out_ip = None, None, None
@@ -150,7 +127,7 @@ class CrossAttention(nn.Module):
context = default(context, x) context = default(context, x)
if self.image_cross_attention and not spatial_self_attn: if self.image_cross_attention and not spatial_self_attn:
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..." # assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
context_agent_state = context[:, :self.agent_state_context_len, :] context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:, context_agent_action = context[:,
self.agent_state_context_len:self. self.agent_state_context_len:self.
@@ -165,28 +142,19 @@ class CrossAttention(nn.Module):
self.agent_action_context_len + self.agent_action_context_len +
self.text_context_len:, :] self.text_context_len:, :]
if self._kv_fused: k = self.to_k(context_ins)
k, v = self.to_kv(context_ins).chunk(2, dim=-1) v = self.to_v(context_ins)
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1) k_ip = self.to_k_ip(context_image)
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1) v_ip = self.to_v_ip(context_image)
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1) k_as = self.to_k_as(context_agent_state)
else: v_as = self.to_v_as(context_agent_state)
k = self.to_k(context_ins) k_aa = self.to_k_aa(context_agent_action)
v = self.to_v(context_ins) v_aa = self.to_v_aa(context_agent_action)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
else: else:
if not spatial_self_attn: if not spatial_self_attn:
context = context[:, :self.text_context_len, :] context = context[:, :self.text_context_len, :]
if self._kv_fused: k = self.to_k(context)
k, v = self.to_kv(context).chunk(2, dim=-1) v = self.to_v(context)
else:
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v)) (q, k, v))
@@ -207,7 +175,8 @@ class CrossAttention(nn.Module):
sim.masked_fill_(~(mask > 0.5), max_neg_value) sim.masked_fill_(~(mask > 0.5), max_neg_value)
# attention, what we cannot get enough of # attention, what we cannot get enough of
sim = sim.softmax(dim=-1) with torch.amp.autocast('cuda', enabled=False):
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v) out = torch.einsum('b i j, b j d -> b i d', sim, v)
if self.relative_position: if self.relative_position:
@@ -224,7 +193,8 @@ class CrossAttention(nn.Module):
sim_ip = torch.einsum('b i d, b j d -> b i j', q, sim_ip = torch.einsum('b i d, b j d -> b i j', q,
k_ip) * self.scale k_ip) * self.scale
del k_ip del k_ip
sim_ip = sim_ip.softmax(dim=-1) with torch.amp.autocast('cuda', enabled=False):
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip) out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h) out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
@@ -235,7 +205,8 @@ class CrossAttention(nn.Module):
sim_as = torch.einsum('b i d, b j d -> b i j', q, sim_as = torch.einsum('b i d, b j d -> b i j', q,
k_as) * self.scale k_as) * self.scale
del k_as del k_as
sim_as = sim_as.softmax(dim=-1) with torch.amp.autocast('cuda', enabled=False):
sim_as = sim_as.softmax(dim=-1)
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as) out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h) out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
@@ -246,7 +217,8 @@ class CrossAttention(nn.Module):
sim_aa = torch.einsum('b i d, b j d -> b i j', q, sim_aa = torch.einsum('b i d, b j d -> b i j', q,
k_aa) * self.scale k_aa) * self.scale
del k_aa del k_aa
sim_aa = sim_aa.softmax(dim=-1) with torch.amp.autocast('cuda', enabled=False):
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa) out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h) out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
@@ -264,168 +236,276 @@ class CrossAttention(nn.Module):
return self.to_out(out) return self.to_out(out)
def bmm_forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None
h = self.heads
q = self.to_q(x)
context = default(context, x)
use_cache = self._kv_cache_enabled and not spatial_self_attn
cache_hit = use_cache and len(self._kv_cache) > 0
if cache_hit:
# Reuse cached K/V (already in (b*h, n, d) shape)
k = self._kv_cache['k']
v = self._kv_cache['v']
if 'k_ip' in self._kv_cache:
k_ip = self._kv_cache['k_ip']
v_ip = self._kv_cache['v_ip']
k_as = self._kv_cache['k_as']
v_as = self._kv_cache['v_as']
k_aa = self._kv_cache['k_aa']
v_aa = self._kv_cache['v_aa']
q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
elif self.image_cross_attention and not spatial_self_attn:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:,
self.agent_state_context_len:self.
agent_state_context_len +
self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len +
self.agent_action_context_len:self.
agent_state_context_len +
self.agent_action_context_len +
self.text_context_len, :]
context_image = context[:, self.agent_state_context_len +
self.agent_action_context_len +
self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_ip, v_ip))
k_as, v_as = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_as, v_as))
k_aa, v_aa = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_aa, v_aa))
if use_cache:
self._kv_cache = {
'k': k, 'v': v,
'k_ip': k_ip, 'v_ip': v_ip,
'k_as': k_as, 'v_as': v_as,
'k_aa': k_aa, 'v_aa': v_aa,
}
else:
if not spatial_self_attn:
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
if use_cache:
self._kv_cache = {'k': k, 'v': v}
# baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul
sim = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
q, k.transpose(-1, -2), beta=0, alpha=self.scale)
if exists(mask):
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
sim.masked_fill_(~(mask > 0.5), max_neg_value)
with torch.amp.autocast('cuda', enabled=False):
sim = sim.softmax(dim=-1)
out = torch.bmm(sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
if k_ip is not None and k_as is not None and k_aa is not None:
## image cross-attention (k_ip/v_ip already in (b*h, n, d) shape)
sim_ip = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device),
q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale)
with torch.amp.autocast('cuda', enabled=False):
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.bmm(sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
## agent state cross-attention (k_as/v_as already in (b*h, n, d) shape)
sim_as = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device),
q, k_as.transpose(-1, -2), beta=0, alpha=self.scale)
with torch.amp.autocast('cuda', enabled=False):
sim_as = sim_as.softmax(dim=-1)
out_as = torch.bmm(sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
## agent action cross-attention (k_aa/v_aa already in (b*h, n, d) shape)
sim_aa = torch.baddbmm(
torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device),
q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale)
with torch.amp.autocast('cuda', enabled=False):
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.bmm(sim_aa, v_aa)
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
if out_ip is not None and out_as is not None and out_aa is not None:
if self.cross_attention_scale_learnable:
out = out + \
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
else:
out = out + \
self.image_cross_attention_scale * out_ip + \
self.agent_state_cross_attention_scale * out_as + \
self.agent_action_cross_attention_scale * out_aa
return self.to_out(out)
def efficient_forward(self, x, context=None, mask=None): def efficient_forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None) spatial_self_attn = (context is None)
k, v, out = None, None, None k, v, out = None, None, None
k_ip, v_ip, out_ip = None, None, None k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None k_aa, v_aa, out_aa = None, None, None
attn_mask_aa = None
h = self.heads
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
b, _, _ = q.shape if self.image_cross_attention and not spatial_self_attn:
q = q.unsqueeze(3).reshape(b, q.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, q.shape[1], self.dim_head).contiguous()
def _reshape_kv(t):
return t.unsqueeze(3).reshape(b, t.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, t.shape[1], self.dim_head).contiguous()
use_cache = self._kv_cache_enabled and not spatial_self_attn
cache_hit = use_cache and len(self._kv_cache) > 0
if cache_hit:
k = self._kv_cache['k']
v = self._kv_cache['v']
k_ip = self._kv_cache.get('k_ip')
v_ip = self._kv_cache.get('v_ip')
k_as = self._kv_cache.get('k_as')
v_as = self._kv_cache.get('v_as')
k_aa = self._kv_cache.get('k_aa')
v_aa = self._kv_cache.get('v_aa')
attn_mask_aa = self._kv_cache.get('attn_mask_aa')
elif self.image_cross_attention and not spatial_self_attn:
if context.shape[1] == self.text_context_len + self.video_length: if context.shape[1] == self.text_context_len + self.video_length:
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :] context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
if self._kv_fused: k = self.to_k(context)
k, v = self.to_kv(context).chunk(2, dim=-1) v = self.to_v(context)
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1) k_ip = self.to_k_ip(context_image)
else: v_ip = self.to_v_ip(context_image)
k = self.to_k(context)
v = self.to_v(context)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k, v = map(_reshape_kv, (k, v))
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
if use_cache:
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip}
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length: elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
context_agent_state = context[:, :self.agent_state_context_len, :] context_agent_state = context[:, :self.agent_state_context_len, :]
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :] context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :] context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
if self._kv_fused: k = self.to_k(context_ins)
k, v = self.to_kv(context_ins).chunk(2, dim=-1) v = self.to_v(context_ins)
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1) k_ip = self.to_k_ip(context_image)
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1) v_ip = self.to_v_ip(context_image)
else: k_as = self.to_k_as(context_agent_state)
k = self.to_k(context_ins) v_as = self.to_v_as(context_agent_state)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k, v = map(_reshape_kv, (k, v))
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
k_as, v_as = map(_reshape_kv, (k_as, v_as))
if use_cache:
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip, 'k_as': k_as, 'v_as': v_as}
else: else:
context_agent_state = context[:, :self.agent_state_context_len, :] context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :] context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :] context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :] context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
if self._kv_fused: k = self.to_k(context_ins)
k, v = self.to_kv(context_ins).chunk(2, dim=-1) v = self.to_v(context_ins)
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1) k_ip = self.to_k_ip(context_image)
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1) v_ip = self.to_v_ip(context_image)
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1) k_as = self.to_k_as(context_agent_state)
else: v_as = self.to_v_as(context_agent_state)
k = self.to_k(context_ins) k_aa = self.to_k_aa(context_agent_action)
v = self.to_v(context_ins) v_aa = self.to_v_aa(context_agent_action)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
k, v = map(_reshape_kv, (k, v)) attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip)) q.shape[1],
k_as, v_as = map(_reshape_kv, (k_as, v_as)) k_aa.shape[1],
k_aa, v_aa = map(_reshape_kv, (k_aa, v_aa)) block_size=16,
device=k_aa.device)
attn_mask_aa_raw = self._get_attn_mask_aa(x.shape[0],
q.shape[1],
k_aa.shape[1],
block_size=16,
device=k_aa.device)
attn_mask_aa = attn_mask_aa_raw.unsqueeze(1).repeat(1, h, 1, 1).reshape(
b * h, attn_mask_aa_raw.shape[1], attn_mask_aa_raw.shape[2]).to(q.dtype)
if use_cache:
self._kv_cache = {
'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip,
'k_as': k_as, 'v_as': v_as, 'k_aa': k_aa, 'v_aa': v_aa,
'attn_mask_aa': attn_mask_aa,
}
else: else:
if not spatial_self_attn: if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..." assert 1 > 2, ">>> ERROR: you should never go into here ..."
context = context[:, :self.text_context_len, :] context = context[:, :self.text_context_len, :]
if self._kv_fused: k = self.to_k(context)
k, v = self.to_kv(context).chunk(2, dim=-1) v = self.to_v(context)
else:
k = self.to_k(context) b, _, _ = q.shape
v = self.to_v(context) q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
k, v = map(_reshape_kv, (k, v))
if use_cache:
self._kv_cache = {'k': k, 'v': v}
if k is not None: if k is not None:
k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(k, v),
)
out = xformers.ops.memory_efficient_attention(q, out = xformers.ops.memory_efficient_attention(q,
k, k,
v, v,
attn_bias=None, attn_bias=None,
op=None) op=None)
out = (out.unsqueeze(0).reshape( out = (out.unsqueeze(0).reshape(
b, h, out.shape[1], b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1, self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1], 3).reshape(b, out.shape[1],
h * self.dim_head)) self.heads * self.dim_head))
if k_ip is not None: if k_ip is not None:
# For image cross-attention
k_ip, v_ip = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_ip, v_ip),
)
out_ip = xformers.ops.memory_efficient_attention(q, out_ip = xformers.ops.memory_efficient_attention(q,
k_ip, k_ip,
v_ip, v_ip,
attn_bias=None, attn_bias=None,
op=None) op=None)
out_ip = (out_ip.unsqueeze(0).reshape( out_ip = (out_ip.unsqueeze(0).reshape(
b, h, out_ip.shape[1], b, self.heads, out_ip.shape[1],
self.dim_head).permute(0, 2, 1, self.dim_head).permute(0, 2, 1,
3).reshape(b, out_ip.shape[1], 3).reshape(b, out_ip.shape[1],
h * self.dim_head)) self.heads * self.dim_head))
if k_as is not None: if k_as is not None:
# For agent state cross-attention
k_as, v_as = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_as, v_as),
)
out_as = xformers.ops.memory_efficient_attention(q, out_as = xformers.ops.memory_efficient_attention(q,
k_as, k_as,
v_as, v_as,
attn_bias=None, attn_bias=None,
op=None) op=None)
out_as = (out_as.unsqueeze(0).reshape( out_as = (out_as.unsqueeze(0).reshape(
b, h, out_as.shape[1], b, self.heads, out_as.shape[1],
self.dim_head).permute(0, 2, 1, self.dim_head).permute(0, 2, 1,
3).reshape(b, out_as.shape[1], 3).reshape(b, out_as.shape[1],
h * self.dim_head)) self.heads * self.dim_head))
if k_aa is not None: if k_aa is not None:
# For agent action cross-attention
k_aa, v_aa = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_aa, v_aa),
)
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
attn_mask_aa = attn_mask_aa.to(q.dtype)
out_aa = xformers.ops.memory_efficient_attention( out_aa = xformers.ops.memory_efficient_attention(
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None) q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
out_aa = (out_aa.unsqueeze(0).reshape( out_aa = (out_aa.unsqueeze(0).reshape(
b, h, out_aa.shape[1], b, self.heads, out_aa.shape[1],
self.dim_head).permute(0, 2, 1, self.dim_head).permute(0, 2, 1,
3).reshape(b, out_aa.shape[1], 3).reshape(b, out_aa.shape[1],
h * self.dim_head)) self.heads * self.dim_head))
if exists(mask): if exists(mask):
raise NotImplementedError raise NotImplementedError
@@ -463,7 +543,7 @@ class CrossAttention(nn.Module):
col_indices = torch.arange(l2, device=target_device) col_indices = torch.arange(l2, device=target_device)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1) mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2) mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device) attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device)
attn_mask[mask] = float('-inf') attn_mask[mask] = float('-inf')
self._attn_mask_aa_cache_key = cache_key self._attn_mask_aa_cache_key = cache_key

View File

@@ -11,7 +11,7 @@ from unifolm_wma.utils.utils import instantiate_from_config
def nonlinearity(x): def nonlinearity(x):
# swish # swish
return x * torch.sigmoid(x) return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32): def Normalize(in_channels, num_groups=32):

View File

@@ -422,7 +422,7 @@ class WMAModel(nn.Module):
self.temporal_attention = temporal_attention self.temporal_attention = temporal_attention
time_embed_dim = model_channels * 4 time_embed_dim = model_channels * 4
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32 self.dtype = torch.float16 if use_fp16 else torch.bfloat16
temporal_self_att_only = True temporal_self_att_only = True
self.addition_attention = addition_attention self.addition_attention = addition_attention
self.temporal_length = temporal_length self.temporal_length = temporal_length
@@ -688,17 +688,8 @@ class WMAModel(nn.Module):
# Context precomputation cache # Context precomputation cache
self._ctx_cache_enabled = False self._ctx_cache_enabled = False
self._ctx_cache = {} self._ctx_cache = {}
# Reusable CUDA stream for parallel state_unet / action_unet # fs_embed cache
self._state_stream = torch.cuda.Stream() self._fs_embed_cache = None
def __getstate__(self):
state = self.__dict__.copy()
state.pop('_state_stream', None)
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._state_stream = torch.cuda.Stream()
def forward(self, def forward(self,
x: Tensor, x: Tensor,
@@ -800,16 +791,20 @@ class WMAModel(nn.Module):
# Combine emb # Combine emb
if self.fs_condition: if self.fs_condition:
if fs is None: if self._ctx_cache_enabled and self._fs_embed_cache is not None:
fs = torch.tensor([self.default_fs] * b, fs_embed = self._fs_embed_cache
dtype=torch.long, else:
device=x.device) if fs is None:
fs_emb = timestep_embedding(fs, fs = torch.tensor([self.default_fs] * b,
self.model_channels, dtype=torch.long,
repeat_only=False).type(x.dtype) device=x.device)
fs_emb = timestep_embedding(fs,
fs_embed = self.fps_embedding(fs_emb) self.model_channels,
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0) repeat_only=False).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
if self._ctx_cache_enabled:
self._fs_embed_cache = fs_embed
emb = emb + fs_embed emb = emb + fs_embed
h = x.type(self.dtype) h = x.type(self.dtype)
@@ -853,16 +848,15 @@ class WMAModel(nn.Module):
if not self.base_model_gen_only: if not self.base_model_gen_only:
ba, _, _ = x_action.shape ba, _, _ = x_action.shape
ts_state = timesteps[:ba] if b > 1 else timesteps
# Run action_unet and state_unet in parallel via CUDA streams
s_stream = self._state_stream
s_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s_stream):
s_y = self.state_unet(x_state, ts_state, hs_a,
context_action[:2], **kwargs)
a_y = self.action_unet(x_action, timesteps[:ba], hs_a, a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
context_action[:2], **kwargs) context_action[:2], **kwargs)
torch.cuda.current_stream().wait_stream(s_stream) # Predict state
if b > 1:
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
else:
s_y = self.state_unet(x_state, timesteps, hs_a,
context_action[:2], **kwargs)
else: else:
a_y = torch.zeros_like(x_action) a_y = torch.zeros_like(x_action)
s_y = torch.zeros_like(x_state) s_y = torch.zeros_like(x_state)
@@ -876,6 +870,7 @@ def enable_ctx_cache(model):
if isinstance(m, WMAModel): if isinstance(m, WMAModel):
m._ctx_cache_enabled = True m._ctx_cache_enabled = True
m._ctx_cache = {} m._ctx_cache = {}
m._fs_embed_cache = None
# conditional_unet1d cache # conditional_unet1d cache
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules(): for m in model.modules():
@@ -890,6 +885,7 @@ def disable_ctx_cache(model):
if isinstance(m, WMAModel): if isinstance(m, WMAModel):
m._ctx_cache_enabled = False m._ctx_cache_enabled = False
m._ctx_cache = {} m._ctx_cache = {}
m._fs_embed_cache = None
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
for m in model.modules(): for m in model.modules():
if isinstance(m, ConditionalUnet1D): if isinstance(m, ConditionalUnet1D):

View File

@@ -7,7 +7,9 @@
# #
# thanks! # thanks!
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from unifolm_wma.utils.utils import instantiate_from_config from unifolm_wma.utils.utils import instantiate_from_config
@@ -78,7 +80,11 @@ def nonlinearity(type='silu'):
class GroupNormSpecific(nn.GroupNorm): class GroupNormSpecific(nn.GroupNorm):
def forward(self, x): def forward(self, x):
return super().forward(x.float()).type(x.dtype) with torch.amp.autocast('cuda', enabled=False):
return F.group_norm(x, self.num_groups,
self.weight.to(x.dtype) if self.weight is not None else None,
self.bias.to(x.dtype) if self.bias is not None else None,
self.eps)
def normalization(channels, num_groups=32): def normalization(channels, num_groups=32):

View File

@@ -0,0 +1,144 @@
2026-02-08 05:20:49.828675: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 05:20:49.831563: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:20:49.861366: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 05:20:49.861402: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 05:20:49.862974: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 05:20:49.870402: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:20:49.870647: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 05:20:50.486843: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:38<16:25, 98.56s/it]
18%|█▊ | 2/11 [03:16<14:44, 98.31s/it]
27%|██▋ | 3/11 [04:55<13:06, 98.33s/it]
36%|███▋ | 4/11 [06:36<11:37, 99.66s/it]
45%|████▌ | 5/11 [08:31<10:29, 104.96s/it]
55%|█████▍ | 6/11 [10:10<08:35, 103.07s/it]
64%|██████▎ | 7/11 [11:48<06:46, 101.50s/it]
73%|███████▎ | 8/11 [13:27<05:01, 100.52s/it]
82%|████████▏ | 9/11 [15:05<03:19, 99.79s/it]
91%|█████████ | 10/11 [16:43<01:39, 99.30s/it]
100%|██████████| 11/11 [18:21<00:00, 98.97s/it]
100%|██████████| 11/11 [18:21<00:00, 100.16s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4",
"pred_video": "unitree_g1_pack_camera/case1/output/inference/unitree_g1_pack_camera_case1_amd.mp4",
"psnr": 16.415668383379177
}

View File

@@ -0,0 +1,158 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 18:28:48.960238: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 18:28:48.963331: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:28:48.995688: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 18:28:48.995732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 18:28:48.997547: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 18:28:49.005673: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 18:28:49.005948: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 18:28:50.009660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
>>> Applying precision settings:
- Diffusion dtype: bf16
- Projector mode: bf16_full
- Encoder mode: bf16_full
- VAE dtype: fp32
✓ Diffusion model weights converted to bfloat16
✓ Projectors converted to bfloat16
✓ Encoders converted to bfloat16
✓ VAE kept in fp32 for best quality
⚠ Found 849 fp32 params, converting to bf16
✓ All parameters converted to bfloat16
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:14<12:29, 74.95s/it]
18%|█▊ | 2/11 [02:23<10:40, 71.18s/it]
27%|██▋ | 3/11 [03:32<09:20, 70.05s/it]
36%|███▋ | 4/11 [04:40<08:06, 69.51s/it]
45%|████▌ | 5/11 [05:49<06:55, 69.19s/it]
55%|█████▍ | 6/11 [06:57<05:44, 68.95s/it]
64%|██████▎ | 7/11 [08:06<04:35, 68.79s/it]
73%|███████▎ | 8/11 [09:14<03:26, 68.70s/it]
82%|████████▏ | 9/11 [10:23<02:17, 68.65s/it]
91%|█████████ | 10/11 [11:31<01:08, 68.58s/it]
100%|██████████| 11/11 [12:40<00:00, 68.51s/it]
100%|██████████| 11/11 [12:40<00:00, 69.11s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_g1_pack_camera/case2/unitree_g1_pack_camera_case2.mp4",
"pred_video": "unitree_g1_pack_camera/case2/output/inference/unitree_g1_pack_camera_case2_amd.mp4",
"psnr": 19.515250190529375
}

View File

@@ -0,0 +1,144 @@
2026-02-08 05:08:32.803904: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 05:08:32.807010: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:08:32.837936: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 05:08:32.837978: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 05:08:32.839785: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 05:08:32.847835: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:08:32.848223: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 05:08:34.120114: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:39<16:34, 99.46s/it]
18%|█▊ | 2/11 [03:18<14:55, 99.48s/it]
27%|██▋ | 3/11 [04:58<13:16, 99.60s/it]
36%|███▋ | 4/11 [06:38<11:37, 99.69s/it]
45%|████▌ | 5/11 [08:18<09:58, 99.68s/it]
55%|█████▍ | 6/11 [09:57<08:18, 99.66s/it]
64%|██████▎ | 7/11 [11:37<06:38, 99.62s/it]
73%|███████▎ | 8/11 [13:16<04:58, 99.55s/it]
82%|████████▏ | 9/11 [14:56<03:19, 99.50s/it]
91%|█████████ | 10/11 [16:35<01:39, 99.43s/it]
100%|██████████| 11/11 [18:14<00:00, 99.36s/it]
100%|██████████| 11/11 [18:14<00:00, 99.51s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_g1_pack_camera/case3/unitree_g1_pack_camera_case3.mp4",
"pred_video": "unitree_g1_pack_camera/case3/output/inference/unitree_g1_pack_camera_case3_amd.mp4",
"psnr": 19.429578160315536
}

View File

@@ -0,0 +1,144 @@
2026-02-08 05:29:19.728303: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 05:29:19.731620: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:29:19.761276: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 05:29:19.761301: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 05:29:19.762880: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 05:29:19.770578: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 05:29:19.771072: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 05:29:21.043661: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:37<16:18, 97.81s/it]
18%|█▊ | 2/11 [03:15<14:38, 97.56s/it]
27%|██▋ | 3/11 [04:52<12:59, 97.48s/it]
36%|███▋ | 4/11 [06:29<11:21, 97.38s/it]
45%|████▌ | 5/11 [08:06<09:43, 97.28s/it]
55%|█████▍ | 6/11 [09:44<08:06, 97.35s/it]
64%|██████▎ | 7/11 [11:21<06:29, 97.36s/it]
73%|███████▎ | 8/11 [12:59<04:52, 97.38s/it]
82%|████████▏ | 9/11 [14:36<03:14, 97.39s/it]
91%|█████████ | 10/11 [16:14<01:37, 97.42s/it]
100%|██████████| 11/11 [17:51<00:00, 97.42s/it]
100%|██████████| 11/11 [17:51<00:00, 97.41s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_g1_pack_camera/case4/unitree_g1_pack_camera_case4.mp4",
"pred_video": "unitree_g1_pack_camera/case4/output/inference/unitree_g1_pack_camera_case4_amd.mp4",
"psnr": 17.80386833747375
}

View File

@@ -1,11 +1,17 @@
2026-02-10 15:38:28.973314: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. /mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
2026-02-10 15:38:29.023024: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered __import__("pkg_resources").declare_namespace(__name__)
2026-02-10 15:38:29.023070: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2026-02-09 18:39:50.119842: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-10 15:38:29.024393: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2026-02-09 18:39:50.123128: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-10 15:38:29.031901: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. 2026-02-09 18:39:50.156652: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2026-02-09 18:39:50.156708: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-10 15:38:29.955454: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 2026-02-09 18:39:50.158926: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Global seed set to 123 2026-02-09 18:39:50.167779: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-09 18:39:50.168073: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-09 18:39:50.915144: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08 INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08 INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
@@ -14,11 +20,27 @@ INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443 DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0 DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config. INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0 DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k). INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded. >>> model checkpoint loaded.
>>> Load pre-trained model ... >>> Load pre-trained model ...
>>> Applying precision settings:
- Diffusion dtype: bf16
- Projector mode: bf16_full
- Encoder mode: bf16_full
- VAE dtype: bf16
✓ Diffusion model weights converted to bfloat16
✓ Projectors converted to bfloat16
✓ Encoders converted to bfloat16
✓ VAE converted to bfloat16
⚠ Found 601 fp32 params, converting to bf16
✓ All parameters converted to bfloat16
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
INFO:root:***** Configing Data ***** INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded. >>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded. >>> unitree_z1_stackbox: data stats loaded.
@@ -41,7 +63,9 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13 DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
>>> Step 0: generating actions ... >>> Step 0: generating actions ...
>>> Step 0: interacting with world model ... >>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin DEBUG:PIL.Image:Importing BlpImagePlugin
@@ -92,7 +116,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:08<07:58, 68.38s/it] 12%|█▎ | 1/8 [01:08<07:58, 68.38s/it]
25%|██▌ | 2/8 [02:13<06:38, 66.48s/it] 25%|██▌ | 2/8 [02:13<06:38, 66.48s/it]
@@ -116,6 +140,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 4: generating actions ... >>> Step 4: generating actions ...
>>> Step 4: interacting with world model ... >>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ... >>> Step 5: generating actions ...
>>> Step 5: interacting with world model ... >>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -1,5 +1,5 @@
{ {
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4", "gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4",
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4", "pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
"psnr": 47.911564449209735 "psnr": 19.586376345676264
} }

View File

@@ -0,0 +1,5 @@
{
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
"psnr": 31.802224855380352
}

View File

@@ -0,0 +1,5 @@
#\!/bin/bash
res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
dataset="unitree_z1_dual_arm_cleanup_pencils"
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/profile_iteration.py --seed 123 --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --savedir "${res_dir}/profile_output" --prompt_dir "${res_dir}/world_model_interaction_prompts" --dataset ${dataset} --bs 1 --height 320 --width 512 --unconditional_guidance_scale 1.0 --ddim_steps 50 --ddim_eta 1.0 --video_length 16 --frame_stride 4 --exe_steps 16 --n_iter 5 --warmup 1 --timestep_spacing uniform_trailing --guidance_rescale 0.7 --perframe_ae --vae_dtype bf16 --fast_policy_no_decode --csv "${res_dir}/profile_output/baseline.csv" 2>&1 | tee "${res_dir}/profile_output/profile.log"

View File

@@ -2,9 +2,9 @@ res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
dataset="unitree_z1_dual_arm_cleanup_pencils" dataset="unitree_z1_dual_arm_cleanup_pencils"
{ {
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \ --ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \
--savedir "${res_dir}/output" \ --savedir "${res_dir}/output" \
--bs 1 --height 320 --width 512 \ --bs 1 --height 320 --width 512 \
@@ -20,5 +20,10 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
--n_iter 8 \ --n_iter 8 \
--timestep_spacing 'uniform_trailing' \ --timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \ --guidance_rescale 0.7 \
--perframe_ae --perframe_ae \
--diffusion_dtype fp32 \
--projector_mode fp32 \
--encoder_mode fp32 \
--vae_dtype fp32 \
--fast_policy_no_decode
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"

View File

@@ -0,0 +1,137 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 06:59:34.465946: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 06:59:34.469367: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 06:59:34.500805: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 06:59:34.500837: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 06:59:34.502917: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 06:59:34.511434: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 06:59:34.511678: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 06:59:35.478194: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:37<11:23, 97.57s/it]
25%|██▌ | 2/8 [03:14<09:44, 97.48s/it]
38%|███▊ | 3/8 [04:52<08:07, 97.47s/it]
50%|█████ | 4/8 [06:29<06:29, 97.49s/it]
62%|██████▎ | 5/8 [08:07<04:52, 97.42s/it]
75%|███████▌ | 6/8 [09:44<03:14, 97.32s/it]
88%|████████▊ | 7/8 [11:21<01:37, 97.34s/it]
100%|██████████| 8/8 [12:59<00:00, 97.36s/it]
100%|██████████| 8/8 [12:59<00:00, 97.40s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case2/unitree_z1_dual_arm_cleanup_pencils_case2.mp4",
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case2/output/inference/unitree_z1_dual_arm_cleanup_pencils_case2_amd.mp4",
"psnr": 20.484298972158296
}

View File

@@ -0,0 +1,137 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:18:52.629976: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:18:52.633025: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:18:52.663985: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:18:52.664018: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:18:52.665837: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:18:52.673889: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:18:52.674218: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:18:53.298338: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:40<11:43, 100.54s/it]
25%|██▌ | 2/8 [03:20<10:02, 100.36s/it]
38%|███▊ | 3/8 [05:01<08:21, 100.32s/it]
50%|█████ | 4/8 [06:41<06:41, 100.36s/it]
62%|██████▎ | 5/8 [08:21<05:00, 100.30s/it]
75%|███████▌ | 6/8 [10:01<03:20, 100.28s/it]
88%|████████▊ | 7/8 [11:42<01:40, 100.34s/it]
100%|██████████| 8/8 [13:22<00:00, 100.36s/it]
100%|██████████| 8/8 [13:22<00:00, 100.34s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case3/unitree_z1_dual_arm_cleanup_pencils_case3.mp4",
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case3/output/inference/unitree_z1_dual_arm_cleanup_pencils_case3_amd.mp4",
"psnr": 21.20205061239349
}

View File

@@ -0,0 +1,137 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:22:15.333099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:22:15.336215: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:22:15.366489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:22:15.366522: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:22:15.368294: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:22:15.376202: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:22:15.376444: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:22:15.995383: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
12%|█▎ | 1/8 [01:37<11:23, 97.68s/it]
25%|██▌ | 2/8 [03:15<09:47, 97.83s/it]
38%|███▊ | 3/8 [04:53<08:09, 97.91s/it]
50%|█████ | 4/8 [06:31<06:32, 98.03s/it]
62%|██████▎ | 5/8 [08:10<04:54, 98.11s/it]
75%|███████▌ | 6/8 [09:48<03:16, 98.18s/it]
88%|████████▊ | 7/8 [11:26<01:38, 98.24s/it]
100%|██████████| 8/8 [13:04<00:00, 98.16s/it]
100%|██████████| 8/8 [13:04<00:00, 98.09s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case4/unitree_z1_dual_arm_cleanup_pencils_case4.mp4",
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case4/output/inference/unitree_z1_dual_arm_cleanup_pencils_case4_amd.mp4",
"psnr": 21.130122583788612
}

View File

@@ -0,0 +1,134 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:24:40.357099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:24:40.360365: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:24:40.391744: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:24:40.391772: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:24:40.393608: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:24:40.401837: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:24:40.402077: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:24:41.022382: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:41<10:09, 101.63s/it]
29%|██▊ | 2/7 [03:20<08:18, 99.78s/it]
43%|████▎ | 3/7 [04:58<06:36, 99.24s/it]
57%|█████▋ | 4/7 [06:37<04:57, 99.05s/it]
71%|███████▏ | 5/7 [08:16<03:17, 98.90s/it]
86%|████████▌ | 6/7 [09:54<01:38, 98.80s/it]
100%|██████████| 7/7 [11:33<00:00, 98.70s/it]
100%|██████████| 7/7 [11:33<00:00, 99.03s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox/case1/unitree_z1_dual_arm_stackbox_case1.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox/case1/output/inference/unitree_z1_dual_arm_stackbox_case1_amd.mp4",
"psnr": 21.258130518117493
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case1"
dataset="unitree_z1_dual_arm_stackbox" dataset="unitree_z1_dual_arm_stackbox"
{ {
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \

View File

@@ -0,0 +1,134 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:25:18.653033: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:25:18.656060: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:25:18.687077: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:25:18.687119: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:25:18.688915: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:25:18.697008: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:25:18.697255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:25:19.338303: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:39<09:56, 99.35s/it]
29%|██▊ | 2/7 [03:18<08:17, 99.50s/it]
43%|████▎ | 3/7 [04:58<06:38, 99.54s/it]
57%|█████▋ | 4/7 [06:38<04:58, 99.52s/it]
71%|███████▏ | 5/7 [08:17<03:19, 99.55s/it]
86%|████████▌ | 6/7 [09:57<01:39, 99.53s/it]
100%|██████████| 7/7 [11:36<00:00, 99.50s/it]
100%|██████████| 7/7 [11:36<00:00, 99.51s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox/case2/unitree_z1_dual_arm_stackbox_case2.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox/case2/output/inference/unitree_z1_dual_arm_stackbox_case2_amd.mp4",
"psnr": 23.878153424077645
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case2"
dataset="unitree_z1_dual_arm_stackbox" dataset="unitree_z1_dual_arm_stackbox"
{ {
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \

View File

@@ -0,0 +1,134 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:35:33.682231: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:35:33.685275: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:35:33.716682: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:35:33.716728: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:35:33.718523: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:35:33.726756: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:35:33.727105: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:35:34.356722: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:41<10:06, 101.02s/it]
29%|██▊ | 2/7 [03:23<08:29, 101.84s/it]
43%|████▎ | 3/7 [05:04<06:45, 101.43s/it]
57%|█████▋ | 4/7 [06:45<05:04, 101.42s/it]
71%|███████▏ | 5/7 [08:27<03:22, 101.40s/it]
86%|████████▌ | 6/7 [10:08<01:41, 101.39s/it]
100%|██████████| 7/7 [11:49<00:00, 101.33s/it]
100%|██████████| 7/7 [11:49<00:00, 101.39s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox/case3/unitree_z1_dual_arm_stackbox_case3.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox/case3/output/inference/unitree_z1_dual_arm_stackbox_case3_amd.mp4",
"psnr": 25.400458754751128
}

View File

@@ -0,0 +1,134 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
14%|█▍ | 1/7 [01:38<09:52, 98.73s/it]
29%|██▊ | 2/7 [03:17<08:14, 98.85s/it]
43%|████▎ | 3/7 [04:56<06:35, 98.80s/it]
57%|█████▋ | 4/7 [06:35<04:56, 98.94s/it]
71%|███████▏ | 5/7 [08:14<03:17, 98.93s/it]
86%|████████▌ | 6/7 [09:53<01:38, 98.89s/it]
100%|██████████| 7/7 [11:31<00:00, 98.81s/it]
100%|██████████| 7/7 [11:31<00:00, 98.85s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox/case4/unitree_z1_dual_arm_stackbox_case4.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox/case4/output/inference/unitree_z1_dual_arm_stackbox_case4_amd.mp4",
"psnr": 24.098958457373858
}

View File

@@ -1,13 +1,34 @@
2026-02-11 11:59:27.241485: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. /mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
2026-02-11 11:59:27.291755: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered __import__("pkg_resources").declare_namespace(__name__)
2026-02-11 11:59:27.291807: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2026-02-08 07:51:23.961486: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-11 11:59:27.293169: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2026-02-08 07:51:24.200063: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-11 11:59:27.300838: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. 2026-02-08 07:51:24.522299: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. 2026-02-08 07:51:24.522350: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-11 11:59:28.228009: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT 2026-02-08 07:51:24.528237: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:51:24.579400: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:51:24.579644: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:51:25.781311: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123 Global seed set to 123
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ... /mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
>>> Prepared model loaded. @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data ***** INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded. >>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded. >>> unitree_z1_stackbox: data stats loaded.
@@ -25,16 +46,19 @@ INFO:root:***** Configing Data *****
>>> unitree_g1_pack_camera: data stats loaded. >>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated. >>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ... >>> Dataset is successfully loaded ...
✓ KV fused: 66 attention layers
>>> Generate 16 frames under each generation ... >>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5 DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13 DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9 DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096 DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ... >>> Step 0: generating actions ...
9%|▉ | 1/11 [00:34<05:40, 34.05s/it]>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
>>> Step 0: interacting with world model ... >>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin DEBUG:PIL.Image:Importing BmpImagePlugin
@@ -84,7 +108,9 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:38<16:20, 98.04s/it]
18%|█▊ | 2/11 [03:15<14:40, 97.81s/it] 18%|█▊ | 2/11 [03:15<14:40, 97.81s/it]
27%|██▋ | 3/11 [04:53<13:01, 97.72s/it] 27%|██▋ | 3/11 [04:53<13:01, 97.72s/it]
36%|███▋ | 4/11 [06:31<11:24, 97.71s/it] 36%|███▋ | 4/11 [06:31<11:24, 97.71s/it]
@@ -115,6 +141,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
>>> Step 6: generating actions ... >>> Step 6: generating actions ...
>>> Step 6: interacting with world model ... >>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ... >>> Step 7: generating actions ...
>>> Step 7: interacting with world model ... >>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -1,5 +1,5 @@
{ {
"gt_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4", "gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
"pred_video": "/home/qhy/unifolm-world-model-action/unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4", "pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/unitree_z1_dual_arm_stackbox_v2_case1_amd.mp4",
"psnr": 28.167025381705358 "psnr": 18.126776535969576
} }

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case1"
dataset="unitree_z1_dual_arm_stackbox_v2" dataset="unitree_z1_dual_arm_stackbox_v2"
{ {
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \
@@ -20,6 +20,5 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
--n_iter 11 \ --n_iter 11 \
--timestep_spacing 'uniform_trailing' \ --timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \ --guidance_rescale 0.7 \
--perframe_ae \ --perframe_ae
--fast_policy_no_decode
} 2>&1 | tee "${res_dir}/output.log" } 2>&1 | tee "${res_dir}/output.log"

View File

@@ -0,0 +1,146 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:56:31.144789: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:56:31.148256: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:56:31.178870: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:56:31.178898: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:56:31.180683: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:56:31.188800: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:56:31.189142: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:56:31.810098: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:40<16:41, 100.16s/it]
18%|█▊ | 2/11 [03:20<15:04, 100.47s/it]
27%|██▋ | 3/11 [05:01<13:24, 100.62s/it]
36%|███▋ | 4/11 [06:42<11:44, 100.69s/it]
45%|████▌ | 5/11 [08:22<10:02, 100.48s/it]
55%|█████▍ | 6/11 [10:02<08:21, 100.33s/it]
64%|██████▎ | 7/11 [11:42<06:40, 100.23s/it]
73%|███████▎ | 8/11 [13:22<05:00, 100.23s/it]
82%|████████▏ | 9/11 [15:03<03:20, 100.23s/it]
91%|█████████ | 10/11 [16:43<01:40, 100.33s/it]
100%|██████████| 11/11 [18:24<00:00, 100.41s/it]
100%|██████████| 11/11 [18:24<00:00, 100.39s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case2/unitree_z1_dual_arm_stackbox_v2_case2.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case2/output/inference/unitree_z1_dual_arm_stackbox_v2_case2_amd.mp4",
"psnr": 19.38130614773096
}

View File

@@ -0,0 +1,146 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 07:56:04.467082: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 07:56:04.470145: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:56:04.502248: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 07:56:04.502277: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 07:56:04.504088: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 07:56:04.512557: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 07:56:04.512830: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 07:56:05.259641: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:38<16:20, 98.03s/it]
18%|█▊ | 2/11 [03:16<14:43, 98.19s/it]
27%|██▋ | 3/11 [04:55<13:08, 98.54s/it]
36%|███▋ | 4/11 [06:33<11:29, 98.52s/it]
45%|████▌ | 5/11 [08:11<09:50, 98.38s/it]
55%|█████▍ | 6/11 [09:49<08:10, 98.11s/it]
64%|██████▎ | 7/11 [11:27<06:31, 97.97s/it]
73%|███████▎ | 8/11 [13:04<04:53, 97.83s/it]
82%|████████▏ | 9/11 [14:42<03:15, 97.72s/it]
91%|█████████ | 10/11 [16:19<01:37, 97.71s/it]
100%|██████████| 11/11 [17:57<00:00, 97.74s/it]
100%|██████████| 11/11 [17:57<00:00, 97.97s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case3/unitree_z1_dual_arm_stackbox_v2_case3.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case3/output/inference/unitree_z1_dual_arm_stackbox_v2_case3_amd.mp4",
"psnr": 18.74462122425683
}

View File

@@ -0,0 +1,146 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:04:16.104516: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:04:16.109112: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:04:16.138703: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:04:16.138737: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:04:16.140302: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:04:16.147672: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:04:16.147903: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:04:17.363218: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
9%|▉ | 1/11 [01:39<16:32, 99.26s/it]
18%|█▊ | 2/11 [03:17<14:49, 98.81s/it]
27%|██▋ | 3/11 [04:56<13:10, 98.76s/it]
36%|███▋ | 4/11 [06:35<11:31, 98.80s/it]
45%|████▌ | 5/11 [08:14<09:53, 98.85s/it]
55%|█████▍ | 6/11 [09:53<08:14, 98.87s/it]
64%|██████▎ | 7/11 [11:31<06:34, 98.68s/it]
73%|███████▎ | 8/11 [13:09<04:55, 98.49s/it]
82%|████████▏ | 9/11 [14:47<03:16, 98.38s/it]
91%|█████████ | 10/11 [16:25<01:38, 98.29s/it]
100%|██████████| 11/11 [18:03<00:00, 98.26s/it]
100%|██████████| 11/11 [18:03<00:00, 98.54s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case4/unitree_z1_dual_arm_stackbox_v2_case4.mp4",
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case4/output/inference/unitree_z1_dual_arm_stackbox_v2_case4_amd.mp4",
"psnr": 19.526448380726254
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case4"
dataset="unitree_z1_dual_arm_stackbox_v2" dataset="unitree_z1_dual_arm_stackbox_v2"
{ {
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \

View File

@@ -0,0 +1,149 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:12:47.424053: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:12:47.427280: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:12:47.458253: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:12:47.458288: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:12:47.462758: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:12:47.518283: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:12:47.518566: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:12:48.593011: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:38<18:08, 98.94s/it]
17%|█▋ | 2/12 [03:18<16:30, 99.01s/it]
25%|██▌ | 3/12 [04:57<14:51, 99.07s/it]
33%|███▎ | 4/12 [06:36<13:12, 99.04s/it]
42%|████▏ | 5/12 [08:15<11:33, 99.00s/it]
50%|█████ | 6/12 [09:54<09:54, 99.10s/it]
58%|█████▊ | 7/12 [11:33<08:14, 99.00s/it]
67%|██████▋ | 8/12 [13:13<06:38, 99.58s/it]
75%|███████▌ | 9/12 [14:54<04:59, 99.88s/it]
83%|████████▎ | 10/12 [16:33<03:19, 99.58s/it]
92%|█████████▏| 11/12 [18:12<01:39, 99.39s/it]
100%|██████████| 12/12 [19:51<00:00, 99.25s/it]
100%|██████████| 12/12 [19:51<00:00, 99.28s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_stackbox/case1/unitree_z1_stackbox_case1.mp4",
"pred_video": "unitree_z1_stackbox/case1/output/inference/unitree_z1_stackbox_case1_amd.mp4",
"psnr": 19.81391789862606
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case1"
dataset="unitree_z1_stackbox" dataset="unitree_z1_stackbox"
{ {
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ time CUDA_VISIBLE_DEVICES=5 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \

View File

@@ -0,0 +1,149 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:37<17:51, 97.37s/it]
17%|█▋ | 2/12 [03:14<16:13, 97.31s/it]
25%|██▌ | 3/12 [04:51<14:35, 97.26s/it]
33%|███▎ | 4/12 [06:29<12:58, 97.25s/it]
42%|████▏ | 5/12 [08:06<11:20, 97.24s/it]
50%|█████ | 6/12 [09:43<09:43, 97.24s/it]
58%|█████▊ | 7/12 [11:20<08:06, 97.27s/it]
67%|██████▋ | 8/12 [12:58<06:29, 97.36s/it]
75%|███████▌ | 9/12 [14:36<04:52, 97.49s/it]
83%|████████▎ | 10/12 [16:13<03:15, 97.52s/it]
92%|█████████▏| 11/12 [17:51<01:37, 97.47s/it]
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_stackbox/case2/unitree_z1_stackbox_case2.mp4",
"pred_video": "unitree_z1_stackbox/case2/output/inference/unitree_z1_stackbox_case2_amd.mp4",
"psnr": 21.083821459054743
}

View File

@@ -0,0 +1,149 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:16:22.299521: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:16:22.302545: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:16:22.335354: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:16:22.335389: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:16:22.337179: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:16:22.345296: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:16:22.345548: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:16:23.008743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[rank: 0] Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:39<18:16, 99.64s/it]
17%|█▋ | 2/12 [03:19<16:35, 99.56s/it]
25%|██▌ | 3/12 [04:58<14:55, 99.53s/it]
33%|███▎ | 4/12 [06:38<13:16, 99.53s/it]
42%|████▏ | 5/12 [08:17<11:36, 99.54s/it]
50%|█████ | 6/12 [09:57<09:57, 99.57s/it]
58%|█████▊ | 7/12 [11:37<08:18, 99.66s/it]
67%|██████▋ | 8/12 [13:17<06:39, 99.83s/it]
75%|███████▌ | 9/12 [14:57<04:59, 99.93s/it]
83%|████████▎ | 10/12 [16:37<03:19, 99.97s/it]
92%|█████████▏| 11/12 [18:17<01:39, 99.85s/it]
100%|██████████| 12/12 [19:56<00:00, 99.71s/it]
100%|██████████| 12/12 [19:56<00:00, 99.71s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_stackbox/case3/unitree_z1_stackbox_case3.mp4",
"pred_video": "unitree_z1_stackbox/case3/output/inference/unitree_z1_stackbox_case3_amd.mp4",
"psnr": 21.322784880212172
}

View File

@@ -0,0 +1,149 @@
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
__import__("pkg_resources").declare_namespace(__name__)
2026-02-08 08:25:54.657305: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-08 08:25:54.660628: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:25:54.691237: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2026-02-08 08:25:54.691275: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2026-02-08 08:25:54.693046: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2026-02-08 08:25:54.701142: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-08 08:25:54.701413: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-08 08:25:55.801367: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
Global seed set to 123
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
checkpoint = torch.load(checkpoint_path, map_location=map_location)
INFO:root:Loaded ViT-H-14 model config.
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
state_dict = torch.load(ckpt, map_location="cpu")
>>> model checkpoint loaded.
>>> Load pre-trained model ...
INFO:root:***** Configing Data *****
>>> unitree_z1_stackbox: 1 data samples loaded.
>>> unitree_z1_stackbox: data stats loaded.
>>> unitree_z1_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
>>> unitree_g1_pack_camera: 1 data samples loaded.
>>> unitree_g1_pack_camera: data stats loaded.
>>> unitree_g1_pack_camera: normalizer initiated.
>>> Dataset is successfully loaded ...
>>> Generate 16 frames under each generation ...
DEBUG:h5py._conv:Creating converter from 3 to 5
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
proj = linear(q, w, b)
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
attn_output = scaled_dot_product_attention(
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
attn_output = scaled_dot_product_attention(
>>> Step 0: generating actions ...
>>> Step 0: interacting with world model ...
DEBUG:PIL.Image:Importing BlpImagePlugin
DEBUG:PIL.Image:Importing BmpImagePlugin
DEBUG:PIL.Image:Importing BufrStubImagePlugin
DEBUG:PIL.Image:Importing CurImagePlugin
DEBUG:PIL.Image:Importing DcxImagePlugin
DEBUG:PIL.Image:Importing DdsImagePlugin
DEBUG:PIL.Image:Importing EpsImagePlugin
DEBUG:PIL.Image:Importing FitsImagePlugin
DEBUG:PIL.Image:Importing FitsStubImagePlugin
DEBUG:PIL.Image:Importing FliImagePlugin
DEBUG:PIL.Image:Importing FpxImagePlugin
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing FtexImagePlugin
DEBUG:PIL.Image:Importing GbrImagePlugin
DEBUG:PIL.Image:Importing GifImagePlugin
DEBUG:PIL.Image:Importing GribStubImagePlugin
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
DEBUG:PIL.Image:Importing IcnsImagePlugin
DEBUG:PIL.Image:Importing IcoImagePlugin
DEBUG:PIL.Image:Importing ImImagePlugin
DEBUG:PIL.Image:Importing ImtImagePlugin
DEBUG:PIL.Image:Importing IptcImagePlugin
DEBUG:PIL.Image:Importing JpegImagePlugin
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
DEBUG:PIL.Image:Importing McIdasImagePlugin
DEBUG:PIL.Image:Importing MicImagePlugin
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
DEBUG:PIL.Image:Importing MpegImagePlugin
DEBUG:PIL.Image:Importing MpoImagePlugin
DEBUG:PIL.Image:Importing MspImagePlugin
DEBUG:PIL.Image:Importing PalmImagePlugin
DEBUG:PIL.Image:Importing PcdImagePlugin
DEBUG:PIL.Image:Importing PcxImagePlugin
DEBUG:PIL.Image:Importing PdfImagePlugin
DEBUG:PIL.Image:Importing PixarImagePlugin
DEBUG:PIL.Image:Importing PngImagePlugin
DEBUG:PIL.Image:Importing PpmImagePlugin
DEBUG:PIL.Image:Importing PsdImagePlugin
DEBUG:PIL.Image:Importing QoiImagePlugin
DEBUG:PIL.Image:Importing SgiImagePlugin
DEBUG:PIL.Image:Importing SpiderImagePlugin
DEBUG:PIL.Image:Importing SunImagePlugin
DEBUG:PIL.Image:Importing TgaImagePlugin
DEBUG:PIL.Image:Importing TiffImagePlugin
DEBUG:PIL.Image:Importing WebPImagePlugin
DEBUG:PIL.Image:Importing WmfImagePlugin
DEBUG:PIL.Image:Importing XbmImagePlugin
DEBUG:PIL.Image:Importing XpmImagePlugin
DEBUG:PIL.Image:Importing XVThumbImagePlugin
8%|▊ | 1/12 [01:37<17:51, 97.38s/it]
17%|█▋ | 2/12 [03:14<16:12, 97.24s/it]
25%|██▌ | 3/12 [04:51<14:35, 97.28s/it]
33%|███▎ | 4/12 [06:29<12:59, 97.40s/it]
42%|████▏ | 5/12 [08:06<11:21, 97.30s/it]
50%|█████ | 6/12 [09:43<09:43, 97.17s/it]
58%|█████▊ | 7/12 [11:20<08:05, 97.07s/it]
67%|██████▋ | 8/12 [12:57<06:28, 97.02s/it]
75%|███████▌ | 9/12 [14:34<04:50, 96.98s/it]
83%|████████▎ | 10/12 [16:11<03:14, 97.00s/it]
92%|█████████▏| 11/12 [17:48<01:37, 97.06s/it]
100%|██████████| 12/12 [19:25<00:00, 97.13s/it]
100%|██████████| 12/12 [19:25<00:00, 97.14s/it]
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 1: generating actions ...
>>> Step 1: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 2: generating actions ...
>>> Step 2: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 3: generating actions ...
>>> Step 3: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 4: generating actions ...
>>> Step 4: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 5: generating actions ...
>>> Step 5: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 6: generating actions ...
>>> Step 6: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 7: generating actions ...
>>> Step 7: interacting with world model ...
>>>>>>>>>>>>>>>>>>>>>>>>
>>> Step 8: generating actions ...
>>> Step 8: interacting with world model ...

View File

@@ -0,0 +1,5 @@
{
"gt_video": "unitree_z1_stackbox/case4/unitree_z1_stackbox_case4.mp4",
"pred_video": "unitree_z1_stackbox/case4/output/inference/unitree_z1_stackbox_case4_amd.mp4",
"psnr": 25.32928948331741
}

View File

@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case4"
dataset="unitree_z1_stackbox" dataset="unitree_z1_stackbox"
{ {
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \ time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \ --seed 123 \
--ckpt_path ckpts/unifolm_wma_dual.ckpt \ --ckpt_path ckpts/unifolm_wma_dual.ckpt \
--config configs/inference/world_model_interaction.yaml \ --config configs/inference/world_model_interaction.yaml \