Compare commits
16 Commits
68d695d81d
...
second
| Author | SHA1 | Date | |
|---|---|---|---|
| 57ba85d147 | |||
| 2cef3e9e45 | |||
| a09d35ae5b | |||
| db848bca01 | |||
| b0ebb7006e | |||
| 125b85ce68 | |||
| 0b3b0e534a | |||
| 6dca3696d8 | |||
| f192c8aca9 | |||
| 4288c9d8c9 | |||
| a2cd34dd51 | |||
| 7338cc384a | |||
| f86ab51a04 | |||
| 75c798ded0 | |||
| e588182642 | |||
| e6c55a648c |
@@ -1,15 +0,0 @@
|
|||||||
{
|
|
||||||
"permissions": {
|
|
||||||
"allow": [
|
|
||||||
"Bash(conda env list:*)",
|
|
||||||
"Bash(mamba env:*)",
|
|
||||||
"Bash(micromamba env list:*)",
|
|
||||||
"Bash(echo:*)",
|
|
||||||
"Bash(git show:*)",
|
|
||||||
"Bash(nvidia-smi:*)",
|
|
||||||
"Bash(conda activate unifolm-wma)",
|
|
||||||
"Bash(conda info:*)",
|
|
||||||
"Bash(direnv allow:*)"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
2
.envrc
2
.envrc
@@ -1,2 +0,0 @@
|
|||||||
eval "$(conda shell.bash hook 2>/dev/null)"
|
|
||||||
conda activate unifolm-wma
|
|
||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -55,6 +55,7 @@ coverage.xml
|
|||||||
*.pot
|
*.pot
|
||||||
|
|
||||||
# Django stuff:
|
# Django stuff:
|
||||||
|
|
||||||
local_settings.py
|
local_settings.py
|
||||||
db.sqlite3
|
db.sqlite3
|
||||||
|
|
||||||
@@ -120,7 +121,6 @@ localTest/
|
|||||||
fig/
|
fig/
|
||||||
figure/
|
figure/
|
||||||
*.mp4
|
*.mp4
|
||||||
|
|
||||||
Data/ControlVAE.yml
|
Data/ControlVAE.yml
|
||||||
Data/Misc
|
Data/Misc
|
||||||
Data/Pretrained
|
Data/Pretrained
|
||||||
@@ -129,8 +129,5 @@ Experiment/checkpoint
|
|||||||
Experiment/log
|
Experiment/log
|
||||||
|
|
||||||
*.ckpt
|
*.ckpt
|
||||||
|
|
||||||
*.0
|
*.0
|
||||||
ckpts/unifolm_wma_dual.ckpt.prepared.pt
|
unitree_z1_dual_arm_cleanup_pencils/case1/profile_output/traces/wx-ms-w7900d-0032_742306.1770698186047591119.pt.trace.json
|
||||||
trt_engines/video_backbone.engine
|
|
||||||
trt_engines/video_backbone.onnx
|
|
||||||
|
|||||||
135
case4_run.log
Normal file
135
case4_run.log
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
nohup: ignoring input
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
14%|█▍ | 1/7 [01:38<09:52, 98.73s/it]
|
||||||
|
29%|██▊ | 2/7 [03:17<08:14, 98.85s/it]
|
||||||
|
43%|████▎ | 3/7 [04:56<06:35, 98.80s/it]
|
||||||
|
57%|█████▋ | 4/7 [06:35<04:56, 98.94s/it]
|
||||||
|
71%|███████▏ | 5/7 [08:14<03:17, 98.93s/it]
|
||||||
|
86%|████████▌ | 6/7 [09:53<01:38, 98.89s/it]
|
||||||
|
100%|██████████| 7/7 [11:31<00:00, 98.81s/it]
|
||||||
|
100%|██████████| 7/7 [11:31<00:00, 98.85s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
1
ckpts/configuration.json
Normal file
1
ckpts/configuration.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"framework": "pytorch", "task": "robotics", "allow_remote": true}
|
||||||
@@ -222,7 +222,7 @@ data:
|
|||||||
test:
|
test:
|
||||||
target: unifolm_wma.data.wma_data.WMAData
|
target: unifolm_wma.data.wma_data.WMAData
|
||||||
params:
|
params:
|
||||||
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||||
video_length: ${model.params.wma_config.params.temporal_length}
|
video_length: ${model.params.wma_config.params.temporal_length}
|
||||||
frame_stride: 2
|
frame_stride: 2
|
||||||
load_raw_resolution: True
|
load_raw_resolution: True
|
||||||
|
|||||||
21
env.sh
Normal file
21
env.sh
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# Note: This script should be sourced, not executed
|
||||||
|
# Usage: source env.sh
|
||||||
|
#
|
||||||
|
# If you need render group permissions, run this first:
|
||||||
|
# newgrp render
|
||||||
|
# Then source this script:
|
||||||
|
# source env.sh
|
||||||
|
|
||||||
|
# Initialize conda
|
||||||
|
source /mnt/ASC1637/miniconda3/etc/profile.d/conda.sh
|
||||||
|
|
||||||
|
# Activate conda environment
|
||||||
|
conda activate unifolm-wma-o
|
||||||
|
|
||||||
|
# Set HuggingFace cache directories
|
||||||
|
export HF_HOME=/mnt/ASC1637/hf_home
|
||||||
|
export HUGGINGFACE_HUB_CACHE=/mnt/ASC1637/hf_home/hub
|
||||||
|
|
||||||
|
echo "Environment configured successfully"
|
||||||
|
echo "Conda environment: unifolm-wma-o"
|
||||||
|
echo "HF_HOME: $HF_HOME"
|
||||||
217
profile_unet_flops.md
Normal file
217
profile_unet_flops.md
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_unet.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml
|
||||||
|
|
||||||
|
|
||||||
|
==================================================================================================================================
|
||||||
|
FLOPS BY ATen OPERATOR (FlopCounterMode)
|
||||||
|
==================================================================================================================================
|
||||||
|
ATen Op | GFLOPS | % of Total
|
||||||
|
-------------------------------------------------------
|
||||||
|
convolution | 6185.17 | 46.4%
|
||||||
|
addmm | 4411.17 | 33.1%
|
||||||
|
mm | 1798.34 | 13.5%
|
||||||
|
bmm | 949.54 | 7.1%
|
||||||
|
|
||||||
|
==================================================================================================================================
|
||||||
|
FLOPS BY MODULE (FlopCounterMode)
|
||||||
|
==================================================================================================================================
|
||||||
|
Module | GFLOPS | % of Total
|
||||||
|
------------------------------------------------------------------------------------------
|
||||||
|
Global | 13344.23 | 100.0%
|
||||||
|
DiffusionWrapper | 13344.23 | 100.0%
|
||||||
|
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
|
||||||
|
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
|
||||||
|
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
|
||||||
|
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
|
||||||
|
|
||||||
|
==================================================================================================================================
|
||||||
|
SUMMARY
|
||||||
|
==================================================================================================================================
|
||||||
|
Total CUDA time: 761.4 ms
|
||||||
|
Matmul CUDA time: 404.2 ms (53.1%)
|
||||||
|
Non-matmul CUDA time: 357.1 ms (46.9%)
|
||||||
|
Total FLOPS (FlopCounter): 13344.23 GFLOPS
|
||||||
|
Matmul throughput: 33.01 TFLOPS/s (54.1% of BF16 peak)
|
||||||
|
Overall throughput: 17.53 TFLOPS/s (28.7% of BF16 peak)
|
||||||
|
GPU peak (BF16): 61.0 TFLOPS
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
==================================================================================================================================
|
||||||
|
FLOPS BY ATen OPERATOR (FlopCounterMode)
|
||||||
|
==================================================================================================================================
|
||||||
|
ATen Op | GFLOPS | % of Total
|
||||||
|
-------------------------------------------------------
|
||||||
|
convolution | 6185.17 | 46.4%
|
||||||
|
addmm | 4411.17 | 33.1%
|
||||||
|
mm | 1798.34 | 13.5%
|
||||||
|
bmm | 949.54 | 7.1%
|
||||||
|
|
||||||
|
==================================================================================================================================
|
||||||
|
FLOPS BY MODULE (FlopCounterMode)
|
||||||
|
==================================================================================================================================
|
||||||
|
Module | GFLOPS | % of Total
|
||||||
|
------------------------------------------------------------------------------------------
|
||||||
|
DiffusionWrapper | 13344.23 | 100.0%
|
||||||
|
Global | 13344.23 | 100.0%
|
||||||
|
DiffusionWrapper.diffusion_model | 13344.23 | 100.0%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.8 | 997.87 | 7.5%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.5 | 992.91 | 7.4%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.9 | 941.81 | 7.1%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.10 | 857.93 | 6.4%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.11 | 857.93 | 6.4%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.6 | 821.71 | 6.2%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.1 | 765.65 | 5.7%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.2 | 765.65 | 5.7%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.7 | 737.82 | 5.5%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.3 | 732.87 | 5.5%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.4 | 732.87 | 5.5%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.5 | 645.55 | 4.8%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.8 | 640.59 | 4.8%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.4 | 611.99 | 4.6%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.7 | 607.04 | 4.5%
|
||||||
|
DiffusionWrapper.diffusion_model.init_attn | 459.02 | 3.4%
|
||||||
|
DiffusionWrapper.diffusion_model.init_attn.0 | 459.02 | 3.4%
|
||||||
|
nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5%
|
||||||
|
DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%
|
||||||
|
|
||||||
|
==================================================================================================================================
|
||||||
|
SUMMARY
|
||||||
|
==================================================================================================================================
|
||||||
|
Total CUDA time: 707.1 ms
|
||||||
|
Matmul CUDA time: 403.1 ms (57.0%)
|
||||||
|
Non-matmul CUDA time: 304.0 ms (43.0%)
|
||||||
|
Total FLOPS (FlopCounter): 13344.23 GFLOPS
|
||||||
|
Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak)
|
||||||
|
Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak)
|
||||||
|
GPU peak (BF16): 61.0 TFLOPS
|
||||||
|
(unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$
|
||||||
|
|
||||||
|
========================================================================
|
||||||
|
TABLE 1: STAGE TIMING
|
||||||
|
========================================================================
|
||||||
|
Stage Mean(ms) Std %
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
1_Image_Embedding 29.5 0.16 0.1%
|
||||||
|
2_VAE_Encode 51.3 0.06 0.1%
|
||||||
|
3_Text_Conditioning 14.7 0.18 0.0%
|
||||||
|
4_Projectors 0.2 0.03 0.0%
|
||||||
|
5_DDIM_Loop 33392.5 3.21 97.3%
|
||||||
|
6_VAE_Decode 808.4 1.00 2.4%
|
||||||
|
7_Post_Process 15.8 0.56 0.0%
|
||||||
|
------------------------------------------------------------------------
|
||||||
|
TOTAL 34312.4
|
||||||
|
|
||||||
|
================================================================================
|
||||||
|
TABLE 2: UNET SUB-MODULE BREAKDOWN
|
||||||
|
================================================================================
|
||||||
|
Module Type Total(ms) Count Per-call %
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
ResBlock 10256.3 1100 9.32 23.2%
|
||||||
|
SpatialTransformer 9228.2 800 11.54 20.9%
|
||||||
|
CrossAttention 8105.8 3300 2.46 18.3%
|
||||||
|
ConditionalUnet1D 6409.5 100 64.10 14.5%
|
||||||
|
TemporalTransformer 5847.0 850 6.88 13.2%
|
||||||
|
FeedForward 4338.1 1650 2.63 9.8%
|
||||||
|
UNet.out 73.8 50 1.48 0.2%
|
||||||
|
--------------------------------------------------------------------------------
|
||||||
|
TOTAL (hooked) 44258.7
|
||||||
|
|
||||||
|
==========================================================================================
|
||||||
|
TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)
|
||||||
|
==========================================================================================
|
||||||
|
Block Total(ms) % Breakdown
|
||||||
|
------------------------------------------------------------------------------------------
|
||||||
|
input_blocks.1 3376.2 7.6% SpatialTransformer=1101, CrossAttention=990, ResBlock=543, TemporalTransformer=454, FeedForward=288
|
||||||
|
input_blocks.2 3374.0 7.6% SpatialTransformer=1100, CrossAttention=991, ResBlock=540, TemporalTransformer=455, FeedForward=288
|
||||||
|
input_blocks.4 1592.4 3.6% SpatialTransformer=394, ResBlock=374, CrossAttention=303, TemporalTransformer=272, FeedForward=249
|
||||||
|
input_blocks.5 1642.5 3.7% ResBlock=425, SpatialTransformer=397, CrossAttention=303, TemporalTransformer=271, FeedForward=247
|
||||||
|
input_blocks.7 1469.0 3.3% ResBlock=416, SpatialTransformer=324, FeedForward=251, CrossAttention=240, TemporalTransformer=237
|
||||||
|
input_blocks.8 1543.7 3.5% ResBlock=491, SpatialTransformer=325, FeedForward=250, CrossAttention=240, TemporalTransformer=238
|
||||||
|
input_blocks.10 217.5 0.5% ResBlock=218
|
||||||
|
input_blocks.11 216.8 0.5% ResBlock=217
|
||||||
|
middle_block 848.9 1.9% ResBlock=434, SpatialTransformer=151, CrossAttention=134, TemporalTransformer=69, FeedForward=61
|
||||||
|
output_blocks.0 303.2 0.7% ResBlock=303
|
||||||
|
output_blocks.1 303.1 0.7% ResBlock=303
|
||||||
|
output_blocks.2 302.8 0.7% ResBlock=303
|
||||||
|
output_blocks.3 1734.8 3.9% ResBlock=687, SpatialTransformer=322, FeedForward=249, CrossAttention=239, TemporalTransformer=237
|
||||||
|
output_blocks.4 1739.8 3.9% ResBlock=688, SpatialTransformer=323, FeedForward=251, CrossAttention=239, TemporalTransformer=238
|
||||||
|
output_blocks.5 1622.3 3.7% ResBlock=570, SpatialTransformer=324, FeedForward=251, CrossAttention=239, TemporalTransformer=238
|
||||||
|
output_blocks.6 1881.0 4.3% ResBlock=664, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=250
|
||||||
|
output_blocks.7 1768.0 4.0% ResBlock=554, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
|
||||||
|
output_blocks.8 1688.7 3.8% ResBlock=474, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249
|
||||||
|
output_blocks.9 3558.6 8.0% SpatialTransformer=1096, CrossAttention=992, ResBlock=727, TemporalTransformer=454, FeedForward=290
|
||||||
|
output_blocks.10 3492.8 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
|
||||||
|
output_blocks.11 3493.3 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289
|
||||||
|
out 73.8 0.2% UNet.out=74
|
||||||
|
action_unet 3212.0 7.3% ConditionalUnet1D=3212
|
||||||
|
state_unet 3197.6 7.2% ConditionalUnet1D=3198
|
||||||
|
other 1606.2 3.6% TemporalTransformer=960, FeedForward=337, CrossAttention=309
|
||||||
|
------------------------------------------------------------------------------------------
|
||||||
|
TOTAL 44258.7
|
||||||
|
|
||||||
|
======================================================================
|
||||||
|
TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)
|
||||||
|
======================================================================
|
||||||
|
Component Total(ms) %
|
||||||
|
----------------------------------------------------------------------
|
||||||
|
CrossAttention 8105.8 65.1%
|
||||||
|
FeedForward 4338.1 34.9%
|
||||||
|
----------------------------------------------------------------------
|
||||||
|
TOTAL (attn+ff) 12443.9
|
||||||
|
|
||||||
|
==================================================
|
||||||
|
TABLE 3: MEMORY SUMMARY
|
||||||
|
==================================================
|
||||||
|
Initial allocated: 11.82 GB
|
||||||
|
Peak allocated: 14.43 GB
|
||||||
|
Delta (pipeline): 2.61 GB
|
||||||
|
|
||||||
|
============================================================
|
||||||
|
TABLE 4: THROUGHPUT
|
||||||
|
============================================================
|
||||||
|
Total pipeline latency: 34312.4 ms
|
||||||
|
DDIM loop latency: 33392.5 ms
|
||||||
|
DDIM steps: 50
|
||||||
|
CFG scale: 1.0 (1x UNet/step)
|
||||||
|
UNet forward calls: 50
|
||||||
|
Per DDIM step: 667.9 ms
|
||||||
|
Per UNet forward: 667.9 ms
|
||||||
|
VAE encode bandwidth: 0.1 GB/s (peak HBM: 864.0 GB/s)
|
||||||
|
VAE decode bandwidth: 0.0 GB/s (peak HBM: 864.0 GB/s)
|
||||||
|
GPU BF16 peak: 61.0 TFLOPS
|
||||||
|
|
||||||
|
Done.
|
||||||
208
record.md
208
record.md
@@ -1,208 +0,0 @@
|
|||||||
好,下面是完整的 decode 优化讲解。
|
|
||||||
|
|
||||||
---
|
|
||||||
背景:原始 decode 流程的问题
|
|
||||||
|
|
||||||
原始代码中,每轮迭代的 world model 交互会调用 image_guided_synthesis_sim_mode,内部对DDIM 采样得到的 latent 做全量 VAE decode:
|
|
||||||
|
|
||||||
# ddpms.py decode_core, perframe_ae=True
|
|
||||||
for index in range(z.shape[0]): # 逐帧循环,16帧就调16次
|
|
||||||
frame_z = 1. / self.scale_factor * z[index:index+1]
|
|
||||||
frame_result = self.first_stage_model.decode(frame_z)
|
|
||||||
results.append(frame_result)
|
|
||||||
|
|
||||||
假设 8 轮迭代、每轮 16 帧,WM 路径 decode 128 次,policy 路径再 decode 128 次,共 256 次 VAE decode。每次都是同步阻塞的。
|
|
||||||
|
|
||||||
decode 完的视频还会被逐轮写成 mp4 + tensorboard,产生大量磁盘 I/O。最后还要把所有轮的 decoded video 在内存中torch.cat
|
|
||||||
拼接,再写一次完整视频。
|
|
||||||
|
|
||||||
---
|
|
||||||
优化1:decode_video 开关——按需跳过 decode
|
|
||||||
|
|
||||||
文件: world_model_interaction.py函数 image_guided_synthesis_sim_mode
|
|
||||||
|
|
||||||
改动: 给函数加decode_video 参数(默认 False),返回值增加 raw samples:
|
|
||||||
|
|
||||||
def image_guided_synthesis_sim_mode(...,
|
|
||||||
decode_video: bool = False, # 新增
|
|
||||||
...) -> tuple[Tensor | None, Tensor, Tensor, Tensor | None]:
|
|
||||||
|
|
||||||
samples = None
|
|
||||||
if ddim_sampler is not None:
|
|
||||||
samples, actions, states, intermedia = ddim_sampler.sample(...)if decode_video: # 条件 decode
|
|
||||||
batch_images = model.decode_first_stage(samples)
|
|
||||||
batch_variants = batch_images
|
|
||||||
|
|
||||||
return batch_variants, actions, states, samples# 多返回 samples
|
|
||||||
|
|
||||||
调用侧:
|
|
||||||
- Policy 路径:由 CLI 参数 --fast_policy_no_decode 控制,只需要 action 时可跳过 decode
|
|
||||||
- WM 交互路径:传decode_video=False,只拿 raw latent
|
|
||||||
|
|
||||||
效果: WM 路径每轮省掉 16 帧全量 decode。
|
|
||||||
|
|
||||||
---
|
|
||||||
优化2:只decode observation 需要的帧
|
|
||||||
|
|
||||||
问题: WM 跳过了全量 decode,但下一轮的CLIP embedding 需要 pixel-space 图像做 observation。
|
|
||||||
|
|
||||||
改动: 只decode exe_steps 帧(通常 1帧),而不是全部 16 帧:
|
|
||||||
|
|
||||||
# WM 调用,不做全量 decode
|
|
||||||
pred_videos_1, _, pred_states, wm_samples = image_guided_synthesis_sim_mode(
|
|
||||||
..., decode_video=False)
|
|
||||||
|
|
||||||
# 只 decode exe_steps 帧给 observation
|
|
||||||
obs_pixels = model.decode_first_stage(
|
|
||||||
wm_samples[:, :, :args.exe_steps, :, :])
|
|
||||||
|
|
||||||
for idx in range(args.exe_steps):
|
|
||||||
observation = {
|
|
||||||
'observation.images.top':obs_pixels[0, :, idx:idx + 1].permute(1, 0, 2, 3),
|
|
||||||
...
|
|
||||||
}
|
|
||||||
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
|
||||||
|
|
||||||
关键细节: 必须逐帧填充 observation queue(idx:idx+1),不能全用最后一帧,否则 CLIP embedding 输入变了会影响精度。
|
|
||||||
|
|
||||||
效果: 每轮从 decode 16 帧降到 decode exe_steps 帧(省15 帧/轮)。
|
|
||||||
|
|
||||||
---
|
|
||||||
优化3:decode stream——GPU 上并行 decode 和 UNet
|
|
||||||
|
|
||||||
问题: 写入最终视频仍需要完整 segment 的 pixel,这部分 decode 还是要做。
|
|
||||||
|
|
||||||
思路: 用独立 CUDA stream 做 segment decode,和下一轮 UNet 推断在 GPU 上并行。
|
|
||||||
|
|
||||||
改动:
|
|
||||||
|
|
||||||
初始化:
|
|
||||||
decode_stream = torch.cuda.Stream(device=device)
|
|
||||||
pending_decode = None
|
|
||||||
|
|
||||||
循环尾部:
|
|
||||||
# 收集上一轮 decode 结果
|
|
||||||
if pending_decode is not None:
|
|
||||||
decode_stream.synchronize()
|
|
||||||
write_q.put(pending_decode.cpu())
|
|
||||||
pending_decode = None
|
|
||||||
|
|
||||||
# 在 decode stream 上启动当前轮 segment decode(不阻塞主线程)
|
|
||||||
latent_slice = wm_samples[:, :, :args.exe_steps]
|
|
||||||
decode_stream.wait_stream(torch.cuda.current_stream()) # 确保 latent 就绪
|
|
||||||
with torch.cuda.stream(decode_stream):
|
|
||||||
pending_decode = model.decode_first_stage(latent_slice)
|
|
||||||
# 主线程立即进入下一轮 UNet
|
|
||||||
|
|
||||||
循环结束后收集最后一轮:
|
|
||||||
if pending_decode is not None:
|
|
||||||
decode_stream.synchronize()
|
|
||||||
write_q.put(pending_decode.cpu())
|
|
||||||
|
|
||||||
原理: decode_stream.wait_stream() 建立 stream间依赖,确保 latent 产出后才开始 decode。两个 stream 的 kernel 可以被GPU
|
|
||||||
调度器交错执行。
|
|
||||||
|
|
||||||
效果: segment decode 时间被下一轮 UNet 推断掩盖。
|
|
||||||
|
|
||||||
---
|
|
||||||
优化4:Writer 进程——CPU 工作跨进程并行
|
|
||||||
|
|
||||||
问题: decode 完的tensor 需要转numpy + cv2 编码写盘,这是 CPU 密集型操作,Python GIL 限制线程并行。
|
|
||||||
|
|
||||||
改动:
|
|
||||||
|
|
||||||
辅助函数(主进程和子进程都能调用):
|
|
||||||
def _video_tensor_to_frames(video: Tensor) -> np.ndarray:
|
|
||||||
video = torch.clamp(video.float(), -1., 1.)
|
|
||||||
n = video.shape[0]
|
|
||||||
video = video.permute(2, 0, 1, 3, 4)
|
|
||||||
frame_grids = [
|
|
||||||
torchvision.utils.make_grid(f, nrow=int(n), padding=0) for f in video
|
|
||||||
]
|
|
||||||
grid = torch.stack(frame_grids, dim=0)
|
|
||||||
grid = ((grid + 1.0) / 2.0 * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
|
||||||
return grid.numpy()[:, :, :, ::-1] # RGB → BGR
|
|
||||||
|
|
||||||
Writer 进程:
|
|
||||||
def _video_writer_process(q: mp.Queue, filename: str, fps: int):
|
|
||||||
vwriter = None
|
|
||||||
while True:
|
|
||||||
item = q.get()
|
|
||||||
if item is None: # sentinel,退出
|
|
||||||
break
|
|
||||||
frames = _video_tensor_to_frames(item)
|
|
||||||
if vwriter is None:
|
|
||||||
h, w = frames.shape[1], frames.shape[2]
|
|
||||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
|
||||||
vwriter = cv2.VideoWriter(filename, fourcc, fps, (w, h))
|
|
||||||
for f in frames:
|
|
||||||
vwriter.write(f)
|
|
||||||
if vwriter is not None:
|
|
||||||
vwriter.release()
|
|
||||||
|
|
||||||
主进程启动 writer:
|
|
||||||
write_q = mp.Queue()
|
|
||||||
writer_proc = mp.Process(target=_video_writer_process,
|
|
||||||
args=(write_q, sample_full_video_file, args.save_fps))
|
|
||||||
writer_proc.start()
|
|
||||||
|
|
||||||
主进程通过 write_q.put(tensor.cpu()) 发送数据,循环结束发None sentinel 并join()。
|
|
||||||
|
|
||||||
效果:
|
|
||||||
- tensor→numpy 转换和cv2 编码不占主进程 CPU 时间
|
|
||||||
- 不受 GIL 限制
|
|
||||||
- cv2.VideoWriter 增量写入,不攒内存,不做最终 torch.cat
|
|
||||||
- Queue 自带背压,writer 处理不过来时 put 自然阻塞,不会 OOM
|
|
||||||
|
|
||||||
---
|
|
||||||
同时删除的冗余 I/O
|
|
||||||
|
|
||||||
- 每轮迭代的 WM 中间tensorboard log(log_to_tensorboard_async)
|
|
||||||
- 每轮迭代的 WM 中间 mp4(save_results_async)
|
|
||||||
- 每轮迭代的 policy 中间 mp4
|
|
||||||
- 最终的 torch.cat + 一次性 save_results_async 写完整视频
|
|
||||||
|
|
||||||
---
|
|
||||||
总结:优化前后对比
|
|
||||||
|
|
||||||
┌─────────────────────┬────────────────────────┬───────────────────────────┐
|
|
||||||
│ │ 原始 │ 优化后 │
|
|
||||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
|
||||||
│ WM decode帧数/轮 │ 16 │ exe_steps(通常1) │
|
|
||||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
|
||||||
│ segment decode 阻塞 │ 同步阻塞主线程 │ decode stream,GPU并行 │
|
|
||||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
|
||||||
│ tensor→numpy+写盘 │ 主线程同步/线程池(GIL) │ 独立进程,无GIL │
|
|
||||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
|
||||||
│ 视频写入方式 │ 攒内存→cat→一次写 │ cv2增量追加写 │
|
|
||||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
|
||||||
│ 中间I/O │ 每轮2×mp4+2×tb │ 无 │
|
|
||||||
├─────────────────────┼────────────────────────┼───────────────────────────┤
|
|
||||||
│ 内存占用 │ 随迭代数线性增长 │ 恒定(不攒decoded video) │
|
|
||||||
└─────────────────────┴────────────────────────┴───────────────────────────┘
|
|
||||||
|
|
||||||
✻ Cooked for 1m 31s
|
|
||||||
|
|
||||||
❯ 目前哪些优化是默认开启的
|
|
||||||
|
|
||||||
● 让我检查各个优化的开关状态。
|
|
||||||
|
|
||||||
● Searched for 2 patterns (ctrl+o to expand)
|
|
||||||
|
|
||||||
● 当前各优化的默认状态:
|
|
||||||
|
|
||||||
┌───────────────────────────────────┬────────────┬────────────────────────────────────────────────────────┐
|
|
||||||
│ 优化 │ 默认开启? │ 控制方式 │
|
|
||||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
|
||||||
│ WM路径跳过全量decode │ ✅ 是 │ 硬编码 decode_video=False(789行) │
|
|
||||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
|
||||||
│ 只decode exe_steps帧给observation │ ✅ 是 │ 无开关,直接生效(792行) │
|
|
||||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
|
||||||
│ decode stream GPU并行 │ ✅ 是 │ 无开关,直接生效(824-827行) │
|
|
||||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
|
||||||
│ Writer进程增量写盘 │ ✅ 是 │ 无开关,直接生效(672-675行) │
|
|
||||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
|
||||||
│ 删除中间wm mp4/tb保存 │ ✅ 是 │ 代码已删除 │
|
|
||||||
├───────────────────────────────────┼────────────┼────────────────────────────────────────────────────────┤
|
|
||||||
│ Policy路径跳过decode │ ❌ 否 │ 需传--fast_policy_no_decode(默认False,即默认decode) │
|
|
||||||
└───────────────────────────────────┴────────────┴────────────────────────────────────────────────────────┘
|
|
||||||
150
run.log
Normal file
150
run.log
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
nohup: ignoring input
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
8%|▊ | 1/12 [01:37<17:51, 97.37s/it]
|
||||||
|
17%|█▋ | 2/12 [03:14<16:13, 97.31s/it]
|
||||||
|
25%|██▌ | 3/12 [04:51<14:35, 97.26s/it]
|
||||||
|
33%|███▎ | 4/12 [06:29<12:58, 97.25s/it]
|
||||||
|
42%|████▏ | 5/12 [08:06<11:20, 97.24s/it]
|
||||||
|
50%|█████ | 6/12 [09:43<09:43, 97.24s/it]
|
||||||
|
58%|█████▊ | 7/12 [11:20<08:06, 97.27s/it]
|
||||||
|
67%|██████▋ | 8/12 [12:58<06:29, 97.36s/it]
|
||||||
|
75%|███████▌ | 9/12 [14:36<04:52, 97.49s/it]
|
||||||
|
83%|████████▎ | 10/12 [16:13<03:15, 97.52s/it]
|
||||||
|
92%|█████████▏| 11/12 [17:51<01:37, 97.47s/it]
|
||||||
|
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
||||||
|
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 6: generating actions ...
|
||||||
|
>>> Step 6: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 7: generating actions ...
|
||||||
|
>>> Step 7: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 8: generating actions ...
|
||||||
|
>>> Step 8: interacting with world model ...
|
||||||
114
run_all_case.sh
114
run_all_case.sh
@@ -1,114 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# 自动执行所有场景的所有case
|
|
||||||
# 总共5个场景,每个场景4个case,共20个case
|
|
||||||
# 设置环境变量(离线模式)
|
|
||||||
export HF_HUB_OFFLINE=1
|
|
||||||
export TRANSFORMERS_OFFLINE=1
|
|
||||||
|
|
||||||
# 颜色定义
|
|
||||||
RED='\033[0;31m'
|
|
||||||
GREEN='\033[0;32m'
|
|
||||||
YELLOW='\033[1;33m'
|
|
||||||
BLUE='\033[0;34m'
|
|
||||||
NC='\033[0m' # No Color
|
|
||||||
|
|
||||||
# 定义所有场景
|
|
||||||
SCENARIOS=(
|
|
||||||
"unitree_g1_pack_camera"
|
|
||||||
"unitree_z1_dual_arm_cleanup_pencils"
|
|
||||||
"unitree_z1_dual_arm_stackbox"
|
|
||||||
"unitree_z1_dual_arm_stackbox_v2"
|
|
||||||
"unitree_z1_stackbox"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 定义case数量
|
|
||||||
CASES=(1 2 3 4)
|
|
||||||
|
|
||||||
# 记录开始时间
|
|
||||||
START_TIME=$(date +%s)
|
|
||||||
LOG_FILE="run_all_cases_$(date +%Y%m%d_%H%M%S).log"
|
|
||||||
|
|
||||||
echo -e "${BLUE}========================================${NC}"
|
|
||||||
echo -e "${BLUE}开始执行所有场景的case${NC}"
|
|
||||||
echo -e "${BLUE}总共: ${#SCENARIOS[@]} 个场景 x ${#CASES[@]} 个case = $((${#SCENARIOS[@]} * ${#CASES[@]})) 个任务${NC}"
|
|
||||||
echo -e "${BLUE}日志文件: ${LOG_FILE}${NC}"
|
|
||||||
echo -e "${BLUE}========================================${NC}"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# 初始化计数器
|
|
||||||
TOTAL_CASES=$((${#SCENARIOS[@]} * ${#CASES[@]}))
|
|
||||||
CURRENT_CASE=0
|
|
||||||
SUCCESS_COUNT=0
|
|
||||||
FAIL_COUNT=0
|
|
||||||
|
|
||||||
# 记录失败的case
|
|
||||||
declare -a FAILED_CASES
|
|
||||||
|
|
||||||
# 遍历所有场景
|
|
||||||
for scenario in "${SCENARIOS[@]}"; do
|
|
||||||
echo -e "${YELLOW}>>> 场景: ${scenario}${NC}"
|
|
||||||
|
|
||||||
# 遍历所有case
|
|
||||||
for case_num in "${CASES[@]}"; do
|
|
||||||
CURRENT_CASE=$((CURRENT_CASE + 1))
|
|
||||||
case_dir="${scenario}/case${case_num}"
|
|
||||||
script_path="${case_dir}/run_world_model_interaction.sh"
|
|
||||||
|
|
||||||
echo -e "${BLUE}[${CURRENT_CASE}/${TOTAL_CASES}] 执行: ${case_dir}${NC}"
|
|
||||||
|
|
||||||
# 检查脚本是否存在
|
|
||||||
if [ ! -f "${script_path}" ]; then
|
|
||||||
echo -e "${RED}错误: 脚本不存在 ${script_path}${NC}"
|
|
||||||
FAIL_COUNT=$((FAIL_COUNT + 1))
|
|
||||||
FAILED_CASES+=("${case_dir} (脚本不存在)")
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 执行脚本
|
|
||||||
echo "开始时间: $(date '+%Y-%m-%d %H:%M:%S')"
|
|
||||||
|
|
||||||
if bash "${script_path}" >> "${LOG_FILE}" 2>&1; then
|
|
||||||
echo -e "${GREEN}✓ 成功: ${case_dir}${NC}"
|
|
||||||
SUCCESS_COUNT=$((SUCCESS_COUNT + 1))
|
|
||||||
else
|
|
||||||
echo -e "${RED}✗ 失败: ${case_dir}${NC}"
|
|
||||||
FAIL_COUNT=$((FAIL_COUNT + 1))
|
|
||||||
FAILED_CASES+=("${case_dir}")
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')"
|
|
||||||
echo ""
|
|
||||||
done
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
done
|
|
||||||
|
|
||||||
# 计算总耗时
|
|
||||||
END_TIME=$(date +%s)
|
|
||||||
DURATION=$((END_TIME - START_TIME))
|
|
||||||
HOURS=$((DURATION / 3600))
|
|
||||||
MINUTES=$(((DURATION % 3600) / 60))
|
|
||||||
SECONDS=$((DURATION % 60))
|
|
||||||
|
|
||||||
# 输出总结
|
|
||||||
echo -e "${BLUE}========================================${NC}"
|
|
||||||
echo -e "${BLUE}执行完成!${NC}"
|
|
||||||
echo -e "${BLUE}========================================${NC}"
|
|
||||||
echo -e "总任务数: ${TOTAL_CASES}"
|
|
||||||
echo -e "${GREEN}成功: ${SUCCESS_COUNT}${NC}"
|
|
||||||
echo -e "${RED}失败: ${FAIL_COUNT}${NC}"
|
|
||||||
echo -e "总耗时: ${HOURS}小时 ${MINUTES}分钟 ${SECONDS}秒"
|
|
||||||
echo -e "详细日志: ${LOG_FILE}"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# 如果有失败的case,列出来
|
|
||||||
if [ ${FAIL_COUNT} -gt 0 ]; then
|
|
||||||
echo -e "${RED}失败的case列表:${NC}"
|
|
||||||
for failed_case in "${FAILED_CASES[@]}"; do
|
|
||||||
echo -e "${RED} - ${failed_case}${NC}"
|
|
||||||
done
|
|
||||||
echo ""
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo -e "${BLUE}========================================${NC}"
|
|
||||||
@@ -1,504 +0,0 @@
|
|||||||
2026-02-18 19:01:56.891895: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
|
||||||
2026-02-18 19:01:56.940243: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
|
||||||
2026-02-18 19:01:56.940285: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
|
||||||
2026-02-18 19:01:56.941395: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
|
||||||
2026-02-18 19:01:56.948327: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-18 19:01:57.870809: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
|
||||||
>>> Prepared model loaded.
|
|
||||||
INFO:root:***** Configing Data *****
|
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
|
||||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
|
||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
|
||||||
>>> Dataset is successfully loaded ...
|
|
||||||
✓ KV fused: 66 attention layers
|
|
||||||
TRT output 'y': [1, 4, 16, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_0': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_1': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_2': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_3': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_4': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_5': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_6': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_7': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_8': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
>>> TRT backbone loaded from /home/qhy/unifolm-world-model-action/scripts/evaluation/../../trt_engines/video_backbone.engine
|
|
||||||
>>> Generate 16 frames under each generation ...
|
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s][02/18/2026-19:02:10] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
|
|
||||||
|
|
||||||
9%|▉ | 1/11 [00:17<02:51, 17.15s/it]>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 1: generating actions ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
|
|
||||||
18%|█▊ | 2/11 [00:33<02:31, 16.87s/it]
|
|
||||||
27%|██▋ | 3/11 [00:50<02:14, 16.76s/it]
|
|
||||||
36%|███▋ | 4/11 [01:07<01:57, 16.81s/it]
|
|
||||||
45%|████▌ | 5/11 [01:24<01:41, 16.85s/it]
|
|
||||||
55%|█████▍ | 6/11 [01:41<01:24, 16.82s/it]
|
|
||||||
64%|██████▎ | 7/11 [01:57<01:07, 16.82s/it]
|
|
||||||
73%|███████▎ | 8/11 [02:14<00:50, 16.83s/it]
|
|
||||||
82%|████████▏ | 9/11 [02:31<00:33, 16.80s/it]
|
|
||||||
91%|█████████ | 10/11 [02:48<00:16, 16.81s/it]
|
|
||||||
100%|██████████| 11/11 [03:05<00:00, 16.81s/it]
|
|
||||||
100%|██████████| 11/11 [03:05<00:00, 16.83s/it]
|
|
||||||
>>> Step 1: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 2: generating actions ...
|
|
||||||
>>> Step 2: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 3: generating actions ...
|
|
||||||
>>> Step 3: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 4: generating actions ...
|
|
||||||
>>> Step 4: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 5: generating actions ...
|
|
||||||
>>> Step 5: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 6: generating actions ...
|
|
||||||
>>> Step 6: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 7: generating actions ...
|
|
||||||
>>> Step 7: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 8: generating actions ...
|
|
||||||
>>> Step 8: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 9: generating actions ...
|
|
||||||
>>> Step 9: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 10: generating actions ...
|
|
||||||
>>> Step 10: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
|
|
||||||
real 3m49.072s
|
|
||||||
user 4m16.055s
|
|
||||||
sys 0m44.636s
|
|
||||||
2026-02-18 19:05:45.956647: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
|
||||||
2026-02-18 19:05:46.004149: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
|
||||||
2026-02-18 19:05:46.004193: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
|
||||||
2026-02-18 19:05:46.005265: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
|
||||||
2026-02-18 19:05:46.012074: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-18 19:05:46.932966: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
|
||||||
>>> Prepared model loaded.
|
|
||||||
INFO:root:***** Configing Data *****
|
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
|
||||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
|
||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
|
||||||
>>> Dataset is successfully loaded ...
|
|
||||||
✓ KV fused: 66 attention layers
|
|
||||||
TRT output 'y': [1, 4, 16, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_0': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_1': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_2': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_3': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_4': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_5': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_6': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_7': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_8': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
>>> TRT backbone loaded from /home/qhy/unifolm-world-model-action/scripts/evaluation/../../trt_engines/video_backbone.engine
|
|
||||||
>>> Generate 16 frames under each generation ...
|
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s][02/18/2026-19:05:59] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
|
|
||||||
|
|
||||||
9%|▉ | 1/11 [00:16<02:47, 16.71s/it]>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 1: generating actions ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
|
|
||||||
18%|█▊ | 2/11 [00:33<02:30, 16.75s/it]
|
|
||||||
27%|██▋ | 3/11 [00:50<02:15, 16.91s/it]
|
|
||||||
36%|███▋ | 4/11 [01:07<01:59, 17.02s/it]
|
|
||||||
45%|████▌ | 5/11 [01:24<01:41, 16.98s/it]
|
|
||||||
55%|█████▍ | 6/11 [01:41<01:24, 16.94s/it]
|
|
||||||
64%|██████▎ | 7/11 [01:58<01:07, 16.90s/it]
|
|
||||||
73%|███████▎ | 8/11 [02:15<00:50, 16.83s/it]
|
|
||||||
82%|████████▏ | 9/11 [02:31<00:33, 16.80s/it]
|
|
||||||
91%|█████████ | 10/11 [02:49<00:16, 16.94s/it]
|
|
||||||
100%|██████████| 11/11 [03:06<00:00, 16.97s/it]
|
|
||||||
100%|██████████| 11/11 [03:06<00:00, 16.91s/it]
|
|
||||||
>>> Step 1: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 2: generating actions ...
|
|
||||||
>>> Step 2: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 3: generating actions ...
|
|
||||||
>>> Step 3: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 4: generating actions ...
|
|
||||||
>>> Step 4: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 5: generating actions ...
|
|
||||||
>>> Step 5: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 6: generating actions ...
|
|
||||||
>>> Step 6: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 7: generating actions ...
|
|
||||||
>>> Step 7: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 8: generating actions ...
|
|
||||||
>>> Step 8: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 9: generating actions ...
|
|
||||||
>>> Step 9: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 10: generating actions ...
|
|
||||||
>>> Step 10: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
|
|
||||||
real 3m49.162s
|
|
||||||
user 4m12.814s
|
|
||||||
sys 0m45.565s
|
|
||||||
2026-02-18 19:09:35.113634: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
|
||||||
2026-02-18 19:09:35.161428: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
|
||||||
2026-02-18 19:09:35.161474: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
|
||||||
2026-02-18 19:09:35.162551: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
|
||||||
2026-02-18 19:09:35.169325: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
|
||||||
2026-02-18 19:09:36.089250: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
|
||||||
Global seed set to 123
|
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
|
||||||
>>> Prepared model loaded.
|
|
||||||
INFO:root:***** Configing Data *****
|
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
|
||||||
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
|
||||||
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
|
||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
|
||||||
>>> Dataset is successfully loaded ...
|
|
||||||
✓ KV fused: 66 attention layers
|
|
||||||
TRT output 'y': [1, 4, 16, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_0': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_1': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_2': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_3': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_4': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_5': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_6': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_7': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_8': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
>>> TRT backbone loaded from /home/qhy/unifolm-world-model-action/scripts/evaluation/../../trt_engines/video_backbone.engine
|
|
||||||
>>> Generate 16 frames under each generation ...
|
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s][02/18/2026-19:09:49] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
|
|
||||||
|
|
||||||
9%|▉ | 1/11 [00:16<02:45, 16.53s/it]>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 1: generating actions ...
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
set -e
|
|
||||||
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
|
||||||
cd "$SCRIPT_DIR"
|
|
||||||
|
|
||||||
SCENARIOS=(
|
|
||||||
unitree_g1_pack_camera
|
|
||||||
unitree_z1_dual_arm_cleanup_pencils
|
|
||||||
unitree_z1_dual_arm_stackbox
|
|
||||||
unitree_z1_dual_arm_stackbox_v2
|
|
||||||
unitree_z1_stackbox
|
|
||||||
)
|
|
||||||
|
|
||||||
CASES=(case1 case2 case3 case4)
|
|
||||||
|
|
||||||
total=0
|
|
||||||
success=0
|
|
||||||
fail=0
|
|
||||||
|
|
||||||
for scenario in "${SCENARIOS[@]}"; do
|
|
||||||
for case in "${CASES[@]}"; do
|
|
||||||
case_dir="${scenario}/${case}"
|
|
||||||
gt_video="${case_dir}/${scenario}_${case}.mp4"
|
|
||||||
pred_video=$(ls "${case_dir}"/output/inference/*_full_fs*.mp4 2>/dev/null | head -1)
|
|
||||||
output_file="${case_dir}/psnr_result.json"
|
|
||||||
|
|
||||||
total=$((total + 1))
|
|
||||||
echo "=========================================="
|
|
||||||
echo "[${total}/20] ${case_dir}"
|
|
||||||
|
|
||||||
if [ ! -f "$gt_video" ]; then
|
|
||||||
echo " SKIP: GT video not found: $gt_video"
|
|
||||||
fail=$((fail + 1))
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
if [ -z "$pred_video" ]; then
|
|
||||||
echo " SKIP: pred video not found in ${case_dir}/output/inference/"
|
|
||||||
fail=$((fail + 1))
|
|
||||||
continue
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo " GT: $gt_video"
|
|
||||||
echo " Pred: $pred_video"
|
|
||||||
echo " Out: $output_file"
|
|
||||||
|
|
||||||
if python3 psnr_score_for_challenge.py \
|
|
||||||
--gt_video "$gt_video" \
|
|
||||||
--pred_video "$pred_video" \
|
|
||||||
--output_file "$output_file"; then
|
|
||||||
success=$((success + 1))
|
|
||||||
echo " DONE"
|
|
||||||
else
|
|
||||||
fail=$((fail + 1))
|
|
||||||
echo " FAILED"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "=========================================="
|
|
||||||
echo "Finished: ${success} success, ${fail} fail, ${total} total"
|
|
||||||
@@ -16,9 +16,6 @@ from collections import OrderedDict
|
|||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
|
|
||||||
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
975
scripts/evaluation/profile_iteration.py
Normal file
975
scripts/evaluation/profile_iteration.py
Normal file
@@ -0,0 +1,975 @@
|
|||||||
|
"""
|
||||||
|
Profile the full iteration loop of world model interaction.
|
||||||
|
|
||||||
|
Three layers of profiling:
|
||||||
|
Layer 1: Iteration-level wall-clock breakdown (CUDA events)
|
||||||
|
Layer 2: GPU timeline trace (torch.profiler → Chrome trace)
|
||||||
|
Layer 3: A/B comparison (standardized CSV output)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Layer 1 only (fast, default):
|
||||||
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python scripts/evaluation/profile_iteration.py \
|
||||||
|
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||||
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
|
--prompt_dir unitree_z1_dual_arm_cleanup_pencils/case1/world_model_interaction_prompts \
|
||||||
|
--dataset unitree_z1_dual_arm_cleanup_pencils \
|
||||||
|
--frame_stride 4 --n_iter 5
|
||||||
|
|
||||||
|
# Layer 1 + Layer 2 (GPU trace):
|
||||||
|
... --trace --trace_dir ./profile_traces
|
||||||
|
|
||||||
|
# Layer 3 (A/B comparison): run twice, diff the CSVs
|
||||||
|
... --csv baseline.csv
|
||||||
|
... --csv optimized.csv
|
||||||
|
python scripts/evaluation/profile_iteration.py --compare baseline.csv optimized.csv
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from PIL import Image
|
||||||
|
from pytorch_lightning import seed_everything
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Constants
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
STAGE_NAMES = [
|
||||||
|
"stack_to_device_1",
|
||||||
|
"synth_policy",
|
||||||
|
"update_action_queue",
|
||||||
|
"stack_to_device_2",
|
||||||
|
"synth_world_model",
|
||||||
|
"update_obs_queue",
|
||||||
|
"tensorboard_log",
|
||||||
|
"save_results",
|
||||||
|
"cpu_transfer",
|
||||||
|
"itr_total",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Sub-stages inside image_guided_synthesis_sim_mode
|
||||||
|
SYNTH_SUB_STAGES = [
|
||||||
|
"ddim_sampler_init",
|
||||||
|
"image_embedding",
|
||||||
|
"vae_encode",
|
||||||
|
"text_conditioning",
|
||||||
|
"projectors",
|
||||||
|
"cond_assembly",
|
||||||
|
"ddim_sampling",
|
||||||
|
"vae_decode",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# CudaTimer — GPU-precise timing via CUDA events
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
class CudaTimer:
|
||||||
|
"""Context manager that records GPU time between enter/exit using CUDA events."""
|
||||||
|
|
||||||
|
def __init__(self, name, records):
|
||||||
|
self.name = name
|
||||||
|
self.records = records
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self._start = torch.cuda.Event(enable_timing=True)
|
||||||
|
self._end = torch.cuda.Event(enable_timing=True)
|
||||||
|
self._start.record()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self._end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elapsed_ms = self._start.elapsed_time(self._end)
|
||||||
|
self.records[self.name].append(elapsed_ms)
|
||||||
|
|
||||||
|
|
||||||
|
class WallTimer:
|
||||||
|
"""Context manager that records CPU wall-clock time (for pure-CPU stages)."""
|
||||||
|
|
||||||
|
def __init__(self, name, records):
|
||||||
|
self.name = name
|
||||||
|
self.records = records
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self._t0 = time.perf_counter()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elapsed_ms = (time.perf_counter() - self._t0) * 1000.0
|
||||||
|
self.records[self.name].append(elapsed_ms)
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Model loading (reused from world_model_interaction.py)
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def patch_norm_bypass_autocast():
|
||||||
|
def _group_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.group_norm(
|
||||||
|
x, self.num_groups,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
def _layer_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.layer_norm(
|
||||||
|
x, self.normalized_shape,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||||
|
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||||
|
|
||||||
|
|
||||||
|
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||||
|
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||||
|
unet = model.model.diffusion_model
|
||||||
|
compiled = 0
|
||||||
|
for idx in hot_indices:
|
||||||
|
block = unet.output_blocks[idx]
|
||||||
|
for layer in block:
|
||||||
|
if isinstance(layer, ResBlock):
|
||||||
|
layer._forward = torch.compile(layer._forward, mode="default")
|
||||||
|
compiled += 1
|
||||||
|
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(args):
|
||||||
|
config = OmegaConf.load(args.config)
|
||||||
|
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.perframe_ae = args.perframe_ae
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||||
|
if "state_dict" in state_dict:
|
||||||
|
state_dict = state_dict["state_dict"]
|
||||||
|
try:
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
except Exception:
|
||||||
|
new_sd = OrderedDict()
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
new_sd[k] = v
|
||||||
|
for k in list(new_sd.keys()):
|
||||||
|
if "framestride_embed" in k:
|
||||||
|
new_sd[k.replace("framestride_embed", "fps_embedding")] = new_sd.pop(k)
|
||||||
|
model.load_state_dict(new_sd, strict=True)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Apply precision: bf16 diffusion + encoders + projectors, fp32/bf16 VAE
|
||||||
|
model.model.to(torch.bfloat16)
|
||||||
|
model.diffusion_autocast_dtype = torch.bfloat16
|
||||||
|
model.embedder.to(torch.bfloat16)
|
||||||
|
model.image_proj_model.to(torch.bfloat16)
|
||||||
|
model.encoder_autocast_dtype = None
|
||||||
|
model.state_projector.to(torch.bfloat16)
|
||||||
|
model.action_projector.to(torch.bfloat16)
|
||||||
|
model.projector_autocast_dtype = None
|
||||||
|
if args.vae_dtype == "bf16":
|
||||||
|
model.first_stage_model.to(torch.bfloat16)
|
||||||
|
|
||||||
|
# Compile hot ResBlocks
|
||||||
|
apply_torch_compile(model)
|
||||||
|
model = model.cuda()
|
||||||
|
print(">>> Model loaded and ready.")
|
||||||
|
return model, config
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Data preparation (reused from world_model_interaction.py)
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def get_init_frame_path(data_dir, sample):
|
||||||
|
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.png')
|
||||||
|
return os.path.join(data_dir, 'images', rel)
|
||||||
|
|
||||||
|
|
||||||
|
def get_transition_path(data_dir, sample):
|
||||||
|
rel = os.path.join(sample['data_dir'], str(sample['videoid']) + '.h5')
|
||||||
|
return os.path.join(data_dir, 'transitions', rel)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_init_input(start_idx, init_frame_path, transition_dict,
|
||||||
|
frame_stride, wma_data, video_length=16, n_obs_steps=2):
|
||||||
|
indices = [start_idx + frame_stride * i for i in range(video_length)]
|
||||||
|
init_frame = Image.open(init_frame_path).convert('RGB')
|
||||||
|
init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(3, 0, 1, 2).float()
|
||||||
|
|
||||||
|
if start_idx < n_obs_steps - 1:
|
||||||
|
state_indices = list(range(0, start_idx + 1))
|
||||||
|
states = transition_dict['observation.state'][state_indices, :]
|
||||||
|
num_padding = n_obs_steps - 1 - start_idx
|
||||||
|
padding = states[0:1, :].repeat(num_padding, 1)
|
||||||
|
states = torch.cat((padding, states), dim=0)
|
||||||
|
else:
|
||||||
|
state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
|
||||||
|
states = transition_dict['observation.state'][state_indices, :]
|
||||||
|
|
||||||
|
actions = transition_dict['action'][indices, :]
|
||||||
|
ori_state_dim = states.shape[-1]
|
||||||
|
ori_action_dim = actions.shape[-1]
|
||||||
|
|
||||||
|
frames_action_state_dict = {
|
||||||
|
'action': actions,
|
||||||
|
'observation.state': states,
|
||||||
|
}
|
||||||
|
frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
|
||||||
|
frames_action_state_dict = wma_data.get_uni_vec(
|
||||||
|
frames_action_state_dict,
|
||||||
|
transition_dict['action_type'],
|
||||||
|
transition_dict['state_type'],
|
||||||
|
)
|
||||||
|
|
||||||
|
if wma_data.spatial_transform is not None:
|
||||||
|
init_frame = wma_data.spatial_transform(init_frame)
|
||||||
|
init_frame = (init_frame / 255 - 0.5) * 2
|
||||||
|
|
||||||
|
data = {'observation.image': init_frame}
|
||||||
|
data.update(frames_action_state_dict)
|
||||||
|
return data, ori_state_dim, ori_action_dim
|
||||||
|
|
||||||
|
|
||||||
|
def populate_queues(queues, batch):
|
||||||
|
for key in batch:
|
||||||
|
if key not in queues:
|
||||||
|
continue
|
||||||
|
if len(queues[key]) != queues[key].maxlen:
|
||||||
|
while len(queues[key]) != queues[key].maxlen:
|
||||||
|
queues[key].append(batch[key])
|
||||||
|
else:
|
||||||
|
queues[key].append(batch[key])
|
||||||
|
return queues
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Instrumented image_guided_synthesis_sim_mode with sub-stage timing
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def get_latent_z(model, videos):
|
||||||
|
b, c, t, h, w = videos.shape
|
||||||
|
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||||
|
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||||
|
x = x.to(dtype=vae_dtype)
|
||||||
|
z = model.encode_first_stage(x)
|
||||||
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(video, filename, fps=8):
|
||||||
|
video = video.detach().cpu()
|
||||||
|
video = torch.clamp(video.float(), -1., 1.)
|
||||||
|
n = video.shape[0]
|
||||||
|
video = video.permute(2, 0, 1, 3, 4)
|
||||||
|
frame_grids = [
|
||||||
|
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
||||||
|
for framesheet in video
|
||||||
|
]
|
||||||
|
grid = torch.stack(frame_grids, dim=0)
|
||||||
|
grid = (grid + 1.0) / 2.0
|
||||||
|
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
||||||
|
torchvision.io.write_video(filename, grid, fps=fps,
|
||||||
|
video_codec='h264', options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
|
def profiled_synthesis(model, prompts, observation, noise_shape,
|
||||||
|
ddim_steps, ddim_eta, unconditional_guidance_scale,
|
||||||
|
fs, text_input, timestep_spacing, guidance_rescale,
|
||||||
|
sim_mode, decode_video, records, prefix):
|
||||||
|
"""image_guided_synthesis_sim_mode with per-sub-stage CUDA event timing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix: "policy" or "wm" — prepended to sub-stage names in records.
|
||||||
|
"""
|
||||||
|
b, _, t, _, _ = noise_shape
|
||||||
|
batch_size = noise_shape[0]
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
# --- sub-stage: ddim_sampler_init ---
|
||||||
|
with CudaTimer(f"{prefix}/ddim_sampler_init", records):
|
||||||
|
ddim_sampler = DDIMSampler(model)
|
||||||
|
fs_t = torch.tensor([fs] * batch_size, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# --- sub-stage: image_embedding ---
|
||||||
|
with CudaTimer(f"{prefix}/image_embedding", records):
|
||||||
|
model_dtype = next(model.embedder.parameters()).dtype
|
||||||
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||||
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
||||||
|
cond_img_emb = model.embedder(cond_img)
|
||||||
|
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||||
|
|
||||||
|
# --- sub-stage: vae_encode ---
|
||||||
|
with CudaTimer(f"{prefix}/vae_encode", records):
|
||||||
|
if model.model.conditioning_key == 'hybrid':
|
||||||
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||||
|
img_cat_cond = z[:, :, -1:, :, :]
|
||||||
|
img_cat_cond = repeat(img_cat_cond,
|
||||||
|
'b c t h w -> b c (repeat t) h w',
|
||||||
|
repeat=noise_shape[2])
|
||||||
|
cond = {"c_concat": [img_cat_cond]}
|
||||||
|
|
||||||
|
# --- sub-stage: text_conditioning ---
|
||||||
|
with CudaTimer(f"{prefix}/text_conditioning", records):
|
||||||
|
if not text_input:
|
||||||
|
prompts_use = [""] * batch_size
|
||||||
|
else:
|
||||||
|
prompts_use = prompts if isinstance(prompts, list) else [prompts] * batch_size
|
||||||
|
cond_ins_emb = model.get_learned_conditioning(prompts_use)
|
||||||
|
|
||||||
|
# --- sub-stage: projectors ---
|
||||||
|
with CudaTimer(f"{prefix}/projectors", records):
|
||||||
|
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||||
|
cond_state_emb = model.state_projector(
|
||||||
|
observation['observation.state'].to(dtype=projector_dtype))
|
||||||
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||||
|
cond_action_emb = model.action_projector(
|
||||||
|
observation['action'].to(dtype=projector_dtype))
|
||||||
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||||
|
if not sim_mode:
|
||||||
|
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||||
|
|
||||||
|
# --- sub-stage: cond_assembly ---
|
||||||
|
with CudaTimer(f"{prefix}/cond_assembly", records):
|
||||||
|
cond["c_crossattn"] = [
|
||||||
|
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb], dim=1)
|
||||||
|
]
|
||||||
|
cond["c_crossattn_action"] = [
|
||||||
|
observation['observation.images.top'][:, :, -model.n_obs_steps_acting:],
|
||||||
|
observation['observation.state'][:, -model.n_obs_steps_acting:],
|
||||||
|
sim_mode,
|
||||||
|
False,
|
||||||
|
]
|
||||||
|
|
||||||
|
# --- sub-stage: ddim_sampling ---
|
||||||
|
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
||||||
|
if autocast_dtype is not None and device.type == 'cuda':
|
||||||
|
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
||||||
|
else:
|
||||||
|
autocast_ctx = nullcontext()
|
||||||
|
|
||||||
|
with CudaTimer(f"{prefix}/ddim_sampling", records):
|
||||||
|
with autocast_ctx:
|
||||||
|
samples, actions, states, _ = ddim_sampler.sample(
|
||||||
|
S=ddim_steps,
|
||||||
|
conditioning=cond,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shape=noise_shape[1:],
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||||
|
unconditional_conditioning=None,
|
||||||
|
eta=ddim_eta,
|
||||||
|
cfg_img=None,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
fs=fs_t,
|
||||||
|
timestep_spacing=timestep_spacing,
|
||||||
|
guidance_rescale=guidance_rescale,
|
||||||
|
unconditional_conditioning_img_nonetext=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- sub-stage: vae_decode ---
|
||||||
|
batch_variants = None
|
||||||
|
if decode_video:
|
||||||
|
with CudaTimer(f"{prefix}/vae_decode", records):
|
||||||
|
batch_variants = model.decode_first_stage(samples)
|
||||||
|
else:
|
||||||
|
records[f"{prefix}/vae_decode"].append(0.0)
|
||||||
|
|
||||||
|
return batch_variants, actions, states
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Instrumented iteration loop
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def run_profiled_iterations(model, args, config, noise_shape, device):
|
||||||
|
"""Run the full iteration loop with per-stage timing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
all_records: list of dicts, one per itr, {stage_name: ms}
|
||||||
|
"""
|
||||||
|
# Load data
|
||||||
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
sample = df.iloc[0]
|
||||||
|
|
||||||
|
data_module = instantiate_from_config(config.data)
|
||||||
|
data_module.setup()
|
||||||
|
|
||||||
|
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
||||||
|
ori_fps = float(sample['fps'])
|
||||||
|
fs = args.frame_stride
|
||||||
|
model_input_fs = ori_fps // fs
|
||||||
|
|
||||||
|
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||||
|
with h5py.File(transition_path, 'r') as h5f:
|
||||||
|
transition_dict = {}
|
||||||
|
for key in h5f.keys():
|
||||||
|
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||||
|
for key in h5f.attrs.keys():
|
||||||
|
transition_dict[key] = h5f.attrs[key]
|
||||||
|
|
||||||
|
# Prepare initial observation
|
||||||
|
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||||
|
0, init_frame_path, transition_dict, fs,
|
||||||
|
data_module.test_datasets[args.dataset],
|
||||||
|
n_obs_steps=model.n_obs_steps_imagen)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
'observation.images.top':
|
||||||
|
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
||||||
|
'observation.state':
|
||||||
|
batch['observation.state'][-1].unsqueeze(0),
|
||||||
|
'action':
|
||||||
|
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
||||||
|
}
|
||||||
|
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||||
|
|
||||||
|
cond_obs_queues = {
|
||||||
|
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
||||||
|
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
||||||
|
"action": deque(maxlen=args.video_length),
|
||||||
|
}
|
||||||
|
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||||
|
|
||||||
|
# Temp dir for save_results profiling
|
||||||
|
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
||||||
|
os.makedirs(tmp_dir, exist_ok=True)
|
||||||
|
|
||||||
|
prompt_text = sample['instruction']
|
||||||
|
all_records = []
|
||||||
|
|
||||||
|
print(f">>> Running {args.n_iter} profiled iterations ...")
|
||||||
|
for itr in range(args.n_iter):
|
||||||
|
rec = defaultdict(list)
|
||||||
|
|
||||||
|
# ── itr_total start ──
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
itr_start = torch.cuda.Event(enable_timing=True)
|
||||||
|
itr_end = torch.cuda.Event(enable_timing=True)
|
||||||
|
itr_start.record()
|
||||||
|
|
||||||
|
# ① stack_to_device_1
|
||||||
|
with CudaTimer("stack_to_device_1", rec):
|
||||||
|
observation = {
|
||||||
|
'observation.images.top':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||||
|
dim=1).permute(0, 2, 1, 3, 4),
|
||||||
|
'observation.state':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||||
|
'action':
|
||||||
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
|
}
|
||||||
|
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||||
|
|
||||||
|
# ② synth_policy
|
||||||
|
with CudaTimer("synth_policy", rec):
|
||||||
|
pred_videos_0, pred_actions, _ = profiled_synthesis(
|
||||||
|
model, prompt_text, observation, noise_shape,
|
||||||
|
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||||
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||||
|
fs=model_input_fs, text_input=True,
|
||||||
|
timestep_spacing=args.timestep_spacing,
|
||||||
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
sim_mode=False,
|
||||||
|
decode_video=not args.fast_policy_no_decode,
|
||||||
|
records=rec, prefix="policy")
|
||||||
|
|
||||||
|
# ③ update_action_queue
|
||||||
|
with WallTimer("update_action_queue", rec):
|
||||||
|
for idx in range(len(pred_actions[0])):
|
||||||
|
obs_a = {'action': pred_actions[0][idx:idx + 1]}
|
||||||
|
obs_a['action'][:, ori_action_dim:] = 0.0
|
||||||
|
cond_obs_queues = populate_queues(cond_obs_queues, obs_a)
|
||||||
|
|
||||||
|
# ④ stack_to_device_2
|
||||||
|
with CudaTimer("stack_to_device_2", rec):
|
||||||
|
observation = {
|
||||||
|
'observation.images.top':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||||
|
dim=1).permute(0, 2, 1, 3, 4),
|
||||||
|
'observation.state':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||||
|
'action':
|
||||||
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
|
}
|
||||||
|
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||||
|
|
||||||
|
# ⑤ synth_world_model
|
||||||
|
with CudaTimer("synth_world_model", rec):
|
||||||
|
pred_videos_1, _, pred_states = profiled_synthesis(
|
||||||
|
model, "", observation, noise_shape,
|
||||||
|
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||||
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||||
|
fs=model_input_fs, text_input=False,
|
||||||
|
timestep_spacing=args.timestep_spacing,
|
||||||
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
sim_mode=True, decode_video=True,
|
||||||
|
records=rec, prefix="wm")
|
||||||
|
|
||||||
|
# ⑥ update_obs_queue
|
||||||
|
with WallTimer("update_obs_queue", rec):
|
||||||
|
for idx in range(args.exe_steps):
|
||||||
|
obs_u = {
|
||||||
|
'observation.images.top':
|
||||||
|
pred_videos_1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||||
|
'observation.state':
|
||||||
|
pred_states[0][idx:idx + 1],
|
||||||
|
'action':
|
||||||
|
torch.zeros_like(pred_actions[0][-1:]),
|
||||||
|
}
|
||||||
|
obs_u['observation.state'][:, ori_state_dim:] = 0.0
|
||||||
|
cond_obs_queues = populate_queues(cond_obs_queues, obs_u)
|
||||||
|
|
||||||
|
# ⑦ tensorboard_log (simulate — no actual writer, measure make_grid cost)
|
||||||
|
with WallTimer("tensorboard_log", rec):
|
||||||
|
for vid in [pred_videos_0, pred_videos_1]:
|
||||||
|
if vid is not None and vid.dim() == 5:
|
||||||
|
v = vid.permute(2, 0, 1, 3, 4)
|
||||||
|
grids = [torchvision.utils.make_grid(f, nrow=1, padding=0) for f in v]
|
||||||
|
_ = torch.stack(grids, dim=0)
|
||||||
|
|
||||||
|
# ⑧ save_results
|
||||||
|
with WallTimer("save_results", rec):
|
||||||
|
if pred_videos_0 is not None:
|
||||||
|
save_results(pred_videos_0.cpu(),
|
||||||
|
os.path.join(tmp_dir, f"dm_{itr}.mp4"),
|
||||||
|
fps=args.save_fps)
|
||||||
|
save_results(pred_videos_1.cpu(),
|
||||||
|
os.path.join(tmp_dir, f"wm_{itr}.mp4"),
|
||||||
|
fps=args.save_fps)
|
||||||
|
|
||||||
|
# ⑨ cpu_transfer
|
||||||
|
with CudaTimer("cpu_transfer", rec):
|
||||||
|
_ = pred_videos_1[:, :, :args.exe_steps].cpu()
|
||||||
|
|
||||||
|
# ── itr_total end ──
|
||||||
|
itr_end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
itr_total_ms = itr_start.elapsed_time(itr_end)
|
||||||
|
rec["itr_total"].append(itr_total_ms)
|
||||||
|
|
||||||
|
# Flatten: each stage has exactly one entry per itr
|
||||||
|
itr_rec = {k: v[0] for k, v in rec.items()}
|
||||||
|
all_records.append(itr_rec)
|
||||||
|
|
||||||
|
# Print live progress
|
||||||
|
print(f" itr {itr}: {itr_total_ms:.0f} ms total | "
|
||||||
|
f"policy={itr_rec.get('synth_policy', 0):.0f} | "
|
||||||
|
f"wm={itr_rec.get('synth_world_model', 0):.0f} | "
|
||||||
|
f"save={itr_rec.get('save_results', 0):.0f} | "
|
||||||
|
f"tb={itr_rec.get('tensorboard_log', 0):.0f}")
|
||||||
|
|
||||||
|
return all_records
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Layer 1: Console report
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def print_iteration_report(all_records, warmup=1):
|
||||||
|
"""Print a structured table of per-stage timing across iterations."""
|
||||||
|
if len(all_records) <= warmup:
|
||||||
|
records = all_records
|
||||||
|
else:
|
||||||
|
records = all_records[warmup:]
|
||||||
|
print(f"\n(Skipping first {warmup} itr(s) as warmup)\n")
|
||||||
|
|
||||||
|
# Collect all stage keys in a stable order
|
||||||
|
all_keys = []
|
||||||
|
seen = set()
|
||||||
|
for rec in records:
|
||||||
|
for k in rec:
|
||||||
|
if k not in seen:
|
||||||
|
all_keys.append(k)
|
||||||
|
seen.add(k)
|
||||||
|
|
||||||
|
# Separate top-level stages from sub-stages
|
||||||
|
top_keys = [k for k in all_keys if '/' not in k]
|
||||||
|
sub_keys = [k for k in all_keys if '/' in k]
|
||||||
|
|
||||||
|
def _print_table(keys, title):
|
||||||
|
if not keys:
|
||||||
|
return
|
||||||
|
print("=" * 82)
|
||||||
|
print(title)
|
||||||
|
print("=" * 82)
|
||||||
|
print(f"{'Stage':<35} {'Mean(ms)':>10} {'Std':>8} {'Min':>10} {'Max':>10} {'%':>7}")
|
||||||
|
print("-" * 82)
|
||||||
|
|
||||||
|
total_mean = np.mean([rec.get("itr_total", 0) for rec in records])
|
||||||
|
for k in keys:
|
||||||
|
vals = [rec.get(k, 0) for rec in records]
|
||||||
|
mean = np.mean(vals)
|
||||||
|
std = np.std(vals)
|
||||||
|
mn = np.min(vals)
|
||||||
|
mx = np.max(vals)
|
||||||
|
pct = mean / total_mean * 100 if total_mean > 0 else 0
|
||||||
|
print(f"{k:<35} {mean:>10.1f} {std:>8.1f} {mn:>10.1f} {mx:>10.1f} {pct:>6.1f}%")
|
||||||
|
print("-" * 82)
|
||||||
|
print()
|
||||||
|
|
||||||
|
_print_table(top_keys, "TABLE 1: ITERATION-LEVEL BREAKDOWN")
|
||||||
|
_print_table(sub_keys, "TABLE 2: SYNTHESIS SUB-STAGE BREAKDOWN")
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Layer 3: CSV output for A/B comparison
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def write_csv(all_records, csv_path, warmup=1):
|
||||||
|
"""Write per-iteration timing to CSV for later comparison."""
|
||||||
|
records = all_records[warmup:] if len(all_records) > warmup else all_records
|
||||||
|
|
||||||
|
# Collect all keys
|
||||||
|
all_keys = []
|
||||||
|
seen = set()
|
||||||
|
for rec in records:
|
||||||
|
for k in rec:
|
||||||
|
if k not in seen:
|
||||||
|
all_keys.append(k)
|
||||||
|
seen.add(k)
|
||||||
|
|
||||||
|
with open(csv_path, 'w', newline='') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=['itr'] + all_keys)
|
||||||
|
writer.writeheader()
|
||||||
|
for i, rec in enumerate(records):
|
||||||
|
row = {'itr': i}
|
||||||
|
row.update({k: f"{rec.get(k, 0):.2f}" for k in all_keys})
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
# Also write a summary row
|
||||||
|
summary_path = csv_path.replace('.csv', '_summary.csv')
|
||||||
|
with open(summary_path, 'w', newline='') as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=['stat'] + all_keys)
|
||||||
|
writer.writeheader()
|
||||||
|
for stat_name, stat_fn in [('mean', np.mean), ('std', np.std),
|
||||||
|
('min', np.min), ('max', np.max)]:
|
||||||
|
row = {'stat': stat_name}
|
||||||
|
row.update({k: f"{stat_fn([r.get(k, 0) for r in records]):.2f}"
|
||||||
|
for k in all_keys})
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
print(f">>> CSV written to: {csv_path}")
|
||||||
|
print(f">>> Summary written to: {summary_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def compare_csvs(path_a, path_b):
|
||||||
|
"""Compare two summary CSVs and print a diff table."""
|
||||||
|
df_a = pd.read_csv(path_a, index_col='stat')
|
||||||
|
df_b = pd.read_csv(path_b, index_col='stat')
|
||||||
|
|
||||||
|
# Use mean row for comparison
|
||||||
|
mean_a = df_a.loc['mean'].astype(float)
|
||||||
|
mean_b = df_b.loc['mean'].astype(float)
|
||||||
|
|
||||||
|
print("=" * 90)
|
||||||
|
print(f"A/B COMPARISON: {os.path.basename(path_a)} vs {os.path.basename(path_b)}")
|
||||||
|
print("=" * 90)
|
||||||
|
print(f"{'Stage':<35} {'A(ms)':>10} {'B(ms)':>10} {'Diff':>10} {'Speedup':>10}")
|
||||||
|
print("-" * 90)
|
||||||
|
|
||||||
|
for col in mean_a.index:
|
||||||
|
if col not in mean_b.index:
|
||||||
|
continue
|
||||||
|
a_val = mean_a[col]
|
||||||
|
b_val = mean_b[col]
|
||||||
|
diff = b_val - a_val
|
||||||
|
speedup = a_val / b_val if b_val > 0 else float('inf')
|
||||||
|
marker = " <<<" if abs(diff) > 50 else ""
|
||||||
|
print(f"{col:<35} {a_val:>10.1f} {b_val:>10.1f} {diff:>+10.1f} {speedup:>9.2f}x{marker}")
|
||||||
|
|
||||||
|
print("-" * 90)
|
||||||
|
total_a = mean_a.get('itr_total', 0)
|
||||||
|
total_b = mean_b.get('itr_total', 0)
|
||||||
|
print(f"{'itr_total':<35} {total_a:>10.1f} {total_b:>10.1f} "
|
||||||
|
f"{total_b - total_a:>+10.1f} {total_a / total_b if total_b > 0 else 0:>9.2f}x")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Layer 2: GPU timeline trace wrapper
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def run_with_trace(model, args, config, noise_shape, device):
|
||||||
|
"""Run iterations under torch.profiler to generate Chrome/TensorBoard traces."""
|
||||||
|
trace_dir = args.trace_dir
|
||||||
|
os.makedirs(trace_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# We need the same data setup as run_profiled_iterations
|
||||||
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
|
df = pd.read_csv(csv_path)
|
||||||
|
sample = df.iloc[0]
|
||||||
|
|
||||||
|
data_module = instantiate_from_config(config.data)
|
||||||
|
data_module.setup()
|
||||||
|
|
||||||
|
init_frame_path = get_init_frame_path(args.prompt_dir, sample)
|
||||||
|
ori_fps = float(sample['fps'])
|
||||||
|
fs = args.frame_stride
|
||||||
|
model_input_fs = ori_fps // fs
|
||||||
|
|
||||||
|
transition_path = get_transition_path(args.prompt_dir, sample)
|
||||||
|
with h5py.File(transition_path, 'r') as h5f:
|
||||||
|
transition_dict = {}
|
||||||
|
for key in h5f.keys():
|
||||||
|
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||||
|
for key in h5f.attrs.keys():
|
||||||
|
transition_dict[key] = h5f.attrs[key]
|
||||||
|
|
||||||
|
batch, ori_state_dim, ori_action_dim = prepare_init_input(
|
||||||
|
0, init_frame_path, transition_dict, fs,
|
||||||
|
data_module.test_datasets[args.dataset],
|
||||||
|
n_obs_steps=model.n_obs_steps_imagen)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
'observation.images.top':
|
||||||
|
batch['observation.image'].permute(1, 0, 2, 3)[-1].unsqueeze(0),
|
||||||
|
'observation.state':
|
||||||
|
batch['observation.state'][-1].unsqueeze(0),
|
||||||
|
'action':
|
||||||
|
torch.zeros_like(batch['action'][-1]).unsqueeze(0),
|
||||||
|
}
|
||||||
|
observation = {k: v.to(device, non_blocking=True) for k, v in observation.items()}
|
||||||
|
|
||||||
|
cond_obs_queues = {
|
||||||
|
"observation.images.top": deque(maxlen=model.n_obs_steps_imagen),
|
||||||
|
"observation.state": deque(maxlen=model.n_obs_steps_imagen),
|
||||||
|
"action": deque(maxlen=args.video_length),
|
||||||
|
}
|
||||||
|
cond_obs_queues = populate_queues(cond_obs_queues, observation)
|
||||||
|
|
||||||
|
tmp_dir = os.path.join(args.savedir, "profile_tmp")
|
||||||
|
os.makedirs(tmp_dir, exist_ok=True)
|
||||||
|
prompt_text = sample['instruction']
|
||||||
|
|
||||||
|
# Total iterations: warmup + active
|
||||||
|
n_warmup = 1
|
||||||
|
n_active = min(args.n_iter, 2) # trace 2 active iterations max
|
||||||
|
n_total = n_warmup + n_active
|
||||||
|
|
||||||
|
print(f">>> GPU trace: {n_warmup} warmup + {n_active} active iterations")
|
||||||
|
print(f">>> Trace output: {trace_dir}")
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
schedule=torch.profiler.schedule(
|
||||||
|
wait=0, warmup=n_warmup, active=n_active, repeat=1),
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir),
|
||||||
|
record_shapes=True,
|
||||||
|
with_stack=True,
|
||||||
|
) as prof:
|
||||||
|
for itr_idx in range(n_total):
|
||||||
|
phase = "warmup" if itr_idx < n_warmup else "active"
|
||||||
|
print(f" trace itr {itr_idx} ({phase})...")
|
||||||
|
|
||||||
|
# ── One full iteration (same logic as run_inference) ──
|
||||||
|
obs_loc = {
|
||||||
|
'observation.images.top':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||||
|
dim=1).permute(0, 2, 1, 3, 4),
|
||||||
|
'observation.state':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||||
|
'action':
|
||||||
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
|
}
|
||||||
|
obs_loc = {k: v.to(device) for k, v in obs_loc.items()}
|
||||||
|
|
||||||
|
# Policy pass
|
||||||
|
dummy_rec = defaultdict(list)
|
||||||
|
pv0, pa, _ = profiled_synthesis(
|
||||||
|
model, prompt_text, obs_loc, noise_shape,
|
||||||
|
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||||
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||||
|
fs=model_input_fs, text_input=True,
|
||||||
|
timestep_spacing=args.timestep_spacing,
|
||||||
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
sim_mode=False,
|
||||||
|
decode_video=not args.fast_policy_no_decode,
|
||||||
|
records=dummy_rec, prefix="policy")
|
||||||
|
|
||||||
|
for idx in range(len(pa[0])):
|
||||||
|
oa = {'action': pa[0][idx:idx + 1]}
|
||||||
|
oa['action'][:, ori_action_dim:] = 0.0
|
||||||
|
populate_queues(cond_obs_queues, oa)
|
||||||
|
|
||||||
|
# Re-stack for world model
|
||||||
|
obs_loc2 = {
|
||||||
|
'observation.images.top':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.images.top']),
|
||||||
|
dim=1).permute(0, 2, 1, 3, 4),
|
||||||
|
'observation.state':
|
||||||
|
torch.stack(list(cond_obs_queues['observation.state']), dim=1),
|
||||||
|
'action':
|
||||||
|
torch.stack(list(cond_obs_queues['action']), dim=1),
|
||||||
|
}
|
||||||
|
obs_loc2 = {k: v.to(device) for k, v in obs_loc2.items()}
|
||||||
|
|
||||||
|
# World model pass
|
||||||
|
pv1, _, ps = profiled_synthesis(
|
||||||
|
model, "", obs_loc2, noise_shape,
|
||||||
|
ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
|
||||||
|
unconditional_guidance_scale=args.unconditional_guidance_scale,
|
||||||
|
fs=model_input_fs, text_input=False,
|
||||||
|
timestep_spacing=args.timestep_spacing,
|
||||||
|
guidance_rescale=args.guidance_rescale,
|
||||||
|
sim_mode=True, decode_video=True,
|
||||||
|
records=dummy_rec, prefix="wm")
|
||||||
|
|
||||||
|
# Update obs queue
|
||||||
|
for idx in range(args.exe_steps):
|
||||||
|
ou = {
|
||||||
|
'observation.images.top':
|
||||||
|
pv1[0][:, idx:idx + 1].permute(1, 0, 2, 3),
|
||||||
|
'observation.state': ps[0][idx:idx + 1],
|
||||||
|
'action': torch.zeros_like(pa[0][-1:]),
|
||||||
|
}
|
||||||
|
ou['observation.state'][:, ori_state_dim:] = 0.0
|
||||||
|
populate_queues(cond_obs_queues, ou)
|
||||||
|
|
||||||
|
# Save results (captures CPU stall in trace)
|
||||||
|
if pv0 is not None:
|
||||||
|
save_results(pv0.cpu(),
|
||||||
|
os.path.join(tmp_dir, f"trace_dm_{itr_idx}.mp4"),
|
||||||
|
fps=args.save_fps)
|
||||||
|
save_results(pv1.cpu(),
|
||||||
|
os.path.join(tmp_dir, f"trace_wm_{itr_idx}.mp4"),
|
||||||
|
fps=args.save_fps)
|
||||||
|
|
||||||
|
prof.step()
|
||||||
|
|
||||||
|
print(f">>> Trace saved to {trace_dir}")
|
||||||
|
print(" View with: tensorboard --logdir", trace_dir)
|
||||||
|
print(" Or open the .json file in chrome://tracing")
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Argument parser
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def get_parser():
|
||||||
|
p = argparse.ArgumentParser(description="Profile full iteration loop")
|
||||||
|
|
||||||
|
# Compare mode (no model needed)
|
||||||
|
p.add_argument("--compare", nargs=2, metavar=("A_SUMMARY", "B_SUMMARY"),
|
||||||
|
help="Compare two summary CSVs and exit")
|
||||||
|
|
||||||
|
# Model / data
|
||||||
|
p.add_argument("--ckpt_path", type=str, default=None)
|
||||||
|
p.add_argument("--config", type=str, default=None)
|
||||||
|
p.add_argument("--prompt_dir", type=str, default=None)
|
||||||
|
p.add_argument("--dataset", type=str, default=None)
|
||||||
|
p.add_argument("--savedir", type=str, default="profile_output")
|
||||||
|
|
||||||
|
# Inference params (match world_model_interaction.py)
|
||||||
|
p.add_argument("--ddim_steps", type=int, default=50)
|
||||||
|
p.add_argument("--ddim_eta", type=float, default=1.0)
|
||||||
|
p.add_argument("--bs", type=int, default=1)
|
||||||
|
p.add_argument("--height", type=int, default=320)
|
||||||
|
p.add_argument("--width", type=int, default=512)
|
||||||
|
p.add_argument("--frame_stride", type=int, default=4)
|
||||||
|
p.add_argument("--unconditional_guidance_scale", type=float, default=1.0)
|
||||||
|
p.add_argument("--video_length", type=int, default=16)
|
||||||
|
p.add_argument("--timestep_spacing", type=str, default="uniform_trailing")
|
||||||
|
p.add_argument("--guidance_rescale", type=float, default=0.7)
|
||||||
|
p.add_argument("--exe_steps", type=int, default=16)
|
||||||
|
p.add_argument("--n_iter", type=int, default=5)
|
||||||
|
p.add_argument("--save_fps", type=int, default=8)
|
||||||
|
p.add_argument("--seed", type=int, default=123)
|
||||||
|
p.add_argument("--perframe_ae", action='store_true', default=False)
|
||||||
|
p.add_argument("--vae_dtype", type=str, choices=["fp32", "bf16"], default="bf16")
|
||||||
|
p.add_argument("--fast_policy_no_decode", action='store_true', default=False)
|
||||||
|
|
||||||
|
# Profiling control
|
||||||
|
p.add_argument("--warmup", type=int, default=1,
|
||||||
|
help="Number of warmup iterations to skip in statistics")
|
||||||
|
p.add_argument("--csv", type=str, default=None,
|
||||||
|
help="Write per-iteration timing to this CSV file")
|
||||||
|
p.add_argument("--trace", action='store_true', default=False,
|
||||||
|
help="Enable Layer 2: GPU timeline trace")
|
||||||
|
p.add_argument("--trace_dir", type=str, default="./profile_traces",
|
||||||
|
help="Directory for trace output")
|
||||||
|
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Main
|
||||||
|
# ──────────────────────────────────────────────────────────────────────
|
||||||
|
def main():
|
||||||
|
patch_norm_bypass_autocast()
|
||||||
|
parser = get_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# ── Compare mode: no model needed ──
|
||||||
|
if args.compare:
|
||||||
|
compare_csvs(args.compare[0], args.compare[1])
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Validate required args ──
|
||||||
|
for required in ['ckpt_path', 'config', 'prompt_dir', 'dataset']:
|
||||||
|
if getattr(args, required) is None:
|
||||||
|
parser.error(f"--{required} is required for profiling mode")
|
||||||
|
|
||||||
|
seed_everything(args.seed)
|
||||||
|
os.makedirs(args.savedir, exist_ok=True)
|
||||||
|
|
||||||
|
# ── Load model ──
|
||||||
|
print("=" * 60)
|
||||||
|
print("PROFILE ITERATION — Loading model...")
|
||||||
|
print("=" * 60)
|
||||||
|
model, config = load_model(args)
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
h, w = args.height // 8, args.width // 8
|
||||||
|
channels = model.model.diffusion_model.out_channels
|
||||||
|
noise_shape = [args.bs, channels, args.video_length, h, w]
|
||||||
|
print(f">>> Noise shape: {noise_shape}")
|
||||||
|
print(f">>> DDIM steps: {args.ddim_steps}")
|
||||||
|
print(f">>> fast_policy_no_decode: {args.fast_policy_no_decode}")
|
||||||
|
|
||||||
|
# ── Layer 2: GPU trace (optional) ──
|
||||||
|
if args.trace:
|
||||||
|
with torch.no_grad():
|
||||||
|
run_with_trace(model, args, config, noise_shape, device)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# ── Layer 1: Iteration-level breakdown ──
|
||||||
|
print("=" * 60)
|
||||||
|
print("LAYER 1: ITERATION-LEVEL PROFILING")
|
||||||
|
print("=" * 60)
|
||||||
|
with torch.no_grad():
|
||||||
|
all_records = run_profiled_iterations(
|
||||||
|
model, args, config, noise_shape, device)
|
||||||
|
|
||||||
|
# Print report
|
||||||
|
print_iteration_report(all_records, warmup=args.warmup)
|
||||||
|
|
||||||
|
# ── Layer 3: CSV output for A/B comparison ──
|
||||||
|
if args.csv:
|
||||||
|
write_csv(all_records, args.csv, warmup=args.warmup)
|
||||||
|
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
733
scripts/evaluation/profile_pipeline.py
Normal file
733
scripts/evaluation/profile_pipeline.py
Normal file
@@ -0,0 +1,733 @@
|
|||||||
|
"""
|
||||||
|
Profile the full inference pipeline of the world model, covering all 7 stages:
|
||||||
|
1. Image Embedding
|
||||||
|
2. VAE Encode
|
||||||
|
3. Text Conditioning
|
||||||
|
4. State/Action Projectors
|
||||||
|
5. DDIM Loop
|
||||||
|
6. VAE Decode
|
||||||
|
7. Post-process
|
||||||
|
|
||||||
|
Reports stage-level timing, UNet sub-module breakdown, memory summary,
|
||||||
|
and throughput analysis.
|
||||||
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3 --deep
|
||||||
|
Usage:
|
||||||
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_pipeline.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --ddim_steps 50 --cfg_scale 1.0 --n_runs 3
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.utils.checkpoint # must be loaded before unifolm_wma.utils.common
|
||||||
|
from contextlib import nullcontext, contextmanager
|
||||||
|
from collections import defaultdict
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
|
from unifolm_wma.modules.attention import (
|
||||||
|
SpatialTransformer, TemporalTransformer,
|
||||||
|
BasicTransformerBlock, CrossAttention, FeedForward,
|
||||||
|
)
|
||||||
|
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||||
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||||
|
|
||||||
|
# --- W7900D theoretical peak ---
|
||||||
|
PEAK_BF16_TFLOPS = 61.0
|
||||||
|
MEM_BW_GBS = 864.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Utility: patch norms to bypass autocast fp32 promotion
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def patch_norm_bypass_autocast():
|
||||||
|
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
||||||
|
|
||||||
|
def _group_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.group_norm(
|
||||||
|
x, self.num_groups,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
def _layer_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.layer_norm(
|
||||||
|
x, self.normalized_shape,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||||
|
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Utility: torch.compile hot ResBlocks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||||
|
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||||
|
unet = model.model.diffusion_model
|
||||||
|
compiled = 0
|
||||||
|
for idx in hot_indices:
|
||||||
|
block = unet.output_blocks[idx]
|
||||||
|
for layer in block:
|
||||||
|
if isinstance(layer, ResBlock):
|
||||||
|
layer._forward = torch.compile(layer._forward, mode="default")
|
||||||
|
compiled += 1
|
||||||
|
print(f" torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Model loading
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def load_model(args):
|
||||||
|
config = OmegaConf.load(args.config)
|
||||||
|
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
|
||||||
|
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||||
|
if "state_dict" in state_dict:
|
||||||
|
state_dict = state_dict["state_dict"]
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.model.to(torch.bfloat16)
|
||||||
|
model.diffusion_autocast_dtype = torch.bfloat16
|
||||||
|
apply_torch_compile(model)
|
||||||
|
model = model.cuda()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CudaTimer — precise GPU timing via CUDA events
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class CudaTimer:
|
||||||
|
"""Context manager for GPU-precise stage timing using CUDA events."""
|
||||||
|
|
||||||
|
def __init__(self, name, records):
|
||||||
|
self.name = name
|
||||||
|
self.records = records
|
||||||
|
self.start = torch.cuda.Event(enable_timing=True)
|
||||||
|
self.end = torch.cuda.Event(enable_timing=True)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.start.record()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.end.record()
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elapsed = self.start.elapsed_time(self.end)
|
||||||
|
self.records[self.name].append(elapsed)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HookProfiler — sub-module level timing inside UNet via hooks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
class HookProfiler:
|
||||||
|
"""Register forward hooks on UNet sub-modules to collect per-call timing."""
|
||||||
|
|
||||||
|
# Coarse-grained targets (original)
|
||||||
|
COARSE_CLASSES = (
|
||||||
|
SpatialTransformer,
|
||||||
|
TemporalTransformer,
|
||||||
|
ResBlock,
|
||||||
|
ConditionalUnet1D,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fine-grained targets for deep DDIM analysis
|
||||||
|
FINE_CLASSES = (
|
||||||
|
CrossAttention,
|
||||||
|
FeedForward,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, unet, deep=False):
|
||||||
|
self.unet = unet
|
||||||
|
self.deep = deep
|
||||||
|
self.handles = []
|
||||||
|
# per-instance data: {instance_id: [(start_event, end_event), ...]}
|
||||||
|
self._events = defaultdict(list)
|
||||||
|
# tag mapping: {instance_id: (class_name, module_name)}
|
||||||
|
self._tags = {}
|
||||||
|
# block location: {instance_id: block_location_str}
|
||||||
|
self._block_loc = {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_block_location(name):
|
||||||
|
"""Derive UNet block location from module name, e.g. 'input_blocks.3.1'."""
|
||||||
|
parts = name.split('.')
|
||||||
|
if len(parts) >= 2 and parts[0] == 'input_blocks':
|
||||||
|
return f"input_blocks.{parts[1]}"
|
||||||
|
elif len(parts) >= 1 and parts[0] == 'middle_block':
|
||||||
|
return "middle_block"
|
||||||
|
elif len(parts) >= 2 and parts[0] == 'output_blocks':
|
||||||
|
return f"output_blocks.{parts[1]}"
|
||||||
|
elif 'action_unet' in name:
|
||||||
|
return "action_unet"
|
||||||
|
elif 'state_unet' in name:
|
||||||
|
return "state_unet"
|
||||||
|
elif name == 'out' or name.startswith('out.'):
|
||||||
|
return "out"
|
||||||
|
return "other"
|
||||||
|
|
||||||
|
def register(self):
|
||||||
|
"""Attach pre/post forward hooks to target sub-modules + unet.out."""
|
||||||
|
target_classes = self.COARSE_CLASSES
|
||||||
|
if self.deep:
|
||||||
|
target_classes = target_classes + self.FINE_CLASSES
|
||||||
|
|
||||||
|
for name, mod in self.unet.named_modules():
|
||||||
|
if isinstance(mod, target_classes):
|
||||||
|
tag = type(mod).__name__
|
||||||
|
inst_id = id(mod)
|
||||||
|
self._tags[inst_id] = (tag, name)
|
||||||
|
self._block_loc[inst_id] = self._get_block_location(name)
|
||||||
|
self.handles.append(
|
||||||
|
mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
|
||||||
|
self.handles.append(
|
||||||
|
mod.register_forward_hook(self._make_post_hook(inst_id)))
|
||||||
|
|
||||||
|
# Also hook unet.out (nn.Sequential)
|
||||||
|
out_mod = self.unet.out
|
||||||
|
inst_id = id(out_mod)
|
||||||
|
self._tags[inst_id] = ("UNet.out", "out")
|
||||||
|
self._block_loc[inst_id] = "out"
|
||||||
|
self.handles.append(
|
||||||
|
out_mod.register_forward_pre_hook(self._make_pre_hook(inst_id)))
|
||||||
|
self.handles.append(
|
||||||
|
out_mod.register_forward_hook(self._make_post_hook(inst_id)))
|
||||||
|
|
||||||
|
def _make_pre_hook(self, inst_id):
|
||||||
|
events = self._events
|
||||||
|
|
||||||
|
def hook(module, input):
|
||||||
|
start = torch.cuda.Event(enable_timing=True)
|
||||||
|
start.record()
|
||||||
|
events[inst_id].append([start, None])
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def _make_post_hook(self, inst_id):
|
||||||
|
events = self._events
|
||||||
|
|
||||||
|
def hook(module, input, output):
|
||||||
|
end = torch.cuda.Event(enable_timing=True)
|
||||||
|
end.record()
|
||||||
|
events[inst_id][-1][1] = end
|
||||||
|
return hook
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Clear collected events for a fresh run."""
|
||||||
|
self._events.clear()
|
||||||
|
|
||||||
|
def synchronize_and_collect(self):
|
||||||
|
"""Sync GPU and compute elapsed times. Returns (by_type, by_instance, by_block)."""
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
by_type = defaultdict(lambda: {"total_ms": 0.0, "count": 0, "calls": []})
|
||||||
|
by_instance = {}
|
||||||
|
# by_block: {block_loc: {tag: {"total_ms", "count"}}}
|
||||||
|
by_block = defaultdict(lambda: defaultdict(lambda: {"total_ms": 0.0, "count": 0}))
|
||||||
|
|
||||||
|
for inst_id, pairs in self._events.items():
|
||||||
|
tag, mod_name = self._tags[inst_id]
|
||||||
|
block_loc = self._block_loc.get(inst_id, "other")
|
||||||
|
inst_times = []
|
||||||
|
for start_evt, end_evt in pairs:
|
||||||
|
if end_evt is not None:
|
||||||
|
ms = start_evt.elapsed_time(end_evt)
|
||||||
|
inst_times.append(ms)
|
||||||
|
by_type[tag]["total_ms"] += ms
|
||||||
|
by_type[tag]["count"] += 1
|
||||||
|
by_type[tag]["calls"].append(ms)
|
||||||
|
by_block[block_loc][tag]["total_ms"] += ms
|
||||||
|
by_block[block_loc][tag]["count"] += 1
|
||||||
|
by_instance[(tag, mod_name)] = inst_times
|
||||||
|
|
||||||
|
return dict(by_type), by_instance, dict(by_block)
|
||||||
|
|
||||||
|
def remove(self):
|
||||||
|
"""Remove all hooks."""
|
||||||
|
for h in self.handles:
|
||||||
|
h.remove()
|
||||||
|
self.handles.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Build dummy inputs matching the pipeline's expected shapes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def build_dummy_inputs(model, noise_shape):
|
||||||
|
"""Create synthetic observation dict and prompts for profiling."""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
B, C, T, H, W = noise_shape
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
# observation.images.top: [B, C, O, H, W] (permuted to [B,O,C,H,W] inside pipeline)
|
||||||
|
O = 2
|
||||||
|
obs_images = torch.randn(B, 3, O, 320, 512, device=device, dtype=dtype)
|
||||||
|
obs_state = torch.randn(B, O, 16, device=device, dtype=dtype)
|
||||||
|
action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
'observation.images.top': obs_images,
|
||||||
|
'observation.state': obs_state,
|
||||||
|
'action': action,
|
||||||
|
}
|
||||||
|
prompts = ["a robot arm performing a task"] * B
|
||||||
|
return observation, prompts
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Run one full pipeline pass with per-stage timing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def run_pipeline(model, observation, prompts, noise_shape, ddim_steps,
|
||||||
|
cfg_scale, hook_profiler):
|
||||||
|
"""Execute the full 7-stage pipeline, returning per-stage timing dict."""
|
||||||
|
records = defaultdict(list)
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
B, C, T, H, W = noise_shape
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
fs = torch.tensor([1] * B, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
# --- Stage 1: Image Embedding ---
|
||||||
|
with CudaTimer("1_Image_Embedding", records):
|
||||||
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||||
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=dtype)
|
||||||
|
with torch.autocast('cuda', dtype=torch.bfloat16):
|
||||||
|
cond_img_emb = model.embedder(cond_img)
|
||||||
|
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||||
|
|
||||||
|
# --- Stage 2: VAE Encode ---
|
||||||
|
with CudaTimer("2_VAE_Encode", records):
|
||||||
|
videos = img.permute(0, 2, 1, 3, 4) # [B, C, O, H, W]
|
||||||
|
b_v, c_v, t_v, h_v, w_v = videos.shape
|
||||||
|
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||||
|
x_vae = rearrange(videos, 'b c t h w -> (b t) c h w').to(dtype=vae_dtype)
|
||||||
|
z = model.encode_first_stage(x_vae)
|
||||||
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b_v, t=t_v)
|
||||||
|
img_cat_cond = z[:, :, -1:, :, :]
|
||||||
|
img_cat_cond = repeat(img_cat_cond,
|
||||||
|
'b c t h w -> b c (repeat t) h w', repeat=T)
|
||||||
|
cond = {"c_concat": [img_cat_cond]}
|
||||||
|
|
||||||
|
vae_enc_input_bytes = x_vae.nelement() * x_vae.element_size()
|
||||||
|
vae_enc_output_bytes = z.nelement() * z.element_size()
|
||||||
|
|
||||||
|
# --- Stage 3: Text Conditioning ---
|
||||||
|
with CudaTimer("3_Text_Conditioning", records):
|
||||||
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
|
# --- Stage 4: State/Action Projectors ---
|
||||||
|
with CudaTimer("4_Projectors", records):
|
||||||
|
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||||
|
with torch.autocast('cuda', dtype=torch.bfloat16):
|
||||||
|
cond_state_emb = model.state_projector(
|
||||||
|
observation['observation.state'].to(dtype=projector_dtype))
|
||||||
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||||
|
|
||||||
|
cond_action_emb = model.action_projector(
|
||||||
|
observation['action'].to(dtype=projector_dtype))
|
||||||
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||||
|
|
||||||
|
# Assemble cross-attention conditioning
|
||||||
|
cond["c_crossattn"] = [
|
||||||
|
torch.cat([cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
|
||||||
|
dim=1)
|
||||||
|
]
|
||||||
|
n_obs_acting = getattr(model, 'n_obs_steps_acting', 2)
|
||||||
|
cond["c_crossattn_action"] = [
|
||||||
|
observation['observation.images.top'][:, :, -n_obs_acting:],
|
||||||
|
observation['observation.state'][:, -n_obs_acting:],
|
||||||
|
True, # sim_mode
|
||||||
|
False,
|
||||||
|
]
|
||||||
|
|
||||||
|
# CFG: build unconditional conditioning if needed
|
||||||
|
uc = None
|
||||||
|
if cfg_scale != 1.0:
|
||||||
|
uc_crossattn = torch.zeros_like(cond["c_crossattn"][0])
|
||||||
|
uc = {
|
||||||
|
"c_concat": cond["c_concat"],
|
||||||
|
"c_crossattn": [uc_crossattn],
|
||||||
|
"c_crossattn_action": cond["c_crossattn_action"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# --- Stage 5: DDIM Loop ---
|
||||||
|
ddim_sampler = DDIMSampler(model)
|
||||||
|
hook_profiler.reset()
|
||||||
|
|
||||||
|
with CudaTimer("5_DDIM_Loop", records):
|
||||||
|
with torch.autocast('cuda', dtype=torch.bfloat16):
|
||||||
|
samples, actions, states, _ = ddim_sampler.sample(
|
||||||
|
S=ddim_steps,
|
||||||
|
conditioning=cond,
|
||||||
|
batch_size=B,
|
||||||
|
shape=noise_shape[1:],
|
||||||
|
verbose=False,
|
||||||
|
unconditional_guidance_scale=cfg_scale,
|
||||||
|
unconditional_conditioning=uc,
|
||||||
|
eta=1.0,
|
||||||
|
cfg_img=None,
|
||||||
|
mask=None,
|
||||||
|
x0=None,
|
||||||
|
fs=fs,
|
||||||
|
timestep_spacing='uniform',
|
||||||
|
guidance_rescale=0.0,
|
||||||
|
unconditional_conditioning_img_nonetext=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
hook_by_type, hook_by_instance, hook_by_block = hook_profiler.synchronize_and_collect()
|
||||||
|
|
||||||
|
# --- Stage 6: VAE Decode ---
|
||||||
|
with CudaTimer("6_VAE_Decode", records):
|
||||||
|
batch_images = model.decode_first_stage(samples)
|
||||||
|
|
||||||
|
vae_dec_input_bytes = samples.nelement() * samples.element_size()
|
||||||
|
vae_dec_output_bytes = batch_images.nelement() * batch_images.element_size()
|
||||||
|
|
||||||
|
# --- Stage 7: Post-process ---
|
||||||
|
with CudaTimer("7_Post_Process", records):
|
||||||
|
batch_images_cpu = batch_images.cpu()
|
||||||
|
actions_cpu = actions.cpu()
|
||||||
|
states_cpu = states.cpu()
|
||||||
|
# Simulate video save overhead: clamp + uint8 conversion
|
||||||
|
_ = (batch_images_cpu.clamp(-1, 1) * 127.5 + 127.5).to(torch.uint8)
|
||||||
|
|
||||||
|
# Flatten single-element lists
|
||||||
|
stage_times = {k: v[0] for k, v in records.items()}
|
||||||
|
|
||||||
|
bandwidth_info = {
|
||||||
|
"vae_enc_input_bytes": vae_enc_input_bytes,
|
||||||
|
"vae_enc_output_bytes": vae_enc_output_bytes,
|
||||||
|
"vae_dec_input_bytes": vae_dec_input_bytes,
|
||||||
|
"vae_dec_output_bytes": vae_dec_output_bytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
return stage_times, hook_by_type, hook_by_instance, hook_by_block, bandwidth_info
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reporting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def print_stage_timing(all_runs_stages):
|
||||||
|
"""Table 1: Stage Timing — name | mean(ms) | std | percent."""
|
||||||
|
import numpy as np
|
||||||
|
stage_names = list(all_runs_stages[0].keys())
|
||||||
|
means = {}
|
||||||
|
stds = {}
|
||||||
|
for name in stage_names:
|
||||||
|
vals = [run[name] for run in all_runs_stages]
|
||||||
|
means[name] = np.mean(vals)
|
||||||
|
stds[name] = np.std(vals)
|
||||||
|
total = sum(means.values())
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 72)
|
||||||
|
print("TABLE 1: STAGE TIMING")
|
||||||
|
print("=" * 72)
|
||||||
|
print(f"{'Stage':<25} {'Mean(ms)':>10} {'Std':>10} {'%':>8}")
|
||||||
|
print("-" * 72)
|
||||||
|
for name in stage_names:
|
||||||
|
pct = means[name] / total * 100 if total > 0 else 0
|
||||||
|
print(f"{name:<25} {means[name]:>10.1f} {stds[name]:>10.2f} {pct:>7.1f}%")
|
||||||
|
print("-" * 72)
|
||||||
|
print(f"{'TOTAL':<25} {total:>10.1f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def print_unet_breakdown(all_runs_hooks):
|
||||||
|
"""Table 2: UNet Sub-Module Breakdown — type | total | count | per-call | percent."""
|
||||||
|
import numpy as np
|
||||||
|
# Aggregate across runs
|
||||||
|
agg = defaultdict(lambda: {"totals": [], "counts": []})
|
||||||
|
for hook_by_type in all_runs_hooks:
|
||||||
|
for tag, data in hook_by_type.items():
|
||||||
|
agg[tag]["totals"].append(data["total_ms"])
|
||||||
|
agg[tag]["counts"].append(data["count"])
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("TABLE 2: UNET SUB-MODULE BREAKDOWN")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"{'Module Type':<25} {'Total(ms)':>10} {'Count':>7} {'Per-call':>10} {'%':>8}")
|
||||||
|
print("-" * 80)
|
||||||
|
|
||||||
|
grand_total = 0
|
||||||
|
rows = []
|
||||||
|
for tag, d in agg.items():
|
||||||
|
mean_total = np.mean(d["totals"])
|
||||||
|
mean_count = np.mean(d["counts"])
|
||||||
|
per_call = mean_total / mean_count if mean_count > 0 else 0
|
||||||
|
grand_total += mean_total
|
||||||
|
rows.append((tag, mean_total, mean_count, per_call))
|
||||||
|
|
||||||
|
rows.sort(key=lambda r: r[1], reverse=True)
|
||||||
|
for tag, mean_total, mean_count, per_call in rows:
|
||||||
|
pct = mean_total / grand_total * 100 if grand_total > 0 else 0
|
||||||
|
print(f"{tag:<25} {mean_total:>10.1f} {int(mean_count):>7} {per_call:>10.2f} {pct:>7.1f}%")
|
||||||
|
print("-" * 80)
|
||||||
|
print(f"{'TOTAL (hooked)':<25} {grand_total:>10.1f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def print_block_timing(all_runs_blocks):
|
||||||
|
"""Table 2b: Per-UNet-block timing — which blocks are hottest."""
|
||||||
|
import numpy as np
|
||||||
|
# Aggregate: {block_loc: {tag: [total_ms_per_run, ...]}}
|
||||||
|
agg = defaultdict(lambda: defaultdict(list))
|
||||||
|
for by_block in all_runs_blocks:
|
||||||
|
for block_loc, tag_dict in by_block.items():
|
||||||
|
for tag, data in tag_dict.items():
|
||||||
|
agg[block_loc][tag].append(data["total_ms"])
|
||||||
|
|
||||||
|
# Compute per-block totals
|
||||||
|
block_totals = {}
|
||||||
|
for block_loc, tag_dict in agg.items():
|
||||||
|
block_totals[block_loc] = sum(np.mean(v) for v in tag_dict.values())
|
||||||
|
|
||||||
|
grand_total = sum(block_totals.values())
|
||||||
|
|
||||||
|
# Sort blocks in logical order
|
||||||
|
def block_sort_key(name):
|
||||||
|
if name.startswith("input_blocks."):
|
||||||
|
return (0, int(name.split('.')[1]))
|
||||||
|
elif name == "middle_block":
|
||||||
|
return (1, 0)
|
||||||
|
elif name.startswith("output_blocks."):
|
||||||
|
return (2, int(name.split('.')[1]))
|
||||||
|
elif name == "out":
|
||||||
|
return (3, 0)
|
||||||
|
elif name == "action_unet":
|
||||||
|
return (4, 0)
|
||||||
|
elif name == "state_unet":
|
||||||
|
return (5, 0)
|
||||||
|
return (9, 0)
|
||||||
|
|
||||||
|
sorted_blocks = sorted(block_totals.keys(), key=block_sort_key)
|
||||||
|
|
||||||
|
print("=" * 90)
|
||||||
|
print("TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)")
|
||||||
|
print("=" * 90)
|
||||||
|
print(f"{'Block':<22} {'Total(ms)':>10} {'%':>7} Breakdown")
|
||||||
|
print("-" * 90)
|
||||||
|
|
||||||
|
for block_loc in sorted_blocks:
|
||||||
|
total = block_totals[block_loc]
|
||||||
|
pct = total / grand_total * 100 if grand_total > 0 else 0
|
||||||
|
# Build breakdown string
|
||||||
|
parts = []
|
||||||
|
for tag, vals in sorted(agg[block_loc].items(),
|
||||||
|
key=lambda x: np.mean(x[1]), reverse=True):
|
||||||
|
parts.append(f"{tag}={np.mean(vals):.0f}")
|
||||||
|
breakdown = ", ".join(parts)
|
||||||
|
print(f"{block_loc:<22} {total:>10.1f} {pct:>6.1f}% {breakdown}")
|
||||||
|
|
||||||
|
print("-" * 90)
|
||||||
|
print(f"{'TOTAL':<22} {grand_total:>10.1f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def print_attn_ff_breakdown(all_runs_hooks):
|
||||||
|
"""Table 2c: CrossAttention vs FeedForward breakdown (--deep mode)."""
|
||||||
|
import numpy as np
|
||||||
|
agg = defaultdict(list)
|
||||||
|
for hook_by_type in all_runs_hooks:
|
||||||
|
for tag, data in hook_by_type.items():
|
||||||
|
if tag in ("CrossAttention", "FeedForward"):
|
||||||
|
agg[tag].append(data["total_ms"])
|
||||||
|
|
||||||
|
if not agg:
|
||||||
|
return
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)")
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"{'Component':<25} {'Total(ms)':>10} {'%':>8}")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
grand = 0
|
||||||
|
rows = []
|
||||||
|
for tag in ("CrossAttention", "FeedForward"):
|
||||||
|
if tag in agg:
|
||||||
|
mean_t = np.mean(agg[tag])
|
||||||
|
grand += mean_t
|
||||||
|
rows.append((tag, mean_t))
|
||||||
|
|
||||||
|
for tag, mean_t in rows:
|
||||||
|
pct = mean_t / grand * 100 if grand > 0 else 0
|
||||||
|
print(f"{tag:<25} {mean_t:>10.1f} {pct:>7.1f}%")
|
||||||
|
print("-" * 70)
|
||||||
|
print(f"{'TOTAL (attn+ff)':<25} {grand:>10.1f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def print_unet_detailed(all_runs_instances):
|
||||||
|
"""Print per-instance UNet sub-module detail (--detailed mode)."""
|
||||||
|
import numpy as np
|
||||||
|
# Use last run's data
|
||||||
|
by_instance = all_runs_instances[-1]
|
||||||
|
print("=" * 100)
|
||||||
|
print("DETAILED: PER-INSTANCE UNET SUB-MODULE TIMING (last run)")
|
||||||
|
print("=" * 100)
|
||||||
|
print(f"{'Type':<22} {'Module Name':<45} {'Calls':>6} {'Total(ms)':>10} {'Mean(ms)':>10}")
|
||||||
|
print("-" * 100)
|
||||||
|
|
||||||
|
rows = []
|
||||||
|
for (tag, mod_name), times in by_instance.items():
|
||||||
|
if len(times) == 0:
|
||||||
|
continue
|
||||||
|
total = sum(times)
|
||||||
|
mean = np.mean(times)
|
||||||
|
rows.append((tag, mod_name, len(times), total, mean))
|
||||||
|
rows.sort(key=lambda r: r[3], reverse=True)
|
||||||
|
|
||||||
|
for tag, mod_name, count, total, mean in rows:
|
||||||
|
short_name = mod_name[-42:] if len(mod_name) > 42 else mod_name
|
||||||
|
print(f"{tag:<22} {short_name:<45} {count:>6} {total:>10.2f} {mean:>10.3f}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def print_memory_summary(mem_before, mem_peak):
|
||||||
|
"""Table 3: Memory Summary."""
|
||||||
|
delta = mem_peak - mem_before
|
||||||
|
print("=" * 50)
|
||||||
|
print("TABLE 3: MEMORY SUMMARY")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f" Initial allocated: {mem_before / 1e9:.2f} GB")
|
||||||
|
print(f" Peak allocated: {mem_peak / 1e9:.2f} GB")
|
||||||
|
print(f" Delta (pipeline): {delta / 1e9:.2f} GB")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def print_throughput(all_runs_stages, all_bw, ddim_steps, cfg_scale):
|
||||||
|
"""Table 4: Throughput — total latency, per-step, per-UNet-forward, VAE bandwidth."""
|
||||||
|
import numpy as np
|
||||||
|
n_runs = len(all_runs_stages)
|
||||||
|
|
||||||
|
# Total latency
|
||||||
|
totals = []
|
||||||
|
for run in all_runs_stages:
|
||||||
|
totals.append(sum(run.values()))
|
||||||
|
mean_total = np.mean(totals)
|
||||||
|
|
||||||
|
# DDIM loop time
|
||||||
|
ddim_times = [run["5_DDIM_Loop"] for run in all_runs_stages]
|
||||||
|
mean_ddim = np.mean(ddim_times)
|
||||||
|
|
||||||
|
unet_calls = ddim_steps if cfg_scale == 1.0 else ddim_steps * 2
|
||||||
|
per_step = mean_ddim / ddim_steps
|
||||||
|
per_unet = mean_ddim / unet_calls
|
||||||
|
|
||||||
|
# VAE bandwidth
|
||||||
|
mean_enc_time = np.mean([run["2_VAE_Encode"] for run in all_runs_stages])
|
||||||
|
mean_dec_time = np.mean([run["6_VAE_Decode"] for run in all_runs_stages])
|
||||||
|
|
||||||
|
bw = all_bw[-1] # use last run's byte counts
|
||||||
|
enc_bytes = bw["vae_enc_input_bytes"] + bw["vae_enc_output_bytes"]
|
||||||
|
dec_bytes = bw["vae_dec_input_bytes"] + bw["vae_dec_output_bytes"]
|
||||||
|
enc_bw = enc_bytes / (mean_enc_time / 1000) / 1e9 if mean_enc_time > 0 else 0
|
||||||
|
dec_bw = dec_bytes / (mean_dec_time / 1000) / 1e9 if mean_dec_time > 0 else 0
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("TABLE 4: THROUGHPUT")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f" Total pipeline latency: {mean_total:.1f} ms")
|
||||||
|
print(f" DDIM loop latency: {mean_ddim:.1f} ms")
|
||||||
|
print(f" DDIM steps: {ddim_steps}")
|
||||||
|
print(f" CFG scale: {cfg_scale} ({'2x UNet/step' if cfg_scale != 1.0 else '1x UNet/step'})")
|
||||||
|
print(f" UNet forward calls: {unet_calls}")
|
||||||
|
print(f" Per DDIM step: {per_step:.1f} ms")
|
||||||
|
print(f" Per UNet forward: {per_unet:.1f} ms")
|
||||||
|
print(f" VAE encode bandwidth: {enc_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
|
||||||
|
print(f" VAE decode bandwidth: {dec_bw:.1f} GB/s (peak HBM: {MEM_BW_GBS} GB/s)")
|
||||||
|
print(f" GPU BF16 peak: {PEAK_BF16_TFLOPS} TFLOPS")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
def main():
|
||||||
|
patch_norm_bypass_autocast()
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Profile the full inference pipeline")
|
||||||
|
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||||
|
parser.add_argument("--config", type=str, required=True)
|
||||||
|
parser.add_argument("--ddim_steps", type=int, default=50)
|
||||||
|
parser.add_argument("--cfg_scale", type=float, default=1.0)
|
||||||
|
parser.add_argument("--n_runs", type=int, default=3)
|
||||||
|
parser.add_argument("--warmup", type=int, default=1)
|
||||||
|
parser.add_argument("--detailed", action="store_true",
|
||||||
|
help="Print per-instance UNet sub-module detail")
|
||||||
|
parser.add_argument("--deep", action="store_true",
|
||||||
|
help="Enable deep DDIM analysis: per-block, attn vs ff")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
noise_shape = [1, 4, 16, 40, 64]
|
||||||
|
|
||||||
|
# --- Load model ---
|
||||||
|
print("Loading model...")
|
||||||
|
model = load_model(args)
|
||||||
|
observation, prompts = build_dummy_inputs(model, noise_shape)
|
||||||
|
|
||||||
|
# --- Setup hook profiler ---
|
||||||
|
unet = model.model.diffusion_model
|
||||||
|
hook_profiler = HookProfiler(unet, deep=args.deep)
|
||||||
|
hook_profiler.register()
|
||||||
|
print(f"Registered hooks on {len(hook_profiler.handles)} sub-modules")
|
||||||
|
|
||||||
|
# --- Warmup ---
|
||||||
|
print(f"Warmup: {args.warmup} run(s)...")
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(args.warmup):
|
||||||
|
run_pipeline(model, observation, prompts, noise_shape,
|
||||||
|
args.ddim_steps, args.cfg_scale, hook_profiler)
|
||||||
|
print(f" warmup {i+1}/{args.warmup} done")
|
||||||
|
|
||||||
|
# --- Measurement runs ---
|
||||||
|
print(f"Measuring: {args.n_runs} run(s)...")
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
mem_before = torch.cuda.memory_allocated()
|
||||||
|
|
||||||
|
all_stages = []
|
||||||
|
all_hooks = []
|
||||||
|
all_instances = []
|
||||||
|
all_blocks = []
|
||||||
|
all_bw = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(args.n_runs):
|
||||||
|
stage_times, hook_by_type, hook_by_instance, hook_by_block, bw = run_pipeline(
|
||||||
|
model, observation, prompts, noise_shape,
|
||||||
|
args.ddim_steps, args.cfg_scale, hook_profiler)
|
||||||
|
all_stages.append(stage_times)
|
||||||
|
all_hooks.append(hook_by_type)
|
||||||
|
all_instances.append(hook_by_instance)
|
||||||
|
all_blocks.append(hook_by_block)
|
||||||
|
all_bw.append(bw)
|
||||||
|
total = sum(stage_times.values())
|
||||||
|
print(f" run {i+1}/{args.n_runs}: {total:.1f} ms total")
|
||||||
|
|
||||||
|
mem_peak = torch.cuda.max_memory_allocated()
|
||||||
|
|
||||||
|
# --- Reports ---
|
||||||
|
print_stage_timing(all_stages)
|
||||||
|
print_unet_breakdown(all_hooks)
|
||||||
|
print_block_timing(all_blocks)
|
||||||
|
if args.deep:
|
||||||
|
print_attn_ff_breakdown(all_hooks)
|
||||||
|
if args.detailed:
|
||||||
|
print_unet_detailed(all_instances)
|
||||||
|
print_memory_summary(mem_before, mem_peak)
|
||||||
|
print_throughput(all_stages, all_bw, args.ddim_steps, args.cfg_scale)
|
||||||
|
|
||||||
|
hook_profiler.remove()
|
||||||
|
print("Done.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
287
scripts/evaluation/profile_unet.py
Normal file
287
scripts/evaluation/profile_unet.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
"""
|
||||||
|
Profile one DDIM sampling iteration to capture all matmul/attention ops,
|
||||||
|
their matrix sizes, wall time, and compute utilization.
|
||||||
|
|
||||||
|
Uses torch.profiler for CUDA timing and FlopCounterMode for accurate
|
||||||
|
FLOPS counting (works on ROCm where Tensile kernels don't report FLOPS).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 \
|
||||||
|
python scripts/evaluation/profile_unet.py \
|
||||||
|
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||||
|
--config configs/inference/world_model_interaction.yaml
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from collections import OrderedDict, defaultdict
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from torch.utils.flop_counter import FlopCounterMode
|
||||||
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def patch_norm_bypass_autocast():
|
||||||
|
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy."""
|
||||||
|
|
||||||
|
def _group_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.group_norm(
|
||||||
|
x, self.num_groups,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
def _layer_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.layer_norm(
|
||||||
|
x, self.normalized_shape,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||||
|
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||||
|
|
||||||
|
|
||||||
|
# --- W7900D theoretical peak (TFLOPS) ---
|
||||||
|
PEAK_BF16_TFLOPS = 61.0
|
||||||
|
PEAK_FP32_TFLOPS = 30.5
|
||||||
|
|
||||||
|
|
||||||
|
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||||
|
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||||
|
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||||
|
unet = model.model.diffusion_model
|
||||||
|
compiled = 0
|
||||||
|
for idx in hot_indices:
|
||||||
|
block = unet.output_blocks[idx]
|
||||||
|
for layer in block:
|
||||||
|
if isinstance(layer, ResBlock):
|
||||||
|
layer._forward = torch.compile(layer._forward, mode="default")
|
||||||
|
compiled += 1
|
||||||
|
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(args):
|
||||||
|
config = OmegaConf.load(args.config)
|
||||||
|
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
|
||||||
|
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||||
|
if "state_dict" in state_dict:
|
||||||
|
state_dict = state_dict["state_dict"]
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
model.model.to(torch.bfloat16)
|
||||||
|
apply_torch_compile(model)
|
||||||
|
model = model.cuda()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_call_kwargs(model, noise_shape):
|
||||||
|
"""Build dummy inputs matching the hybrid conditioning forward signature."""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
B, C, T, H, W = noise_shape # [1, 4, 16, 40, 64]
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
x_action = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||||
|
x_state = torch.randn(B, 16, 16, device=device, dtype=dtype)
|
||||||
|
timesteps = torch.tensor([500], device=device, dtype=torch.long)
|
||||||
|
context = torch.randn(B, 351, 1024, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
obs_images = torch.randn(B, 3, 2, 320, 512, device=device, dtype=dtype)
|
||||||
|
obs_state = torch.randn(B, 2, 16, device=device, dtype=dtype)
|
||||||
|
context_action = [obs_images, obs_state, True, False]
|
||||||
|
fps = torch.tensor([1], device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
x_raw = torch.randn(B, C, T, H, W, device=device, dtype=dtype)
|
||||||
|
c_concat = [torch.randn(B, C, T, H, W, device=device, dtype=dtype)]
|
||||||
|
|
||||||
|
return dict(
|
||||||
|
x=x_raw, x_action=x_action, x_state=x_state, t=timesteps,
|
||||||
|
c_concat=c_concat, c_crossattn=[context],
|
||||||
|
c_crossattn_action=context_action, s=fps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def profile_one_step(model, noise_shape):
|
||||||
|
"""Run one UNet forward pass under torch.profiler for CUDA timing."""
|
||||||
|
diff_wrapper = model.model
|
||||||
|
call_kwargs = build_call_kwargs(model, noise_shape)
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
||||||
|
# Warmup
|
||||||
|
for _ in range(2):
|
||||||
|
_ = diff_wrapper(**call_kwargs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
with torch.profiler.profile(
|
||||||
|
activities=[torch.profiler.ProfilerActivity.CUDA],
|
||||||
|
record_shapes=True,
|
||||||
|
with_flops=True,
|
||||||
|
) as prof:
|
||||||
|
_ = diff_wrapper(**call_kwargs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return prof
|
||||||
|
|
||||||
|
|
||||||
|
def count_flops(model, noise_shape):
|
||||||
|
"""Run one UNet forward pass under FlopCounterMode for accurate FLOPS."""
|
||||||
|
diff_wrapper = model.model
|
||||||
|
call_kwargs = build_call_kwargs(model, noise_shape)
|
||||||
|
|
||||||
|
with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
|
||||||
|
flop_counter = FlopCounterMode(display=False)
|
||||||
|
with flop_counter:
|
||||||
|
_ = diff_wrapper(**call_kwargs)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
return flop_counter
|
||||||
|
|
||||||
|
|
||||||
|
def print_report(prof, flop_counter):
|
||||||
|
"""Parse profiler results and print a structured report with accurate FLOPS."""
|
||||||
|
events = prof.key_averages()
|
||||||
|
|
||||||
|
# --- Extract per-operator FLOPS from FlopCounterMode ---
|
||||||
|
# flop_counts is {module_name: {op_name: count}}; use only "Global" to avoid double-counting
|
||||||
|
flop_by_op = {}
|
||||||
|
flop_by_module = {}
|
||||||
|
if hasattr(flop_counter, 'flop_counts'):
|
||||||
|
# Per-op: only from top-level "Global" entry (no parent/child duplication)
|
||||||
|
global_ops = flop_counter.flop_counts.get("Global", {})
|
||||||
|
for op_name, flop_count in global_ops.items():
|
||||||
|
key = str(op_name).split('.')[-1]
|
||||||
|
flop_by_op[key] = flop_by_op.get(key, 0) + flop_count
|
||||||
|
|
||||||
|
# Per-module: collect all, skip "Global" and top-level wrapper duplicates
|
||||||
|
for module_name, op_dict in flop_counter.flop_counts.items():
|
||||||
|
module_total = sum(op_dict.values())
|
||||||
|
if module_total > 0:
|
||||||
|
flop_by_module[module_name] = module_total
|
||||||
|
|
||||||
|
total_counted_flops = flop_counter.get_total_flops()
|
||||||
|
|
||||||
|
# Collect matmul-like ops
|
||||||
|
matmul_ops = []
|
||||||
|
other_ops = []
|
||||||
|
|
||||||
|
for evt in events:
|
||||||
|
if evt.device_time_total <= 0:
|
||||||
|
continue
|
||||||
|
name = evt.key
|
||||||
|
is_matmul = any(k in name.lower() for k in
|
||||||
|
['mm', 'gemm', 'addmm', 'bmm', 'einsum', 'dot', 'linear'])
|
||||||
|
entry = {
|
||||||
|
'name': name,
|
||||||
|
'input_shapes': str(evt.input_shapes) if evt.input_shapes else '',
|
||||||
|
'cuda_time_ms': evt.device_time_total / 1000.0,
|
||||||
|
'count': evt.count,
|
||||||
|
'flops': evt.flops if evt.flops else 0,
|
||||||
|
}
|
||||||
|
if is_matmul:
|
||||||
|
matmul_ops.append(entry)
|
||||||
|
else:
|
||||||
|
other_ops.append(entry)
|
||||||
|
|
||||||
|
# Sort by CUDA time
|
||||||
|
matmul_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
||||||
|
other_ops.sort(key=lambda x: x['cuda_time_ms'], reverse=True)
|
||||||
|
|
||||||
|
total_cuda_ms = sum(e['cuda_time_ms'] for e in matmul_ops + other_ops)
|
||||||
|
total_matmul_ms = sum(e['cuda_time_ms'] for e in matmul_ops)
|
||||||
|
# --- Print matmul ops ---
|
||||||
|
print("=" * 130)
|
||||||
|
print("MATMUL / LINEAR OPS (sorted by CUDA time)")
|
||||||
|
print("=" * 130)
|
||||||
|
print(f"{'Op':>35} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
||||||
|
print("-" * 130)
|
||||||
|
|
||||||
|
for op in matmul_ops:
|
||||||
|
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
||||||
|
print(f"{op['name']:>35} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
||||||
|
|
||||||
|
# --- Print top non-matmul ops ---
|
||||||
|
print()
|
||||||
|
print("=" * 130)
|
||||||
|
print("TOP NON-MATMUL OPS (sorted by CUDA time)")
|
||||||
|
print("=" * 130)
|
||||||
|
print(f"{'Op':>40} | {'Count':>5} | {'CUDA(ms)':>10} | Shapes")
|
||||||
|
print("-" * 130)
|
||||||
|
|
||||||
|
for op in other_ops[:20]:
|
||||||
|
shapes_str = op['input_shapes'][:60] if op['input_shapes'] else ''
|
||||||
|
print(f"{op['name']:>40} | {op['count']:>5} | {op['cuda_time_ms']:>10.3f} | {shapes_str}")
|
||||||
|
|
||||||
|
# --- FlopCounterMode per-operator breakdown ---
|
||||||
|
print()
|
||||||
|
print("=" * 130)
|
||||||
|
print("FLOPS BY ATen OPERATOR (FlopCounterMode)")
|
||||||
|
print("=" * 130)
|
||||||
|
print(f"{'ATen Op':>25} | {'GFLOPS':>12} | {'% of Total':>10}")
|
||||||
|
print("-" * 55)
|
||||||
|
|
||||||
|
sorted_flop_ops = sorted(flop_by_op.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
for op_name, flops in sorted_flop_ops:
|
||||||
|
gflops = flops / 1e9
|
||||||
|
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
||||||
|
print(f"{op_name:>25} | {gflops:>12.2f} | {pct:>9.1f}%")
|
||||||
|
|
||||||
|
# --- FlopCounterMode per-module breakdown ---
|
||||||
|
if flop_by_module:
|
||||||
|
print()
|
||||||
|
print("=" * 130)
|
||||||
|
print("FLOPS BY MODULE (FlopCounterMode)")
|
||||||
|
print("=" * 130)
|
||||||
|
print(f"{'Module':>60} | {'GFLOPS':>12} | {'% of Total':>10}")
|
||||||
|
print("-" * 90)
|
||||||
|
|
||||||
|
sorted_modules = sorted(flop_by_module.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
for mod_name, flops in sorted_modules[:30]:
|
||||||
|
gflops = flops / 1e9
|
||||||
|
pct = flops / total_counted_flops * 100 if total_counted_flops > 0 else 0
|
||||||
|
name_str = mod_name[-57:] if len(mod_name) > 57 else mod_name
|
||||||
|
print(f"{name_str:>60} | {gflops:>12.2f} | {pct:>9.1f}%")
|
||||||
|
|
||||||
|
# --- Summary ---
|
||||||
|
print()
|
||||||
|
print("=" * 130)
|
||||||
|
print("SUMMARY")
|
||||||
|
print("=" * 130)
|
||||||
|
print(f" Total CUDA time: {total_cuda_ms:.1f} ms")
|
||||||
|
print(f" Matmul CUDA time: {total_matmul_ms:.1f} ms ({total_matmul_ms/total_cuda_ms*100:.1f}%)")
|
||||||
|
print(f" Non-matmul CUDA time: {total_cuda_ms - total_matmul_ms:.1f} ms ({(total_cuda_ms-total_matmul_ms)/total_cuda_ms*100:.1f}%)")
|
||||||
|
print(f" Total FLOPS (FlopCounter): {total_counted_flops/1e9:.2f} GFLOPS")
|
||||||
|
if total_matmul_ms > 0 and total_counted_flops > 0:
|
||||||
|
avg_tflops = total_counted_flops / (total_matmul_ms / 1000.0) / 1e12
|
||||||
|
avg_util = avg_tflops / PEAK_BF16_TFLOPS * 100
|
||||||
|
overall_tflops = total_counted_flops / (total_cuda_ms / 1000.0) / 1e12
|
||||||
|
overall_util = overall_tflops / PEAK_BF16_TFLOPS * 100
|
||||||
|
print(f" Matmul throughput: {avg_tflops:.2f} TFLOPS/s ({avg_util:.1f}% of BF16 peak)")
|
||||||
|
print(f" Overall throughput: {overall_tflops:.2f} TFLOPS/s ({overall_util:.1f}% of BF16 peak)")
|
||||||
|
print(f" GPU peak (BF16): {PEAK_BF16_TFLOPS} TFLOPS")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
patch_norm_bypass_autocast()
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ckpt_path", type=str, required=True)
|
||||||
|
parser.add_argument("--config", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("Loading model...")
|
||||||
|
model = load_model(args)
|
||||||
|
|
||||||
|
noise_shape = [1, 4, 16, 40, 64]
|
||||||
|
|
||||||
|
print(f"Profiling UNet forward pass with shape {noise_shape}...")
|
||||||
|
prof = profile_one_step(model, noise_shape)
|
||||||
|
|
||||||
|
print("Counting FLOPS with FlopCounterMode...")
|
||||||
|
flop_counter = count_flops(model, noise_shape)
|
||||||
|
|
||||||
|
print_report(prof, flop_counter)
|
||||||
@@ -19,9 +19,6 @@ from fastapi.responses import JSONResponse
|
|||||||
from typing import Any, Dict, Optional, Tuple, List
|
from typing import Any, Dict, Optional, Tuple, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import argparse, os, glob
|
import argparse, os, glob
|
||||||
|
from contextlib import nullcontext
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import random
|
import random
|
||||||
import torch
|
import torch
|
||||||
@@ -9,8 +10,6 @@ import logging
|
|||||||
import einops
|
import einops
|
||||||
import warnings
|
import warnings
|
||||||
import imageio
|
import imageio
|
||||||
import atexit
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@@ -18,18 +17,39 @@ from tqdm import tqdm
|
|||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from eval_utils import populate_queues
|
from eval_utils import populate_queues, log_to_tensorboard
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Optional, List, Any
|
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def patch_norm_bypass_autocast():
|
||||||
|
"""Monkey-patch GroupNorm and LayerNorm to bypass autocast's fp32 policy.
|
||||||
|
This eliminates bf16->fp32->bf16 dtype conversions during UNet forward."""
|
||||||
|
|
||||||
|
def _group_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.group_norm(
|
||||||
|
x, self.num_groups,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
def _layer_norm_forward(self, x):
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.layer_norm(
|
||||||
|
x, self.normalized_shape,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
torch.nn.GroupNorm.forward = _group_norm_forward
|
||||||
|
torch.nn.LayerNorm.forward = _layer_norm_forward
|
||||||
|
|
||||||
|
|
||||||
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
def get_device_from_parameters(module: nn.Module) -> torch.device:
|
||||||
@@ -44,6 +64,92 @@ def get_device_from_parameters(module: nn.Module) -> torch.device:
|
|||||||
return next(iter(module.parameters())).device
|
return next(iter(module.parameters())).device
|
||||||
|
|
||||||
|
|
||||||
|
def apply_precision_settings(model: nn.Module, args: argparse.Namespace) -> nn.Module:
|
||||||
|
"""Apply precision settings to model components based on command-line arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to apply precision settings to.
|
||||||
|
args (argparse.Namespace): Parsed command-line arguments containing precision settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
nn.Module: Model with precision settings applied.
|
||||||
|
"""
|
||||||
|
print(f">>> Applying precision settings:")
|
||||||
|
print(f" - Diffusion dtype: {args.diffusion_dtype}")
|
||||||
|
print(f" - Projector mode: {args.projector_mode}")
|
||||||
|
print(f" - Encoder mode: {args.encoder_mode}")
|
||||||
|
print(f" - VAE dtype: {args.vae_dtype}")
|
||||||
|
|
||||||
|
# 1. Set Diffusion backbone precision
|
||||||
|
if args.diffusion_dtype == "bf16":
|
||||||
|
# Convert diffusion model weights to bf16
|
||||||
|
model.model.to(torch.bfloat16)
|
||||||
|
model.diffusion_autocast_dtype = torch.bfloat16
|
||||||
|
print(" ✓ Diffusion model weights converted to bfloat16")
|
||||||
|
else:
|
||||||
|
model.diffusion_autocast_dtype = torch.bfloat16
|
||||||
|
print(" ✓ Diffusion model using fp32")
|
||||||
|
|
||||||
|
# 2. Set Projector precision
|
||||||
|
if args.projector_mode == "bf16_full":
|
||||||
|
model.state_projector.to(torch.bfloat16)
|
||||||
|
model.action_projector.to(torch.bfloat16)
|
||||||
|
model.projector_autocast_dtype = None
|
||||||
|
print(" ✓ Projectors converted to bfloat16")
|
||||||
|
elif args.projector_mode == "autocast":
|
||||||
|
model.projector_autocast_dtype = torch.bfloat16
|
||||||
|
print(" ✓ Projectors will use autocast (weights fp32, compute bf16)")
|
||||||
|
else:
|
||||||
|
model.projector_autocast_dtype = None
|
||||||
|
# fp32 mode: do nothing, keep original precision
|
||||||
|
|
||||||
|
# 3. Set Encoder precision
|
||||||
|
if args.encoder_mode == "bf16_full":
|
||||||
|
model.embedder.to(torch.bfloat16)
|
||||||
|
model.image_proj_model.to(torch.bfloat16)
|
||||||
|
model.encoder_autocast_dtype = None
|
||||||
|
print(" ✓ Encoders converted to bfloat16")
|
||||||
|
elif args.encoder_mode == "autocast":
|
||||||
|
model.encoder_autocast_dtype = torch.bfloat16
|
||||||
|
print(" ✓ Encoders will use autocast (weights fp32, compute bf16)")
|
||||||
|
else:
|
||||||
|
model.encoder_autocast_dtype = None
|
||||||
|
# fp32 mode: do nothing, keep original precision
|
||||||
|
|
||||||
|
# 4. Set VAE precision
|
||||||
|
if args.vae_dtype == "bf16":
|
||||||
|
model.first_stage_model.to(torch.bfloat16)
|
||||||
|
print(" ✓ VAE converted to bfloat16")
|
||||||
|
else:
|
||||||
|
print(" ✓ VAE kept in fp32 for best quality")
|
||||||
|
|
||||||
|
# 5. Safety net: ensure no fp32 parameters remain when all components are bf16
|
||||||
|
if args.diffusion_dtype == "bf16":
|
||||||
|
fp32_params = [(n, p) for n, p in model.named_parameters() if p.dtype == torch.float32]
|
||||||
|
if fp32_params:
|
||||||
|
print(f" ⚠ Found {len(fp32_params)} fp32 params, converting to bf16")
|
||||||
|
for name, param in fp32_params:
|
||||||
|
param.data = param.data.to(torch.bfloat16)
|
||||||
|
print(" ✓ All parameters converted to bfloat16")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def apply_torch_compile(model, hot_indices=(5, 8, 9)):
|
||||||
|
"""Compile ResBlock._forward in the hottest output_blocks for operator fusion."""
|
||||||
|
from unifolm_wma.modules.networks.wma_model import ResBlock
|
||||||
|
unet = model.model.diffusion_model
|
||||||
|
compiled = 0
|
||||||
|
for idx in hot_indices:
|
||||||
|
block = unet.output_blocks[idx]
|
||||||
|
for layer in block:
|
||||||
|
if isinstance(layer, ResBlock):
|
||||||
|
layer._forward = torch.compile(layer._forward, mode="default")
|
||||||
|
compiled += 1
|
||||||
|
print(f" ✓ torch.compile: {compiled} ResBlocks in output_blocks{list(hot_indices)}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
|
||||||
"""Save a list of frames to a video file.
|
"""Save a list of frames to a video file.
|
||||||
|
|
||||||
@@ -156,81 +262,6 @@ def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
|
|||||||
options={'crf': '10'})
|
options={'crf': '10'})
|
||||||
|
|
||||||
|
|
||||||
# ========== Async I/O ==========
|
|
||||||
_io_executor: Optional[ThreadPoolExecutor] = None
|
|
||||||
_io_futures: List[Any] = []
|
|
||||||
|
|
||||||
|
|
||||||
def _get_io_executor() -> ThreadPoolExecutor:
|
|
||||||
global _io_executor
|
|
||||||
if _io_executor is None:
|
|
||||||
_io_executor = ThreadPoolExecutor(max_workers=2)
|
|
||||||
return _io_executor
|
|
||||||
|
|
||||||
|
|
||||||
def _flush_io():
|
|
||||||
"""Wait for all pending async I/O to finish."""
|
|
||||||
global _io_futures
|
|
||||||
for fut in _io_futures:
|
|
||||||
try:
|
|
||||||
fut.result()
|
|
||||||
except Exception as e:
|
|
||||||
print(f">>> [async I/O] error: {e}")
|
|
||||||
_io_futures.clear()
|
|
||||||
|
|
||||||
|
|
||||||
atexit.register(_flush_io)
|
|
||||||
|
|
||||||
|
|
||||||
def _save_results_sync(video_cpu: Tensor, filename: str, fps: int) -> None:
|
|
||||||
"""Synchronous save on CPU tensor (runs in background thread)."""
|
|
||||||
video = torch.clamp(video_cpu.float(), -1., 1.)
|
|
||||||
n = video.shape[0]
|
|
||||||
video = video.permute(2, 0, 1, 3, 4)
|
|
||||||
frame_grids = [
|
|
||||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
||||||
for framesheet in video
|
|
||||||
]
|
|
||||||
grid = torch.stack(frame_grids, dim=0)
|
|
||||||
grid = (grid + 1.0) / 2.0
|
|
||||||
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
|
|
||||||
torchvision.io.write_video(filename,
|
|
||||||
grid,
|
|
||||||
fps=fps,
|
|
||||||
video_codec='h264',
|
|
||||||
options={'crf': '10'})
|
|
||||||
|
|
||||||
|
|
||||||
def save_results_async(video: Tensor, filename: str, fps: int = 8) -> None:
|
|
||||||
"""Submit video saving to background thread pool."""
|
|
||||||
video_cpu = video.detach().cpu()
|
|
||||||
fut = _get_io_executor().submit(_save_results_sync, video_cpu, filename, fps)
|
|
||||||
_io_futures.append(fut)
|
|
||||||
|
|
||||||
|
|
||||||
def _log_to_tb_sync(writer, video_cpu: Tensor, tag: str, fps: int) -> None:
|
|
||||||
"""Synchronous TensorBoard log on CPU tensor (runs in background thread)."""
|
|
||||||
if video_cpu.dim() == 5:
|
|
||||||
n = video_cpu.shape[0]
|
|
||||||
video = video_cpu.permute(2, 0, 1, 3, 4)
|
|
||||||
frame_grids = [
|
|
||||||
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
|
|
||||||
for framesheet in video
|
|
||||||
]
|
|
||||||
grid = torch.stack(frame_grids, dim=0)
|
|
||||||
grid = (grid + 1.0) / 2.0
|
|
||||||
grid = grid.unsqueeze(dim=0)
|
|
||||||
writer.add_video(tag, grid, fps=fps)
|
|
||||||
|
|
||||||
|
|
||||||
def log_to_tensorboard_async(writer, data: Tensor, tag: str, fps: int = 10) -> None:
|
|
||||||
"""Submit TensorBoard logging to background thread pool."""
|
|
||||||
if isinstance(data, torch.Tensor) and data.dim() == 5:
|
|
||||||
data_cpu = data.detach().cpu()
|
|
||||||
fut = _get_io_executor().submit(_log_to_tb_sync, writer, data_cpu, tag, fps)
|
|
||||||
_io_futures.append(fut)
|
|
||||||
|
|
||||||
|
|
||||||
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
def get_init_frame_path(data_dir: str, sample: dict) -> str:
|
||||||
"""Construct the init_frame path from directory and sample metadata.
|
"""Construct the init_frame path from directory and sample metadata.
|
||||||
|
|
||||||
@@ -343,6 +374,11 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
|||||||
"""
|
"""
|
||||||
b, c, t, h, w = videos.shape
|
b, c, t, h, w = videos.shape
|
||||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||||
|
|
||||||
|
# Auto-detect VAE dtype and convert input
|
||||||
|
vae_dtype = next(model.first_stage_model.parameters()).dtype
|
||||||
|
x = x.to(dtype=vae_dtype)
|
||||||
|
|
||||||
z = model.encode_first_stage(x)
|
z = model.encode_first_stage(x)
|
||||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||||
return z
|
return z
|
||||||
@@ -448,10 +484,22 @@ def image_guided_synthesis_sim_mode(
|
|||||||
|
|
||||||
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
|
||||||
|
|
||||||
|
# Auto-detect model dtype and convert inputs accordingly
|
||||||
|
model_dtype = next(model.embedder.parameters()).dtype
|
||||||
|
|
||||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:].to(dtype=model_dtype)
|
||||||
cond_img_emb = model.embedder(cond_img)
|
|
||||||
cond_img_emb = model.image_proj_model(cond_img_emb)
|
# Encoder autocast: weights stay fp32, compute in bf16
|
||||||
|
enc_ac_dtype = getattr(model, 'encoder_autocast_dtype', None)
|
||||||
|
if enc_ac_dtype is not None and model.device.type == 'cuda':
|
||||||
|
enc_ctx = torch.autocast('cuda', dtype=enc_ac_dtype)
|
||||||
|
else:
|
||||||
|
enc_ctx = nullcontext()
|
||||||
|
|
||||||
|
with enc_ctx:
|
||||||
|
cond_img_emb = model.embedder(cond_img)
|
||||||
|
cond_img_emb = model.image_proj_model(cond_img_emb)
|
||||||
|
|
||||||
if model.model.conditioning_key == 'hybrid':
|
if model.model.conditioning_key == 'hybrid':
|
||||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||||
@@ -465,11 +513,22 @@ def image_guided_synthesis_sim_mode(
|
|||||||
prompts = [""] * batch_size
|
prompts = [""] * batch_size
|
||||||
cond_ins_emb = model.get_learned_conditioning(prompts)
|
cond_ins_emb = model.get_learned_conditioning(prompts)
|
||||||
|
|
||||||
cond_state_emb = model.state_projector(observation['observation.state'])
|
# Auto-detect projector dtype and convert inputs
|
||||||
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
projector_dtype = next(model.state_projector.parameters()).dtype
|
||||||
|
|
||||||
cond_action_emb = model.action_projector(observation['action'])
|
# Projector autocast: weights stay fp32, compute in bf16
|
||||||
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
proj_ac_dtype = getattr(model, 'projector_autocast_dtype', None)
|
||||||
|
if proj_ac_dtype is not None and model.device.type == 'cuda':
|
||||||
|
proj_ctx = torch.autocast('cuda', dtype=proj_ac_dtype)
|
||||||
|
else:
|
||||||
|
proj_ctx = nullcontext()
|
||||||
|
|
||||||
|
with proj_ctx:
|
||||||
|
cond_state_emb = model.state_projector(observation['observation.state'].to(dtype=projector_dtype))
|
||||||
|
cond_state_emb = cond_state_emb + model.agent_state_pos_emb
|
||||||
|
|
||||||
|
cond_action_emb = model.action_projector(observation['action'].to(dtype=projector_dtype))
|
||||||
|
cond_action_emb = cond_action_emb + model.agent_action_pos_emb
|
||||||
|
|
||||||
if not sim_mode:
|
if not sim_mode:
|
||||||
cond_action_emb = torch.zeros_like(cond_action_emb)
|
cond_action_emb = torch.zeros_like(cond_action_emb)
|
||||||
@@ -491,9 +550,18 @@ def image_guided_synthesis_sim_mode(
|
|||||||
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
kwargs.update({"unconditional_conditioning_img_nonetext": None})
|
||||||
cond_mask = None
|
cond_mask = None
|
||||||
cond_z0 = None
|
cond_z0 = None
|
||||||
|
|
||||||
|
# Setup autocast context for diffusion sampling
|
||||||
|
autocast_dtype = getattr(model, 'diffusion_autocast_dtype', None)
|
||||||
|
if autocast_dtype is not None and model.device.type == 'cuda':
|
||||||
|
autocast_ctx = torch.autocast('cuda', dtype=autocast_dtype)
|
||||||
|
else:
|
||||||
|
autocast_ctx = nullcontext()
|
||||||
|
|
||||||
batch_variants = None
|
batch_variants = None
|
||||||
if ddim_sampler is not None:
|
if ddim_sampler is not None:
|
||||||
samples, actions, states, intermedia = ddim_sampler.sample(
|
with autocast_ctx:
|
||||||
|
samples, actions, states, intermedia = ddim_sampler.sample(
|
||||||
S=ddim_steps,
|
S=ddim_steps,
|
||||||
conditioning=cond,
|
conditioning=cond,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
@@ -540,44 +608,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
|
||||||
df = pd.read_csv(csv_path)
|
df = pd.read_csv(csv_path)
|
||||||
|
|
||||||
# Load config (always needed for data setup)
|
# Load config
|
||||||
config = OmegaConf.load(args.config)
|
config = OmegaConf.load(args.config)
|
||||||
|
config['model']['params']['wma_config']['params'][
|
||||||
|
'use_checkpoint'] = False
|
||||||
|
model = instantiate_from_config(config.model)
|
||||||
|
model.perframe_ae = args.perframe_ae
|
||||||
|
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
||||||
|
model = load_model_checkpoint(model, args.ckpt_path)
|
||||||
|
model.eval()
|
||||||
|
print(f'>>> Load pre-trained model ...')
|
||||||
|
|
||||||
prepared_path = args.ckpt_path + ".prepared.pt"
|
# Apply precision settings before moving to GPU
|
||||||
if os.path.exists(prepared_path):
|
model = apply_precision_settings(model, args)
|
||||||
# ---- Fast path: load the fully-prepared model ----
|
|
||||||
print(f">>> Loading prepared model from {prepared_path} ...")
|
|
||||||
model = torch.load(prepared_path,
|
|
||||||
map_location=f"cuda:{gpu_no}",
|
|
||||||
weights_only=False,
|
|
||||||
mmap=True)
|
|
||||||
model.eval()
|
|
||||||
print(f">>> Prepared model loaded.")
|
|
||||||
else:
|
|
||||||
# ---- Normal path: construct + load checkpoint ----
|
|
||||||
config['model']['params']['wma_config']['params'][
|
|
||||||
'use_checkpoint'] = False
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
model.perframe_ae = args.perframe_ae
|
|
||||||
|
|
||||||
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
|
# Compile hot ResBlocks for operator fusion
|
||||||
model = load_model_checkpoint(model, args.ckpt_path)
|
apply_torch_compile(model)
|
||||||
model.eval()
|
|
||||||
model = model.cuda(gpu_no)
|
|
||||||
print(f'>>> Load pre-trained model ...')
|
|
||||||
|
|
||||||
# Save prepared model for fast loading next time
|
|
||||||
print(f">>> Saving prepared model to {prepared_path} ...")
|
|
||||||
torch.save(model, prepared_path)
|
|
||||||
print(f">>> Prepared model saved ({os.path.getsize(prepared_path) / 1024**3:.1f} GB).")
|
|
||||||
|
|
||||||
# Build normalizer (always needed, independent of model loading path)
|
|
||||||
logging.info("***** Configing Data *****")
|
|
||||||
data = instantiate_from_config(config.data)
|
|
||||||
data.setup()
|
|
||||||
print(">>> Dataset is successfully loaded ...")
|
|
||||||
|
|
||||||
device = get_device_from_parameters(model)
|
|
||||||
|
|
||||||
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
|
# Fuse KV projections in attention layers (to_k + to_v → to_kv)
|
||||||
from unifolm_wma.modules.attention import CrossAttention
|
from unifolm_wma.modules.attention import CrossAttention
|
||||||
@@ -585,10 +631,22 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
if isinstance(m, CrossAttention) and m.fuse_kv())
|
if isinstance(m, CrossAttention) and m.fuse_kv())
|
||||||
print(f" ✓ KV fused: {kv_count} attention layers")
|
print(f" ✓ KV fused: {kv_count} attention layers")
|
||||||
|
|
||||||
# Load TRT backbone if engine exists
|
# Export precision-converted checkpoint if requested
|
||||||
trt_engine_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'trt_engines', 'video_backbone.engine')
|
if args.export_precision_ckpt:
|
||||||
if os.path.exists(trt_engine_path):
|
export_path = args.export_precision_ckpt
|
||||||
model.model.diffusion_model.load_trt_backbone(trt_engine_path)
|
os.makedirs(os.path.dirname(export_path) or '.', exist_ok=True)
|
||||||
|
torch.save({"state_dict": model.state_dict()}, export_path)
|
||||||
|
print(f">>> Precision-converted checkpoint saved to: {export_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build unnomalizer
|
||||||
|
logging.info("***** Configing Data *****")
|
||||||
|
data = instantiate_from_config(config.data)
|
||||||
|
data.setup()
|
||||||
|
print(">>> Dataset is successfully loaded ...")
|
||||||
|
|
||||||
|
model = model.cuda(gpu_no)
|
||||||
|
device = get_device_from_parameters(model)
|
||||||
|
|
||||||
# Run over data
|
# Run over data
|
||||||
assert (args.height % 16 == 0) and (
|
assert (args.height % 16 == 0) and (
|
||||||
@@ -762,31 +820,31 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
cond_obs_queues = populate_queues(cond_obs_queues,
|
cond_obs_queues = populate_queues(cond_obs_queues,
|
||||||
observation)
|
observation)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making (async)
|
# Save the imagen videos for decision-making
|
||||||
if pred_videos_0 is not None:
|
if pred_videos_0 is not None:
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard_async(writer,
|
log_to_tensorboard(writer,
|
||||||
pred_videos_0,
|
pred_videos_0,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
# Save videos environment changes via world-model interaction
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
|
||||||
log_to_tensorboard_async(writer,
|
log_to_tensorboard(writer,
|
||||||
pred_videos_1,
|
pred_videos_1,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
|
|
||||||
# Save the imagen videos for decision-making
|
# Save the imagen videos for decision-making
|
||||||
if pred_videos_0 is not None:
|
if pred_videos_0 is not None:
|
||||||
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
|
||||||
save_results_async(pred_videos_0,
|
save_results(pred_videos_0.cpu(),
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
# Save videos environment changes via world-model interaction
|
# Save videos environment changes via world-model interaction
|
||||||
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
|
||||||
save_results_async(pred_videos_1,
|
save_results(pred_videos_1.cpu(),
|
||||||
sample_video_file,
|
sample_video_file,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
|
|
||||||
print('>' * 24)
|
print('>' * 24)
|
||||||
# Collect the result of world-model interactions
|
# Collect the result of world-model interactions
|
||||||
@@ -794,15 +852,12 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
|
|
||||||
full_video = torch.cat(wm_video, dim=2)
|
full_video = torch.cat(wm_video, dim=2)
|
||||||
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
|
||||||
log_to_tensorboard_async(writer,
|
log_to_tensorboard(writer,
|
||||||
full_video,
|
full_video,
|
||||||
sample_tag,
|
sample_tag,
|
||||||
fps=args.save_fps)
|
fps=args.save_fps)
|
||||||
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
|
||||||
save_results_async(full_video, sample_full_video_file, fps=args.save_fps)
|
save_results(full_video, sample_full_video_file, fps=args.save_fps)
|
||||||
|
|
||||||
# Wait for all async I/O to complete
|
|
||||||
_flush_io()
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@@ -926,10 +981,40 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=8,
|
default=8,
|
||||||
help="fps for the saving video")
|
help="fps for the saving video")
|
||||||
|
parser.add_argument(
|
||||||
|
"--diffusion_dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["fp32", "bf16"],
|
||||||
|
default="bf16",
|
||||||
|
help="Diffusion backbone precision (fp32/bf16)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--projector_mode",
|
||||||
|
type=str,
|
||||||
|
choices=["fp32", "autocast", "bf16_full"],
|
||||||
|
default="bf16_full",
|
||||||
|
help="Projector precision mode (fp32/autocast/bf16_full)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder_mode",
|
||||||
|
type=str,
|
||||||
|
choices=["fp32", "autocast", "bf16_full"],
|
||||||
|
default="bf16_full",
|
||||||
|
help="Encoder precision mode (fp32/autocast/bf16_full)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--vae_dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["fp32", "bf16"],
|
||||||
|
default="fp32",
|
||||||
|
help="VAE precision (fp32/bf16, most affects image quality)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--export_precision_ckpt",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Export precision-converted checkpoint to this path, then exit.")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
patch_norm_bypass_autocast()
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
seed = args.seed
|
seed = args.seed
|
||||||
|
|||||||
@@ -1,87 +0,0 @@
|
|||||||
"""Export video UNet backbone to ONNX, then convert to TensorRT engine.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python scripts/export_trt.py \
|
|
||||||
--ckpt ckpts/unifolm_wma_dual.ckpt.prepared.pt \
|
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
|
||||||
--out_dir trt_engines
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import tensorrt as trt
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
|
||||||
from unifolm_wma.trt_utils import export_backbone_onnx
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(config_path, ckpt_path):
|
|
||||||
if ckpt_path.endswith('.prepared.pt'):
|
|
||||||
model = torch.load(ckpt_path, map_location='cpu')
|
|
||||||
else:
|
|
||||||
config = OmegaConf.load(config_path)
|
|
||||||
model = instantiate_from_config(config.model)
|
|
||||||
state_dict = torch.load(ckpt_path, map_location='cpu')
|
|
||||||
if 'state_dict' in state_dict:
|
|
||||||
state_dict = state_dict['state_dict']
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
model.eval().cuda()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--ckpt', required=True)
|
|
||||||
parser.add_argument('--config', default='configs/inference/world_model_interaction.yaml')
|
|
||||||
parser.add_argument('--out_dir', default='trt_engines')
|
|
||||||
parser.add_argument('--context_len', type=int, default=95)
|
|
||||||
parser.add_argument('--fp16', action='store_true', default=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
os.makedirs(args.out_dir, exist_ok=True)
|
|
||||||
onnx_path = os.path.join(args.out_dir, 'video_backbone.onnx')
|
|
||||||
engine_path = os.path.join(args.out_dir, 'video_backbone.engine')
|
|
||||||
|
|
||||||
if os.path.exists(onnx_path):
|
|
||||||
print(f">>> ONNX already exists at {onnx_path}, skipping export.")
|
|
||||||
n_outputs = 10
|
|
||||||
else:
|
|
||||||
print(">>> Loading model ...")
|
|
||||||
model = load_model(args.config, args.ckpt)
|
|
||||||
print(">>> Exporting ONNX ...")
|
|
||||||
with torch.no_grad():
|
|
||||||
n_outputs = export_backbone_onnx(model, onnx_path, context_len=args.context_len)
|
|
||||||
del model
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
print(">>> Converting ONNX -> TensorRT engine ...")
|
|
||||||
logger = trt.Logger(trt.Logger.WARNING)
|
|
||||||
builder = trt.Builder(logger)
|
|
||||||
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
|
||||||
parser = trt.OnnxParser(network, logger)
|
|
||||||
|
|
||||||
if not parser.parse_from_file(os.path.abspath(onnx_path)):
|
|
||||||
for i in range(parser.num_errors):
|
|
||||||
print(f" ONNX parse error: {parser.get_error(i)}")
|
|
||||||
raise RuntimeError("ONNX parsing failed")
|
|
||||||
|
|
||||||
config = builder.create_builder_config()
|
|
||||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 16 << 30)
|
|
||||||
if args.fp16:
|
|
||||||
config.set_flag(trt.BuilderFlag.FP16)
|
|
||||||
|
|
||||||
engine_bytes = builder.build_serialized_network(network, config)
|
|
||||||
with open(engine_path, 'wb') as f:
|
|
||||||
f.write(engine_bytes)
|
|
||||||
|
|
||||||
print(f"\n>>> Done! Engine saved to {engine_path}")
|
|
||||||
print(f" Outputs: 1 y + {n_outputs - 1} hs_a tensors")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
@@ -11,9 +11,6 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
|||||||
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
||||||
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
|
||||||
torch.backends.cudnn.allow_tf32 = True
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser(**parser_kwargs):
|
def get_parser(**parser_kwargs):
|
||||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||||
|
|||||||
@@ -1105,6 +1105,10 @@ class LatentDiffusion(DDPM):
|
|||||||
else:
|
else:
|
||||||
reshape_back = False
|
reshape_back = False
|
||||||
|
|
||||||
|
# Align input dtype with VAE weights (e.g. fp32 samples → bf16 VAE)
|
||||||
|
vae_dtype = next(self.first_stage_model.parameters()).dtype
|
||||||
|
z = z.to(dtype=vae_dtype)
|
||||||
|
|
||||||
if not self.perframe_ae:
|
if not self.perframe_ae:
|
||||||
z = 1. / self.scale_factor * z
|
z = 1. / self.scale_factor * z
|
||||||
results = self.first_stage_model.decode(z, **kwargs)
|
results = self.first_stage_model.decode(z, **kwargs)
|
||||||
@@ -1799,7 +1803,9 @@ class LatentDiffusion(DDPM):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if ddim:
|
if ddim:
|
||||||
ddim_sampler = DDIMSampler(self)
|
if not hasattr(self, '_ddim_sampler') or self._ddim_sampler is None:
|
||||||
|
self._ddim_sampler = DDIMSampler(self)
|
||||||
|
ddim_sampler = self._ddim_sampler
|
||||||
shape = (self.channels, self.temporal_length, *self.image_size)
|
shape = (self.channels, self.temporal_length, *self.image_size)
|
||||||
samples, actions, states, intermediates = ddim_sampler.sample(
|
samples, actions, states, intermediates = ddim_sampler.sample(
|
||||||
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
||||||
@@ -2457,7 +2463,6 @@ class DiffusionWrapper(pl.LightningModule):
|
|||||||
Returns:
|
Returns:
|
||||||
Output from the inner diffusion model (tensor or tuple, depending on the model).
|
Output from the inner diffusion model (tensor or tuple, depending on the model).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.conditioning_key is None:
|
if self.conditioning_key is None:
|
||||||
out = self.diffusion_model(x, t)
|
out = self.diffusion_model(x, t)
|
||||||
elif self.conditioning_key == 'concat':
|
elif self.conditioning_key == 'concat':
|
||||||
|
|||||||
@@ -567,6 +567,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
timesteps = timesteps.expand(sample.shape[0])
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
global_feature = self.diffusion_step_encoder(timesteps)
|
global_feature = self.diffusion_step_encoder(timesteps)
|
||||||
|
# Pre-expand global_feature once (reused in every down/mid/up block)
|
||||||
|
if self.use_linear_act_proj:
|
||||||
|
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, T, -1)
|
||||||
|
else:
|
||||||
|
global_feature_expanded = global_feature.unsqueeze(1).expand(-1, 2, -1)
|
||||||
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
|
||||||
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
|
||||||
|
|
||||||
@@ -603,15 +608,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=T, dim=1)
|
|
||||||
else:
|
else:
|
||||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=2, dim=1)
|
|
||||||
cur_global_feature = torch.cat(
|
cur_global_feature = torch.cat(
|
||||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||||
x = resnet(x, cur_global_feature)
|
x = resnet(x, cur_global_feature)
|
||||||
x = resnet2(x, cur_global_feature)
|
x = resnet2(x, cur_global_feature)
|
||||||
h.append(x)
|
h.append(x)
|
||||||
@@ -638,15 +639,11 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
|
||||||
repeats=T, dim=1)
|
|
||||||
else:
|
else:
|
||||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
|
|
||||||
repeats=2, dim=1)
|
|
||||||
cur_global_feature = torch.cat(
|
cur_global_feature = torch.cat(
|
||||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||||
x = resnet(x, cur_global_feature)
|
x = resnet(x, cur_global_feature)
|
||||||
x = resnet2(x, cur_global_feature)
|
x = resnet2(x, cur_global_feature)
|
||||||
|
|
||||||
@@ -683,16 +680,12 @@ class ConditionalUnet1D(nn.Module):
|
|||||||
|
|
||||||
if self.use_linear_act_proj:
|
if self.use_linear_act_proj:
|
||||||
imagen_cond = imagen_cond.reshape(B, T, -1)
|
imagen_cond = imagen_cond.reshape(B, T, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=T, dim=1)
|
|
||||||
else:
|
else:
|
||||||
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
|
||||||
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
imagen_cond = imagen_cond.reshape(B, 2, -1)
|
||||||
cur_global_feature = global_feature.unsqueeze(
|
|
||||||
1).repeat_interleave(repeats=2, dim=1)
|
|
||||||
|
|
||||||
cur_global_feature = torch.cat(
|
cur_global_feature = torch.cat(
|
||||||
[cur_global_feature, global_cond, imagen_cond], axis=-1)
|
[global_feature_expanded, global_cond, imagen_cond], axis=-1)
|
||||||
|
|
||||||
x = torch.cat((x, h.pop()), dim=1)
|
x = torch.cat((x, h.pop()), dim=1)
|
||||||
x = resnet(x, cur_global_feature)
|
x = resnet(x, cur_global_feature)
|
||||||
|
|||||||
@@ -8,12 +8,14 @@ class SinusoidalPosEmb(nn.Module):
|
|||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
|
# Dummy buffer so .to(dtype) propagates to this module
|
||||||
|
self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
device = x.device
|
device = x.device
|
||||||
half_dim = self.dim // 2
|
half_dim = self.dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||||
emb = x[:, None] * emb[None, :]
|
emb = x.float()[:, None] * emb[None, :]
|
||||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
return emb
|
return emb.to(self._dtype_buf.dtype)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class DDIMSampler(object):
|
|||||||
self.ddpm_num_timesteps = model.num_timesteps
|
self.ddpm_num_timesteps = model.num_timesteps
|
||||||
self.schedule = schedule
|
self.schedule = schedule
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
|
self._schedule_key = None # (ddim_num_steps, ddim_discretize, ddim_eta)
|
||||||
|
|
||||||
def register_buffer(self, name, attr):
|
def register_buffer(self, name, attr):
|
||||||
if type(attr) == torch.Tensor:
|
if type(attr) == torch.Tensor:
|
||||||
@@ -30,6 +31,11 @@ class DDIMSampler(object):
|
|||||||
ddim_discretize="uniform",
|
ddim_discretize="uniform",
|
||||||
ddim_eta=0.,
|
ddim_eta=0.,
|
||||||
verbose=True):
|
verbose=True):
|
||||||
|
key = (ddim_num_steps, ddim_discretize, ddim_eta)
|
||||||
|
if self._schedule_key == key:
|
||||||
|
return
|
||||||
|
self._schedule_key = key
|
||||||
|
|
||||||
self.ddim_timesteps = make_ddim_timesteps(
|
self.ddim_timesteps = make_ddim_timesteps(
|
||||||
ddim_discr_method=ddim_discretize,
|
ddim_discr_method=ddim_discretize,
|
||||||
num_ddim_timesteps=ddim_num_steps,
|
num_ddim_timesteps=ddim_num_steps,
|
||||||
@@ -38,7 +44,7 @@ class DDIMSampler(object):
|
|||||||
alphas_cumprod = self.model.alphas_cumprod
|
alphas_cumprod = self.model.alphas_cumprod
|
||||||
assert alphas_cumprod.shape[
|
assert alphas_cumprod.shape[
|
||||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
|
to_torch = lambda x: x.clone().detach().to(torch.float64).to(self.model
|
||||||
.device)
|
.device)
|
||||||
|
|
||||||
if self.model.use_dynamic_rescale:
|
if self.model.use_dynamic_rescale:
|
||||||
@@ -211,9 +217,9 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
if precision is not None:
|
if precision is not None:
|
||||||
if precision == 16:
|
if precision == 16:
|
||||||
img = img.to(dtype=torch.float16)
|
img = img.to(dtype=torch.bfloat16)
|
||||||
action = action.to(dtype=torch.float16)
|
action = action.to(dtype=torch.bfloat16)
|
||||||
state = state.to(dtype=torch.float16)
|
state = state.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
if timesteps is None:
|
if timesteps is None:
|
||||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||||
@@ -245,6 +251,13 @@ class DDIMSampler(object):
|
|||||||
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
|
||||||
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
|
||||||
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
ts = torch.empty((b, ), device=device, dtype=torch.long)
|
||||||
|
noise_buf = torch.empty_like(img)
|
||||||
|
# Pre-convert schedule arrays to inference dtype (avoid per-step .to())
|
||||||
|
_dtype = img.dtype
|
||||||
|
_alphas = (self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas).to(_dtype)
|
||||||
|
_alphas_prev = (self.model.alphas_cumprod_prev if ddim_use_original_steps else self.ddim_alphas_prev).to(_dtype)
|
||||||
|
_sqrt_one_minus = (self.model.sqrt_one_minus_alphas_cumprod if ddim_use_original_steps else self.ddim_sqrt_one_minus_alphas).to(_dtype)
|
||||||
|
_sigmas = (self.ddim_sigmas_for_original_num_steps if ddim_use_original_steps else self.ddim_sigmas).to(_dtype)
|
||||||
enable_cross_attn_kv_cache(self.model)
|
enable_cross_attn_kv_cache(self.model)
|
||||||
enable_ctx_cache(self.model)
|
enable_ctx_cache(self.model)
|
||||||
try:
|
try:
|
||||||
@@ -280,6 +293,8 @@ class DDIMSampler(object):
|
|||||||
x0=x0,
|
x0=x0,
|
||||||
fs=fs,
|
fs=fs,
|
||||||
guidance_rescale=guidance_rescale,
|
guidance_rescale=guidance_rescale,
|
||||||
|
noise_buf=noise_buf,
|
||||||
|
schedule_arrays=(_alphas, _alphas_prev, _sqrt_one_minus, _sigmas),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
img, pred_x0, model_output_action, model_output_state = outs
|
img, pred_x0, model_output_action, model_output_state = outs
|
||||||
@@ -333,6 +348,8 @@ class DDIMSampler(object):
|
|||||||
mask=None,
|
mask=None,
|
||||||
x0=None,
|
x0=None,
|
||||||
guidance_rescale=0.0,
|
guidance_rescale=0.0,
|
||||||
|
noise_buf=None,
|
||||||
|
schedule_arrays=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
|
|
||||||
@@ -378,12 +395,14 @@ class DDIMSampler(object):
|
|||||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||||
**corrector_kwargs)
|
**corrector_kwargs)
|
||||||
|
|
||||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
if schedule_arrays is not None:
|
||||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
alphas, alphas_prev, sqrt_one_minus_alphas, sigmas = schedule_arrays
|
||||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
else:
|
||||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
alphas = (self.model.alphas_cumprod if use_original_steps else self.ddim_alphas).to(x.dtype)
|
||||||
|
alphas_prev = (self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev).to(x.dtype)
|
||||||
|
sqrt_one_minus_alphas = (self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas).to(x.dtype)
|
||||||
|
sigmas = (self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas).to(x.dtype)
|
||||||
|
|
||||||
# Use 0-d tensors directly (already on device); broadcasting handles shape
|
|
||||||
a_t = alphas[index]
|
a_t = alphas[index]
|
||||||
a_prev = alphas_prev[index]
|
a_prev = alphas_prev[index]
|
||||||
sigma_t = sigmas[index]
|
sigma_t = sigmas[index]
|
||||||
@@ -405,8 +424,12 @@ class DDIMSampler(object):
|
|||||||
|
|
||||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||||
|
|
||||||
noise = sigma_t * noise_like(x.shape, device,
|
if noise_buf is not None:
|
||||||
repeat_noise) * temperature
|
noise_buf.normal_()
|
||||||
|
noise = sigma_t * noise_buf * temperature
|
||||||
|
else:
|
||||||
|
noise = sigma_t * noise_like(x.shape, device,
|
||||||
|
repeat_noise) * temperature
|
||||||
if noise_dropout > 0.:
|
if noise_dropout > 0.:
|
||||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||||
|
|
||||||
|
|||||||
@@ -86,9 +86,8 @@ class CrossAttention(nn.Module):
|
|||||||
self.relative_position_v = RelativePosition(
|
self.relative_position_v = RelativePosition(
|
||||||
num_units=dim_head, max_relative_position=temporal_length)
|
num_units=dim_head, max_relative_position=temporal_length)
|
||||||
else:
|
else:
|
||||||
## only used for spatial attention, while NOT for temporal attention
|
## bmm fused-scale attention for all non-relative-position cases
|
||||||
if XFORMERS_IS_AVAILBLE and temporal_length is None:
|
self.forward = self.bmm_forward
|
||||||
self.forward = self.efficient_forward
|
|
||||||
|
|
||||||
self.video_length = video_length
|
self.video_length = video_length
|
||||||
self.image_cross_attention = image_cross_attention
|
self.image_cross_attention = image_cross_attention
|
||||||
@@ -150,7 +149,141 @@ class CrossAttention(nn.Module):
|
|||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
if self.image_cross_attention and not spatial_self_attn:
|
if self.image_cross_attention and not spatial_self_attn:
|
||||||
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
|
# assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
|
||||||
|
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||||
|
context_agent_action = context[:,
|
||||||
|
self.agent_state_context_len:self.
|
||||||
|
agent_state_context_len +
|
||||||
|
self.agent_action_context_len, :]
|
||||||
|
context_ins = context[:, self.agent_state_context_len +
|
||||||
|
self.agent_action_context_len:self.
|
||||||
|
agent_state_context_len +
|
||||||
|
self.agent_action_context_len +
|
||||||
|
self.text_context_len, :]
|
||||||
|
context_image = context[:, self.agent_state_context_len +
|
||||||
|
self.agent_action_context_len +
|
||||||
|
self.text_context_len:, :]
|
||||||
|
|
||||||
|
k = self.to_k(context_ins)
|
||||||
|
v = self.to_v(context_ins)
|
||||||
|
k_ip = self.to_k_ip(context_image)
|
||||||
|
v_ip = self.to_v_ip(context_image)
|
||||||
|
k_as = self.to_k_as(context_agent_state)
|
||||||
|
v_as = self.to_v_as(context_agent_state)
|
||||||
|
k_aa = self.to_k_aa(context_agent_action)
|
||||||
|
v_aa = self.to_v_aa(context_agent_action)
|
||||||
|
else:
|
||||||
|
if not spatial_self_attn:
|
||||||
|
context = context[:, :self.text_context_len, :]
|
||||||
|
k = self.to_k(context)
|
||||||
|
v = self.to_v(context)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
|
(q, k, v))
|
||||||
|
|
||||||
|
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||||
|
if self.relative_position:
|
||||||
|
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
|
||||||
|
k2 = self.relative_position_k(len_q, len_k)
|
||||||
|
sim2 = einsum('b t d, t s d -> b t s', q,
|
||||||
|
k2) * self.scale # TODO check
|
||||||
|
sim += sim2
|
||||||
|
del k
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
## feasible for causal attention mask only
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
|
||||||
|
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||||
|
|
||||||
|
# attention, what we cannot get enough of
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
||||||
|
if self.relative_position:
|
||||||
|
v2 = self.relative_position_v(len_q, len_v)
|
||||||
|
out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
|
||||||
|
out += out2
|
||||||
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
|
if k_ip is not None and k_as is not None and k_aa is not None:
|
||||||
|
## for image cross-attention
|
||||||
|
k_ip, v_ip = map(
|
||||||
|
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
|
(k_ip, v_ip))
|
||||||
|
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
||||||
|
k_ip) * self.scale
|
||||||
|
del k_ip
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
sim_ip = sim_ip.softmax(dim=-1)
|
||||||
|
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
||||||
|
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
|
## for agent state cross-attention
|
||||||
|
k_as, v_as = map(
|
||||||
|
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
|
(k_as, v_as))
|
||||||
|
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
||||||
|
k_as) * self.scale
|
||||||
|
del k_as
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
sim_as = sim_as.softmax(dim=-1)
|
||||||
|
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
||||||
|
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
|
## for agent action cross-attention
|
||||||
|
k_aa, v_aa = map(
|
||||||
|
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
|
(k_aa, v_aa))
|
||||||
|
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
||||||
|
k_aa) * self.scale
|
||||||
|
del k_aa
|
||||||
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
sim_aa = sim_aa.softmax(dim=-1)
|
||||||
|
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
||||||
|
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
|
if out_ip is not None and out_as is not None and out_aa is not None:
|
||||||
|
if self.cross_attention_scale_learnable:
|
||||||
|
out = out + \
|
||||||
|
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
|
||||||
|
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
|
||||||
|
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
|
||||||
|
else:
|
||||||
|
out = out + \
|
||||||
|
self.image_cross_attention_scale * out_ip + \
|
||||||
|
self.agent_state_cross_attention_scale * out_as + \
|
||||||
|
self.agent_action_cross_attention_scale * out_aa
|
||||||
|
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
def bmm_forward(self, x, context=None, mask=None):
|
||||||
|
spatial_self_attn = (context is None)
|
||||||
|
k_ip, v_ip, out_ip = None, None, None
|
||||||
|
k_as, v_as, out_as = None, None, None
|
||||||
|
k_aa, v_aa, out_aa = None, None, None
|
||||||
|
|
||||||
|
h = self.heads
|
||||||
|
q = self.to_q(x)
|
||||||
|
context = default(context, x)
|
||||||
|
|
||||||
|
use_cache = self._kv_cache_enabled and not spatial_self_attn
|
||||||
|
cache_hit = use_cache and len(self._kv_cache) > 0
|
||||||
|
|
||||||
|
if cache_hit:
|
||||||
|
# Reuse cached K/V (already in (b*h, n, d) shape)
|
||||||
|
k = self._kv_cache['k']
|
||||||
|
v = self._kv_cache['v']
|
||||||
|
if 'k_ip' in self._kv_cache:
|
||||||
|
k_ip = self._kv_cache['k_ip']
|
||||||
|
v_ip = self._kv_cache['v_ip']
|
||||||
|
k_as = self._kv_cache['k_as']
|
||||||
|
v_as = self._kv_cache['v_as']
|
||||||
|
k_aa = self._kv_cache['k_aa']
|
||||||
|
v_aa = self._kv_cache['v_aa']
|
||||||
|
q = rearrange(q, 'b n (h d) -> (b h) n d', h=h)
|
||||||
|
elif self.image_cross_attention and not spatial_self_attn:
|
||||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||||
context_agent_action = context[:,
|
context_agent_action = context[:,
|
||||||
self.agent_state_context_len:self.
|
self.agent_state_context_len:self.
|
||||||
@@ -179,6 +312,23 @@ class CrossAttention(nn.Module):
|
|||||||
v_as = self.to_v_as(context_agent_state)
|
v_as = self.to_v_as(context_agent_state)
|
||||||
k_aa = self.to_k_aa(context_agent_action)
|
k_aa = self.to_k_aa(context_agent_action)
|
||||||
v_aa = self.to_v_aa(context_agent_action)
|
v_aa = self.to_v_aa(context_agent_action)
|
||||||
|
|
||||||
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
|
(q, k, v))
|
||||||
|
k_ip, v_ip = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
|
(k_ip, v_ip))
|
||||||
|
k_as, v_as = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
|
(k_as, v_as))
|
||||||
|
k_aa, v_aa = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
|
(k_aa, v_aa))
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
self._kv_cache = {
|
||||||
|
'k': k, 'v': v,
|
||||||
|
'k_ip': k_ip, 'v_ip': v_ip,
|
||||||
|
'k_as': k_as, 'v_as': v_as,
|
||||||
|
'k_aa': k_aa, 'v_aa': v_aa,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
if not spatial_self_attn:
|
if not spatial_self_attn:
|
||||||
context = context[:, :self.text_context_len, :]
|
context = context[:, :self.text_context_len, :]
|
||||||
@@ -188,66 +338,54 @@ class CrossAttention(nn.Module):
|
|||||||
k = self.to_k(context)
|
k = self.to_k(context)
|
||||||
v = self.to_v(context)
|
v = self.to_v(context)
|
||||||
|
|
||||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||||
(q, k, v))
|
(q, k, v))
|
||||||
|
|
||||||
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
if use_cache:
|
||||||
if self.relative_position:
|
self._kv_cache = {'k': k, 'v': v}
|
||||||
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
|
|
||||||
k2 = self.relative_position_k(len_q, len_k)
|
# baddbmm: fuse scale into GEMM → one kernel instead of matmul + mul
|
||||||
sim2 = einsum('b t d, t s d -> b t s', q,
|
sim = torch.baddbmm(
|
||||||
k2) * self.scale # TODO check
|
torch.empty(q.shape[0], q.shape[1], k.shape[1], dtype=q.dtype, device=q.device),
|
||||||
sim += sim2
|
q, k.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||||
del k
|
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
## feasible for causal attention mask only
|
|
||||||
max_neg_value = -torch.finfo(sim.dtype).max
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
|
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
|
||||||
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
sim.masked_fill_(~(mask > 0.5), max_neg_value)
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
sim = sim.softmax(dim=-1)
|
sim = sim.softmax(dim=-1)
|
||||||
|
|
||||||
out = torch.einsum('b i j, b j d -> b i d', sim, v)
|
out = torch.bmm(sim, v)
|
||||||
if self.relative_position:
|
|
||||||
v2 = self.relative_position_v(len_q, len_v)
|
|
||||||
out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
|
|
||||||
out += out2
|
|
||||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
if k_ip is not None and k_as is not None and k_aa is not None:
|
if k_ip is not None and k_as is not None and k_aa is not None:
|
||||||
## for image cross-attention
|
## image cross-attention (k_ip/v_ip already in (b*h, n, d) shape)
|
||||||
k_ip, v_ip = map(
|
sim_ip = torch.baddbmm(
|
||||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
torch.empty(q.shape[0], q.shape[1], k_ip.shape[1], dtype=q.dtype, device=q.device),
|
||||||
(k_ip, v_ip))
|
q, k_ip.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||||
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
k_ip) * self.scale
|
sim_ip = sim_ip.softmax(dim=-1)
|
||||||
del k_ip
|
out_ip = torch.bmm(sim_ip, v_ip)
|
||||||
sim_ip = sim_ip.softmax(dim=-1)
|
|
||||||
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
|
|
||||||
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
## for agent state cross-attention
|
## agent state cross-attention (k_as/v_as already in (b*h, n, d) shape)
|
||||||
k_as, v_as = map(
|
sim_as = torch.baddbmm(
|
||||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
torch.empty(q.shape[0], q.shape[1], k_as.shape[1], dtype=q.dtype, device=q.device),
|
||||||
(k_as, v_as))
|
q, k_as.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||||
sim_as = torch.einsum('b i d, b j d -> b i j', q,
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
k_as) * self.scale
|
sim_as = sim_as.softmax(dim=-1)
|
||||||
del k_as
|
out_as = torch.bmm(sim_as, v_as)
|
||||||
sim_as = sim_as.softmax(dim=-1)
|
|
||||||
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
|
|
||||||
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
## for agent action cross-attention
|
## agent action cross-attention (k_aa/v_aa already in (b*h, n, d) shape)
|
||||||
k_aa, v_aa = map(
|
sim_aa = torch.baddbmm(
|
||||||
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
torch.empty(q.shape[0], q.shape[1], k_aa.shape[1], dtype=q.dtype, device=q.device),
|
||||||
(k_aa, v_aa))
|
q, k_aa.transpose(-1, -2), beta=0, alpha=self.scale)
|
||||||
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
k_aa) * self.scale
|
sim_aa = sim_aa.softmax(dim=-1)
|
||||||
del k_aa
|
out_aa = torch.bmm(sim_aa, v_aa)
|
||||||
sim_aa = sim_aa.softmax(dim=-1)
|
|
||||||
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
|
|
||||||
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
|
||||||
|
|
||||||
if out_ip is not None and out_as is not None and out_aa is not None:
|
if out_ip is not None and out_as is not None and out_aa is not None:
|
||||||
@@ -270,162 +408,135 @@ class CrossAttention(nn.Module):
|
|||||||
k_ip, v_ip, out_ip = None, None, None
|
k_ip, v_ip, out_ip = None, None, None
|
||||||
k_as, v_as, out_as = None, None, None
|
k_as, v_as, out_as = None, None, None
|
||||||
k_aa, v_aa, out_aa = None, None, None
|
k_aa, v_aa, out_aa = None, None, None
|
||||||
attn_mask_aa = None
|
|
||||||
|
|
||||||
h = self.heads
|
|
||||||
q = self.to_q(x)
|
q = self.to_q(x)
|
||||||
context = default(context, x)
|
context = default(context, x)
|
||||||
|
|
||||||
b, _, _ = q.shape
|
if self.image_cross_attention and not spatial_self_attn:
|
||||||
q = q.unsqueeze(3).reshape(b, q.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, q.shape[1], self.dim_head).contiguous()
|
|
||||||
|
|
||||||
def _reshape_kv(t):
|
|
||||||
return t.unsqueeze(3).reshape(b, t.shape[1], h, self.dim_head).permute(0, 2, 1, 3).reshape(b * h, t.shape[1], self.dim_head).contiguous()
|
|
||||||
|
|
||||||
use_cache = self._kv_cache_enabled and not spatial_self_attn
|
|
||||||
cache_hit = use_cache and len(self._kv_cache) > 0
|
|
||||||
|
|
||||||
if cache_hit:
|
|
||||||
k = self._kv_cache['k']
|
|
||||||
v = self._kv_cache['v']
|
|
||||||
k_ip = self._kv_cache.get('k_ip')
|
|
||||||
v_ip = self._kv_cache.get('v_ip')
|
|
||||||
k_as = self._kv_cache.get('k_as')
|
|
||||||
v_as = self._kv_cache.get('v_as')
|
|
||||||
k_aa = self._kv_cache.get('k_aa')
|
|
||||||
v_aa = self._kv_cache.get('v_aa')
|
|
||||||
attn_mask_aa = self._kv_cache.get('attn_mask_aa')
|
|
||||||
elif self.image_cross_attention and not spatial_self_attn:
|
|
||||||
if context.shape[1] == self.text_context_len + self.video_length:
|
if context.shape[1] == self.text_context_len + self.video_length:
|
||||||
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
|
||||||
if self._kv_fused:
|
k = self.to_k(context)
|
||||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
v = self.to_v(context)
|
||||||
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
k_ip = self.to_k_ip(context_image)
|
||||||
else:
|
v_ip = self.to_v_ip(context_image)
|
||||||
k = self.to_k(context)
|
|
||||||
v = self.to_v(context)
|
|
||||||
k_ip = self.to_k_ip(context_image)
|
|
||||||
v_ip = self.to_v_ip(context_image)
|
|
||||||
k, v = map(_reshape_kv, (k, v))
|
|
||||||
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
|
||||||
if use_cache:
|
|
||||||
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip}
|
|
||||||
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
|
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
|
||||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||||
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
|
||||||
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
|
||||||
if self._kv_fused:
|
k = self.to_k(context_ins)
|
||||||
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
v = self.to_v(context_ins)
|
||||||
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
k_ip = self.to_k_ip(context_image)
|
||||||
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
v_ip = self.to_v_ip(context_image)
|
||||||
else:
|
k_as = self.to_k_as(context_agent_state)
|
||||||
k = self.to_k(context_ins)
|
v_as = self.to_v_as(context_agent_state)
|
||||||
v = self.to_v(context_ins)
|
|
||||||
k_ip = self.to_k_ip(context_image)
|
|
||||||
v_ip = self.to_v_ip(context_image)
|
|
||||||
k_as = self.to_k_as(context_agent_state)
|
|
||||||
v_as = self.to_v_as(context_agent_state)
|
|
||||||
k, v = map(_reshape_kv, (k, v))
|
|
||||||
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
|
||||||
k_as, v_as = map(_reshape_kv, (k_as, v_as))
|
|
||||||
if use_cache:
|
|
||||||
self._kv_cache = {'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip, 'k_as': k_as, 'v_as': v_as}
|
|
||||||
else:
|
else:
|
||||||
context_agent_state = context[:, :self.agent_state_context_len, :]
|
context_agent_state = context[:, :self.agent_state_context_len, :]
|
||||||
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
|
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
|
||||||
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
|
||||||
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
|
||||||
|
|
||||||
if self._kv_fused:
|
k = self.to_k(context_ins)
|
||||||
k, v = self.to_kv(context_ins).chunk(2, dim=-1)
|
v = self.to_v(context_ins)
|
||||||
k_ip, v_ip = self.to_kv_ip(context_image).chunk(2, dim=-1)
|
k_ip = self.to_k_ip(context_image)
|
||||||
k_as, v_as = self.to_kv_as(context_agent_state).chunk(2, dim=-1)
|
v_ip = self.to_v_ip(context_image)
|
||||||
k_aa, v_aa = self.to_kv_aa(context_agent_action).chunk(2, dim=-1)
|
k_as = self.to_k_as(context_agent_state)
|
||||||
else:
|
v_as = self.to_v_as(context_agent_state)
|
||||||
k = self.to_k(context_ins)
|
k_aa = self.to_k_aa(context_agent_action)
|
||||||
v = self.to_v(context_ins)
|
v_aa = self.to_v_aa(context_agent_action)
|
||||||
k_ip = self.to_k_ip(context_image)
|
|
||||||
v_ip = self.to_v_ip(context_image)
|
|
||||||
k_as = self.to_k_as(context_agent_state)
|
|
||||||
v_as = self.to_v_as(context_agent_state)
|
|
||||||
k_aa = self.to_k_aa(context_agent_action)
|
|
||||||
v_aa = self.to_v_aa(context_agent_action)
|
|
||||||
|
|
||||||
k, v = map(_reshape_kv, (k, v))
|
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
|
||||||
k_ip, v_ip = map(_reshape_kv, (k_ip, v_ip))
|
q.shape[1],
|
||||||
k_as, v_as = map(_reshape_kv, (k_as, v_as))
|
k_aa.shape[1],
|
||||||
k_aa, v_aa = map(_reshape_kv, (k_aa, v_aa))
|
block_size=16,
|
||||||
|
device=k_aa.device)
|
||||||
attn_mask_aa_raw = self._get_attn_mask_aa(x.shape[0],
|
|
||||||
q.shape[1],
|
|
||||||
k_aa.shape[1],
|
|
||||||
block_size=16,
|
|
||||||
device=k_aa.device)
|
|
||||||
attn_mask_aa = attn_mask_aa_raw.unsqueeze(1).repeat(1, h, 1, 1).reshape(
|
|
||||||
b * h, attn_mask_aa_raw.shape[1], attn_mask_aa_raw.shape[2]).to(q.dtype)
|
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
self._kv_cache = {
|
|
||||||
'k': k, 'v': v, 'k_ip': k_ip, 'v_ip': v_ip,
|
|
||||||
'k_as': k_as, 'v_as': v_as, 'k_aa': k_aa, 'v_aa': v_aa,
|
|
||||||
'attn_mask_aa': attn_mask_aa,
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
if not spatial_self_attn:
|
if not spatial_self_attn:
|
||||||
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
assert 1 > 2, ">>> ERROR: you should never go into here ..."
|
||||||
context = context[:, :self.text_context_len, :]
|
context = context[:, :self.text_context_len, :]
|
||||||
if self._kv_fused:
|
k = self.to_k(context)
|
||||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
v = self.to_v(context)
|
||||||
else:
|
|
||||||
k = self.to_k(context)
|
b, _, _ = q.shape
|
||||||
v = self.to_v(context)
|
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
|
||||||
k, v = map(_reshape_kv, (k, v))
|
|
||||||
if use_cache:
|
|
||||||
self._kv_cache = {'k': k, 'v': v}
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
|
k, v = map(
|
||||||
|
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||||
|
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||||
|
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||||
|
(k, v),
|
||||||
|
)
|
||||||
out = xformers.ops.memory_efficient_attention(q,
|
out = xformers.ops.memory_efficient_attention(q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
attn_bias=None,
|
attn_bias=None,
|
||||||
op=None)
|
op=None)
|
||||||
out = (out.unsqueeze(0).reshape(
|
out = (out.unsqueeze(0).reshape(
|
||||||
b, h, out.shape[1],
|
b, self.heads, out.shape[1],
|
||||||
self.dim_head).permute(0, 2, 1,
|
self.dim_head).permute(0, 2, 1,
|
||||||
3).reshape(b, out.shape[1],
|
3).reshape(b, out.shape[1],
|
||||||
h * self.dim_head))
|
self.heads * self.dim_head))
|
||||||
|
|
||||||
if k_ip is not None:
|
if k_ip is not None:
|
||||||
|
# For image cross-attention
|
||||||
|
k_ip, v_ip = map(
|
||||||
|
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||||
|
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||||
|
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||||
|
),
|
||||||
|
(k_ip, v_ip),
|
||||||
|
)
|
||||||
out_ip = xformers.ops.memory_efficient_attention(q,
|
out_ip = xformers.ops.memory_efficient_attention(q,
|
||||||
k_ip,
|
k_ip,
|
||||||
v_ip,
|
v_ip,
|
||||||
attn_bias=None,
|
attn_bias=None,
|
||||||
op=None)
|
op=None)
|
||||||
out_ip = (out_ip.unsqueeze(0).reshape(
|
out_ip = (out_ip.unsqueeze(0).reshape(
|
||||||
b, h, out_ip.shape[1],
|
b, self.heads, out_ip.shape[1],
|
||||||
self.dim_head).permute(0, 2, 1,
|
self.dim_head).permute(0, 2, 1,
|
||||||
3).reshape(b, out_ip.shape[1],
|
3).reshape(b, out_ip.shape[1],
|
||||||
h * self.dim_head))
|
self.heads * self.dim_head))
|
||||||
|
|
||||||
if k_as is not None:
|
if k_as is not None:
|
||||||
|
# For agent state cross-attention
|
||||||
|
k_as, v_as = map(
|
||||||
|
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||||
|
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||||
|
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||||
|
),
|
||||||
|
(k_as, v_as),
|
||||||
|
)
|
||||||
out_as = xformers.ops.memory_efficient_attention(q,
|
out_as = xformers.ops.memory_efficient_attention(q,
|
||||||
k_as,
|
k_as,
|
||||||
v_as,
|
v_as,
|
||||||
attn_bias=None,
|
attn_bias=None,
|
||||||
op=None)
|
op=None)
|
||||||
out_as = (out_as.unsqueeze(0).reshape(
|
out_as = (out_as.unsqueeze(0).reshape(
|
||||||
b, h, out_as.shape[1],
|
b, self.heads, out_as.shape[1],
|
||||||
self.dim_head).permute(0, 2, 1,
|
self.dim_head).permute(0, 2, 1,
|
||||||
3).reshape(b, out_as.shape[1],
|
3).reshape(b, out_as.shape[1],
|
||||||
h * self.dim_head))
|
self.heads * self.dim_head))
|
||||||
|
|
||||||
if k_aa is not None:
|
if k_aa is not None:
|
||||||
|
# For agent action cross-attention
|
||||||
|
k_aa, v_aa = map(
|
||||||
|
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||||
|
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||||
|
b * self.heads, t.shape[1], self.dim_head).contiguous(
|
||||||
|
),
|
||||||
|
(k_aa, v_aa),
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
|
||||||
|
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
|
||||||
|
attn_mask_aa = attn_mask_aa.to(q.dtype)
|
||||||
|
|
||||||
out_aa = xformers.ops.memory_efficient_attention(
|
out_aa = xformers.ops.memory_efficient_attention(
|
||||||
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
|
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
|
||||||
|
|
||||||
out_aa = (out_aa.unsqueeze(0).reshape(
|
out_aa = (out_aa.unsqueeze(0).reshape(
|
||||||
b, h, out_aa.shape[1],
|
b, self.heads, out_aa.shape[1],
|
||||||
self.dim_head).permute(0, 2, 1,
|
self.dim_head).permute(0, 2, 1,
|
||||||
3).reshape(b, out_aa.shape[1],
|
3).reshape(b, out_aa.shape[1],
|
||||||
h * self.dim_head))
|
self.heads * self.dim_head))
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@@ -463,7 +574,7 @@ class CrossAttention(nn.Module):
|
|||||||
col_indices = torch.arange(l2, device=target_device)
|
col_indices = torch.arange(l2, device=target_device)
|
||||||
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
|
||||||
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
|
||||||
attn_mask = torch.zeros(b, l1, l2, dtype=torch.float, device=target_device)
|
attn_mask = torch.zeros(b, l1, l2, dtype=torch.bfloat16, device=target_device)
|
||||||
attn_mask[mask] = float('-inf')
|
attn_mask[mask] = float('-inf')
|
||||||
|
|
||||||
self._attn_mask_aa_cache_key = cache_key
|
self._attn_mask_aa_cache_key = cache_key
|
||||||
|
|||||||
@@ -422,7 +422,7 @@ class WMAModel(nn.Module):
|
|||||||
self.temporal_attention = temporal_attention
|
self.temporal_attention = temporal_attention
|
||||||
time_embed_dim = model_channels * 4
|
time_embed_dim = model_channels * 4
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
self.dtype = torch.float16 if use_fp16 else torch.bfloat16
|
||||||
temporal_self_att_only = True
|
temporal_self_att_only = True
|
||||||
self.addition_attention = addition_attention
|
self.addition_attention = addition_attention
|
||||||
self.temporal_length = temporal_length
|
self.temporal_length = temporal_length
|
||||||
@@ -688,24 +688,10 @@ class WMAModel(nn.Module):
|
|||||||
# Context precomputation cache
|
# Context precomputation cache
|
||||||
self._ctx_cache_enabled = False
|
self._ctx_cache_enabled = False
|
||||||
self._ctx_cache = {}
|
self._ctx_cache = {}
|
||||||
self._trt_backbone = None # TRT engine for video UNet backbone
|
# fs_embed cache
|
||||||
# Reusable CUDA stream for parallel state_unet / action_unet
|
self._fs_embed_cache = None
|
||||||
self._state_stream = torch.cuda.Stream()
|
# Pre-created CUDA stream for parallel action/state UNet
|
||||||
|
self._side_stream = torch.cuda.Stream() if not self.base_model_gen_only else None
|
||||||
def __getstate__(self):
|
|
||||||
state = self.__dict__.copy()
|
|
||||||
state.pop('_state_stream', None)
|
|
||||||
return state
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
self.__dict__.update(state)
|
|
||||||
self._state_stream = torch.cuda.Stream()
|
|
||||||
|
|
||||||
def load_trt_backbone(self, engine_path, n_hs_a=9):
|
|
||||||
"""Load a TensorRT engine for the video UNet backbone."""
|
|
||||||
from unifolm_wma.trt_utils import TRTBackbone
|
|
||||||
self._trt_backbone = TRTBackbone(engine_path, n_hs_a=n_hs_a)
|
|
||||||
print(f">>> TRT backbone loaded from {engine_path}")
|
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
@@ -807,68 +793,66 @@ class WMAModel(nn.Module):
|
|||||||
|
|
||||||
# Combine emb
|
# Combine emb
|
||||||
if self.fs_condition:
|
if self.fs_condition:
|
||||||
if fs is None:
|
if self._ctx_cache_enabled and self._fs_embed_cache is not None:
|
||||||
fs = torch.tensor([self.default_fs] * b,
|
fs_embed = self._fs_embed_cache
|
||||||
dtype=torch.long,
|
else:
|
||||||
device=x.device)
|
if fs is None:
|
||||||
fs_emb = timestep_embedding(fs,
|
fs = torch.tensor([self.default_fs] * b,
|
||||||
self.model_channels,
|
dtype=torch.long,
|
||||||
repeat_only=False).type(x.dtype)
|
device=x.device)
|
||||||
|
fs_emb = timestep_embedding(fs,
|
||||||
fs_embed = self.fps_embedding(fs_emb)
|
self.model_channels,
|
||||||
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
repeat_only=False).type(x.dtype)
|
||||||
|
fs_embed = self.fps_embedding(fs_emb)
|
||||||
|
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
|
||||||
|
if self._ctx_cache_enabled:
|
||||||
|
self._fs_embed_cache = fs_embed
|
||||||
emb = emb + fs_embed
|
emb = emb + fs_embed
|
||||||
|
|
||||||
if self._trt_backbone is not None:
|
h = x.type(self.dtype)
|
||||||
# TRT path: run backbone via TensorRT engine
|
adapter_idx = 0
|
||||||
h_in = x.type(self.dtype).contiguous()
|
hs = []
|
||||||
y, hs_a = self._trt_backbone(h_in, emb.contiguous(), context.contiguous())
|
hs_a = []
|
||||||
else:
|
for id, module in enumerate(self.input_blocks):
|
||||||
# PyTorch path: original backbone
|
h = module(h, emb, context=context, batch_size=b)
|
||||||
h = x.type(self.dtype)
|
if id == 0 and self.addition_attention:
|
||||||
adapter_idx = 0
|
h = self.init_attn(h, emb, context=context, batch_size=b)
|
||||||
hs = []
|
# plug-in adapter features
|
||||||
hs_a = []
|
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
||||||
for id, module in enumerate(self.input_blocks):
|
h = h + features_adapter[adapter_idx]
|
||||||
h = module(h, emb, context=context, batch_size=b)
|
adapter_idx += 1
|
||||||
if id == 0 and self.addition_attention:
|
if id != 0:
|
||||||
h = self.init_attn(h, emb, context=context, batch_size=b)
|
if isinstance(module[0], Downsample):
|
||||||
# plug-in adapter features
|
|
||||||
if ((id + 1) % 3 == 0) and features_adapter is not None:
|
|
||||||
h = h + features_adapter[adapter_idx]
|
|
||||||
adapter_idx += 1
|
|
||||||
if id != 0:
|
|
||||||
if isinstance(module[0], Downsample):
|
|
||||||
hs_a.append(
|
|
||||||
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
|
||||||
hs.append(h)
|
|
||||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
|
||||||
|
|
||||||
if features_adapter is not None:
|
|
||||||
assert len(
|
|
||||||
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
|
||||||
h = self.middle_block(h, emb, context=context, batch_size=b)
|
|
||||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
|
||||||
|
|
||||||
hs_out = []
|
|
||||||
for module in self.output_blocks:
|
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
|
||||||
h = module(h, emb, context=context, batch_size=b)
|
|
||||||
if isinstance(module[-1], Upsample):
|
|
||||||
hs_a.append(
|
hs_a.append(
|
||||||
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
||||||
hs_out.append(h)
|
hs.append(h)
|
||||||
h = h.type(x.dtype)
|
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
|
||||||
|
|
||||||
y = self.out(h)
|
if features_adapter is not None:
|
||||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
assert len(
|
||||||
|
features_adapter) == adapter_idx, 'Wrong features_adapter'
|
||||||
|
h = self.middle_block(h, emb, context=context, batch_size=b)
|
||||||
|
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
||||||
|
|
||||||
|
hs_out = []
|
||||||
|
for module in self.output_blocks:
|
||||||
|
h = torch.cat([h, hs.pop()], dim=1)
|
||||||
|
h = module(h, emb, context=context, batch_size=b)
|
||||||
|
if isinstance(module[-1], Upsample):
|
||||||
|
hs_a.append(
|
||||||
|
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||||
|
hs_out.append(h)
|
||||||
|
h = h.type(x.dtype)
|
||||||
|
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
||||||
|
|
||||||
|
y = self.out(h)
|
||||||
|
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
||||||
|
|
||||||
if not self.base_model_gen_only:
|
if not self.base_model_gen_only:
|
||||||
ba, _, _ = x_action.shape
|
ba, _, _ = x_action.shape
|
||||||
ts_state = timesteps[:ba] if b > 1 else timesteps
|
ts_state = timesteps[:ba] if b > 1 else timesteps
|
||||||
# Run action_unet and state_unet in parallel via CUDA streams
|
# Run action_unet and state_unet in parallel via pre-created CUDA stream
|
||||||
s_stream = self._state_stream
|
s_stream = self._side_stream
|
||||||
s_stream.wait_stream(torch.cuda.current_stream())
|
s_stream.wait_stream(torch.cuda.current_stream())
|
||||||
with torch.cuda.stream(s_stream):
|
with torch.cuda.stream(s_stream):
|
||||||
s_y = self.state_unet(x_state, ts_state, hs_a,
|
s_y = self.state_unet(x_state, ts_state, hs_a,
|
||||||
@@ -889,6 +873,7 @@ def enable_ctx_cache(model):
|
|||||||
if isinstance(m, WMAModel):
|
if isinstance(m, WMAModel):
|
||||||
m._ctx_cache_enabled = True
|
m._ctx_cache_enabled = True
|
||||||
m._ctx_cache = {}
|
m._ctx_cache = {}
|
||||||
|
m._fs_embed_cache = None
|
||||||
# conditional_unet1d cache
|
# conditional_unet1d cache
|
||||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
@@ -903,6 +888,7 @@ def disable_ctx_cache(model):
|
|||||||
if isinstance(m, WMAModel):
|
if isinstance(m, WMAModel):
|
||||||
m._ctx_cache_enabled = False
|
m._ctx_cache_enabled = False
|
||||||
m._ctx_cache = {}
|
m._ctx_cache = {}
|
||||||
|
m._fs_embed_cache = None
|
||||||
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
from unifolm_wma.models.diffusion_head.conditional_unet1d import ConditionalUnet1D
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, ConditionalUnet1D):
|
if isinstance(m, ConditionalUnet1D):
|
||||||
|
|||||||
@@ -1,151 +0,0 @@
|
|||||||
"""TensorRT acceleration utilities for the video UNet backbone."""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from einops import rearrange
|
|
||||||
from unifolm_wma.modules.networks.wma_model import Downsample, Upsample
|
|
||||||
|
|
||||||
|
|
||||||
class VideoBackboneForExport(nn.Module):
|
|
||||||
"""Wrapper that isolates the video UNet backbone for ONNX export.
|
|
||||||
|
|
||||||
Takes already-preprocessed inputs (after context/time embedding prep)
|
|
||||||
and returns y + hs_a as a flat tuple.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, wma_model):
|
|
||||||
super().__init__()
|
|
||||||
self.input_blocks = wma_model.input_blocks
|
|
||||||
self.middle_block = wma_model.middle_block
|
|
||||||
self.output_blocks = wma_model.output_blocks
|
|
||||||
self.out = wma_model.out
|
|
||||||
self.addition_attention = wma_model.addition_attention
|
|
||||||
if self.addition_attention:
|
|
||||||
self.init_attn = wma_model.init_attn
|
|
||||||
self.dtype = wma_model.dtype
|
|
||||||
|
|
||||||
def forward(self, h, emb, context):
|
|
||||||
t = 16
|
|
||||||
b = 1
|
|
||||||
|
|
||||||
hs = []
|
|
||||||
hs_a = []
|
|
||||||
h = h.type(self.dtype)
|
|
||||||
for id, module in enumerate(self.input_blocks):
|
|
||||||
h = module(h, emb, context=context, batch_size=b)
|
|
||||||
if id == 0 and self.addition_attention:
|
|
||||||
h = self.init_attn(h, emb, context=context, batch_size=b)
|
|
||||||
if id != 0:
|
|
||||||
if isinstance(module[0], Downsample):
|
|
||||||
hs_a.append(rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
|
|
||||||
hs.append(h)
|
|
||||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
|
||||||
|
|
||||||
h = self.middle_block(h, emb, context=context, batch_size=b)
|
|
||||||
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
|
|
||||||
|
|
||||||
hs_out = []
|
|
||||||
for module in self.output_blocks:
|
|
||||||
h = torch.cat([h, hs.pop()], dim=1)
|
|
||||||
h = module(h, emb, context=context, batch_size=b)
|
|
||||||
if isinstance(module[-1], Upsample):
|
|
||||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
|
||||||
hs_out.append(h)
|
|
||||||
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
|
|
||||||
|
|
||||||
y = self.out(h.type(h.dtype))
|
|
||||||
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
|
|
||||||
return (y, *hs_a)
|
|
||||||
|
|
||||||
|
|
||||||
def export_backbone_onnx(model, save_path, context_len=95):
|
|
||||||
wma = model.model.diffusion_model
|
|
||||||
wrapper = VideoBackboneForExport(wma)
|
|
||||||
wrapper.eval().cuda()
|
|
||||||
|
|
||||||
for m in wrapper.modules():
|
|
||||||
if hasattr(m, 'checkpoint'):
|
|
||||||
m.checkpoint = False
|
|
||||||
if hasattr(m, 'use_checkpoint'):
|
|
||||||
m.use_checkpoint = False
|
|
||||||
|
|
||||||
import xformers.ops
|
|
||||||
_orig_mea = xformers.ops.memory_efficient_attention
|
|
||||||
def _sdpa_replacement(q, k, v, attn_bias=None, op=None, **kw):
|
|
||||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_bias)
|
|
||||||
xformers.ops.memory_efficient_attention = _sdpa_replacement
|
|
||||||
|
|
||||||
BT = 16
|
|
||||||
emb_dim = wma.model_channels * 4
|
|
||||||
ctx_dim = 1024
|
|
||||||
in_ch = wma.in_channels
|
|
||||||
|
|
||||||
dummy_h = torch.randn(BT, in_ch, 40, 64, device='cuda', dtype=torch.float32)
|
|
||||||
dummy_emb = torch.randn(BT, emb_dim, device='cuda', dtype=torch.float32)
|
|
||||||
dummy_ctx = torch.randn(BT, context_len, ctx_dim, device='cuda', dtype=torch.float32)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = wrapper(dummy_h, dummy_emb, dummy_ctx)
|
|
||||||
n_outputs = len(outputs)
|
|
||||||
print(f">>> Backbone has {n_outputs} outputs (1 y + {n_outputs-1} hs_a)")
|
|
||||||
for i, o in enumerate(outputs):
|
|
||||||
print(f" output[{i}]: {o.shape} {o.dtype}")
|
|
||||||
|
|
||||||
output_names = ['y'] + [f'hs_a_{i}' for i in range(n_outputs - 1)]
|
|
||||||
|
|
||||||
torch.onnx.export(
|
|
||||||
wrapper,
|
|
||||||
(dummy_h, dummy_emb, dummy_ctx),
|
|
||||||
save_path,
|
|
||||||
input_names=['h', 'emb', 'context'],
|
|
||||||
output_names=output_names,
|
|
||||||
opset_version=17,
|
|
||||||
do_constant_folding=True,
|
|
||||||
)
|
|
||||||
print(f">>> ONNX exported to {save_path}")
|
|
||||||
xformers.ops.memory_efficient_attention = _orig_mea
|
|
||||||
return n_outputs
|
|
||||||
|
|
||||||
|
|
||||||
class TRTBackbone:
|
|
||||||
"""TensorRT runtime wrapper for the video UNet backbone."""
|
|
||||||
|
|
||||||
def __init__(self, engine_path, n_hs_a=9):
|
|
||||||
import tensorrt as trt
|
|
||||||
|
|
||||||
self.logger = trt.Logger(trt.Logger.WARNING)
|
|
||||||
with open(engine_path, 'rb') as f:
|
|
||||||
runtime = trt.Runtime(self.logger)
|
|
||||||
self.engine = runtime.deserialize_cuda_engine(f.read())
|
|
||||||
self.context = self.engine.create_execution_context()
|
|
||||||
self.n_hs_a = n_hs_a
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
self.output_buffers = {}
|
|
||||||
for i in range(self.engine.num_io_tensors):
|
|
||||||
name = self.engine.get_tensor_name(i)
|
|
||||||
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
|
||||||
shape = self.engine.get_tensor_shape(name)
|
|
||||||
np_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
|
|
||||||
buf = torch.empty(list(shape), dtype=torch.from_numpy(np.empty(0, dtype=np_dtype)).dtype, device='cuda')
|
|
||||||
self.output_buffers[name] = buf
|
|
||||||
print(f" TRT output '{name}': {list(shape)} {buf.dtype}")
|
|
||||||
|
|
||||||
def __call__(self, h, emb, context):
|
|
||||||
import tensorrt as trt
|
|
||||||
for name, tensor in [('h', h), ('emb', emb), ('context', context)]:
|
|
||||||
expected_dtype = trt.nptype(self.engine.get_tensor_dtype(name))
|
|
||||||
torch_expected = torch.from_numpy(__import__('numpy').empty(0, dtype=expected_dtype)).dtype
|
|
||||||
if tensor.dtype != torch_expected:
|
|
||||||
tensor = tensor.to(torch_expected)
|
|
||||||
self.context.set_tensor_address(name, tensor.contiguous().data_ptr())
|
|
||||||
|
|
||||||
for name, buf in self.output_buffers.items():
|
|
||||||
self.context.set_tensor_address(name, buf.data_ptr())
|
|
||||||
|
|
||||||
self.context.execute_async_v3(torch.cuda.current_stream().cuda_stream)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
y = self.output_buffers['y']
|
|
||||||
hs_a = [self.output_buffers[f'hs_a_{i}'] for i in range(self.n_hs_a)]
|
|
||||||
return y, hs_a
|
|
||||||
@@ -7,7 +7,9 @@
|
|||||||
#
|
#
|
||||||
# thanks!
|
# thanks!
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
@@ -78,7 +80,11 @@ def nonlinearity(type='silu'):
|
|||||||
class GroupNormSpecific(nn.GroupNorm):
|
class GroupNormSpecific(nn.GroupNorm):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return super().forward(x.float()).type(x.dtype)
|
with torch.amp.autocast('cuda', enabled=False):
|
||||||
|
return F.group_norm(x, self.num_groups,
|
||||||
|
self.weight.to(x.dtype) if self.weight is not None else None,
|
||||||
|
self.bias.to(x.dtype) if self.bias is not None else None,
|
||||||
|
self.eps)
|
||||||
|
|
||||||
|
|
||||||
def normalization(channels, num_groups=32):
|
def normalization(channels, num_groups=32):
|
||||||
|
|||||||
@@ -1,13 +1,32 @@
|
|||||||
2026-02-18 19:01:56.891895: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-08 05:20:49.828675: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-18 19:01:56.940243: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-08 05:20:49.831563: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-18 19:01:56.940285: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-08 05:20:49.861366: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-18 19:01:56.941395: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-08 05:20:49.861402: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-18 19:01:56.948327: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-08 05:20:49.862974: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
2026-02-08 05:20:49.870402: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-18 19:01:57.870809: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-08 05:20:49.870647: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 05:20:50.486843: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
>>> Prepared model loaded.
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -25,125 +44,71 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
✓ KV fused: 66 attention layers
|
|
||||||
TRT output 'y': [1, 4, 16, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_0': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_1': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_2': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_3': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_4': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_5': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_6': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_7': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_8': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
>>> TRT backbone loaded from /home/qhy/unifolm-world-model-action/scripts/evaluation/../../trt_engines/video_backbone.engine
|
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s][02/18/2026-19:02:10] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
|
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
>>> Step 0: generating actions ...
|
>>> Step 0: generating actions ...
|
||||||
9%|▉ | 1/11 [00:17<02:51, 17.15s/it]>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 0: interacting with world model ...
|
>>> Step 0: interacting with world model ...
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
|
9%|▉ | 1/11 [01:38<16:25, 98.56s/it]
|
||||||
18%|█▊ | 2/11 [03:16<14:44, 98.31s/it]
|
18%|█▊ | 2/11 [03:16<14:44, 98.31s/it]
|
||||||
27%|██▋ | 3/11 [04:55<13:06, 98.33s/it]
|
27%|██▋ | 3/11 [04:55<13:06, 98.33s/it]
|
||||||
36%|███▋ | 4/11 [06:36<11:37, 99.66s/it]
|
36%|███▋ | 4/11 [06:36<11:37, 99.66s/it]
|
||||||
@@ -174,6 +139,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4",
|
"gt_video": "unitree_g1_pack_camera/case1/unitree_g1_pack_camera_case1.mp4",
|
||||||
"pred_video": "unitree_g1_pack_camera/case1/output/inference/0_full_fs6.mp4",
|
"pred_video": "unitree_g1_pack_camera/case1/output/inference/unitree_g1_pack_camera_case1_amd.mp4",
|
||||||
"psnr": 35.615362167470806
|
"psnr": 16.415668383379177
|
||||||
}
|
}
|
||||||
@@ -1,13 +1,46 @@
|
|||||||
2026-02-18 19:05:45.956647: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
2026-02-18 19:05:46.004149: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
2026-02-18 19:05:46.004193: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-08 18:28:48.960238: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-18 19:05:46.005265: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-08 18:28:48.963331: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-18 19:05:46.012074: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-08 18:28:48.995688: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
2026-02-08 18:28:48.995732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-18 19:05:46.932966: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-08 18:28:48.997547: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
Global seed set to 123
|
2026-02-08 18:28:49.005673: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
2026-02-08 18:28:49.005948: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
>>> Prepared model loaded.
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 18:28:50.009660: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
>>> Applying precision settings:
|
||||||
|
- Diffusion dtype: bf16
|
||||||
|
- Projector mode: bf16_full
|
||||||
|
- Encoder mode: bf16_full
|
||||||
|
- VAE dtype: fp32
|
||||||
|
✓ Diffusion model weights converted to bfloat16
|
||||||
|
✓ Projectors converted to bfloat16
|
||||||
|
✓ Encoders converted to bfloat16
|
||||||
|
✓ VAE kept in fp32 for best quality
|
||||||
|
⚠ Found 849 fp32 params, converting to bf16
|
||||||
|
✓ All parameters converted to bfloat16
|
||||||
|
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -25,125 +58,71 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
✓ KV fused: 66 attention layers
|
|
||||||
TRT output 'y': [1, 4, 16, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_0': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_1': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_2': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_3': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_4': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_5': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_6': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_7': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_8': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
>>> TRT backbone loaded from /home/qhy/unifolm-world-model-action/scripts/evaluation/../../trt_engines/video_backbone.engine
|
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s][02/18/2026-19:05:59] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
|
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
>>> Step 0: generating actions ...
|
>>> Step 0: generating actions ...
|
||||||
9%|▉ | 1/11 [00:16<02:47, 16.71s/it]>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 0: interacting with world model ...
|
>>> Step 0: interacting with world model ...
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
|
9%|▉ | 1/11 [01:14<12:29, 74.95s/it]
|
||||||
18%|█▊ | 2/11 [02:23<10:40, 71.18s/it]
|
18%|█▊ | 2/11 [02:23<10:40, 71.18s/it]
|
||||||
27%|██▋ | 3/11 [03:32<09:20, 70.05s/it]
|
27%|██▋ | 3/11 [03:32<09:20, 70.05s/it]
|
||||||
36%|███▋ | 4/11 [04:40<08:06, 69.51s/it]
|
36%|███▋ | 4/11 [04:40<08:06, 69.51s/it]
|
||||||
@@ -174,6 +153,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_g1_pack_camera/case2/unitree_g1_pack_camera_case2.mp4",
|
"gt_video": "unitree_g1_pack_camera/case2/unitree_g1_pack_camera_case2.mp4",
|
||||||
"pred_video": "unitree_g1_pack_camera/case2/output/inference/50_full_fs6.mp4",
|
"pred_video": "unitree_g1_pack_camera/case2/output/inference/unitree_g1_pack_camera_case2_amd.mp4",
|
||||||
"psnr": 34.61979248212279
|
"psnr": 19.515250190529375
|
||||||
}
|
}
|
||||||
@@ -1,13 +1,32 @@
|
|||||||
2026-02-18 19:09:35.113634: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
2026-02-08 05:08:32.803904: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-18 19:09:35.161428: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
2026-02-08 05:08:32.807010: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-18 19:09:35.161474: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-08 05:08:32.837936: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
2026-02-18 19:09:35.162551: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-08 05:08:32.837978: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-18 19:09:35.169325: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-08 05:08:32.839785: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
2026-02-08 05:08:32.847835: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-18 19:09:36.089250: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-08 05:08:32.848223: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
Global seed set to 123
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
2026-02-08 05:08:34.120114: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
>>> Prepared model loaded.
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -25,122 +44,101 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
✓ KV fused: 66 attention layers
|
|
||||||
TRT output 'y': [1, 4, 16, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_0': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_1': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_2': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_3': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_4': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_5': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_6': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_7': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_8': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
>>> TRT backbone loaded from /home/qhy/unifolm-world-model-action/scripts/evaluation/../../trt_engines/video_backbone.engine
|
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s][02/18/2026-19:09:49] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
|
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
>>> Step 0: generating actions ...
|
>>> Step 0: generating actions ...
|
||||||
9%|▉ | 1/11 [00:16<02:45, 16.53s/it]>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 0: interacting with world model ...
|
>>> Step 0: interacting with world model ...
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
|
9%|▉ | 1/11 [01:39<16:34, 99.46s/it]
|
||||||
|
18%|█▊ | 2/11 [03:18<14:55, 99.48s/it]
|
||||||
|
27%|██▋ | 3/11 [04:58<13:16, 99.60s/it]
|
||||||
|
36%|███▋ | 4/11 [06:38<11:37, 99.69s/it]
|
||||||
|
45%|████▌ | 5/11 [08:18<09:58, 99.68s/it]
|
||||||
|
55%|█████▍ | 6/11 [09:57<08:18, 99.66s/it]
|
||||||
|
64%|██████▎ | 7/11 [11:37<06:38, 99.62s/it]
|
||||||
|
73%|███████▎ | 8/11 [13:16<04:58, 99.55s/it]
|
||||||
|
82%|████████▏ | 9/11 [14:56<03:19, 99.50s/it]
|
||||||
|
91%|█████████ | 10/11 [16:35<01:39, 99.43s/it]
|
||||||
|
100%|██████████| 11/11 [18:14<00:00, 99.36s/it]
|
||||||
|
100%|██████████| 11/11 [18:14<00:00, 99.51s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 6: generating actions ...
|
||||||
|
>>> Step 6: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 7: generating actions ...
|
||||||
|
>>> Step 7: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_g1_pack_camera/case3/unitree_g1_pack_camera_case3.mp4",
|
"gt_video": "unitree_g1_pack_camera/case3/unitree_g1_pack_camera_case3.mp4",
|
||||||
"pred_video": "unitree_g1_pack_camera/case3/output/inference/100_full_fs6.mp4",
|
"pred_video": "unitree_g1_pack_camera/case3/output/inference/unitree_g1_pack_camera_case3_amd.mp4",
|
||||||
"psnr": 37.034952654534486
|
"psnr": 19.429578160315536
|
||||||
}
|
}
|
||||||
144
unitree_g1_pack_camera/case4/output.log
Normal file
144
unitree_g1_pack_camera/case4/output.log
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
2026-02-08 05:29:19.728303: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 05:29:19.731620: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 05:29:19.761276: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 05:29:19.761301: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 05:29:19.762880: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 05:29:19.770578: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 05:29:19.771072: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 05:29:21.043661: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
9%|▉ | 1/11 [01:37<16:18, 97.81s/it]
|
||||||
|
18%|█▊ | 2/11 [03:15<14:38, 97.56s/it]
|
||||||
|
27%|██▋ | 3/11 [04:52<12:59, 97.48s/it]
|
||||||
|
36%|███▋ | 4/11 [06:29<11:21, 97.38s/it]
|
||||||
|
45%|████▌ | 5/11 [08:06<09:43, 97.28s/it]
|
||||||
|
55%|█████▍ | 6/11 [09:44<08:06, 97.35s/it]
|
||||||
|
64%|██████▎ | 7/11 [11:21<06:29, 97.36s/it]
|
||||||
|
73%|███████▎ | 8/11 [12:59<04:52, 97.38s/it]
|
||||||
|
82%|████████▏ | 9/11 [14:36<03:14, 97.39s/it]
|
||||||
|
91%|█████████ | 10/11 [16:14<01:37, 97.42s/it]
|
||||||
|
100%|██████████| 11/11 [17:51<00:00, 97.42s/it]
|
||||||
|
100%|██████████| 11/11 [17:51<00:00, 97.41s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 6: generating actions ...
|
||||||
|
>>> Step 6: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 7: generating actions ...
|
||||||
|
>>> Step 7: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_g1_pack_camera/case4/unitree_g1_pack_camera_case4.mp4",
|
"gt_video": "unitree_g1_pack_camera/case4/unitree_g1_pack_camera_case4.mp4",
|
||||||
"pred_video": "unitree_g1_pack_camera/case4/output/inference/200_full_fs6.mp4",
|
"pred_video": "unitree_g1_pack_camera/case4/output/inference/unitree_g1_pack_camera_case4_amd.mp4",
|
||||||
"psnr": 31.43390896360405
|
"psnr": 17.80386833747375
|
||||||
}
|
}
|
||||||
@@ -1,11 +1,17 @@
|
|||||||
2026-02-10 15:38:28.973314: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
2026-02-10 15:38:29.023024: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
2026-02-10 15:38:29.023070: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-10 17:57:48.047156: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-10 15:38:29.024393: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-10 17:57:48.050303: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-10 15:38:29.031901: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-10 17:57:48.081710: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
2026-02-10 17:57:48.081741: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-10 15:38:29.955454: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-10 17:57:48.083577: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
Global seed set to 123
|
2026-02-10 17:57:48.091772: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-10 17:57:48.092045: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-10 17:57:48.787960: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
@@ -14,11 +20,28 @@ INFO:root:Loaded ViT-H-14 model config.
|
|||||||
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
INFO:root:Loaded ViT-H-14 model config.
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:198: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
>>> model checkpoint loaded.
|
>>> model checkpoint loaded.
|
||||||
>>> Load pre-trained model ...
|
>>> Load pre-trained model ...
|
||||||
|
>>> Applying precision settings:
|
||||||
|
- Diffusion dtype: bf16
|
||||||
|
- Projector mode: bf16_full
|
||||||
|
- Encoder mode: bf16_full
|
||||||
|
- VAE dtype: bf16
|
||||||
|
✓ Diffusion model weights converted to bfloat16
|
||||||
|
✓ Projectors converted to bfloat16
|
||||||
|
✓ Encoders converted to bfloat16
|
||||||
|
✓ VAE converted to bfloat16
|
||||||
|
⚠ Found 601 fp32 params, converting to bf16
|
||||||
|
✓ All parameters converted to bfloat16
|
||||||
|
✓ torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||||
|
✓ KV fused: 66 attention layers
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -41,7 +64,9 @@ DEBUG:h5py._conv:Creating converter from 3 to 5
|
|||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
>>> Step 0: generating actions ...
|
>>> Step 0: generating actions ...
|
||||||
>>> Step 0: interacting with world model ...
|
>>> Step 0: interacting with world model ...
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
@@ -92,7 +117,7 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
12%|█▎ | 1/8 [01:03<07:22, 63.25s/it]
|
12%|█▎ | 1/8 [01:03<07:22, 63.25s/it]
|
||||||
25%|██▌ | 2/8 [02:02<06:05, 60.93s/it]
|
25%|██▌ | 2/8 [02:02<06:05, 60.93s/it]
|
||||||
@@ -116,6 +141,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 4: generating actions ...
|
>>> Step 4: generating actions ...
|
||||||
>>> Step 4: interacting with world model ...
|
>>> Step 4: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 5: generating actions ...
|
>>> Step 5: generating actions ...
|
||||||
>>> Step 5: interacting with world model ...
|
>>> Step 5: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
itr,stack_to_device_1,policy/ddim_sampler_init,policy/image_embedding,policy/vae_encode,policy/text_conditioning,policy/projectors,policy/cond_assembly,policy/ddim_sampling,policy/vae_decode,synth_policy,update_action_queue,stack_to_device_2,wm/ddim_sampler_init,wm/image_embedding,wm/vae_encode,wm/text_conditioning,wm/projectors,wm/cond_assembly,wm/ddim_sampling,wm/vae_decode,synth_world_model,update_obs_queue,tensorboard_log,save_results,cpu_transfer,itr_total
|
||||||
|
0,0.16,0.08,20.98,49.56,14.51,0.29,0.07,31005.48,0.00,31094.51,0.39,0.13,0.09,20.62,48.76,14.17,0.28,0.07,31011.17,775.40,31875.87,0.61,0.31,97.28,7.19,63077.50
|
||||||
|
1,0.16,0.09,20.97,49.63,14.52,0.30,0.07,31035.49,0.00,31125.16,0.54,0.17,0.14,21.46,49.26,14.88,0.49,0.12,31047.54,777.56,31918.60,0.75,0.60,109.89,6.21,63163.18
|
||||||
|
2,0.18,0.10,21.44,49.71,15.05,0.34,0.07,31047.64,0.00,31138.56,0.58,0.16,0.13,21.03,48.74,14.69,0.32,0.08,31036.47,776.96,31905.96,0.67,0.39,116.96,7.43,63171.90
|
||||||
|
3,0.18,0.10,21.38,49.47,15.02,0.35,0.08,31041.05,0.00,31132.03,0.48,0.16,0.12,20.81,49.34,14.41,0.47,0.11,31051.98,777.11,31920.42,0.64,0.38,121.67,7.29,63184.26
|
||||||
|
@@ -0,0 +1,5 @@
|
|||||||
|
stat,stack_to_device_1,policy/ddim_sampler_init,policy/image_embedding,policy/vae_encode,policy/text_conditioning,policy/projectors,policy/cond_assembly,policy/ddim_sampling,policy/vae_decode,synth_policy,update_action_queue,stack_to_device_2,wm/ddim_sampler_init,wm/image_embedding,wm/vae_encode,wm/text_conditioning,wm/projectors,wm/cond_assembly,wm/ddim_sampling,wm/vae_decode,synth_world_model,update_obs_queue,tensorboard_log,save_results,cpu_transfer,itr_total
|
||||||
|
mean,0.17,0.09,21.19,49.59,14.78,0.32,0.07,31032.42,0.00,31122.56,0.49,0.15,0.12,20.98,49.03,14.53,0.39,0.10,31036.79,776.76,31905.21,0.67,0.42,111.45,7.03,63149.21
|
||||||
|
std,0.01,0.01,0.22,0.09,0.26,0.03,0.00,16.13,0.00,16.88,0.07,0.01,0.02,0.31,0.28,0.27,0.09,0.02,15.83,0.82,17.84,0.05,0.11,9.19,0.48,42.08
|
||||||
|
min,0.16,0.08,20.97,49.47,14.51,0.29,0.07,31005.48,0.00,31094.51,0.39,0.13,0.09,20.62,48.74,14.17,0.28,0.07,31011.17,775.40,31875.87,0.61,0.31,97.28,6.21,63077.50
|
||||||
|
max,0.18,0.10,21.44,49.71,15.05,0.35,0.08,31047.64,0.00,31138.56,0.58,0.17,0.14,21.46,49.34,14.88,0.49,0.12,31051.98,777.56,31920.42,0.75,0.60,121.67,7.43,63184.26
|
||||||
|
@@ -0,0 +1,45 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/profile_iteration.py:168: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(args.ckpt_path, map_location="cpu")
|
||||||
|
============================================================
|
||||||
|
PROFILE ITERATION — Loading model...
|
||||||
|
============================================================
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
torch.compile: 3 ResBlocks in output_blocks[5, 8, 9]
|
||||||
|
>>> Model loaded and ready.
|
||||||
|
>>> Noise shape: [1, 4, 16, 40, 64]
|
||||||
|
>>> DDIM steps: 50
|
||||||
|
>>> fast_policy_no_decode: True
|
||||||
|
============================================================
|
||||||
|
LAYER 1: ITERATION-LEVEL PROFILING
|
||||||
|
============================================================
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Running 5 profiled iterations ...
|
||||||
|
Traceback (most recent call last):
|
||||||
|
File "/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/profile_iteration.py", line 981, in <module>
|
||||||
|
main()
|
||||||
|
File "/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/profile_iteration.py", line 967, in main
|
||||||
|
all_records = run_profiled_iterations(
|
||||||
|
File "/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/profile_iteration.py", line 502, in run_profiled_iterations
|
||||||
|
sampler_type=args.sampler_type)
|
||||||
|
AttributeError: 'Namespace' object has no attribute 'sampler_type'
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4",
|
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case1/unitree_z1_dual_arm_cleanup_pencils_case1.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||||
"psnr": 47.911564449209735
|
"psnr": 19.586376345676264
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"gt_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/unitree_z1_dual_arm_cleanup_pencils_case1_amd.mp4",
|
||||||
|
"pred_video": "/mnt/ASC1637/unifolm-world-model-action/unitree_z1_dual_arm_cleanup_pencils/case1/output/inference/0_full_fs4.mp4",
|
||||||
|
"psnr": 32.442113263955434
|
||||||
|
}
|
||||||
5
unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh
Normal file
5
unitree_z1_dual_arm_cleanup_pencils/case1/run_profile.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
#\!/bin/bash
|
||||||
|
res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
|
||||||
|
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||||
|
|
||||||
|
TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/profile_iteration.py --seed 123 --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml --savedir "${res_dir}/profile_output" --prompt_dir "${res_dir}/world_model_interaction_prompts" --dataset ${dataset} --bs 1 --height 320 --width 512 --unconditional_guidance_scale 1.0 --ddim_steps 50 --ddim_eta 1.0 --video_length 16 --frame_stride 4 --exe_steps 16 --n_iter 5 --warmup 1 --timestep_spacing uniform_trailing --guidance_rescale 0.7 --perframe_ae --vae_dtype bf16 --fast_policy_no_decode --csv "${res_dir}/profile_output/baseline.csv" 2>&1 | tee "${res_dir}/profile_output/profile.log"
|
||||||
@@ -2,9 +2,9 @@ res_dir="unitree_z1_dual_arm_cleanup_pencils/case1"
|
|||||||
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
dataset="unitree_z1_dual_arm_cleanup_pencils"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
--savedir "${res_dir}/output" \
|
--savedir "${res_dir}/output" \
|
||||||
--bs 1 --height 320 --width 512 \
|
--bs 1 --height 320 --width 512 \
|
||||||
@@ -20,5 +20,7 @@ dataset="unitree_z1_dual_arm_cleanup_pencils"
|
|||||||
--n_iter 8 \
|
--n_iter 8 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae
|
--perframe_ae \
|
||||||
|
--vae_dtype bf16 \
|
||||||
|
--fast_policy_no_decode
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
137
unitree_z1_dual_arm_cleanup_pencils/case2/output.log
Normal file
137
unitree_z1_dual_arm_cleanup_pencils/case2/output.log
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 06:59:34.465946: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 06:59:34.469367: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 06:59:34.500805: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 06:59:34.500837: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 06:59:34.502917: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 06:59:34.511434: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 06:59:34.511678: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 06:59:35.478194: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
12%|█▎ | 1/8 [01:37<11:23, 97.57s/it]
|
||||||
|
25%|██▌ | 2/8 [03:14<09:44, 97.48s/it]
|
||||||
|
38%|███▊ | 3/8 [04:52<08:07, 97.47s/it]
|
||||||
|
50%|█████ | 4/8 [06:29<06:29, 97.49s/it]
|
||||||
|
62%|██████▎ | 5/8 [08:07<04:52, 97.42s/it]
|
||||||
|
75%|███████▌ | 6/8 [09:44<03:14, 97.32s/it]
|
||||||
|
88%|████████▊ | 7/8 [11:21<01:37, 97.34s/it]
|
||||||
|
100%|██████████| 8/8 [12:59<00:00, 97.36s/it]
|
||||||
|
100%|██████████| 8/8 [12:59<00:00, 97.40s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case2/unitree_z1_dual_arm_cleanup_pencils_case2.mp4",
|
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case2/unitree_z1_dual_arm_cleanup_pencils_case2.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case2/output/inference/50_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case2/output/inference/unitree_z1_dual_arm_cleanup_pencils_case2_amd.mp4",
|
||||||
"psnr": 48.344571927558974
|
"psnr": 20.484298972158296
|
||||||
}
|
}
|
||||||
137
unitree_z1_dual_arm_cleanup_pencils/case3/output.log
Normal file
137
unitree_z1_dual_arm_cleanup_pencils/case3/output.log
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 07:18:52.629976: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 07:18:52.633025: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:18:52.663985: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 07:18:52.664018: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 07:18:52.665837: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:18:52.673889: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:18:52.674218: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:18:53.298338: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
12%|█▎ | 1/8 [01:40<11:43, 100.54s/it]
|
||||||
|
25%|██▌ | 2/8 [03:20<10:02, 100.36s/it]
|
||||||
|
38%|███▊ | 3/8 [05:01<08:21, 100.32s/it]
|
||||||
|
50%|█████ | 4/8 [06:41<06:41, 100.36s/it]
|
||||||
|
62%|██████▎ | 5/8 [08:21<05:00, 100.30s/it]
|
||||||
|
75%|███████▌ | 6/8 [10:01<03:20, 100.28s/it]
|
||||||
|
88%|████████▊ | 7/8 [11:42<01:40, 100.34s/it]
|
||||||
|
100%|██████████| 8/8 [13:22<00:00, 100.36s/it]
|
||||||
|
100%|██████████| 8/8 [13:22<00:00, 100.34s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case3/unitree_z1_dual_arm_cleanup_pencils_case3.mp4",
|
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case3/unitree_z1_dual_arm_cleanup_pencils_case3.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case3/output/inference/100_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case3/output/inference/unitree_z1_dual_arm_cleanup_pencils_case3_amd.mp4",
|
||||||
"psnr": 41.152374490134825
|
"psnr": 21.20205061239349
|
||||||
}
|
}
|
||||||
137
unitree_z1_dual_arm_cleanup_pencils/case4/output.log
Normal file
137
unitree_z1_dual_arm_cleanup_pencils/case4/output.log
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 07:22:15.333099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 07:22:15.336215: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:22:15.366489: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 07:22:15.366522: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 07:22:15.368294: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:22:15.376202: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:22:15.376444: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:22:15.995383: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/8 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
12%|█▎ | 1/8 [01:37<11:23, 97.68s/it]
|
||||||
|
25%|██▌ | 2/8 [03:15<09:47, 97.83s/it]
|
||||||
|
38%|███▊ | 3/8 [04:53<08:09, 97.91s/it]
|
||||||
|
50%|█████ | 4/8 [06:31<06:32, 98.03s/it]
|
||||||
|
62%|██████▎ | 5/8 [08:10<04:54, 98.11s/it]
|
||||||
|
75%|███████▌ | 6/8 [09:48<03:16, 98.18s/it]
|
||||||
|
88%|████████▊ | 7/8 [11:26<01:38, 98.24s/it]
|
||||||
|
100%|██████████| 8/8 [13:04<00:00, 98.16s/it]
|
||||||
|
100%|██████████| 8/8 [13:04<00:00, 98.09s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case4/unitree_z1_dual_arm_cleanup_pencils_case4.mp4",
|
"gt_video": "unitree_z1_dual_arm_cleanup_pencils/case4/unitree_z1_dual_arm_cleanup_pencils_case4.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case4/output/inference/200_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_cleanup_pencils/case4/output/inference/unitree_z1_dual_arm_cleanup_pencils_case4_amd.mp4",
|
||||||
"psnr": 46.025723557253855
|
"psnr": 21.130122583788612
|
||||||
}
|
}
|
||||||
134
unitree_z1_dual_arm_stackbox/case1/output.log
Normal file
134
unitree_z1_dual_arm_stackbox/case1/output.log
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 07:24:40.357099: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 07:24:40.360365: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:24:40.391744: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 07:24:40.391772: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 07:24:40.393608: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:24:40.401837: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:24:40.402077: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:24:41.022382: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
14%|█▍ | 1/7 [01:41<10:09, 101.63s/it]
|
||||||
|
29%|██▊ | 2/7 [03:20<08:18, 99.78s/it]
|
||||||
|
43%|████▎ | 3/7 [04:58<06:36, 99.24s/it]
|
||||||
|
57%|█████▋ | 4/7 [06:37<04:57, 99.05s/it]
|
||||||
|
71%|███████▏ | 5/7 [08:16<03:17, 98.90s/it]
|
||||||
|
86%|████████▌ | 6/7 [09:54<01:38, 98.80s/it]
|
||||||
|
100%|██████████| 7/7 [11:33<00:00, 98.70s/it]
|
||||||
|
100%|██████████| 7/7 [11:33<00:00, 99.03s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox/case1/unitree_z1_dual_arm_stackbox_case1.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox/case1/unitree_z1_dual_arm_stackbox_case1.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox/case1/output/inference/5_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox/case1/output/inference/unitree_z1_dual_arm_stackbox_case1_amd.mp4",
|
||||||
"psnr": 44.3480149502738
|
"psnr": 21.258130518117493
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case1"
|
|||||||
dataset="unitree_z1_dual_arm_stackbox"
|
dataset="unitree_z1_dual_arm_stackbox"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
|
|||||||
134
unitree_z1_dual_arm_stackbox/case2/output.log
Normal file
134
unitree_z1_dual_arm_stackbox/case2/output.log
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 07:25:18.653033: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 07:25:18.656060: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:25:18.687077: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 07:25:18.687119: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 07:25:18.688915: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:25:18.697008: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:25:18.697255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:25:19.338303: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
14%|█▍ | 1/7 [01:39<09:56, 99.35s/it]
|
||||||
|
29%|██▊ | 2/7 [03:18<08:17, 99.50s/it]
|
||||||
|
43%|████▎ | 3/7 [04:58<06:38, 99.54s/it]
|
||||||
|
57%|█████▋ | 4/7 [06:38<04:58, 99.52s/it]
|
||||||
|
71%|███████▏ | 5/7 [08:17<03:19, 99.55s/it]
|
||||||
|
86%|████████▌ | 6/7 [09:57<01:39, 99.53s/it]
|
||||||
|
100%|██████████| 7/7 [11:36<00:00, 99.50s/it]
|
||||||
|
100%|██████████| 7/7 [11:36<00:00, 99.51s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox/case2/unitree_z1_dual_arm_stackbox_case2.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox/case2/unitree_z1_dual_arm_stackbox_case2.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox/case2/output/inference/15_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox/case2/output/inference/unitree_z1_dual_arm_stackbox_case2_amd.mp4",
|
||||||
"psnr": 39.867728254007716
|
"psnr": 23.878153424077645
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox/case2"
|
|||||||
dataset="unitree_z1_dual_arm_stackbox"
|
dataset="unitree_z1_dual_arm_stackbox"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
|
|||||||
134
unitree_z1_dual_arm_stackbox/case3/output.log
Normal file
134
unitree_z1_dual_arm_stackbox/case3/output.log
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 07:35:33.682231: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 07:35:33.685275: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:35:33.716682: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 07:35:33.716728: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 07:35:33.718523: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:35:33.726756: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:35:33.727105: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:35:34.356722: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
14%|█▍ | 1/7 [01:41<10:06, 101.02s/it]
|
||||||
|
29%|██▊ | 2/7 [03:23<08:29, 101.84s/it]
|
||||||
|
43%|████▎ | 3/7 [05:04<06:45, 101.43s/it]
|
||||||
|
57%|█████▋ | 4/7 [06:45<05:04, 101.42s/it]
|
||||||
|
71%|███████▏ | 5/7 [08:27<03:22, 101.40s/it]
|
||||||
|
86%|████████▌ | 6/7 [10:08<01:41, 101.39s/it]
|
||||||
|
100%|██████████| 7/7 [11:49<00:00, 101.33s/it]
|
||||||
|
100%|██████████| 7/7 [11:49<00:00, 101.39s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox/case3/unitree_z1_dual_arm_stackbox_case3.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox/case3/unitree_z1_dual_arm_stackbox_case3.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox/case3/output/inference/25_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox/case3/output/inference/unitree_z1_dual_arm_stackbox_case3_amd.mp4",
|
||||||
"psnr": 39.19101039445159
|
"psnr": 25.400458754751128
|
||||||
}
|
}
|
||||||
134
unitree_z1_dual_arm_stackbox/case4/output.log
Normal file
134
unitree_z1_dual_arm_stackbox/case4/output.log
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 07:38:45.572744: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 07:38:45.576864: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:38:45.624825: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 07:38:45.624883: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 07:38:45.627150: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:38:45.638316: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:38:45.638803: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:38:46.426363: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/7 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
14%|█▍ | 1/7 [01:38<09:52, 98.73s/it]
|
||||||
|
29%|██▊ | 2/7 [03:17<08:14, 98.85s/it]
|
||||||
|
43%|████▎ | 3/7 [04:56<06:35, 98.80s/it]
|
||||||
|
57%|█████▋ | 4/7 [06:35<04:56, 98.94s/it]
|
||||||
|
71%|███████▏ | 5/7 [08:14<03:17, 98.93s/it]
|
||||||
|
86%|████████▌ | 6/7 [09:53<01:38, 98.89s/it]
|
||||||
|
100%|██████████| 7/7 [11:31<00:00, 98.81s/it]
|
||||||
|
100%|██████████| 7/7 [11:31<00:00, 98.85s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox/case4/unitree_z1_dual_arm_stackbox_case4.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox/case4/unitree_z1_dual_arm_stackbox_case4.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox/case4/output/inference/35_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox/case4/output/inference/unitree_z1_dual_arm_stackbox_case4_amd.mp4",
|
||||||
"psnr": 40.29563315341769
|
"psnr": 24.098958457373858
|
||||||
}
|
}
|
||||||
@@ -1,13 +1,34 @@
|
|||||||
2026-02-18 18:49:49.117856: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
2026-02-18 18:49:49.165270: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
2026-02-18 18:49:49.165322: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-08 07:51:23.961486: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-18 18:49:49.166382: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-08 07:51:24.200063: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-18 18:49:49.173299: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-08 07:51:24.522299: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
2026-02-08 07:51:24.522350: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-18 18:49:50.090214: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-08 07:51:24.528237: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:51:24.579400: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:51:24.579644: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:51:25.781311: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
>>> Prepared model loaded.
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -25,28 +46,19 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
✓ KV fused: 66 attention layers
|
|
||||||
TRT output 'y': [1, 4, 16, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_0': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_1': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_2': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_3': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_4': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_5': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_6': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_7': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_8': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
>>> TRT backbone loaded from /home/qhy/unifolm-world-model-action/scripts/evaluation/../../trt_engines/video_backbone.engine
|
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s][02/18/2026-18:50:03] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
|
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
>>> Step 0: generating actions ...
|
>>> Step 0: generating actions ...
|
||||||
9%|▉ | 1/11 [00:15<02:38, 15.88s/it]>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
>>> Step 0: interacting with world model ...
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
@@ -96,7 +108,9 @@ DEBUG:PIL.Image:Importing WmfImagePlugin
|
|||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
9%|▉ | 1/11 [01:38<16:20, 98.04s/it]
|
||||||
18%|█▊ | 2/11 [03:15<14:40, 97.81s/it]
|
18%|█▊ | 2/11 [03:15<14:40, 97.81s/it]
|
||||||
27%|██▋ | 3/11 [04:53<13:01, 97.72s/it]
|
27%|██▋ | 3/11 [04:53<13:01, 97.72s/it]
|
||||||
36%|███▋ | 4/11 [06:31<11:24, 97.71s/it]
|
36%|███▋ | 4/11 [06:31<11:24, 97.71s/it]
|
||||||
@@ -127,6 +141,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case1/unitree_z1_dual_arm_stackbox_v2_case1.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/5_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case1/output/inference/unitree_z1_dual_arm_stackbox_v2_case1_amd.mp4",
|
||||||
"psnr": 27.62636266067224
|
"psnr": 18.126776535969576
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case1"
|
|||||||
dataset="unitree_z1_dual_arm_stackbox_v2"
|
dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
@@ -20,6 +20,5 @@ dataset="unitree_z1_dual_arm_stackbox_v2"
|
|||||||
--n_iter 11 \
|
--n_iter 11 \
|
||||||
--timestep_spacing 'uniform_trailing' \
|
--timestep_spacing 'uniform_trailing' \
|
||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae \
|
--perframe_ae
|
||||||
--fast_policy_no_decode
|
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
146
unitree_z1_dual_arm_stackbox_v2/case2/output.log
Normal file
146
unitree_z1_dual_arm_stackbox_v2/case2/output.log
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 07:56:31.144789: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 07:56:31.148256: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:56:31.178870: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 07:56:31.178898: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 07:56:31.180683: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:56:31.188800: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:56:31.189142: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:56:31.810098: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
9%|▉ | 1/11 [01:40<16:41, 100.16s/it]
|
||||||
|
18%|█▊ | 2/11 [03:20<15:04, 100.47s/it]
|
||||||
|
27%|██▋ | 3/11 [05:01<13:24, 100.62s/it]
|
||||||
|
36%|███▋ | 4/11 [06:42<11:44, 100.69s/it]
|
||||||
|
45%|████▌ | 5/11 [08:22<10:02, 100.48s/it]
|
||||||
|
55%|█████▍ | 6/11 [10:02<08:21, 100.33s/it]
|
||||||
|
64%|██████▎ | 7/11 [11:42<06:40, 100.23s/it]
|
||||||
|
73%|███████▎ | 8/11 [13:22<05:00, 100.23s/it]
|
||||||
|
82%|████████▏ | 9/11 [15:03<03:20, 100.23s/it]
|
||||||
|
91%|█████████ | 10/11 [16:43<01:40, 100.33s/it]
|
||||||
|
100%|██████████| 11/11 [18:24<00:00, 100.41s/it]
|
||||||
|
100%|██████████| 11/11 [18:24<00:00, 100.39s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 6: generating actions ...
|
||||||
|
>>> Step 6: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 7: generating actions ...
|
||||||
|
>>> Step 7: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case2/unitree_z1_dual_arm_stackbox_v2_case2.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case2/unitree_z1_dual_arm_stackbox_v2_case2.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case2/output/inference/15_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case2/output/inference/unitree_z1_dual_arm_stackbox_v2_case2_amd.mp4",
|
||||||
"psnr": 33.90444714332389
|
"psnr": 19.38130614773096
|
||||||
}
|
}
|
||||||
146
unitree_z1_dual_arm_stackbox_v2/case3/output.log
Normal file
146
unitree_z1_dual_arm_stackbox_v2/case3/output.log
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 07:56:04.467082: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 07:56:04.470145: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:56:04.502248: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 07:56:04.502277: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 07:56:04.504088: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 07:56:04.512557: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 07:56:04.512830: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 07:56:05.259641: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
9%|▉ | 1/11 [01:38<16:20, 98.03s/it]
|
||||||
|
18%|█▊ | 2/11 [03:16<14:43, 98.19s/it]
|
||||||
|
27%|██▋ | 3/11 [04:55<13:08, 98.54s/it]
|
||||||
|
36%|███▋ | 4/11 [06:33<11:29, 98.52s/it]
|
||||||
|
45%|████▌ | 5/11 [08:11<09:50, 98.38s/it]
|
||||||
|
55%|█████▍ | 6/11 [09:49<08:10, 98.11s/it]
|
||||||
|
64%|██████▎ | 7/11 [11:27<06:31, 97.97s/it]
|
||||||
|
73%|███████▎ | 8/11 [13:04<04:53, 97.83s/it]
|
||||||
|
82%|████████▏ | 9/11 [14:42<03:15, 97.72s/it]
|
||||||
|
91%|█████████ | 10/11 [16:19<01:37, 97.71s/it]
|
||||||
|
100%|██████████| 11/11 [17:57<00:00, 97.74s/it]
|
||||||
|
100%|██████████| 11/11 [17:57<00:00, 97.97s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 6: generating actions ...
|
||||||
|
>>> Step 6: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 7: generating actions ...
|
||||||
|
>>> Step 7: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case3/unitree_z1_dual_arm_stackbox_v2_case3.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case3/unitree_z1_dual_arm_stackbox_v2_case3.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case3/output/inference/25_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case3/output/inference/unitree_z1_dual_arm_stackbox_v2_case3_amd.mp4",
|
||||||
"psnr": 34.50192428908007
|
"psnr": 18.74462122425683
|
||||||
}
|
}
|
||||||
@@ -1,13 +1,34 @@
|
|||||||
2026-02-18 18:54:56.403136: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
2026-02-18 18:54:56.451144: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
2026-02-18 18:54:56.451189: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
2026-02-08 08:04:16.104516: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
2026-02-18 18:54:56.452312: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
2026-02-08 08:04:16.109112: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
2026-02-18 18:54:56.459281: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
2026-02-08 08:04:16.138703: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
2026-02-08 08:04:16.138737: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
2026-02-18 18:54:57.381032: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
2026-02-08 08:04:16.140302: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 08:04:16.147672: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:04:16.147903: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 08:04:17.363218: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
Global seed set to 123
|
Global seed set to 123
|
||||||
>>> Loading prepared model from ckpts/unifolm_wma_dual.ckpt.prepared.pt ...
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
>>> Prepared model loaded.
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
INFO:root:***** Configing Data *****
|
INFO:root:***** Configing Data *****
|
||||||
>>> unitree_z1_stackbox: 1 data samples loaded.
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
>>> unitree_z1_stackbox: data stats loaded.
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
@@ -25,125 +46,71 @@ INFO:root:***** Configing Data *****
|
|||||||
>>> unitree_g1_pack_camera: data stats loaded.
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
>>> unitree_g1_pack_camera: normalizer initiated.
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
>>> Dataset is successfully loaded ...
|
>>> Dataset is successfully loaded ...
|
||||||
✓ KV fused: 66 attention layers
|
|
||||||
TRT output 'y': [1, 4, 16, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_0': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
TRT output 'hs_a_1': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_2': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_3': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_4': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_5': [1, 16, 1280, 5, 8] torch.float32
|
|
||||||
TRT output 'hs_a_6': [1, 16, 1280, 10, 16] torch.float32
|
|
||||||
TRT output 'hs_a_7': [1, 16, 640, 20, 32] torch.float32
|
|
||||||
TRT output 'hs_a_8': [1, 16, 320, 40, 64] torch.float32
|
|
||||||
>>> TRT backbone loaded from /home/qhy/unifolm-world-model-action/scripts/evaluation/../../trt_engines/video_backbone.engine
|
|
||||||
>>> Generate 16 frames under each generation ...
|
>>> Generate 16 frames under each generation ...
|
||||||
DEBUG:h5py._conv:Creating converter from 3 to 5
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
0%| | 0/11 [00:00<?, ?it/s][02/18/2026-18:55:10] [TRT] [W] Using default stream in enqueueV3() may lead to performance issues due to additional calls to cudaStreamSynchronize() by TensorRT to ensure correct synchronization. Please use non-default stream instead.
|
0%| | 0/11 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
>>> Step 0: generating actions ...
|
>>> Step 0: generating actions ...
|
||||||
9%|▉ | 1/11 [00:16<02:45, 16.53s/it]>>> Step 0: generating actions ...
|
|
||||||
>>> Step 0: interacting with world model ...
|
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
|
||||||
>>> Step 0: interacting with world model ...
|
>>> Step 0: interacting with world model ...
|
||||||
DEBUG:PIL.Image:Importing BlpImagePlugin
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BmpImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
DEBUG:PIL.Image:Importing CurImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing DcxImagePlugin
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing DdsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing EpsImagePlugin
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FliImagePlugin
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
DEBUG:PIL.Image:Importing FpxImagePlugin
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing FtexImagePlugin
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GbrImagePlugin
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GifImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IcoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
DEBUG:PIL.Image:Importing ImImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing ImtImagePlugin
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
DEBUG:PIL.Image:Importing IptcImagePlugin
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing JpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MicImagePlugin
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MpegImagePlugin
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MpoImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
DEBUG:PIL.Image:Importing MspImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PalmImagePlugin
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcdImagePlugin
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PcxImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PdfImagePlugin
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PixarImagePlugin
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PngImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PpmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing PsdImagePlugin
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
DEBUG:PIL.Image:Importing QoiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SgiImagePlugin
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing SunImagePlugin
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TgaImagePlugin
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing TiffImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing WebPImagePlugin
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
DEBUG:PIL.Image:Importing WmfImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
|
||||||
DEBUG:PIL.Image:Importing XbmImagePlugin
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XpmImagePlugin
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|
||||||
|
9%|▉ | 1/11 [01:39<16:32, 99.26s/it]
|
||||||
18%|█▊ | 2/11 [03:17<14:49, 98.81s/it]
|
18%|█▊ | 2/11 [03:17<14:49, 98.81s/it]
|
||||||
27%|██▋ | 3/11 [04:56<13:10, 98.76s/it]
|
27%|██▋ | 3/11 [04:56<13:10, 98.76s/it]
|
||||||
36%|███▋ | 4/11 [06:35<11:31, 98.80s/it]
|
36%|███▋ | 4/11 [06:35<11:31, 98.80s/it]
|
||||||
@@ -174,6 +141,6 @@ DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
|||||||
>>> Step 6: generating actions ...
|
>>> Step 6: generating actions ...
|
||||||
>>> Step 6: interacting with world model ...
|
>>> Step 6: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
>>> Step 7: generating actions ...
|
>>> Step 7: generating actions ...
|
||||||
>>> Step 7: interacting with world model ...
|
>>> Step 7: interacting with world model ...
|
||||||
>>>>>>>>>>>>>>>>>>>>>>>>
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case4/unitree_z1_dual_arm_stackbox_v2_case4.mp4",
|
"gt_video": "unitree_z1_dual_arm_stackbox_v2/case4/unitree_z1_dual_arm_stackbox_v2_case4.mp4",
|
||||||
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case4/output/inference/35_full_fs4.mp4",
|
"pred_video": "unitree_z1_dual_arm_stackbox_v2/case4/output/inference/unitree_z1_dual_arm_stackbox_v2_case4_amd.mp4",
|
||||||
"psnr": 25.49270910031428
|
"psnr": 19.526448380726254
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_dual_arm_stackbox_v2/case4"
|
|||||||
dataset="unitree_z1_dual_arm_stackbox_v2"
|
dataset="unitree_z1_dual_arm_stackbox_v2"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=6 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
|
|||||||
149
unitree_z1_stackbox/case1/output.log
Normal file
149
unitree_z1_stackbox/case1/output.log
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 08:12:47.424053: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 08:12:47.427280: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:12:47.458253: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 08:12:47.458288: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 08:12:47.462758: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 08:12:47.518283: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:12:47.518566: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 08:12:48.593011: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
8%|▊ | 1/12 [01:38<18:08, 98.94s/it]
|
||||||
|
17%|█▋ | 2/12 [03:18<16:30, 99.01s/it]
|
||||||
|
25%|██▌ | 3/12 [04:57<14:51, 99.07s/it]
|
||||||
|
33%|███▎ | 4/12 [06:36<13:12, 99.04s/it]
|
||||||
|
42%|████▏ | 5/12 [08:15<11:33, 99.00s/it]
|
||||||
|
50%|█████ | 6/12 [09:54<09:54, 99.10s/it]
|
||||||
|
58%|█████▊ | 7/12 [11:33<08:14, 99.00s/it]
|
||||||
|
67%|██████▋ | 8/12 [13:13<06:38, 99.58s/it]
|
||||||
|
75%|███████▌ | 9/12 [14:54<04:59, 99.88s/it]
|
||||||
|
83%|████████▎ | 10/12 [16:33<03:19, 99.58s/it]
|
||||||
|
92%|█████████▏| 11/12 [18:12<01:39, 99.39s/it]
|
||||||
|
100%|██████████| 12/12 [19:51<00:00, 99.25s/it]
|
||||||
|
100%|██████████| 12/12 [19:51<00:00, 99.28s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 6: generating actions ...
|
||||||
|
>>> Step 6: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 7: generating actions ...
|
||||||
|
>>> Step 7: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 8: generating actions ...
|
||||||
|
>>> Step 8: interacting with world model ...
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_stackbox/case1/unitree_z1_stackbox_case1.mp4",
|
"gt_video": "unitree_z1_stackbox/case1/unitree_z1_stackbox_case1.mp4",
|
||||||
"pred_video": "unitree_z1_stackbox/case1/output/inference/5_full_fs4.mp4",
|
"pred_video": "unitree_z1_stackbox/case1/output/inference/unitree_z1_stackbox_case1_amd.mp4",
|
||||||
"psnr": 42.83913947323794
|
"psnr": 19.81391789862606
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case1"
|
|||||||
dataset="unitree_z1_stackbox"
|
dataset="unitree_z1_stackbox"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=5 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
|
|||||||
149
unitree_z1_stackbox/case2/output.log
Normal file
149
unitree_z1_stackbox/case2/output.log
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 08:15:49.934949: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 08:15:49.937974: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:15:49.969069: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 08:15:49.969100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 08:15:49.970909: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 08:15:49.979005: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:15:49.979255: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 08:15:50.597743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
8%|▊ | 1/12 [01:37<17:51, 97.37s/it]
|
||||||
|
17%|█▋ | 2/12 [03:14<16:13, 97.31s/it]
|
||||||
|
25%|██▌ | 3/12 [04:51<14:35, 97.26s/it]
|
||||||
|
33%|███▎ | 4/12 [06:29<12:58, 97.25s/it]
|
||||||
|
42%|████▏ | 5/12 [08:06<11:20, 97.24s/it]
|
||||||
|
50%|█████ | 6/12 [09:43<09:43, 97.24s/it]
|
||||||
|
58%|█████▊ | 7/12 [11:20<08:06, 97.27s/it]
|
||||||
|
67%|██████▋ | 8/12 [12:58<06:29, 97.36s/it]
|
||||||
|
75%|███████▌ | 9/12 [14:36<04:52, 97.49s/it]
|
||||||
|
83%|████████▎ | 10/12 [16:13<03:15, 97.52s/it]
|
||||||
|
92%|█████████▏| 11/12 [17:51<01:37, 97.47s/it]
|
||||||
|
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
||||||
|
100%|██████████| 12/12 [19:28<00:00, 97.35s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 6: generating actions ...
|
||||||
|
>>> Step 6: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 7: generating actions ...
|
||||||
|
>>> Step 7: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 8: generating actions ...
|
||||||
|
>>> Step 8: interacting with world model ...
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_stackbox/case2/unitree_z1_stackbox_case2.mp4",
|
"gt_video": "unitree_z1_stackbox/case2/unitree_z1_stackbox_case2.mp4",
|
||||||
"pred_video": "unitree_z1_stackbox/case2/output/inference/15_full_fs4.mp4",
|
"pred_video": "unitree_z1_stackbox/case2/output/inference/unitree_z1_stackbox_case2_amd.mp4",
|
||||||
"psnr": 48.64571989587276
|
"psnr": 21.083821459054743
|
||||||
}
|
}
|
||||||
149
unitree_z1_stackbox/case3/output.log
Normal file
149
unitree_z1_stackbox/case3/output.log
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 08:16:22.299521: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 08:16:22.302545: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:16:22.335354: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 08:16:22.335389: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 08:16:22.337179: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 08:16:22.345296: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:16:22.345548: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 08:16:23.008743: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
[rank: 0] Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
8%|▊ | 1/12 [01:39<18:16, 99.64s/it]
|
||||||
|
17%|█▋ | 2/12 [03:19<16:35, 99.56s/it]
|
||||||
|
25%|██▌ | 3/12 [04:58<14:55, 99.53s/it]
|
||||||
|
33%|███▎ | 4/12 [06:38<13:16, 99.53s/it]
|
||||||
|
42%|████▏ | 5/12 [08:17<11:36, 99.54s/it]
|
||||||
|
50%|█████ | 6/12 [09:57<09:57, 99.57s/it]
|
||||||
|
58%|█████▊ | 7/12 [11:37<08:18, 99.66s/it]
|
||||||
|
67%|██████▋ | 8/12 [13:17<06:39, 99.83s/it]
|
||||||
|
75%|███████▌ | 9/12 [14:57<04:59, 99.93s/it]
|
||||||
|
83%|████████▎ | 10/12 [16:37<03:19, 99.97s/it]
|
||||||
|
92%|█████████▏| 11/12 [18:17<01:39, 99.85s/it]
|
||||||
|
100%|██████████| 12/12 [19:56<00:00, 99.71s/it]
|
||||||
|
100%|██████████| 12/12 [19:56<00:00, 99.71s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 6: generating actions ...
|
||||||
|
>>> Step 6: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 7: generating actions ...
|
||||||
|
>>> Step 7: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 8: generating actions ...
|
||||||
|
>>> Step 8: interacting with world model ...
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_stackbox/case3/unitree_z1_stackbox_case3.mp4",
|
"gt_video": "unitree_z1_stackbox/case3/unitree_z1_stackbox_case3.mp4",
|
||||||
"pred_video": "unitree_z1_stackbox/case3/output/inference/25_full_fs4.mp4",
|
"pred_video": "unitree_z1_stackbox/case3/output/inference/unitree_z1_stackbox_case3_amd.mp4",
|
||||||
"psnr": 45.127553229898034
|
"psnr": 21.322784880212172
|
||||||
}
|
}
|
||||||
149
unitree_z1_stackbox/case4/output.log
Normal file
149
unitree_z1_stackbox/case4/output.log
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/lightning_fabric/__init__.py:29: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
|
||||||
|
__import__("pkg_resources").declare_namespace(__name__)
|
||||||
|
2026-02-08 08:25:54.657305: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
|
||||||
|
2026-02-08 08:25:54.660628: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:25:54.691237: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
|
||||||
|
2026-02-08 08:25:54.691275: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
|
||||||
|
2026-02-08 08:25:54.693046: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
|
||||||
|
2026-02-08 08:25:54.701142: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
|
||||||
|
2026-02-08 08:25:54.701413: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
|
||||||
|
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
|
||||||
|
2026-02-08 08:25:55.801367: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
|
||||||
|
Global seed set to 123
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/kornia/feature/lightglue.py:44: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
|
||||||
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||||
|
INFO:mainlogger:LatentVisualDiffusion: Running in v-prediction mode
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
INFO:unifolm_wma.models.diffusion_head.conditional_unet1d:number of parameters: 5.010531e+08
|
||||||
|
AE working on z of shape (1, 4, 32, 32) = 4096 dimensions.
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): hf-mirror.com:443
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/open_clip/factory.py:88: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=map_location)
|
||||||
|
INFO:root:Loaded ViT-H-14 model config.
|
||||||
|
DEBUG:urllib3.connectionpool:https://hf-mirror.com:443 "HEAD /laion/CLIP-ViT-H-14-laion2B-s32B-b79K/resolve/main/open_clip_pytorch_model.bin HTTP/1.1" 302 0
|
||||||
|
INFO:root:Loading pretrained ViT-H-14 weights (laion2b_s32b_b79k).
|
||||||
|
/mnt/ASC1637/unifolm-world-model-action/scripts/evaluation/world_model_interaction.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
|
||||||
|
state_dict = torch.load(ckpt, map_location="cpu")
|
||||||
|
>>> model checkpoint loaded.
|
||||||
|
>>> Load pre-trained model ...
|
||||||
|
INFO:root:***** Configing Data *****
|
||||||
|
>>> unitree_z1_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_stackbox_v2: normalizer initiated.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: 1 data samples loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: data stats loaded.
|
||||||
|
>>> unitree_z1_dual_arm_cleanup_pencils: normalizer initiated.
|
||||||
|
>>> unitree_g1_pack_camera: 1 data samples loaded.
|
||||||
|
>>> unitree_g1_pack_camera: data stats loaded.
|
||||||
|
>>> unitree_g1_pack_camera: normalizer initiated.
|
||||||
|
>>> Dataset is successfully loaded ...
|
||||||
|
>>> Generate 16 frames under each generation ...
|
||||||
|
DEBUG:h5py._conv:Creating converter from 3 to 5
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IHDR' 16 13
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'pHYs' 41 9
|
||||||
|
DEBUG:PIL.PngImagePlugin:STREAM b'IDAT' 62 4096
|
||||||
|
|
||||||
|
0%| | 0/12 [00:00<?, ?it/s]/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:5501: UserWarning: Attempting to use hipBLASLt on an unsupported architecture! Overriding blas backend to hipblas (Triggered internally at ../aten/src/ATen/Context.cpp:296.)
|
||||||
|
proj = linear(q, w, b)
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Flash attention support on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:225.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
/mnt/ASC1637/miniconda3/envs/unifolm-wma-o/lib/python3.10/site-packages/torch/nn/functional.py:6278: UserWarning: Memory Efficient attention on Navi31 GPU is still experimental. Enable it with TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1. (Triggered internally at ../aten/src/ATen/native/transformers/hip/sdp_utils.cpp:269.)
|
||||||
|
attn_output = scaled_dot_product_attention(
|
||||||
|
>>> Step 0: generating actions ...
|
||||||
|
>>> Step 0: interacting with world model ...
|
||||||
|
DEBUG:PIL.Image:Importing BlpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BmpImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing BufrStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing CurImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing DdsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing EpsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FitsStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FliImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing FpxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import FpxImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing FtexImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GbrImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GifImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing GribStubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Hdf5StubImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcnsImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IcoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing ImtImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing IptcImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing JpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing Jpeg2KImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing McIdasImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MicImagePlugin
|
||||||
|
DEBUG:PIL.Image:Image: failed to import MicImagePlugin: No module named 'olefile'
|
||||||
|
DEBUG:PIL.Image:Importing MpegImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MpoImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing MspImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PalmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PcxImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PdfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PixarImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PngImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing PsdImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing QoiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SgiImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SpiderImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing SunImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TgaImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing TiffImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WebPImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing WmfImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XbmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XpmImagePlugin
|
||||||
|
DEBUG:PIL.Image:Importing XVThumbImagePlugin
|
||||||
|
|
||||||
|
8%|▊ | 1/12 [01:37<17:51, 97.38s/it]
|
||||||
|
17%|█▋ | 2/12 [03:14<16:12, 97.24s/it]
|
||||||
|
25%|██▌ | 3/12 [04:51<14:35, 97.28s/it]
|
||||||
|
33%|███▎ | 4/12 [06:29<12:59, 97.40s/it]
|
||||||
|
42%|████▏ | 5/12 [08:06<11:21, 97.30s/it]
|
||||||
|
50%|█████ | 6/12 [09:43<09:43, 97.17s/it]
|
||||||
|
58%|█████▊ | 7/12 [11:20<08:05, 97.07s/it]
|
||||||
|
67%|██████▋ | 8/12 [12:57<06:28, 97.02s/it]
|
||||||
|
75%|███████▌ | 9/12 [14:34<04:50, 96.98s/it]
|
||||||
|
83%|████████▎ | 10/12 [16:11<03:14, 97.00s/it]
|
||||||
|
92%|█████████▏| 11/12 [17:48<01:37, 97.06s/it]
|
||||||
|
100%|██████████| 12/12 [19:25<00:00, 97.13s/it]
|
||||||
|
100%|██████████| 12/12 [19:25<00:00, 97.14s/it]
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 1: generating actions ...
|
||||||
|
>>> Step 1: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 2: generating actions ...
|
||||||
|
>>> Step 2: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 3: generating actions ...
|
||||||
|
>>> Step 3: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 4: generating actions ...
|
||||||
|
>>> Step 4: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 5: generating actions ...
|
||||||
|
>>> Step 5: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 6: generating actions ...
|
||||||
|
>>> Step 6: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 7: generating actions ...
|
||||||
|
>>> Step 7: interacting with world model ...
|
||||||
|
>>>>>>>>>>>>>>>>>>>>>>>>
|
||||||
|
>>> Step 8: generating actions ...
|
||||||
|
>>> Step 8: interacting with world model ...
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
{
|
{
|
||||||
"gt_video": "unitree_z1_stackbox/case4/unitree_z1_stackbox_case4.mp4",
|
"gt_video": "unitree_z1_stackbox/case4/unitree_z1_stackbox_case4.mp4",
|
||||||
"pred_video": "unitree_z1_stackbox/case4/output/inference/35_full_fs4.mp4",
|
"pred_video": "unitree_z1_stackbox/case4/output/inference/unitree_z1_stackbox_case4_amd.mp4",
|
||||||
"psnr": 50.642542240144444
|
"psnr": 25.32928948331741
|
||||||
}
|
}
|
||||||
@@ -2,7 +2,7 @@ res_dir="unitree_z1_stackbox/case4"
|
|||||||
dataset="unitree_z1_stackbox"
|
dataset="unitree_z1_stackbox"
|
||||||
|
|
||||||
{
|
{
|
||||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
time CUDA_VISIBLE_DEVICES=7 python3 scripts/evaluation/world_model_interaction.py \
|
||||||
--seed 123 \
|
--seed 123 \
|
||||||
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
||||||
--config configs/inference/world_model_interaction.yaml \
|
--config configs/inference/world_model_interaction.yaml \
|
||||||
|
|||||||
Reference in New Issue
Block a user